Skip to content

feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator#5246

Merged
wanghan-iapcm merged 4 commits intomasterfrom
copilot/auto-generate-forward-methods
Feb 18, 2026
Merged

feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator#5246
wanghan-iapcm merged 4 commits intomasterfrom
copilot/auto-generate-forward-methods

Conversation

Copy link
Contributor

Copilot AI commented Feb 16, 2026

  • Understand the current implementation of torch_module decorator
  • Identify all classes using @torch_module with manual forward() methods
  • Modify torch_module decorator to auto-generate forward() from call()
  • Modify torch_module decorator to auto-generate forward_lower() from call_lower()
  • Remove boilerplate forward() methods from descriptor classes (DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd, DescrptBlockSeTTebd)
  • Remove boilerplate forward() methods from fitting classes (InvarFitting, EnergyFittingNet)
  • Remove boilerplate forward() method from NativeNet
  • Run linting and formatting
  • Create tests to verify auto-generated methods work correctly
  • Run targeted tests to verify changes (63 pt_expt tests pass)
  • Manually verify forward methods work with real data
  • Request code review and address feedback
  • Update docstring to document auto-generation behavior
  • Update tests to use module(...) invocation path
  • Ready for final approval
Original prompt

This section details on the original issue you should resolve

<issue_title>pt_expt: auto-generate forward/forward_lower in torch_module decorator to avoid boilerplate</issue_title>
<issue_description>## Summary

The torch_module decorator in deepmd/pt_expt/common.py currently handles __init__, __call__, and __setattr__ automatically, but each wrapped class still needs to manually define forward() (and potentially forward_lower()) methods that simply redirect to call() (and call_lower()).

Current situation

Every descriptor class in deepmd/pt_expt/descriptor/ repeats the same boilerplate pattern:

@torch_module
class DescrptSeA(DescrptSeADP):
    def forward(
        self,
        extended_coord: torch.Tensor,
        extended_atype: torch.Tensor,
        nlist: torch.Tensor,
        mapping: torch.Tensor | None = None,
    ) -> tuple[...]:
        descrpt, rot_mat, g2, h2, sw = self.call(
            extended_coord,
            extended_atype,
            nlist,
            mapping=mapping,
        )
        return descrpt, rot_mat, g2, h2, sw

This identical forward → call redirect is duplicated across DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd, DescrptBlockSeTTebd, and similarly for fitting classes like InvarFitting (though InvarFitting does not use torch_module).

Proposal

Modify the torch_module decorator to automatically generate forward and forward_lower methods that delegate to call and call_lower respectively, if:

  1. The wrapped class has a call method (from the dpmodel base) but does not define its own forward.
  2. Similarly for call_lowerforward_lower.

This could be implemented by adding something like the following inside torch_module:

def torch_module(module: type[NativeOP]) -> type[torch.nn.Module]:
    @wraps(module, updated=())
    class TorchModule(module, torch.nn.Module):
        def __init__(self, *args, **kwargs):
            torch.nn.Module.__init__(self)
            module.__init__(self, *args, **kwargs)

        def __call__(self, *args, **kwargs):
            return torch.nn.Module.__call__(self, *args, **kwargs)

        def __setattr__(self, name, value):
            handled, value = dpmodel_setattr(self, name, value)
            if not handled:
                super().__setattr__(name, value)

    # Auto-generate forward -> call redirect
    if hasattr(module, "call") and "forward" not in module.__dict__:
        def forward(self, *args, **kwargs):
            return self.call(*args, **kwargs)
        TorchModule.forward = forward

    # Auto-generate forward_lower -> call_lower redirect
    if hasattr(module, "call_lower") and "forward_lower" not in module.__dict__:
        def forward_lower(self, *args, **kwargs):
            return self.call_lower(*args, **kwargs)
        TorchModule.forward_lower = forward_lower

    return TorchModule

Benefits

  • Eliminates boilerplate: Descriptor/fitting wrapper classes become minimal (just the @torch_module decorator + @BaseDescriptor.register + class declaration).
  • Reduces maintenance burden: Adding new descriptors to pt_expt requires fewer lines and less copy-paste.
  • Consistent behavior: The forward ↔ call contract is enforced in one place rather than scattered across many files.
  • Less error-prone: No risk of forgetting to update forward when call signature changes.

Example after this change

