Skip to content

refact(dpmodel): model output made the same as pt backend#5250

Merged
wanghan-iapcm merged 6 commits intodeepmodeling:masterfrom
wanghan-iapcm:refact-model-output
Feb 20, 2026
Merged

refact(dpmodel): model output made the same as pt backend#5250
wanghan-iapcm merged 6 commits intodeepmodeling:masterfrom
wanghan-iapcm:refact-model-output

Conversation

@wanghan-iapcm
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm commented Feb 19, 2026

Summary by CodeRabbit

  • New Features

    • Models expose unified high-level prediction APIs with additional outputs (atom-level, reduced, extended) and conditional mask propagation; spin-energy prediction tests added.
  • Refactor

    • Standardized output key naming and backend return shapes; consolidated backend dispatch and adjusted tensor squeeze/shape behavior for gradients, virials, and Hessians.
  • Tests

    • Updated and reorganized tests to align with new keys, shapes, and backend groupings.

Comment on lines +56 to +64
model_predict["dipole"] = model_ret["dipole"]
model_predict["global_dipole"] = model_ret["dipole_redu"]
if self.do_grad_r("dipole") and model_ret["dipole_derv_r"] is not None:
model_predict["force"] = model_ret["dipole_derv_r"]
if self.do_grad_c("dipole") and model_ret["dipole_derv_c_redu"] is not None:
model_predict["virial"] = model_ret["dipole_derv_c_redu"]
if do_atomic_virial and model_ret["dipole_derv_c"] is not None:
model_predict["atom_virial"] = model_ret["dipole_derv_c"]
if "mask" in model_ret:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +48 to +56
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +58 to +66
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret["energy_derv_r"] is not None:
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy") and model_ret["energy_derv_c_redu"] is not None:
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial and model_ret["energy_derv_c"] is not None:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2)
if "mask" in model_ret:

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +101 to +109
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None:
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +48 to +56
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["polar"] = model_ret["polarizability"]

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Comment on lines +51 to +59
) -> dict[str, Array]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bea33efe80

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 19, 2026

📝 Walkthrough

Walkthrough

Adds public high-level call APIs and translated output definitions across many DPModel classes, standardizes backend key resolution in evaluators/serializers, tweaks tensor squeeze axes and mask propagation in PyTorch models, and updates numerous tests to new key/shape conventions.

Changes

