feat(pt_expt): auto-generate forward/forward_lower in torch_module decorator#5246
Conversation
…corator Modified torch_module decorator to automatically generate forward() and forward_lower() methods that delegate to call() and call_lower() if they exist and aren't explicitly defined. Removed boilerplate forward methods from all descriptor and fitting classes. Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Added comprehensive tests to verify torch_module decorator correctly auto-generates forward() and forward_lower() methods, respects explicit overrides, and handles edge cases properly. Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
No, the model and atomic model classes don't have the same boilerplate pattern. After investigating:
The descriptor/fitting classes had pure boilerplate (simple pass-through), while model classes have intentional logic that shouldn't be auto-generated. |
| """Test that torch_module auto-generates forward() from call().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 13 days ago
In general, to fix a signature mismatch in an overriding method, the overriding method should accept at least the same arguments as the base method. For a Python base method that allows arbitrary positional and keyword arguments (e.g. def call(self, *args, **kwargs)), the override should also accept *args and **kwargs, even if it only uses some of them. This preserves Liskov Substitution Principle and avoids runtime TypeError when code written against the base type passes extra parameters.
In this file, each inner MockNativeOP class defines call(self, x: np.ndarray) -> np.ndarray. To align with NativeOP.call, we should change these call definitions to accept flexible arguments, e.g.:
def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray:
...This keeps the existing behavior for the tests (they still pass a single x argument) while allowing additional positional/keyword arguments without error. The test implementations themselves do not use extra arguments, so they can safely ignore *args and **kwargs. We only modify the shown call methods inside source/tests/pt_expt/utils/test_common.py; no other imports or helpers are needed.
Concretely:
- In
test_torch_module_auto_generates_forward, updateMockNativeOP.callon line 34. - In
test_torch_module_auto_generates_forward_lower, updateMockNativeOP.callon line 52. - In
test_torch_module_respects_explicit_forward, updateMockNativeOP.callon line 74. - In
test_torch_module_handles_no_call_lower_method, updateMockNativeOP.callon line 118.
No change is needed to call_lower or other methods, because the alert is specifically about NativeOP.call.
| @@ -31,7 +31,7 @@ | ||
| """Test that torch_module auto-generates forward() from call().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module | ||
| @@ -49,7 +49,7 @@ | ||
| """Test that torch_module auto-generates forward_lower() from call_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray: | ||
| return x | ||
|
|
||
| def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: | ||
| @@ -71,7 +71,7 @@ | ||
| """Test that torch_module doesn't override an explicitly defined forward().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module | ||
| @@ -115,7 +115,7 @@ | ||
| """Test that torch_module works even when there's no call_lower() method.""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args: object, **kwargs: object) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module |
| """Test that torch_module auto-generates forward_lower() from call_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 13 days ago
In general, to fix a signature mismatch between an overriding method and its base method, the overriding method should accept the same parameters as the base; if the base accepts variadic positional/keyword arguments, the override should also accept *args and **kwargs (and possibly default values) so that any call valid for the base is also valid for the subclass.
Here, CodeQL indicates that NativeOP.call can be called with arbitrary keyword arguments and possibly varying numbers of positional arguments, while the overrides in the inner MockNativeOP classes only accept a single x positional argument. To fix this without changing observed behavior in the tests, we can change these call (and call_lower) methods to accept *args, **kwargs and then interpret the arguments in the same way as before. For the first test (test_torch_module_auto_generates_forward), the behavior depends only on a single tensor/array, so the new implementation can treat args[0] as x and ignore other arguments. For the second test (test_torch_module_auto_generates_forward_lower), call similarly only needs the first positional argument, while call_lower needs two positional arguments; we can obtain them as x, y = args[0], args[1]. This keeps the semantics of the tests identical while making the override signatures compatible with the base class.
Concretely:
- In
test_torch_module_auto_generates_forward, updateMockNativeOP.calltodef call(self, *args, **kwargs)and inside, setx = args[0]and returnx * 2. - In
test_torch_module_auto_generates_forward_lower, update:MockNativeOP.callto the same variadic form, returningxunchanged.MockNativeOP.call_lowerto a variadic form that extractsxandyfromargsand returnsx + y.
No imports or additional definitions are needed; we only adjust these method signatures and bodies.
| @@ -31,7 +31,8 @@ | ||
| """Test that torch_module auto-generates forward() from call().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, *args, **kwargs) -> np.ndarray: | ||
| x = args[0] | ||
| return x * 2 | ||
|
|
||
| @torch_module | ||
| @@ -49,10 +50,12 @@ | ||
| """Test that torch_module auto-generates forward_lower() from call_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, *args, **kwargs) -> np.ndarray: | ||
| x = args[0] | ||
| return x | ||
|
|
||
| def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: | ||
| def call_lower(self, *args, **kwargs) -> np.ndarray: | ||
| x, y = args[0], args[1] | ||
| return x + y | ||
|
|
||
| @torch_module |
| """Test that torch_module doesn't override an explicitly defined forward().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 13 days ago
In general, to fix a signature mismatch between an overriding method and its base method, the overriding method should accept parameters that are compatible with the base method’s call sites: it should not require more arguments than the base, and it should accept any additional positional or keyword arguments that callers are allowed to pass to the base method (typically by using *args and **kwargs).
Here, NativeOP.call can be called with one or many positional arguments and arbitrary keyword arguments, but the test’s MockNativeOP.call is defined as def call(self, x: np.ndarray) -> np.ndarray:. To retain the existing test behavior while making the override compatible, we can expand the signature to accept *args and **kwargs, then internally interpret the arguments as before. Since the tests only use a single argument, we can treat the first positional argument as x and ignore or simply not use any extra arguments. To keep things simple and non-invasive, we don’t need to add any new imports or change how the tests invoke call; we only need to update the call definitions inside the mock classes in this file.
Concretely:
- Update each
MockNativeOP.call(self, x: np.ndarray) -> np.ndarray:definition to something likedef call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:. This preserves the existing explicit parameterx(used by the test) and adds variadic parameters to match the base class’s flexibility. - Leave the method bodies unchanged, still operating on
xas before. - No changes are needed for
call_lower, since the alert concernscall, and we have no evidence of its base signature being more flexible; we should not guess or alter it without necessity.
This single pattern change in each relevant MockNativeOP class definition in source/tests/pt_expt/utils/test_common.py will address all three variants of the alert without modifying test behavior.
| @@ -31,7 +31,7 @@ | ||
| """Test that torch_module auto-generates forward() from call().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module | ||
| @@ -49,7 +49,7 @@ | ||
| """Test that torch_module auto-generates forward_lower() from call_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x | ||
|
|
||
| def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: | ||
| @@ -71,7 +71,7 @@ | ||
| """Test that torch_module doesn't override an explicitly defined forward().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module |
| """Test that torch_module doesn't override an explicitly defined forward_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 13 days ago
In general, to fix this kind of issue you must ensure that the overriding method in the subclass has a signature compatible with the base class: it should accept at least the same kinds of arguments (including arbitrary positional/keyword arguments if the base does), even if it ignores some of them.
Here, NativeOP.call accepts arbitrary keyword arguments (and possibly a flexible number of positional arguments), while the overrides in the nested MockNativeOP classes only accept a single positional parameter x. To align with the base class without changing existing behavior, we can widen the override signatures to accept additional *args and **kwargs, while continuing to use only x inside the method bodies. That way, all existing calls remain valid, new calls that pass extra arguments won’t fail, and test semantics are unchanged.
Concretely:
- In
test_torch_module_auto_generates_forward_lower, update the nestedMockNativeOPclass so thatcallbecomesdef call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:. - In
test_torch_module_respects_explicit_forward, update its nestedMockNativeOPsimilarly. - In
test_torch_module_respects_explicit_forward_lower, update its nestedMockNativeOPlikewise.
No new imports or helper methods are needed; we only adjust the function signatures. The method bodies remain the same and ignore *args/**kwargs.
| @@ -49,7 +49,7 @@ | ||
| """Test that torch_module auto-generates forward_lower() from call_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x | ||
|
|
||
| def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: | ||
| @@ -71,7 +71,7 @@ | ||
| """Test that torch_module doesn't override an explicitly defined forward().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module | ||
| @@ -91,7 +91,7 @@ | ||
| """Test that torch_module doesn't override an explicitly defined forward_lower().""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x | ||
|
|
||
| def call_lower(self, x: np.ndarray, y: np.ndarray) -> np.ndarray: |
| """Test that torch_module works even when there's no call_lower() method.""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: |
Check warning
Code scanning / CodeQL
Signature mismatch in overriding method Warning test
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI 13 days ago
In general, to fix a signature mismatch in an overriding method, ensure that the overriding method’s parameters are compatible with the base method: it must accept at least the same positional parameters and be able to handle any additional optional or variadic *args/**kwargs the base allows. You can still ignore extra arguments in the implementation if they are not needed.
Here, NativeOP.call (the base) accepts one or more positional arguments and arbitrary keyword arguments. The override in test_torch_module_handles_no_call_lower_method currently has a narrow signature: def call(self, x: np.ndarray) -> np.ndarray. To match the base while preserving existing behavior, we should expand the override signature to accept *args and **kwargs, and implement it in a way that keeps the current semantics (multiplying the first argument by 2). Because this is a small inner test class, the minimal change is:
- Change
def call(self, x: np.ndarray) -> np.ndarray:todef call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray:. - Leave the body as
return x * 2, ignoring any extra arguments.
This preserves the test’s current behavior (the generated forward will still be based on a single main argument) while making the override signature compatible with NativeOP.call. The change is localized to the snippet shown in source/tests/pt_expt/utils/test_common.py (around line 118), and no new imports or additional methods are needed.
| @@ -115,7 +115,7 @@ | ||
| """Test that torch_module works even when there's no call_lower() method.""" | ||
|
|
||
| class MockNativeOP(NativeOP): | ||
| def call(self, x: np.ndarray) -> np.ndarray: | ||
| def call(self, x: np.ndarray, *args, **kwargs) -> np.ndarray: | ||
| return x * 2 | ||
|
|
||
| @torch_module |
There was a problem hiding this comment.
Pull request overview
This PR removes repeated PyTorch boilerplate across pt_expt by enhancing the @torch_module decorator to auto-generate forward()/forward_lower() delegation methods from call()/call_lower() when they are not explicitly defined.
Changes:
- Update
deepmd.pt_expt.common.torch_moduleto auto-generateforwardandforward_lowerdelegation methods. - Remove redundant
forward()implementations from multiple descriptors/fittings/utils now covered by the decorator. - Add unit tests validating auto-generation and “do not override explicit method” behavior.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/pt_expt/common.py | Adds auto-generation of forward()/forward_lower() in the torch_module decorator. |
| deepmd/pt_expt/utils/network.py | Removes NativeNet.forward() boilerplate now provided by @torch_module. |
| deepmd/pt_expt/fitting/invar_fitting.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/fitting/ener_fitting.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/descriptor/se_e2_a.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/descriptor/se_r.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/descriptor/se_t.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/descriptor/se_t_tebd.py | Removes manual forward() delegation; relies on decorator generation. |
| deepmd/pt_expt/descriptor/se_t_tebd_block.py | Removes manual forward() delegation; relies on decorator generation. |
| source/tests/pt_expt/utils/test_common.py | Adds tests for auto-generated delegation methods and override-respecting behavior. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #5246 +/- ##
==========================================
- Coverage 82.16% 82.15% -0.01%
==========================================
Files 740 740
Lines 74420 74408 -12
Branches 3615 3615
==========================================
- Hits 61144 61133 -11
+ Misses 12114 12112 -2
- Partials 1162 1163 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@copilot apply changes based on the comments in this thread |
- Enhanced torch_module decorator docstring to document auto-generation behavior - Updated tests to use module(...) instead of module.forward(...) to test the full invocation path through torch.nn.Module.__call__ - This ensures tests cover hooks, tracing, and export behavior properly Addresses review feedback from PR #5246 Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Applied all changes from the review:
All changes committed in 6703f13. |
torch_moduledecorator@torch_modulewith manualforward()methodstorch_moduledecorator to auto-generateforward()fromcall()torch_moduledecorator to auto-generateforward_lower()fromcall_lower()forward()methods from descriptor classes (DescrptSeA, DescrptSeR, DescrptSeT, DescrptSeTTebd, DescrptBlockSeTTebd)forward()methods from fitting classes (InvarFitting, EnergyFittingNet)forward()method from NativeNetOriginal prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.