@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
@torch_module
class DescrptSeA(DescrptSeADP):
    pass  # forward() auto-generated from call()

Submitted by OpenClaw on behalf of @njzjz</issue_description>

Comments on the Issue (you are @copilot in this section)


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Copilot AI and others added 2 commits February 16, 2026 19:32
…corator

Modified torch_module decorator to automatically generate forward() and forward_lower()
methods that delegate to call() and call_lower() if they exist and aren't explicitly
defined. Removed boilerplate forward methods from all descriptor and fitting classes.

Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Added comprehensive tests to verify torch_module decorator correctly auto-generates
forward() and forward_lower() methods, respects explicit overrides, and handles
edge cases properly.

Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Copilot AI changed the title [WIP] Add auto-generation for forward methods in torch_module feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator Feb 16, 2026
Copilot AI requested a review from njzjz February 16, 2026 19:39
Copy link
Member

@njzjz njzjz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot are there the same issues in model and atomic model?

Copy link
Contributor Author

Copilot AI commented Feb 16, 2026

@copilot are there the same issues in model and atomic model?

No, the model and atomic model classes don't have the same boilerplate pattern.

After investigating:

DPAtomicModel (atomic_model/dp_atomic_model.py):

  • forward() calls forward_atomic(), not call()
  • Different pattern, not applicable for auto-generation

CM class in make_model (model/make_model.py):

  • Does have forward()call() delegation
  • But it's intentionally a default implementation for subclasses to override
  • EnergyModel overrides it with custom output translation logic

EnergyModel (model/ener_model.py):

  • Doesn't use @torch_module
  • forward() calls call() but adds significant processing (dict transformation)
  • forward_lower() calls _forward_lower(), not call_lower()

The descriptor/fitting classes had pure boilerplate (simple pass-through), while model classes have intentional logic that shouldn't be auto-generated.

Copilot AI requested a review from njzjz February 16, 2026 19:59
@njzjz njzjz marked this pull request as ready for review February 16, 2026 20:29
Copilot AI review requested due to automatic review settings February 16, 2026 20:29
@dosubot dosubot bot added the new feature label Feb 16, 2026
"""Test that torch_module auto-generates forward() from call()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning test

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.

Copilot Autofix

AI 13 days ago

In general, to fix a signature mismatch in an overriding method, the overriding method should accept at least the same arguments as the base method. For a Python base method that allows arbitrary positional and keyword arguments (e.g. def call(self, *args, **kwargs)), the override should also accept *args and **kwargs, even if it only uses some of them. This preserves Liskov Substitution Principle and avoids runtime TypeError when code written against the base type passes extra parameters.

In this file, each inner MockNativeOP class defines call(self, x: np.ndarray) -> np.ndarray. To align with NativeOP.call, we should change these call definitions to accept flexible arguments, e.g.:

def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
    ...

This keeps the existing behavior for the tests (they still pass a single x argument) while allowing additional positional/keyword arguments without error. The test implementations themselves do not use extra arguments, so they can safely ignore *args and **kwargs. We only modify the shown call methods inside source/tests/pt_expt/utils/test_common.py; no other imports or helpers are needed.

Concretely:

  • In test_torch_module_auto_generates_forward, update MockNativeOP.call on line 34.
  • In test_torch_module_auto_generates_forward_lower, update MockNativeOP.call on line 52.
  • In test_torch_module_respects_explicit_forward, update MockNativeOP.call on line 74.
  • In test_torch_module_handles_no_call_lower_method, update MockNativeOP.call on line 118.

No change is needed to call_lower or other methods, because the alert is specifically about NativeOP.call.


Suggested changeset 1
source/tests/pt_expt/utils/test_common.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py
--- a/source/tests/pt_expt/utils/test_common.py
+++ b/source/tests/pt_expt/utils/test_common.py
@@ -31,7 +31,7 @@
     """Test that torch_module auto-generates forward() from call()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
             return x * 2
 
     @torch_module
@@ -49,7 +49,7 @@
     """Test that torch_module auto-generates forward_lower() from call_lower()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
             return x
 
         def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
     """Test that torch_module doesn't override an explicitly defined forward()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
             return x * 2
 
     @torch_module
@@ -115,7 +115,7 @@
     """Test that torch_module works even when there's no call_lower() method."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
             return x * 2
 
     @torch_module
EOF
@@ -31,7 +31,7 @@
"""Test that torch_module auto-generates forward() from call()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
return x * 2

@torch_module
@@ -49,7 +49,7 @@
"""Test that torch_module auto-generates forward_lower() from call_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
return x

def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
"""Test that torch_module doesn't override an explicitly defined forward()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
return x * 2

@torch_module
@@ -115,7 +115,7 @@
"""Test that torch_module works even when there's no call_lower() method."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
return x * 2

@torch_module
Copilot is powered by AI and may make mistakes. Always verify output.
"""Test that torch_module auto-generates forward_lower() from call_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning test

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.

Copilot Autofix

AI 13 days ago

In general, to fix a signature mismatch between an overriding method and its base method, the overriding method should accept the same parameters as the base; if the base accepts variadic positional/keyword arguments, the override should also accept *args and **kwargs (and possibly default values) so that any call valid for the base is also valid for the subclass.

Here, CodeQL indicates that NativeOP.call can be called with arbitrary keyword arguments and possibly varying numbers of positional arguments, while the overrides in the inner MockNativeOP classes only accept a single x positional argument. To fix this without changing observed behavior in the tests, we can change these call (and call_lower) methods to accept *args, **kwargs and then interpret the arguments in the same way as before. For the first test (test_torch_module_auto_generates_forward), the behavior depends only on a single tensor/array, so the new implementation can treat args[0] as x and ignore other arguments. For the second test (test_torch_module_auto_generates_forward_lower), call similarly only needs the first positional argument, while call_lower needs two positional arguments; we can obtain them as x, y = args[0], args[1]. This keeps the semantics of the tests identical while making the override signatures compatible with the base class.

Concretely:

  • In test_torch_module_auto_generates_forward, update MockNativeOP.call to def call(self, *args, **kwargs) and inside, set x = args[0] and return x * 2.
  • In test_torch_module_auto_generates_forward_lower, update:
    • MockNativeOP.call to the same variadic form, returning x unchanged.
    • MockNativeOP.call_lower to a variadic form that extracts x and y from args and returns x + y.
      No imports or additional definitions are needed; we only adjust these method signatures and bodies.
Suggested changeset 1
source/tests/pt_expt/utils/test_common.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py
--- a/source/tests/pt_expt/utils/test_common.py
+++ b/source/tests/pt_expt/utils/test_common.py
@@ -31,7 +31,8 @@
     """Test that torch_module auto-generates forward() from call()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, *args, **kwargs) -> np.ndarray:
+            x = args[0]
             return x * 2
 
     @torch_module
@@ -49,10 +50,12 @@
     """Test that torch_module auto-generates forward_lower() from call_lower()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, *args, **kwargs) -> np.ndarray:
+            x = args[0]
             return x
 
-        def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
+        def call_lower(self, *args, **kwargs) -> np.ndarray:
+            x, y = args[0], args[1]
             return x + y
 
     @torch_module
EOF
@@ -31,7 +31,8 @@
"""Test that torch_module auto-generates forward() from call()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, *args, **kwargs) -> np.ndarray:
x = args[0]
return x * 2

@torch_module
@@ -49,10 +50,12 @@
"""Test that torch_module auto-generates forward_lower() from call_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, *args, **kwargs) -> np.ndarray:
x = args[0]
return x

def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
def call_lower(self, *args, **kwargs) -> np.ndarray:
x, y = args[0], args[1]
return x + y