Cohort / File(s) Summary
DPModel public API additions
deepmd/dpmodel/model/ener_model.py, deepmd/dpmodel/model/dipole_model.py, deepmd/dpmodel/model/dos_model.py, deepmd/dpmodel/model/polar_model.py, deepmd/dpmodel/model/property_model.py, deepmd/dpmodel/model/dp_zbl_model.py
Added typed public methods call, call_lower, and translated_output_def that delegate to call_common/call_common_lower and translate internal output keys to public names (e.g., atom_energy, energy, force, virial, etc.).
Model factory / spin & routing changes
deepmd/dpmodel/model/make_model.py, deepmd/dpmodel/model/spin_model.py, deepmd/pt_expt/model/make_model.py, deepmd/pt_expt/model/ener_model.py
Introduced call_common / call_common_lower as internal implementations and updated aliases/forwarding; spin_model refactored to wrapper-style call paths and added translated output handling.
Backend evaluation & serialization
deepmd/dpmodel/infer/deep_eval.py, deepmd/jax/infer/deep_eval.py, deepmd/jax/jax2tf/serialization.py, deepmd/jax/utils/serialization.py
Changed backend output-key resolution: dpmodel evaluator now maps DP names via _OUTDEF_DP2BACKEND when building backend keys; JAX/TF export paths adjusted to reference call_common_lower (and one evaluator switching odef name usage).
PyTorch model shape & mask propagation
deepmd/pt/model/model/dipole_model.py, deepmd/pt/model/model/dp_zbl_model.py, deepmd/pt/model/model/ener_model.py, deepmd/pt/model/model/dos_model.py, deepmd/pt/model/model/polar_model.py
Adjusted squeeze axes for gradient/Hessian outputs (removed/changed squeezes) and added conditional propagation of mask from internal model_ret into model_predict when fitting networks provide it.
Test updates: key naming, backend dispatch, call sites
source/tests/*, source/tests/consistent/model/*, source/tests/jax/*, source/tests/pd/*, source/tests/pt/*, source/tests/pt_expt/*, source/tests/universal/*
Updated many tests to new public key names (e.g., atom_energy, energy, force, virial, hessian), removed {var}_redu usages, consolidated backend extract_ret branches (TF positional outputs vs DP/PT/JAX dictionaries), and replaced call/call_lower usages with call_common/call_common_lower where applicable.
DP test conventions & padding changes
source/tests/common/dpmodel/test_dp_model.py, source/tests/common/dpmodel/test_padding_atoms.py, source/tests/jax/test_padding_atoms.py
Changed reduced-key detection to explicit set membership (energy, virial) and updated assertions to use atom_<var> and <var> conventions for atom vs reduced values.
New spin-energy tests
source/tests/consistent/model/test_spin_ener.py
Added comprehensive DP/PT spin-energy test modules exercising new call/call_lower and translated-output behaviors for spin-augmented models.
Minor wiring & alias updates
deepmd/jax/infer/deep_eval.py, deepmd/dpmodel/infer/deep_eval.py, deepmd/jax/jax2tf/serialization.py, deepmd/jax/utils/serialization.py, source/tests/universal/dpmodel/backend.py
Small changes to which internal methods are referenced (e.g., call_lowercall_common_lower) and to odef name mapping behavior in evaluators/serializers; compatibility alias added for forward_lower when only call_lower exists.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • iProzd
  • njzjz
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.23% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately describes the main refactoring objective: aligning dpmodel output behavior with the PT (PyTorch) backend implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
source/tests/jax/test_dp_hessian_model.py (1)

109-114: Minor coverage gap: hessian is not re-checked in the do_atomic_virial=True block.

The second invocation (lines 109-110) exercises do_atomic_virial=True, but only asserts "atom_virial". Since enable_hessian() remains active, the hessian output is still produced. Asserting it here would confirm no interaction between do_atomic_virial and hessian computation.

💡 Suggested addition
         np.testing.assert_allclose(
             to_numpy_array(ret0["atom_virial"]),
             to_numpy_array(ret1["atom_virial"]),
             atol=self.atol,
         )
+        np.testing.assert_allclose(
+            to_numpy_array(ret0["hessian"]),
+            to_numpy_array(ret1["hessian"]),
+            atol=self.atol,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/jax/test_dp_hessian_model.py` around lines 109 - 114, The
do_atomic_virial=True call currently only asserts atom_virial equality; add an
assertion that the Hessian outputs from md0.call and md1.call are equal as well
to ensure enable_hessian() still produces identical results when
do_atomic_virial is enabled—specifically compare ret0["hessian"] to
ret1["hessian"] (using the same to_numpy_array wrapper and
np.testing.assert_allclose with self.atol) after the existing atom_virial check.
source/tests/jax/test_make_hessian_model.py (2)

169-176: test_output_def has no coverage for the new public "hessian" key.

The test still asserts only model_output_def()["energy_derv_r_derv_r"] (the internal key). Now that call() returns "hessian" as the public name, consider adding an assertion that the translated/public output def also exposes this key (e.g., via translated_output_def()["hessian"]), ensuring the public API contract is verified alongside the internal one.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/jax/test_make_hessian_model.py` around lines 169 - 176, Update
the test_output_def to assert the public "hessian" key is present and correctly
described by the translated/public output definition: call
translated_output_def() on model_hess and add assertions that
translated_output_def()["hessian"].category equals
OutputVariableCategory.DERV_R_DERV_R and that its r_hessian flag is true
(analogous to the existing checks for the internal "energy_derv_r_derv_r" and
r_hessian on "energy"); this verifies the public API mapping produced by call()
in addition to the internal key.

74-74: nv is now dead code.

After the line 130 reshape no longer references nv, the assignment nv = self.nv at line 74 is unreferenced throughout HessianTest.test(). Consider removing both this assignment and self.nv = 1 in TestDPModel.setUp().

♻️ Proposed cleanup
-        nv = self.nv
-        self.nv = 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/jax/test_make_hessian_model.py` at line 74, Remove the
now-unused local assignment nv = self.nv in HessianTest.test() and also drop the
redundant attribute initialization self.nv = 1 in TestDPModel.setUp(); locate
the two occurrences by searching for the exact symbols "nv = self.nv" and
"self.nv = 1" and delete them so the tests no longer set or reference the dead
nv variable.
source/tests/consistent/model/test_dpa1.py (1)

232-239: DP backend skips force/virial cross-validation.

Three SKIP_FLAGs at positions [2], [3], [4] mean force, virial, and atom_virial are never compared against the DP backend. If this is intentional (e.g., DP backend doesn't support force output in this context), a brief inline comment would make the intent clear for future contributors.

💡 Suggested comment to document intent
         elif backend is self.RefBackend.DP:
             return (
                 ret["energy"].ravel(),
                 ret["atom_energy"].ravel(),
-                SKIP_FLAG,
-                SKIP_FLAG,
-                SKIP_FLAG,
+                SKIP_FLAG,  # DP backend does not compute force
+                SKIP_FLAG,  # DP backend does not compute virial
+                SKIP_FLAG,  # DP backend does not compute atom_virial
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/consistent/model/test_dpa1.py` around lines 232 - 239, The DP
backend branch (the elif checking self.RefBackend.DP) currently returns three
SKIP_FLAGs for force, virial and atom_virial which hides intent; add a brief
inline comment right above that return in test_dpa1.py explaining that the DP
reference backend does not provide force/virial/atom_virial outputs in this test
context (or note if this is a temporary limitation), so SKIP_FLAG is
deliberately used — reference the branch condition self.RefBackend.DP and the
SKIP_FLAG placeholders when adding the comment.
deepmd/dpmodel/model/dp_zbl_model.py (1)

70-102: Near-duplicate of EnergyModel.call_lower() — consider extracting shared logic.

The call(), call_lower(), and translated_output_def() bodies in DPZBLModel are nearly identical to EnergyModel (minus hessian). A shared helper or mixin would reduce the maintenance surface, but this can be deferred.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/model/dp_zbl_model.py` around lines 70 - 102, DPZBLModel and
EnergyModel duplicate the same output-building logic in call(), call_lower(),
and translated_output_def(); extract that logic into a shared helper (e.g.,
EnergyOutputMixin.build_predict_from_ret or
BaseEnergyModel._build_model_predict_from_ret) that accepts model_ret,
do_atomic_virial and the gradient-query helpers (do_grad_r/do_grad_c) and
returns the model_predict dict; then replace the duplicated bodies in
DPZBLModel.call_lower / call / translated_output_def and EnergyModel equivalents
to call this helper, preserving EnergyModel-only behavior (like hessian
handling) by applying any extra steps after the shared helper returns and
add/adjust unit tests to cover both classes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/dpmodel/model/dipole_model.py`:
- Around line 87-96: call_lower in dipole_model.py builds model_predict but
never adds the "mask" key, causing inconsistency with call and
translated_output_def and breaking lower-level callers; update the call_lower
implementation to check for a mask (e.g., model_ret.get("mask") or similar
presence check used in call) and, when present, set model_predict["mask"] =
model_ret["mask"] (mirror the behavior in polar_model.py's call_lower) so mask
is propagated out of call_lower.

In `@deepmd/jax/infer/deep_eval.py`:
- Line 391: Validate that odef.name exists in the mapping before indexing
self._OUTDEF_DP2BACKEND: in deepmd/jax/infer/deep_eval.py around the line
assigning dp_name = self._OUTDEF_DP2BACKEND[odef.name], replace the direct
lookup with a defensive check (e.g., if odef.name not in
self._OUTDEF_DP2BACKEND: raise a clear ValueError listing the missing name and
available keys, or use self._OUTDEF_DP2BACKEND.get(odef.name, <fallback>) as
appropriate). Alternatively, perform this validation during initialization of
the mapping (where _OUTDEF_DP2BACKEND is populated or extended, see
deepmd/infer/deep_property.py registration logic) to ensure all model output
definitions are registered before inference; include the symbol
_OUTDEF_DP2BACKEND and odef.name in the error message to aid debugging.

---

Duplicate comments:
In `@deepmd/dpmodel/infer/deep_eval.py`:
- Line 361: The lookup self._OUTDEF_DP2BACKEND[odef.name] can raise
AttributeError/KeyError; update DeepEvalBackend to ensure the mapping attribute
exists and is exhaustive for all names in request_defs: either add a class-level
_OUTDEF_DP2BACKEND with all expected keys or validate at runtime before use
(e.g., check hasattr(self, "_OUTDEF_DP2BACKEND") and that odef.name in
self._OUTDEF_DP2BACKEND) and raise a clear, descriptive error if missing;
reference the symbols _OUTDEF_DP2BACKEND, odef.name, DeepEvalBackend and
request_defs when adding the verification or filling missing entries.

In `@deepmd/dpmodel/model/dp_zbl_model.py`:
- Around line 104-120: In translated_output_def the calls like
output_def["force"].squeeze(-2) on OutputVariableDef objects are no-ops because
the squeezed result isn't used; change them to assign the squeezed value back
(e.g., output_def["force"] = output_def["force"].squeeze(-2)) for each
occurrence (force, virial, atom_virial) so the squeezed OutputVariableDef
replaces the original; update the blocks that call self.do_grad_r("energy") and
self.do_grad_c("energy") to perform these assignments.

In `@deepmd/pt/model/model/dp_zbl_model.py`:
- Around line 44-52: The output_def is being mutated in-place by
OutputVariableDef.squeeze() for keys "force", "virial", and "atom_virial", which
can corrupt the shared cached ModelOutputDef; fix this by deepcopying the
assigned OutputVariableDef before calling squeeze (i.e., when setting
output_def["force"] = out_def_data["energy_derv_r"], output_def["virial"] =
out_def_data["energy_derv_c_redu"], and output_def["atom_virial"] =
out_def_data["energy_derv_c"]) so you call squeeze on the copy; mirror the same
deepcopy approach used in EnergyModel.translated_output_def to avoid
shared-state mutation.

---

Nitpick comments:
In `@deepmd/dpmodel/model/dp_zbl_model.py`:
- Around line 70-102: DPZBLModel and EnergyModel duplicate the same
output-building logic in call(), call_lower(), and translated_output_def();
extract that logic into a shared helper (e.g.,
EnergyOutputMixin.build_predict_from_ret or
BaseEnergyModel._build_model_predict_from_ret) that accepts model_ret,
do_atomic_virial and the gradient-query helpers (do_grad_r/do_grad_c) and
returns the model_predict dict; then replace the duplicated bodies in
DPZBLModel.call_lower / call / translated_output_def and EnergyModel equivalents
to call this helper, preserving EnergyModel-only behavior (like hessian
handling) by applying any extra steps after the shared helper returns and
add/adjust unit tests to cover both classes.

In `@source/tests/consistent/model/test_dpa1.py`:
- Around line 232-239: The DP backend branch (the elif checking
self.RefBackend.DP) currently returns three SKIP_FLAGs for force, virial and
atom_virial which hides intent; add a brief inline comment right above that
return in test_dpa1.py explaining that the DP reference backend does not provide
force/virial/atom_virial outputs in this test context (or note if this is a
temporary limitation), so SKIP_FLAG is deliberately used — reference the branch
condition self.RefBackend.DP and the SKIP_FLAG placeholders when adding the
comment.

In `@source/tests/jax/test_dp_hessian_model.py`:
- Around line 109-114: The do_atomic_virial=True call currently only asserts
atom_virial equality; add an assertion that the Hessian outputs from md0.call
and md1.call are equal as well to ensure enable_hessian() still produces
identical results when do_atomic_virial is enabled—specifically compare
ret0["hessian"] to ret1["hessian"] (using the same to_numpy_array wrapper and
np.testing.assert_allclose with self.atol) after the existing atom_virial check.

In `@source/tests/jax/test_make_hessian_model.py`:
- Around line 169-176: Update the test_output_def to assert the public "hessian"
key is present and correctly described by the translated/public output
definition: call translated_output_def() on model_hess and add assertions that
translated_output_def()["hessian"].category equals
OutputVariableCategory.DERV_R_DERV_R and that its r_hessian flag is true
(analogous to the existing checks for the internal "energy_derv_r_derv_r" and
r_hessian on "energy"); this verifies the public API mapping produced by call()
in addition to the internal key.
- Line 74: Remove the now-unused local assignment nv = self.nv in
HessianTest.test() and also drop the redundant attribute initialization self.nv
= 1 in TestDPModel.setUp(); locate the two occurrences by searching for the
exact symbols "nv = self.nv" and "self.nv = 1" and delete them so the tests no
longer set or reference the dead nv variable.

@wanghan-iapcm wanghan-iapcm requested a review from iProzd February 19, 2026 15:51
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
deepmd/pt/model/model/polar_model.py (1)

73-75: ⚠️ Potential issue | 🟡 Minor

Remove the unreachable else branch in forward and forward_lower (lines 73–75 and 107–108).

The else branch at line 73–75 is unreachable dead code. DPPolarAtomicModel.__init__ enforces that the fitting parameter must be a PolarFittingNet instance (raises TypeError otherwise), so get_fitting_net() can never return None. The same dead code exists in ener_model.py, dos_model.py, and dipole_model.py. Additionally, the two branches are asymmetric: forward line 75 attempts to mutate updated_coord (which doesn't belong in a polar model), whereas forward_lower line 108 correctly does not. Remove the entire else block, or if conditional logic is needed for safety, align both methods.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/model/model/polar_model.py` around lines 73 - 75, The else branch
in DPPolarAtomicModel.forward and forward_lower is dead/unreachable because
DPPolarAtomicModel.__init__ enforces get_fitting_net() returns a PolarFittingNet
(TypeError otherwise); remove the entire else block that mutates
model_predict["updated_coord"] (the model_ret-to-model_predict fallback) so both
forward and forward_lower only use the primary model_ret path, or if you prefer
defensive coding keep a single guarded check that does not mutate fields
inappropriate for polar models (do not add updated_coord). Update references in
forward, forward_lower, and any similar patterns in ener_model.py, dos_model.py,
and dipole_model.py to remove the unreachable else branch and ensure
consistency.
deepmd/pt/model/model/dipole_model.py (1)

72-87: ⚠️ Potential issue | 🔴 Critical

Add .squeeze(-2) to force, virial, and atom_virial outputs in DipoleModel.forward().

DipoleModel is missing the squeeze operations applied in all other derivative models (EnergyModel, DPZBLModel, DPLinearModel). Lines 77, 79, and 81 should squeeze the last dimension to match output shapes from other models:

Expected pattern (from EnergyModel)
if self.do_grad_r("dipole"):
    model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2)
if self.do_grad_c("dipole"):
    model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2)
    if do_atomic_virial:
        model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze(-2)

Without this fix, DipoleModel outputs will have shape (N, 1, 3) and (N, 1, 9) instead of (N, 3) and (N, 9), causing downstream shape mismatches in tests and evaluators.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/model/model/dipole_model.py` around lines 72 - 87, In
DipoleModel.forward(), the tensor dims for force, virial and atom_virial are not
squeezed and thus keep an extra length-1 axis; update the assignments so that
model_predict["force"] uses model_ret["dipole_derv_r"].squeeze(-2),
model_predict["virial"] uses model_ret["dipole_derv_c_redu"].squeeze(-2), and
model_predict["atom_virial"] uses model_ret["dipole_derv_c"].squeeze(-2) (only
when do_atomic_virial is true) so shapes match the other models (see
EnergyModel/DPZBLModel pattern).
source/tests/universal/dpmodel/model/test_model.py (1)

274-274: ⚠️ Potential issue | 🟠 Major

dpmodel's TestSpinEnergyModelDP uses raw model_output_def() while PT's equivalent uses translated keys, missing parity.

Line 167 (TestEnergyModelDP) uses cls.module.translated_output_def(), which returns human-friendly keys. Line 274 (TestSpinEnergyModelDP) still calls cls.module.model_output_def().get_data(), returning raw internal keys. PT's TestSpinEnergyModelDP uses SpinEnergyModel with translated_output_def(), returning spin-specific keys (atom_energy, energy, mask_mag, force). dpmodel's SpinModel lacks a translated_output_def() override—it would need one that maps spin-specific output keys (like PT's SpinEnergyModel.translated_output_def()), or dpmodel needs a new SpinEnergyModel subclass mirroring PT's design.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/universal/dpmodel/model/test_model.py` at line 274,
TestSpinEnergyModelDP is calling cls.module.model_output_def().get_data() which
returns raw internal keys while TestEnergyModelDP uses
cls.module.translated_output_def() (human-friendly keys); fix by either adding a
translated_output_def() override on the dpmodel SpinModel that maps
spin-specific outputs (e.g. atom_energy, energy, mask_mag, force) to the
internal keys, or create a dpmodel SpinEnergyModel subclass mirroring PT's
SpinEnergyModel with translated_output_def(), and update TestSpinEnergyModelDP
to use cls.module.translated_output_def() instead of model_output_def().
🧹 Nitpick comments (3)
deepmd/dpmodel/model/dipole_model.py (1)

58-62: Improve consistency: use direct key access or .get() uniformly across call and call_lower.

Lines 58, 60, 62 in call use direct dict access (model_ret["dipole_derv_r"], etc.), while lines 90, 92, 94 in call_lower use .get() for the same keys. Both approaches are safe since call_common and call_common_lower always populate these keys (even as None via fit_output_to_model_output and communicate_extended_output), but the inconsistency should be reconciled for clarity.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/model/dipole_model.py` around lines 58 - 62, The code mixes
direct dict indexing and .get() for model_ret keys between call and call_lower;
make them consistent by switching the direct accesses in call (e.g.,
model_ret["dipole_derv_r"], model_ret["dipole_derv_c_redu"],
model_ret["dipole_derv_c"]) to use model_ret.get(...) like call_lower does, and
keep the conditional logic that assigns model_predict["force"] and
model_predict["virial"] unchanged so behavior remains identical.
deepmd/dpmodel/model/make_model.py (1)

368-369: Consider explicit forwarding methods to prevent potential MRO bypass in future subclasses.

Lines 368-369 use static class-body aliasing (call = call_common) that binds to the function object at definition time. If a future subclass were to override call_common without also overriding call, dispatch would silently use the base-class version instead of the override. While the current concrete models (ener_model, property_model, dipole_model, dp_zbl_model, dos_model) all explicitly redefine both call and call_lower—avoiding active breakage—the pattern is a maintenance trap for future subclasses.

Replace with explicit forwarding methods to preserve dynamic dispatch through self:

♻️ Suggested refactor
-        call = call_common
-        call_lower = call_common_lower
+        def call(self, *args: Any, **kwargs: Any) -> dict[str, Array]:
+            """Alias for call_common; may be overridden in subclasses."""
+            return self.call_common(*args, **kwargs)
+
+        def call_lower(self, *args: Any, **kwargs: Any) -> dict[str, Array]:
+            """Alias for call_common_lower; may be overridden in subclasses."""
+            return self.call_common_lower(*args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/model/make_model.py` around lines 368 - 369, The class
currently aliases call = call_common and call_lower = call_common_lower which
binds static function objects and can bypass future overrides; replace those
aliases with explicit forwarding instance methods (e.g., def call(self, *args,
**kwargs): return self.call_common(*args, **kwargs) and similarly for
call_lower) so dispatch goes through self and respects MRO and overrides of
call_common/call_common_lower; ensure the forwarding method signatures match the
original call_common/call_common_lower signatures and preserve
docstrings/annotations if present.
source/tests/consistent/model/test_ener.py (1)

280-287: DP backend now skips force/virial comparison — ensure this is intentional and documented.

The DP backend returns SKIP_FLAG for force, virial, and atom_virial in both TestEner (lines 284–286) and TestEnerLower (lines 522–524). This means these quantities are never cross-validated against DP.

If the dpmodel backend is capable of computing force/virial via its call path (e.g., through finite-difference or analytical gradient), skipping the comparison could mask regressions. If it genuinely cannot produce these, this is correct. The PR title suggests this is intentional alignment with the PT backend's output contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/consistent/model/test_ener.py` around lines 280 - 287, The DP
backend branch in TestEner and TestEnerLower returns SKIP_FLAG for force,
virial, and atom_virial (see RefBackend.DP, ret["energy"], ret["atom_energy"])
which prevents cross-validation; confirm intent and either restore comparisons
or document/guard the skip: if dpmodel can compute forces/virial via its call
path (finite-difference or analytic gradients) remove the SKIP_FLAG returns and
return the actual arrays so tests validate them, otherwise add a clear inline
comment and/or a test-level assertion explaining that DP does not provide
forces/virial (and reference the dpmodel call path) so future reviewers know
this is intentional rather than a regression.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/universal/dpmodel/model/test_model.py`:
- Line 167: The dpmodel refactor missed implementing translated_output_def on
SpinModel and updating the test to use it; implement a
translated_output_def(self) method on the dpmodel SpinModel class (mirror
EnergyModel.translated_output_def and the PT backend SpinModel implementation)
so it returns the same translated structure currently produced by
model_output_def().get_data() (or delegates to model_output_def() then converts
its data to the translated shape), and change the test to call
module.translated_output_def() (instead of model_output_def().get_data()) so
both EnergyModel and SpinModel use the unified translated_output_def API; ensure
the method signature and return format match the PT backend's
translated_output_def.

---

Outside diff comments:
In `@deepmd/pt/model/model/dipole_model.py`:
- Around line 72-87: In DipoleModel.forward(), the tensor dims for force, virial
and atom_virial are not squeezed and thus keep an extra length-1 axis; update
the assignments so that model_predict["force"] uses
model_ret["dipole_derv_r"].squeeze(-2), model_predict["virial"] uses
model_ret["dipole_derv_c_redu"].squeeze(-2), and model_predict["atom_virial"]
uses model_ret["dipole_derv_c"].squeeze(-2) (only when do_atomic_virial is true)
so shapes match the other models (see EnergyModel/DPZBLModel pattern).

In `@deepmd/pt/model/model/polar_model.py`:
- Around line 73-75: The else branch in DPPolarAtomicModel.forward and
forward_lower is dead/unreachable because DPPolarAtomicModel.__init__ enforces
get_fitting_net() returns a PolarFittingNet (TypeError otherwise); remove the
entire else block that mutates model_predict["updated_coord"] (the
model_ret-to-model_predict fallback) so both forward and forward_lower only use
the primary model_ret path, or if you prefer defensive coding keep a single
guarded check that does not mutate fields inappropriate for polar models (do not
add updated_coord). Update references in forward, forward_lower, and any similar
patterns in ener_model.py, dos_model.py, and dipole_model.py to remove the
unreachable else branch and ensure consistency.

In `@source/tests/universal/dpmodel/model/test_model.py`:
- Line 274: TestSpinEnergyModelDP is calling
cls.module.model_output_def().get_data() which returns raw internal keys while
TestEnergyModelDP uses cls.module.translated_output_def() (human-friendly keys);
fix by either adding a translated_output_def() override on the dpmodel SpinModel
that maps spin-specific outputs (e.g. atom_energy, energy, mask_mag, force) to
the internal keys, or create a dpmodel SpinEnergyModel subclass mirroring PT's
SpinEnergyModel with translated_output_def(), and update TestSpinEnergyModelDP
to use cls.module.translated_output_def() instead of model_output_def().

---

Nitpick comments:
In `@deepmd/dpmodel/model/dipole_model.py`:
- Around line 58-62: The code mixes direct dict indexing and .get() for
model_ret keys between call and call_lower; make them consistent by switching
the direct accesses in call (e.g., model_ret["dipole_derv_r"],
model_ret["dipole_derv_c_redu"], model_ret["dipole_derv_c"]) to use
model_ret.get(...) like call_lower does, and keep the conditional logic that
assigns model_predict["force"] and model_predict["virial"] unchanged so behavior
remains identical.

In `@deepmd/dpmodel/model/make_model.py`:
- Around line 368-369: The class currently aliases call = call_common and
call_lower = call_common_lower which binds static function objects and can
bypass future overrides; replace those aliases with explicit forwarding instance
methods (e.g., def call(self, *args, **kwargs): return self.call_common(*args,
**kwargs) and similarly for call_lower) so dispatch goes through self and
respects MRO and overrides of call_common/call_common_lower; ensure the
forwarding method signatures match the original call_common/call_common_lower
signatures and preserve docstrings/annotations if present.

In `@source/tests/consistent/model/test_ener.py`:
- Around line 280-287: The DP backend branch in TestEner and TestEnerLower
returns SKIP_FLAG for force, virial, and atom_virial (see RefBackend.DP,
ret["energy"], ret["atom_energy"]) which prevents cross-validation; confirm
intent and either restore comparisons or document/guard the skip: if dpmodel can
compute forces/virial via its call path (finite-difference or analytic
gradients) remove the SKIP_FLAG returns and return the actual arrays so tests
validate them, otherwise add a clear inline comment and/or a test-level
assertion explaining that DP does not provide forces/virial (and reference the
dpmodel call path) so future reviewers know this is intentional rather than a
regression.

@codecov
Copy link

codecov bot commented Feb 19, 2026

Codecov Report

❌ Patch coverage is 65.05190% with 101 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.16%. Comparing base (4ddc37d) to head (a3dfaa6).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/dpmodel/model/dipole_model.py 39.53% 26 Missing ⚠️
deepmd/dpmodel/model/dp_zbl_model.py 39.53% 26 Missing ⚠️
deepmd/dpmodel/model/property_model.py 48.14% 14 Missing ⚠️
deepmd/dpmodel/model/dos_model.py 47.82% 12 Missing ⚠️
deepmd/dpmodel/model/polar_model.py 47.82% 12 Missing ⚠️
deepmd/dpmodel/model/spin_model.py 86.95% 9 Missing ⚠️
deepmd/pt_expt/model/make_model.py 50.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5250      +/-   ##
==========================================
+ Coverage   82.12%   82.16%   +0.04%     
==========================================
  Files         740      745       +5     
  Lines       74473    74825     +352     
  Branches     3616     3616              
==========================================
+ Hits        61162    61483     +321     
- Misses      12149    12180      +31     
  Partials     1162     1162              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
source/tests/consistent/model/test_spin_ener.py (1)

94-112: get_reference_backend and pass_data_to_cls are identical in both classes.

Both methods are copy-pasted verbatim. Extracting them into a shared base class (e.g., _SpinEnerTestBase) would eliminate the duplication.

♻️ Proposed refactor
+class _SpinEnerTestBase:
+    """Shared helpers for spin energy model tests."""
+
+    def get_reference_backend(self):
+        if not self.skip_pt:
+            return self.RefBackend.PT
+        if not self.skip_dp:
+            return self.RefBackend.DP
+        raise ValueError("No available reference")
+
+    def pass_data_to_cls(self, cls, data) -> Any:
+        data = copy.deepcopy(data)
+        if cls is SpinModelDP:
+            return get_model_dp(data)
+        elif cls is SpinEnergyModelPT:
+            return get_model_pt(data)
+        return cls(**data, **self.additional_data)
+

-class TestSpinEner(CommonTest, ModelTest, unittest.TestCase):
+class TestSpinEner(_SpinEnerTestBase, CommonTest, ModelTest, unittest.TestCase):
     ...
-    def get_reference_backend(self): ...
-    def pass_data_to_cls(self, cls, data): ...

-class TestSpinEnerLower(CommonTest, ModelTest, unittest.TestCase):
+class TestSpinEnerLower(_SpinEnerTestBase, CommonTest, ModelTest, unittest.TestCase):
     ...
-    def get_reference_backend(self): ...
-    def pass_data_to_cls(self, cls, data): ...

Also applies to: 236-254

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/consistent/model/test_spin_ener.py` around lines 94 - 112, The
two methods get_reference_backend and pass_data_to_cls are duplicated across
test classes; extract them into a shared base class (e.g., create class
_SpinEnerTestBase) that defines get_reference_backend and pass_data_to_cls, move
the implementations there, and have the existing test classes (the ones
currently defining those methods) inherit from _SpinEnerTestBase and remove
their local copies; ensure the base class provides any attributes used (skip_pt,
skip_dp, RefBackend, additional_data) or document they must exist on subclasses
so tests still run.
source/tests/pt/model/test_ener_spin_model.py (1)

310-383: The new public DPSpinModel.call() and call_lower() methods are not tested against their PT equivalents.

test_dp_consistency correctly verifies internal parity between the DP and PT backends using call_common() / call_common_lower() paired with PT's forward_common() / forward_common_lower(). Both return the same internal keys (energy, energy_redu).

However, the public API methods DPSpinModel.call() and call_lower() translate these internal keys to user-facing keys (e.g., atom_energy, energy, force, force_mag), matching what PT's public forward() / forward_lower() methods expose. Currently, no test verifies that dp_model.call(...) and self.model.forward(...) return consistent translated outputs.

Consider adding a test section within test_dp_consistency (or as a separate method) that calls dp_model.call() with translated output keys and compares against self.model.forward(...) results, similar to how test_output_shape verifies the PT public interface.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/model/test_ener_spin_model.py` around lines 310 - 383, Add
assertions that DPSpinModel.call() and call_lower() produce the same user-facing
outputs as PT's forward() and forward_lower(): after creating dp_model
(DPSpinModel.deserialize(...)), call dp_model.call(...) and compare its returned
keys (e.g., "atom_energy", "energy", "force", "force_mag") to
self.model.forward(...) outputs converted via to_numpy_array using
np.testing.assert_allclose with rtol/atol=self.prec; likewise call
dp_model.call_lower(...) and compare to self.model.forward_lower(...). Use the
same inputs already prepared in test_dp_consistency (coord, atype, spin, cell
and extended versions) and mirror the existing assert_allclose pattern for
energy/energy_redu so the public-key translation is validated.
deepmd/dpmodel/model/spin_model.py (2)

390-393: Pervasive code duplication: extract a _get_var_name() helper.

The identical 3-line pattern — model_output_type(), pop("mask"), var_name = model_output_type[0] — is repeated six times across call_common, call, call_common_lower, call_lower, translated_output_def, and model_output_def. Extracting it into a private helper eliminates the duplication and the risk of the implementations diverging.

♻️ Suggested helper
+    def _get_var_name(self) -> str:
+        """Return the primary output variable name, stripping the 'mask' sentinel."""
+        model_output_type = self.backbone_model.model_output_type()
+        if "mask" in model_output_type:
+            model_output_type.pop(model_output_type.index("mask"))
+        return model_output_type[0]

Then replace every repeated block with var_name = self._get_var_name().

Also applies to: 473-476, 552-555, 642-645, 672-675

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/model/spin_model.py` around lines 390 - 393, Extract the
repeated 3-line pattern into a private helper named _get_var_name(self) that
calls self.backbone_model.model_output_type(), removes "mask" if present, and
returns the first element; then replace the duplicated blocks in the methods
call_common, call, call_common_lower, call_lower, translated_output_def, and
model_output_def with var_name = self._get_var_name(); ensure the helper is used
everywhere the pattern appears (previously at the locations around lines 390,
473, 552, 642, 672) to remove duplication and keep behavior identical.

419-423: Nit: reuse already-computed nframes / nloc instead of re-destructuring atype.shape.

nframes_m and nloc_m are identical to nframes and nloc from line 376.

✨ Suggested simplification
-        if "mask_mag" not in model_ret:
-            nframes_m, nloc_m = atype.shape[:2]
-            atomic_mask = self.virtual_scale_mask[atype].reshape([nframes_m, nloc_m, 1])
-            model_ret["mask_mag"] = atomic_mask > 0.0
+        if "mask_mag" not in model_ret:
+            atomic_mask = self.virtual_scale_mask[atype].reshape([nframes, nloc, 1])
+            model_ret["mask_mag"] = atomic_mask > 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/dpmodel/model/spin_model.py` around lines 419 - 423, The code
recomputes nframes_m and nloc_m from atype.shape to build mask_mag; instead
reuse the already-computed nframes and nloc variables (from earlier in this
scope) when creating atomic_mask via
self.virtual_scale_mask[atype].reshape([nframes, nloc, 1]) and set
model_ret["mask_mag"] = atomic_mask > 0.0, avoiding the redundant destructuring
of atype.shape.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/consistent/model/test_spin_ener.py`:
- Around line 336-341: Remove the unused variable nall which is assigned from
extended_coord.shape[1] but never used; in the block around the assignment to
nall and the subsequent np.take_along_axis call that sets self.extended_spin,
delete the nall line so only the np.take_along_axis call using mapping remains
(keep extended_coord, mapping and self.extended_spin intact).

---

Nitpick comments:
In `@deepmd/dpmodel/model/spin_model.py`:
- Around line 390-393: Extract the repeated 3-line pattern into a private helper
named _get_var_name(self) that calls self.backbone_model.model_output_type(),
removes "mask" if present, and returns the first element; then replace the
duplicated blocks in the methods call_common, call, call_common_lower,
call_lower, translated_output_def, and model_output_def with var_name =
self._get_var_name(); ensure the helper is used everywhere the pattern appears
(previously at the locations around lines 390, 473, 552, 642, 672) to remove
duplication and keep behavior identical.
- Around line 419-423: The code recomputes nframes_m and nloc_m from atype.shape
to build mask_mag; instead reuse the already-computed nframes and nloc variables
(from earlier in this scope) when creating atomic_mask via
self.virtual_scale_mask[atype].reshape([nframes, nloc, 1]) and set
model_ret["mask_mag"] = atomic_mask > 0.0, avoiding the redundant destructuring
of atype.shape.

In `@source/tests/consistent/model/test_spin_ener.py`:
- Around line 94-112: The two methods get_reference_backend and pass_data_to_cls
are duplicated across test classes; extract them into a shared base class (e.g.,
create class _SpinEnerTestBase) that defines get_reference_backend and
pass_data_to_cls, move the implementations there, and have the existing test
classes (the ones currently defining those methods) inherit from
_SpinEnerTestBase and remove their local copies; ensure the base class provides
any attributes used (skip_pt, skip_dp, RefBackend, additional_data) or document
they must exist on subclasses so tests still run.

In `@source/tests/pt/model/test_ener_spin_model.py`:
- Around line 310-383: Add assertions that DPSpinModel.call() and call_lower()
produce the same user-facing outputs as PT's forward() and forward_lower():
after creating dp_model (DPSpinModel.deserialize(...)), call dp_model.call(...)
and compare its returned keys (e.g., "atom_energy", "energy", "force",
"force_mag") to self.model.forward(...) outputs converted via to_numpy_array
using np.testing.assert_allclose with rtol/atol=self.prec; likewise call
dp_model.call_lower(...) and compare to self.model.forward_lower(...). Use the
same inputs already prepared in test_dp_consistency (coord, atype, spin, cell
and extended versions) and mirror the existing assert_allclose pattern for
energy/energy_redu so the public-key translation is validated.

@wanghan-iapcm wanghan-iapcm added the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
@github-actions github-actions bot removed the Test CUDA Trigger test CUDA workflow label Feb 20, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
source/tests/consistent/model/test_spin_ener.py (1)

105-112: pass_data_to_cls and get_reference_backend are duplicated verbatim.

Both TestSpinEner (lines 94–112) and TestSpinEnerLower (lines 236–254) define identical pass_data_to_cls, get_reference_backend, and build_tf bodies. Extracting them into a shared mixin reduces maintenance surface:

♻️ Suggested refactor
+class _SpinEnerTestBase:
+    def get_reference_backend(self):
+        if not self.skip_pt:
+            return self.RefBackend.PT
+        if not self.skip_dp:
+            return self.RefBackend.DP
+        raise ValueError("No available reference")
+
+    def pass_data_to_cls(self, cls, data):
+        data = copy.deepcopy(data)
+        if cls is SpinModelDP:
+            return get_model_dp(data)
+        elif cls is SpinEnergyModelPT:
+            return get_model_pt(data)
+        return cls(**data, **self.additional_data)
+
+    def build_tf(self, obj, suffix):
+        raise NotImplementedError("no TF in this test")
+
 class TestSpinEner(_SpinEnerTestBase, CommonTest, ModelTest, unittest.TestCase):
     ...
-    def get_reference_backend(self): ...
-    def pass_data_to_cls(self, cls, data): ...
-    def build_tf(self, obj, suffix): ...

 class TestSpinEnerLower(_SpinEnerTestBase, CommonTest, ModelTest, unittest.TestCase):
     ...
-    def get_reference_backend(self): ...
-    def pass_data_to_cls(self, cls, data): ...
-    def build_tf(self, obj, suffix): ...

Also applies to: 247-254

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/consistent/model/test_spin_ener.py` around lines 105 - 112, Two
test classes (TestSpinEner and TestSpinEnerLower) duplicate pass_data_to_cls,
get_reference_backend, and build_tf; extract these shared methods into a single
mixin (e.g., SpinEnerTestMixin) and have both test classes inherit from it. Move
the implementations of pass_data_to_cls, get_reference_backend, and build_tf out
of TestSpinEner and TestSpinEnerLower into the mixin and update both classes to
remove their local definitions and subclass the new mixin so existing references
to get_model_dp/get_model_pt, SpinModelDP, SpinEnergyModelPT, and
self.additional_data continue to resolve.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/consistent/model/test_spin_ener.py`:
- Around line 199-215: The DP branch in the test's extract_ret currently returns
SKIP_FLAG for force/force_mag (and extended force fields elsewhere) despite
SpinModelDP.call() and SpinModelDP.call_lower() computing and returning these
when the backbone's do_grad_r(var_name) is True; update both extract_ret
overrides to assert/validate the force outputs instead of skipping them (replace
SKIP_FLAG with the corresponding ret["force"].ravel() and
ret["force_mag"].ravel() or their extended counterparts) so the DP path mirrors
the PT path, or alternatively ensure the test sets the backbone model to return
do_grad_r() == False so forces are legitimately absent (reference extract_ret,
SpinModelDP.call, SpinModelDP.call_lower, and translated_output_def to locate
related logic).

---

Nitpick comments:
In `@source/tests/consistent/model/test_spin_ener.py`:
- Around line 105-112: Two test classes (TestSpinEner and TestSpinEnerLower)
duplicate pass_data_to_cls, get_reference_backend, and build_tf; extract these
shared methods into a single mixin (e.g., SpinEnerTestMixin) and have both test
classes inherit from it. Move the implementations of pass_data_to_cls,
get_reference_backend, and build_tf out of TestSpinEner and TestSpinEnerLower
into the mixin and update both classes to remove their local definitions and
subclass the new mixin so existing references to get_model_dp/get_model_pt,
SpinModelDP, SpinEnergyModelPT, and self.additional_data continue to resolve.

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Feb 20, 2026
Merged via the queue into deepmodeling:master with commit 367e626 Feb 20, 2026
73 checks passed
@wanghan-iapcm wanghan-iapcm deleted the refact-model-output branch February 20, 2026 15:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants