feat(jax): add training#5460
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds JAX training support: backend wiring to return a CLI main, a CLI dispatcher, a train entrypoint with SummaryPrinter and update_sel, a DPTrainer implementing the JAX/Optax training loop with checkpointing and helpers, and an end-to-end test that runs one-step training and checks outputs. ChangesJAX Training Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/entrypoints/train.py`:
- Around line 136-137: The train() function currently calls update_sel(jdata)
unconditionally, ignoring the skip_neighbor_stat flag; modify train() to check
skip_neighbor_stat before calling update_sel so neighbor-stat updates only run
when skip_neighbor_stat is false (or when a corresponding flag like
skip_neighbor_stat is not set), e.g., wrap the update_sel(jdata) call in an
if/else that respects the skip_neighbor_stat parameter/attribute and ensure any
downstream use of jdata remains correct when the update is skipped.
- Around line 207-217: The code builds train_data with get_data then returns
jdata_cpy without applying BaseModel.update_sel, causing an expensive no-op
path; short-circuit this by avoiding the expensive get_data/update sequence and
immediately return jdata_cpy (or jdata) until update_sel is restored: remove or
skip the get_data call and the commented update_sel invocation in train.py, add
a clear TODO comment referencing BaseModel.update_sel and the OOM, and ensure no
unused variables (train_data) remain to prevent wasted computation.
In `@deepmd/jax/train/trainer.py`:
- Around line 94-95: The checkpoint currently reads current_step from
model_dict["`@variables`"] (used to set self.start_step) but the save path writes
it under model_def_script, causing restarts to reset; change the save or load to
use the same key unconditionally — either always write current_step into
model_dict["`@variables`"]["current_step"] when saving (the writer around
model_def_script in the saving block at/near the code that sets
model_def_script) or always read from model_def_script when loading (the code
that sets self.start_step); pick one canonical location (prefer `@variables` for
consistency) and update the corresponding save/read logic so both
serialize_from_file/load and the saver use the identical key name.
In `@source/tests/jax/test_training.py`:
- Around line 80-82: The test test_train_entrypoint_runs_one_step_from_scratch
needs a 60s timeout: add a pytest timeout marker (e.g. `@pytest.mark.timeout`(60))
directly above the test function (keeping the existing
`@patch`("deepmd.jax.entrypoints.train.SummaryPrinter.__call__") decorator), and
ensure pytest is imported in the module (add "import pytest" if missing); this
enforces a 60-second cap for the train entrypoint test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a39dade5-5b8f-47b2-8eb2-58ea8d7233c4
📒 Files selected for processing (7)
deepmd/backend/jax.pydeepmd/jax/entrypoints/__init__.pydeepmd/jax/entrypoints/main.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/__init__.pydeepmd/jax/train/trainer.pysource/tests/jax/test_training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5460 +/- ##
==========================================
+ Coverage 82.46% 82.48% +0.01%
==========================================
Files 829 833 +4
Lines 88763 89095 +332
Branches 4225 4226 +1
==========================================
+ Hits 73197 73488 +291
- Misses 14274 14315 +41
Partials 1292 1292 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
59611c8 to
2528142
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
deepmd/jax/train/trainer.py (1)
89-92:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep
current_stepin a single checkpoint location.Resume reads
current_stepfrom@variables, but save writes it intomodel_def_script. That makes restarted runs fall back to step 0, which also resets the LR schedule and checkpoint numbering.Also applies to: 382-392
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/train/trainer.py` around lines 89 - 92, The resume logic reads current_step from model_dict["`@variables`"] but checkpoint saving stores it under model_dict["model_def_script"], causing restarts to reset to 0; update the restart handling (the branch using serialize_from_file and BaseModel.deserialize) to obtain current_step from the same place the saver writes (check model_dict["model_def_script"].get("current_step") first, then fall back to model_dict["`@variables`"].get("current_step", 0)) and assign that value to self.start_step; apply the same change to the other restart/resume block around the 382-392 region so both resume paths use a single unified source for current_step.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 285-286: The code currently opens self.disp_file with mode "w"
which truncates the learning-curve on resume; change the open to use "a"
(append) when self.start_step > 0 (keep "w" only for fresh runs), and ensure the
header is written only when starting from step 0 or when the file is newly
created/empty; update the block around disp_file_fp = open(self.disp_file, "w")
and the similar block around the later lines (the routine that writes the disp
header and per-step rows) to check self.start_step and/or file emptiness before
printing the header so resumed runs append new rows without losing prior data.
- Around line 142-145: The validation batch count from configuration isn't
honored: when tr_data contains "validation_data" the code sets
self.valid_numb_batch only in the if-branch but the else-branch (and other
validation code paths around valid_numb_batch usage) ends up forcing one batch;
update the logic that reads tr_data["validation_data"].get("numb_btch", 1) so
that self.valid_numb_batch is assigned whenever validation_data is present and
ensure all validation paths (including the block around lines ~332-358) use
self.valid_numb_batch rather than hardcoding 1; specifically, inspect and update
the assignments and usages of self.valid_numb_batch in trainer.py so validation
honors validation_data["numb_btch"] uniformly.
---
Duplicate comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 89-92: The resume logic reads current_step from
model_dict["`@variables`"] but checkpoint saving stores it under
model_dict["model_def_script"], causing restarts to reset to 0; update the
restart handling (the branch using serialize_from_file and
BaseModel.deserialize) to obtain current_step from the same place the saver
writes (check model_dict["model_def_script"].get("current_step") first, then
fall back to model_dict["`@variables`"].get("current_step", 0)) and assign that
value to self.start_step; apply the same change to the other restart/resume
block around the 382-392 region so both resume paths use a single unified source
for current_step.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: d2c042f3-0e5a-4549-b6cc-5cc4852485e1
📒 Files selected for processing (7)
deepmd/backend/jax.pydeepmd/jax/entrypoints/__init__.pydeepmd/jax/entrypoints/main.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/__init__.pydeepmd/jax/train/trainer.pysource/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
- deepmd/jax/entrypoints/init.py
- deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (4)
- deepmd/backend/jax.py
- deepmd/jax/entrypoints/main.py
- source/tests/jax/test_training.py
- deepmd/jax/entrypoints/train.py
There was a problem hiding this comment.
Pull request overview
This PR adds a local-only JAX training entrypoint to DeePMD-Kit, wiring it into the backend entry-point hook so dp --jax train ... can run end-to-end training (including checkpointing and export) without the parallel/distributed features from the prior parallel branch.
Changes:
- Add a JAX backend CLI entry module and hook it up via
deepmd/backend/jax.pyso JAX can act as an ENTRY_POINT backend. - Introduce a local JAX trainer implementing initialization, stat computation, training/validation loop, learning-curve output, and Orbax checkpoint saving.
- Add an end-to-end regression test that runs a 1-step JAX training workflow and verifies expected artifacts.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
source/tests/jax/test_training.py |
New end-to-end regression test for local JAX training entrypoint outputs. |
deepmd/jax/train/trainer.py |
New local JAX trainer with loss wiring, input prep (ghosts/nlist), training loop, lcurve logging, and Orbax checkpoints. |
deepmd/jax/train/__init__.py |
Package init for JAX training module. |
deepmd/jax/entrypoints/train.py |
New JAX train entrypoint (input loading/normalization, data init, training invocation). |
deepmd/jax/entrypoints/main.py |
JAX backend command dispatcher for train/freeze. |
deepmd/jax/entrypoints/__init__.py |
Package init for JAX entrypoints. |
deepmd/backend/jax.py |
Implements entry_point_hook to route CLI execution to the new JAX entrypoints. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
2528142 to
12e5a77
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/jax/train/trainer.py (1)
289-289: 💤 Low valueFile handle not protected against exceptions.
The display file is opened at line 289 and closed at line 415, but if an exception occurs during training, the file won't be closed. Consider using a context manager or try/finally.
💡 Suggested approach
- disp_file_fp = open(disp_path, disp_mode) - for step in range(self.start_step, self.num_steps): + with open(disp_path, disp_mode) as disp_file_fp: + for step in range(self.start_step, self.num_steps): ... - - disp_file_fp.close()Alternatively, wrap the training loop in try/finally to ensure
disp_file_fp.close()is called.Also applies to: 415-415
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/train/trainer.py` at line 289, The file handle opened with disp_file_fp = open(disp_path, disp_mode) must be guaranteed to close on exceptions; wrap the use of disp_file_fp around the training loop in a context manager (with open(disp_path, disp_mode) as disp_file_fp:) or surround the training loop with try/finally and call disp_file_fp.close() in finally; update any references to disp_file_fp inside the loop accordingly so the handle is scoped correctly and ensure the close is executed even if an exception is raised.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/entrypoints/train.py`:
- Line 82: The parameter init_frz_model is annotated as str but the code (and
tests) pass None; update the signature where init_frz_model is declared (the
function in deepmd/jax/entrypoints/train.py) to accept None by changing its type
to Optional[str] and ensure Optional is imported from typing (or use Union[str,
None]); keep the default value None if present and leave the runtime check if
init_frz_model: unchanged.
---
Nitpick comments:
In `@deepmd/jax/train/trainer.py`:
- Line 289: The file handle opened with disp_file_fp = open(disp_path,
disp_mode) must be guaranteed to close on exceptions; wrap the use of
disp_file_fp around the training loop in a context manager (with open(disp_path,
disp_mode) as disp_file_fp:) or surround the training loop with try/finally and
call disp_file_fp.close() in finally; update any references to disp_file_fp
inside the loop accordingly so the handle is scoped correctly and ensure the
close is executed even if an exception is raised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9ca529ed-c612-4b32-9c2b-47064226d987
📒 Files selected for processing (7)
deepmd/backend/jax.pydeepmd/jax/entrypoints/__init__.pydeepmd/jax/entrypoints/main.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/__init__.pydeepmd/jax/train/trainer.pysource/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
- deepmd/jax/entrypoints/init.py
- deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/backend/jax.py
- deepmd/jax/entrypoints/main.py
12e5a77 to
fce455e
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/jax/entrypoints/train.py (1)
103-104: 💤 Low valueDocstring does not reflect the
Noneoption forinit_frz_model.The function signature correctly types
init_frz_modelasstr | None, but the docstring saysinit_frz_model : str. Update for consistency.📝 Suggested fix
- init_frz_model : str + init_frz_model : str | None path to frozen model or None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/entrypoints/train.py` around lines 103 - 104, The docstring for the parameter init_frz_model is incorrect: change the parameter description for init_frz_model in the train() (or the top-level function where init_frz_model is declared) docstring from "init_frz_model : str" to reflect the optional type (e.g., "init_frz_model : str | None" or "init_frz_model : Optional[str]") and update the description text to mention that None means no frozen model will be used; keep the parameter name init_frz_model to locate the entry and ensure the docstring formatting matches the surrounding style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 336-373: The validation aggregation can divide by zero when
self.valid_numb_batch is 0; clamp the number of validation iterations before the
loop (e.g., num_valid_batches = max(1, self.valid_numb_batch)) and iterate that
instead of self.valid_numb_batch, then compute valid_more_loss as before; update
the loop that builds valid_more_loss_list and the subsequent averaging so it
uses num_valid_batches, referencing the existing names valid_data,
valid_more_loss_list, prepare_input, loss_fn_more_loss, and valid_more_loss.
---
Nitpick comments:
In `@deepmd/jax/entrypoints/train.py`:
- Around line 103-104: The docstring for the parameter init_frz_model is
incorrect: change the parameter description for init_frz_model in the train()
(or the top-level function where init_frz_model is declared) docstring from
"init_frz_model : str" to reflect the optional type (e.g., "init_frz_model : str
| None" or "init_frz_model : Optional[str]") and update the description text to
mention that None means no frozen model will be used; keep the parameter name
init_frz_model to locate the entry and ensure the docstring formatting matches
the surrounding style.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ce9bf732-55aa-4d56-874e-d257b75e0a5b
📒 Files selected for processing (7)
deepmd/backend/jax.pydeepmd/jax/entrypoints/__init__.pydeepmd/jax/entrypoints/main.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/__init__.pydeepmd/jax/train/trainer.pysource/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
- deepmd/jax/entrypoints/init.py
- deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/backend/jax.py
- source/tests/jax/test_training.py
- deepmd/jax/entrypoints/main.py
Port the JAX training entrypoint from the parallel branch onto current master, but keep it local-only by removing distributed, sharding, and Hessian hooks. Use the current dpmodel compute_or_load_stat data-stat path and add regression coverage for the cleanup constraints. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
fce455e to
c4b12bc
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 534-543: The non-periodic branch leaves cc in its original
(possibly flattened) layout while the periodic branch reshapes to (nframes,
nloc, 3), causing extend_coord_with_ghosts(...) to receive the wrong shape;
update the else branch so coord_normalized = cc.copy() is reshaped to
cc.reshape(nframes, nloc, 3) (or cc.copy().reshape(...)) before calling
extend_coord_with_ghosts, keeping the same variables (cc, bb, normalize_coord,
coord_normalized, extend_coord_with_ghosts, atype, rcut, mapping).
- Around line 393-418: The training loop only saves when (step+1) %
self.save_freq == 0 so runs with total steps < save_freq produce no checkpoint;
extract the existing save logic (the block using nnx.split(model), creating
ckpt_path, ocp.Checkpointer with ocp.args.Composite/state/model_def_script,
log.info, _link_checkpoint and self._cleanup_old_checkpoints plus writing
"checkpoint") into a helper method (e.g., _save_checkpoint(state, step)) and
call it both inside the periodic save branch and once after training completes
(ensuring model_def_script["current_step"]=step+1 is set); this guarantees the
final checkpoint and checkpoint pointer are always written even if the final
step is not on a save_freq boundary.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 62a9d9f7-d097-4375-9a54-8ee8d5e8809e
📒 Files selected for processing (7)
deepmd/backend/jax.pydeepmd/jax/entrypoints/__init__.pydeepmd/jax/entrypoints/main.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/__init__.pydeepmd/jax/train/trainer.pysource/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
- deepmd/jax/entrypoints/init.py
- deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/backend/jax.py
- deepmd/jax/entrypoints/main.py
- source/tests/jax/test_training.py
Save the final checkpoint even when the last step is not on a save interval and normalize non-periodic coordinates to the expected 3D layout before ghost extension. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
wanghan-iapcm
left a comment
There was a problem hiding this comment.
Code review: 2 issues, posted as inline comments below.
Add the missing JAX freeze module without Hessian support and cover the CLI dispatch path so importing the JAX backend main entry point no longer fails. Also pass the true atom count to the energy loss instead of the flattened coordinate width. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Apply the import ordering change required by pre-commit.ci. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Summary
This PR ports the JAX training entrypoint from the
parallelbranch onto the currentdeepmodeling/deepmd-kitmaster as a local-only training path.The change keeps the useful JAX trainer/CLI pieces while deliberately removing the parallel/distributed parts requested for cleanup:
trainentrypoint and wire it into the JAX backend command pathcompute_or_load_statdata-stat practice from masterEnergyLosscommunicate_extended_outputso extended/ghost atom force contributions are scattered back to local atoms correctlyTests
/tmp/deepmd-jax-venv/bin/python -m pytest -q source/tests/jax/test_training.pyTest Python: https://github.com/njzjz-bothub/deepmd-kit/actions/runs/26464854510Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Summary by CodeRabbit
New Features
Tests