@torch_module
Copilot is powered by AI and may make mistakes. Always verify output.
"""Test that torch_module doesn't override an explicitly defined forward()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning test

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.

Copilot Autofix

AI 13 days ago

In general, to fix a signature mismatch between an overriding method and its base method, the overriding method should accept parameters that are compatible with the base method’s call sites: it should not require more arguments than the base, and it should accept any additional positional or keyword arguments that callers are allowed to pass to the base method (typically by using *args and **kwargs).

Here, NativeOP.call can be called with one or many positional arguments and arbitrary keyword arguments, but the test’s MockNativeOP.call is defined as def call(self, x: np.ndarray) -> np.ndarray:. To retain the existing test behavior while making the override compatible, we can expand the signature to accept *args and **kwargs, then internally interpret the arguments as before. Since the tests only use a single argument, we can treat the first positional argument as x and ignore or simply not use any extra arguments. To keep things simple and non-invasive, we don’t need to add any new imports or change how the tests invoke call; we only need to update the call definitions inside the mock classes in this file.

Concretely:

  • Update each MockNativeOP.call(self, x: np.ndarray) -> np.ndarray: definition to something like def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:. This preserves the existing explicit parameter x (used by the test) and adds variadic parameters to match the base class’s flexibility.
  • Leave the method bodies unchanged, still operating on x as before.
  • No changes are needed for call_lower, since the alert concerns call, and we have no evidence of its base signature being more flexible; we should not guess or alter it without necessity.

This single pattern change in each relevant MockNativeOP class definition in source/tests/pt_expt/utils/test_common.py will address all three variants of the alert without modifying test behavior.

Suggested changeset 1
source/tests/pt_expt/utils/test_common.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py
--- a/source/tests/pt_expt/utils/test_common.py
+++ b/source/tests/pt_expt/utils/test_common.py
@@ -31,7 +31,7 @@
     """Test that torch_module auto-generates forward() from call()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x * 2
 
     @torch_module
@@ -49,7 +49,7 @@
     """Test that torch_module auto-generates forward_lower() from call_lower()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x
 
         def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
     """Test that torch_module doesn't override an explicitly defined forward()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x * 2
 
     @torch_module
EOF
@@ -31,7 +31,7 @@
"""Test that torch_module auto-generates forward() from call()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x * 2

@torch_module
@@ -49,7 +49,7 @@
"""Test that torch_module auto-generates forward_lower() from call_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x

def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
"""Test that torch_module doesn't override an explicitly defined forward()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x * 2

@torch_module
Copilot is powered by AI and may make mistakes. Always verify output.
"""Test that torch_module doesn't override an explicitly defined forward_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning test

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.

Copilot Autofix

AI 13 days ago

In general, to fix this kind of issue you must ensure that the overriding method in the subclass has a signature compatible with the base class: it should accept at least the same kinds of arguments (including arbitrary positional/keyword arguments if the base does), even if it ignores some of them.

Here, NativeOP.call accepts arbitrary keyword arguments (and possibly a flexible number of positional arguments), while the overrides in the nested MockNativeOP classes only accept a single positional parameter x. To align with the base class without changing existing behavior, we can widen the override signatures to accept additional *args and **kwargs, while continuing to use only x inside the method bodies. That way, all existing calls remain valid, new calls that pass extra arguments won’t fail, and test semantics are unchanged.

Concretely:

  • In test_torch_module_auto_generates_forward_lower, update the nested MockNativeOP class so that call becomes def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:.
  • In test_torch_module_respects_explicit_forward, update its nested MockNativeOP similarly.
  • In test_torch_module_respects_explicit_forward_lower, update its nested MockNativeOP likewise.

No new imports or helper methods are needed; we only adjust the function signatures. The method bodies remain the same and ignore *args/**kwargs.


Suggested changeset 1
source/tests/pt_expt/utils/test_common.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py
--- a/source/tests/pt_expt/utils/test_common.py
+++ b/source/tests/pt_expt/utils/test_common.py
@@ -49,7 +49,7 @@
     """Test that torch_module auto-generates forward_lower() from call_lower()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x
 
         def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
     """Test that torch_module doesn't override an explicitly defined forward()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x * 2
 
     @torch_module
@@ -91,7 +91,7 @@
     """Test that torch_module doesn't override an explicitly defined forward_lower()."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x
 
         def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
EOF
@@ -49,7 +49,7 @@
"""Test that torch_module auto-generates forward_lower() from call_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x

def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
@@ -71,7 +71,7 @@
"""Test that torch_module doesn't override an explicitly defined forward()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x * 2

@torch_module
@@ -91,7 +91,7 @@
"""Test that torch_module doesn't override an explicitly defined forward_lower()."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x

def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
Copilot is powered by AI and may make mistakes. Always verify output.
"""Test that torch_module works even when there's no call_lower() method."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning test

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires 2 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.

Copilot Autofix

AI 13 days ago

In general, to fix a signature mismatch in an overriding method, ensure that the overriding method’s parameters are compatible with the base method: it must accept at least the same positional parameters and be able to handle any additional optional or variadic *args/**kwargs the base allows. You can still ignore extra arguments in the implementation if they are not needed.

Here, NativeOP.call (the base) accepts one or more positional arguments and arbitrary keyword arguments. The override in test_torch_module_handles_no_call_lower_method currently has a narrow signature: def call(self, x: np.ndarray) -> np.ndarray. To match the base while preserving existing behavior, we should expand the override signature to accept *args and **kwargs, and implement it in a way that keeps the current semantics (multiplying the first argument by 2). Because this is a small inner test class, the minimal change is:

  • Change def call(self, x: np.ndarray) -> np.ndarray: to def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:.
  • Leave the body as return x * 2, ignoring any extra arguments.

This preserves the test’s current behavior (the generated forward will still be based on a single main argument) while making the override signature compatible with NativeOP.call. The change is localized to the snippet shown in source/tests/pt_expt/utils/test_common.py (around line 118), and no new imports or additional methods are needed.

Suggested changeset 1
source/tests/pt_expt/utils/test_common.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py
--- a/source/tests/pt_expt/utils/test_common.py
+++ b/source/tests/pt_expt/utils/test_common.py
@@ -115,7 +115,7 @@
     """Test that torch_module works even when there's no call_lower() method."""
 
     class MockNativeOP(NativeOP):
-        def call(self, x: np.ndarray) -> np.ndarray:
+        def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
             return x * 2
 
     @torch_module
EOF
@@ -115,7 +115,7 @@
"""Test that torch_module works even when there's no call_lower() method."""

class MockNativeOP(NativeOP):
def call(self, x: np.ndarray) -> np.ndarray:
def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:
return x * 2

@torch_module
Copilot is powered by AI and may make mistakes. Always verify output.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR removes repeated PyTorch boilerplate across pt_expt by enhancing the @torch_module decorator to auto-generate forward()/forward_lower() delegation methods from call()/call_lower() when they are not explicitly defined.

Changes:

  • Update deepmd.pt_expt.common.torch_module to auto-generate forward and forward_lower delegation methods.
  • Remove redundant forward() implementations from multiple descriptors/fittings/utils now covered by the decorator.
  • Add unit tests validating auto-generation and “do not override explicit method” behavior.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
deepmd/pt_expt/common.py Adds auto-generation of forward()/forward_lower() in the torch_module decorator.
deepmd/pt_expt/utils/network.py Removes NativeNet.forward() boilerplate now provided by @torch_module.
deepmd/pt_expt/fitting/invar_fitting.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/fitting/ener_fitting.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/descriptor/se_e2_a.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/descriptor/se_r.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/descriptor/se_t.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/descriptor/se_t_tebd.py Removes manual forward() delegation; relies on decorator generation.
deepmd/pt_expt/descriptor/se_t_tebd_block.py Removes manual forward() delegation; relies on decorator generation.
source/tests/pt_expt/utils/test_common.py Adds tests for auto-generated delegation methods and override-respecting behavior.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@codecov
Copy link

codecov bot commented Feb 16, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.15%. Comparing base (c337dea) to head (6703f13).
⚠️ Report is 3 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5246      +/-   ##
==========================================
- Coverage   82.16%   82.15%   -0.01%     
==========================================
  Files         740      740              
  Lines       74420    74408      -12     
  Branches     3615     3615              
==========================================
- Hits        61144    61133      -11     
+ Misses      12114    12112       -2     
- Partials     1162     1163       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz
Copy link
Member

njzjz commented Feb 16, 2026

@copilot apply changes based on the comments in this thread

- Enhanced torch_module decorator docstring to document auto-generation behavior
- Updated tests to use module(...) instead of module.forward(...) to test the full
  invocation path through torch.nn.Module.__call__
- This ensures tests cover hooks, tracing, and export behavior properly

Addresses review feedback from PR #5246

Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Feb 16, 2026

@copilot apply changes based on the comments in this thread

Applied all changes from the review:

  1. Enhanced torch_module docstring to document auto-generation behavior
  2. Updated tests to use module(...) instead of module.forward(...) to test the full invocation path

All changes committed in 6703f13.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Feb 18, 2026
Merged via the queue into master with commit 4ddc37d Feb 18, 2026
69 checks passed
@wanghan-iapcm wanghan-iapcm deleted the copilot/auto-generate-forward-methods branch February 18, 2026 15:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

pt_expt: auto-generate forward/forward_lower in torch_module decorator to avoid boilerplate

4 participants