Skip to content

feat(jax): add training#5460

Open
njzjz-bot wants to merge 4 commits into
deepmodeling:masterfrom
njzjz-bothub:cleanup/parallel-no-hessian
Open

feat(jax): add training#5460
njzjz-bot wants to merge 4 commits into
deepmodeling:masterfrom
njzjz-bothub:cleanup/parallel-no-hessian

Conversation

@njzjz-bot
Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot commented May 26, 2026

Summary

This PR ports the JAX training entrypoint from the parallel branch onto the current deepmodeling/deepmd-kit master as a local-only training path.

The change keeps the useful JAX trainer/CLI pieces while deliberately removing the parallel/distributed parts requested for cleanup:

  • add a JAX train entrypoint and wire it into the JAX backend command path
  • add local JAX trainer infrastructure for model initialization, data statistics, loss setup, training, validation, checkpointing, and model export
  • use the current dpmodel compute_or_load_stat data-stat practice from master
  • remove parallel/sharding-specific behavior from the training path
  • remove Hessian-specific behavior from the training path
  • map the lower-interface model outputs into the keys expected by EnergyLoss
  • use communicate_extended_output so extended/ghost atom force contributions are scattered back to local atoms correctly
  • add regression coverage for the local JAX training entrypoint and cleanup constraints

Tests

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)

Summary by CodeRabbit

  • New Features

    • Added a CLI entrypoint to run JAX train/freeze commands.
    • Backend hook now returns the JAX entrypoint, enabling invocation as the JAX backend.
    • New JAX trainer providing training/validation flow, checkpointing, learning-rate scheduling, input preparation and data conversion.
    • Enhanced runtime summary and logging (device/backend info, GPU counts, JAX version).
  • Tests

    • Added an end-to-end test that runs a single-step JAX training workflow and verifies produced artifacts.

Review Change Stack

Comment thread deepmd/jax/entrypoints/train.py Fixed
Comment thread deepmd/jax/entrypoints/train.py Fixed
Comment thread deepmd/jax/entrypoints/train.py Fixed
Comment thread deepmd/jax/train/trainer.py Fixed
Comment thread deepmd/jax/train/trainer.py Fixed
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds JAX training support: backend wiring to return a CLI main, a CLI dispatcher, a train entrypoint with SummaryPrinter and update_sel, a DPTrainer implementing the JAX/Optax training loop with checkpointing and helpers, and an end-to-end test that runs one-step training and checks outputs.

Changes

JAX Training Implementation

Layer / File(s) Summary
Backend Integration and CLI Entry Point
deepmd/backend/jax.py, deepmd/jax/entrypoints/__init__.py, deepmd/jax/entrypoints/main.py
JAX backend entry_point_hook now returns the CLI main. main.py provides argument parsing, logging setup, and dispatch for train and freeze commands.
Training Entrypoint Orchestration
deepmd/jax/entrypoints/train.py
Adds train() orchestration, SummaryPrinter, and update_sel; loads and normalizes control JSON, writes resolved JSON to output, instantiates DPTrainer, optionally seeds per-process RNG, builds datasets, and invokes model.train(...). Unsupported options raise NotImplementedError.
DPTrainer Training Loop and Utilities
deepmd/jax/train/__init__.py, deepmd/jax/train/trainer.py
Implements DPTrainer (init from scratch/init/restart, LR scheduling, loss), JIT-compiled loss/grad updates with Optax, periodic logging and Orbax checkpointing (stable pointer + cleanup), learning-curve output, prepare_input() for neighbor lists/ghosts/normalization, and convert_numpy_data_to_jax_data().
End-to-End Training Test
source/tests/jax/test_training.py
New test runs one-step training with an injected model config, asserts out.json, lcurve.out, checkpoint, and model-1.jax/ exist and that lcurve.out includes step 1.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat(jax): add training' clearly and concisely summarizes the main change: adding JAX training functionality to the codebase, which aligns with the PR objectives of porting the JAX training entrypoint and implementing local JAX trainer infrastructure.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/entrypoints/train.py`:
- Around line 136-137: The train() function currently calls update_sel(jdata)
unconditionally, ignoring the skip_neighbor_stat flag; modify train() to check
skip_neighbor_stat before calling update_sel so neighbor-stat updates only run
when skip_neighbor_stat is false (or when a corresponding flag like
skip_neighbor_stat is not set), e.g., wrap the update_sel(jdata) call in an
if/else that respects the skip_neighbor_stat parameter/attribute and ensure any
downstream use of jdata remains correct when the update is skipped.
- Around line 207-217: The code builds train_data with get_data then returns
jdata_cpy without applying BaseModel.update_sel, causing an expensive no-op
path; short-circuit this by avoiding the expensive get_data/update sequence and
immediately return jdata_cpy (or jdata) until update_sel is restored: remove or
skip the get_data call and the commented update_sel invocation in train.py, add
a clear TODO comment referencing BaseModel.update_sel and the OOM, and ensure no
unused variables (train_data) remain to prevent wasted computation.

In `@deepmd/jax/train/trainer.py`:
- Around line 94-95: The checkpoint currently reads current_step from
model_dict["`@variables`"] (used to set self.start_step) but the save path writes
it under model_def_script, causing restarts to reset; change the save or load to
use the same key unconditionally — either always write current_step into
model_dict["`@variables`"]["current_step"] when saving (the writer around
model_def_script in the saving block at/near the code that sets
model_def_script) or always read from model_def_script when loading (the code
that sets self.start_step); pick one canonical location (prefer `@variables` for
consistency) and update the corresponding save/read logic so both
serialize_from_file/load and the saver use the identical key name.

In `@source/tests/jax/test_training.py`:
- Around line 80-82: The test test_train_entrypoint_runs_one_step_from_scratch
needs a 60s timeout: add a pytest timeout marker (e.g. `@pytest.mark.timeout`(60))
directly above the test function (keeping the existing
`@patch`("deepmd.jax.entrypoints.train.SummaryPrinter.__call__") decorator), and
ensure pytest is imported in the module (add "import pytest" if missing); this
enforces a 60-second cap for the train entrypoint test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: a39dade5-5b8f-47b2-8eb2-58ea8d7233c4

📥 Commits

Reviewing files that changed from the base of the PR and between f39a081 and 59611c8.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/__init__.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/__init__.py
  • deepmd/jax/train/trainer.py
  • source/tests/jax/test_training.py

Comment thread deepmd/jax/entrypoints/train.py Outdated
Comment thread deepmd/jax/entrypoints/train.py Outdated
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread source/tests/jax/test_training.py
@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

❌ Patch coverage is 86.18619% with 46 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.48%. Comparing base (f39a081) to head (a0d66c0).
⚠️ Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/jax/train/trainer.py 88.98% 25 Missing ⚠️
deepmd/jax/entrypoints/train.py 80.82% 14 Missing ⚠️
deepmd/jax/entrypoints/main.py 73.68% 5 Missing ⚠️
deepmd/backend/jax.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5460      +/-   ##
==========================================
+ Coverage   82.46%   82.48%   +0.01%     
==========================================
  Files         829      833       +4     
  Lines       88763    89095     +332     
  Branches     4225     4226       +1     
==========================================
+ Hits        73197    73488     +291     
- Misses      14274    14315      +41     
  Partials     1292     1292              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz force-pushed the cleanup/parallel-no-hessian branch from 59611c8 to 2528142 Compare May 27, 2026 00:38
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
deepmd/jax/train/trainer.py (1)

89-92: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep current_step in a single checkpoint location.

Resume reads current_step from @variables, but save writes it into model_def_script. That makes restarted runs fall back to step 0, which also resets the LR schedule and checkpoint numbering.

Also applies to: 382-392

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/train/trainer.py` around lines 89 - 92, The resume logic reads
current_step from model_dict["`@variables`"] but checkpoint saving stores it under
model_dict["model_def_script"], causing restarts to reset to 0; update the
restart handling (the branch using serialize_from_file and
BaseModel.deserialize) to obtain current_step from the same place the saver
writes (check model_dict["model_def_script"].get("current_step") first, then
fall back to model_dict["`@variables`"].get("current_step", 0)) and assign that
value to self.start_step; apply the same change to the other restart/resume
block around the 382-392 region so both resume paths use a single unified source
for current_step.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 285-286: The code currently opens self.disp_file with mode "w"
which truncates the learning-curve on resume; change the open to use "a"
(append) when self.start_step > 0 (keep "w" only for fresh runs), and ensure the
header is written only when starting from step 0 or when the file is newly
created/empty; update the block around disp_file_fp = open(self.disp_file, "w")
and the similar block around the later lines (the routine that writes the disp
header and per-step rows) to check self.start_step and/or file emptiness before
printing the header so resumed runs append new rows without losing prior data.
- Around line 142-145: The validation batch count from configuration isn't
honored: when tr_data contains "validation_data" the code sets
self.valid_numb_batch only in the if-branch but the else-branch (and other
validation code paths around valid_numb_batch usage) ends up forcing one batch;
update the logic that reads tr_data["validation_data"].get("numb_btch", 1) so
that self.valid_numb_batch is assigned whenever validation_data is present and
ensure all validation paths (including the block around lines ~332-358) use
self.valid_numb_batch rather than hardcoding 1; specifically, inspect and update
the assignments and usages of self.valid_numb_batch in trainer.py so validation
honors validation_data["numb_btch"] uniformly.

---

Duplicate comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 89-92: The resume logic reads current_step from
model_dict["`@variables`"] but checkpoint saving stores it under
model_dict["model_def_script"], causing restarts to reset to 0; update the
restart handling (the branch using serialize_from_file and
BaseModel.deserialize) to obtain current_step from the same place the saver
writes (check model_dict["model_def_script"].get("current_step") first, then
fall back to model_dict["`@variables`"].get("current_step", 0)) and assign that
value to self.start_step; apply the same change to the other restart/resume
block around the 382-392 region so both resume paths use a single unified source
for current_step.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: d2c042f3-0e5a-4549-b6cc-5cc4852485e1

📥 Commits

Reviewing files that changed from the base of the PR and between 59611c8 and 2528142.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/__init__.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/__init__.py
  • deepmd/jax/train/trainer.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/jax/entrypoints/init.py
  • deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/main.py
  • source/tests/jax/test_training.py
  • deepmd/jax/entrypoints/train.py

Comment thread deepmd/jax/train/trainer.py
Comment thread deepmd/jax/train/trainer.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a local-only JAX training entrypoint to DeePMD-Kit, wiring it into the backend entry-point hook so dp --jax train ... can run end-to-end training (including checkpointing and export) without the parallel/distributed features from the prior parallel branch.

Changes:

  • Add a JAX backend CLI entry module and hook it up via deepmd/backend/jax.py so JAX can act as an ENTRY_POINT backend.
  • Introduce a local JAX trainer implementing initialization, stat computation, training/validation loop, learning-curve output, and Orbax checkpoint saving.
  • Add an end-to-end regression test that runs a 1-step JAX training workflow and verifies expected artifacts.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
source/tests/jax/test_training.py New end-to-end regression test for local JAX training entrypoint outputs.
deepmd/jax/train/trainer.py New local JAX trainer with loss wiring, input prep (ghosts/nlist), training loop, lcurve logging, and Orbax checkpoints.
deepmd/jax/train/__init__.py Package init for JAX training module.
deepmd/jax/entrypoints/train.py New JAX train entrypoint (input loading/normalization, data init, training invocation).
deepmd/jax/entrypoints/main.py JAX backend command dispatcher for train/freeze.
deepmd/jax/entrypoints/__init__.py Package init for JAX entrypoints.
deepmd/backend/jax.py Implements entry_point_hook to route CLI execution to the new JAX entrypoints.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/jax/entrypoints/train.py
Comment thread deepmd/jax/entrypoints/train.py Outdated
Comment thread deepmd/jax/entrypoints/train.py Outdated
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread deepmd/jax/train/trainer.py
@njzjz njzjz force-pushed the cleanup/parallel-no-hessian branch from 2528142 to 12e5a77 Compare May 27, 2026 01:50
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/jax/train/trainer.py (1)

289-289: 💤 Low value

File handle not protected against exceptions.

The display file is opened at line 289 and closed at line 415, but if an exception occurs during training, the file won't be closed. Consider using a context manager or try/finally.

💡 Suggested approach
-        disp_file_fp = open(disp_path, disp_mode)
-        for step in range(self.start_step, self.num_steps):
+        with open(disp_path, disp_mode) as disp_file_fp:
+          for step in range(self.start_step, self.num_steps):
             ...
-
-        disp_file_fp.close()

Alternatively, wrap the training loop in try/finally to ensure disp_file_fp.close() is called.

Also applies to: 415-415

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/train/trainer.py` at line 289, The file handle opened with
disp_file_fp = open(disp_path, disp_mode) must be guaranteed to close on
exceptions; wrap the use of disp_file_fp around the training loop in a context
manager (with open(disp_path, disp_mode) as disp_file_fp:) or surround the
training loop with try/finally and call disp_file_fp.close() in finally; update
any references to disp_file_fp inside the loop accordingly so the handle is
scoped correctly and ensure the close is executed even if an exception is
raised.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/entrypoints/train.py`:
- Line 82: The parameter init_frz_model is annotated as str but the code (and
tests) pass None; update the signature where init_frz_model is declared (the
function in deepmd/jax/entrypoints/train.py) to accept None by changing its type
to Optional[str] and ensure Optional is imported from typing (or use Union[str,
None]); keep the default value None if present and leave the runtime check if
init_frz_model: unchanged.

---

Nitpick comments:
In `@deepmd/jax/train/trainer.py`:
- Line 289: The file handle opened with disp_file_fp = open(disp_path,
disp_mode) must be guaranteed to close on exceptions; wrap the use of
disp_file_fp around the training loop in a context manager (with open(disp_path,
disp_mode) as disp_file_fp:) or surround the training loop with try/finally and
call disp_file_fp.close() in finally; update any references to disp_file_fp
inside the loop accordingly so the handle is scoped correctly and ensure the
close is executed even if an exception is raised.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 9ca529ed-c612-4b32-9c2b-47064226d987

📥 Commits

Reviewing files that changed from the base of the PR and between 2528142 and 12e5a77.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/__init__.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/__init__.py
  • deepmd/jax/train/trainer.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/jax/entrypoints/init.py
  • deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/main.py

Comment thread deepmd/jax/entrypoints/train.py Outdated
Comment thread deepmd/jax/train/trainer.py Fixed
@njzjz njzjz force-pushed the cleanup/parallel-no-hessian branch from 12e5a77 to fce455e Compare May 27, 2026 02:19
@njzjz njzjz changed the title feat(jax): add local training entrypoint feat(jax): add training May 27, 2026
@njzjz njzjz mentioned this pull request May 27, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/jax/entrypoints/train.py (1)

103-104: 💤 Low value

Docstring does not reflect the None option for init_frz_model.

The function signature correctly types init_frz_model as str | None, but the docstring says init_frz_model : str. Update for consistency.

📝 Suggested fix
-    init_frz_model : str
+    init_frz_model : str | None
         path to frozen model or None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/jax/entrypoints/train.py` around lines 103 - 104, The docstring for
the parameter init_frz_model is incorrect: change the parameter description for
init_frz_model in the train() (or the top-level function where init_frz_model is
declared) docstring from "init_frz_model : str" to reflect the optional type
(e.g., "init_frz_model : str | None" or "init_frz_model : Optional[str]") and
update the description text to mention that None means no frozen model will be
used; keep the parameter name init_frz_model to locate the entry and ensure the
docstring formatting matches the surrounding style.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 336-373: The validation aggregation can divide by zero when
self.valid_numb_batch is 0; clamp the number of validation iterations before the
loop (e.g., num_valid_batches = max(1, self.valid_numb_batch)) and iterate that
instead of self.valid_numb_batch, then compute valid_more_loss as before; update
the loop that builds valid_more_loss_list and the subsequent averaging so it
uses num_valid_batches, referencing the existing names valid_data,
valid_more_loss_list, prepare_input, loss_fn_more_loss, and valid_more_loss.

---

Nitpick comments:
In `@deepmd/jax/entrypoints/train.py`:
- Around line 103-104: The docstring for the parameter init_frz_model is
incorrect: change the parameter description for init_frz_model in the train()
(or the top-level function where init_frz_model is declared) docstring from
"init_frz_model : str" to reflect the optional type (e.g., "init_frz_model : str
| None" or "init_frz_model : Optional[str]") and update the description text to
mention that None means no frozen model will be used; keep the parameter name
init_frz_model to locate the entry and ensure the docstring formatting matches
the surrounding style.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: ce9bf732-55aa-4d56-874e-d257b75e0a5b

📥 Commits

Reviewing files that changed from the base of the PR and between 12e5a77 and fce455e.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/__init__.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/__init__.py
  • deepmd/jax/train/trainer.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/jax/entrypoints/init.py
  • deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/backend/jax.py
  • source/tests/jax/test_training.py
  • deepmd/jax/entrypoints/main.py

Comment thread deepmd/jax/train/trainer.py
Port the JAX training entrypoint from the parallel branch onto current master,
but keep it local-only by removing distributed, sharding, and Hessian hooks.
Use the current dpmodel compute_or_load_stat data-stat path and add regression
coverage for the cleanup constraints.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
@njzjz njzjz force-pushed the cleanup/parallel-no-hessian branch from fce455e to c4b12bc Compare May 27, 2026 02:28
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@deepmd/jax/train/trainer.py`:
- Around line 534-543: The non-periodic branch leaves cc in its original
(possibly flattened) layout while the periodic branch reshapes to (nframes,
nloc, 3), causing extend_coord_with_ghosts(...) to receive the wrong shape;
update the else branch so coord_normalized = cc.copy() is reshaped to
cc.reshape(nframes, nloc, 3) (or cc.copy().reshape(...)) before calling
extend_coord_with_ghosts, keeping the same variables (cc, bb, normalize_coord,
coord_normalized, extend_coord_with_ghosts, atype, rcut, mapping).
- Around line 393-418: The training loop only saves when (step+1) %
self.save_freq == 0 so runs with total steps < save_freq produce no checkpoint;
extract the existing save logic (the block using nnx.split(model), creating
ckpt_path, ocp.Checkpointer with ocp.args.Composite/state/model_def_script,
log.info, _link_checkpoint and self._cleanup_old_checkpoints plus writing
"checkpoint") into a helper method (e.g., _save_checkpoint(state, step)) and
call it both inside the periodic save branch and once after training completes
(ensuring model_def_script["current_step"]=step+1 is set); this guarantees the
final checkpoint and checkpoint pointer are always written even if the final
step is not on a save_freq boundary.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 62a9d9f7-d097-4375-9a54-8ee8d5e8809e

📥 Commits

Reviewing files that changed from the base of the PR and between fce455e and c4b12bc.

📒 Files selected for processing (7)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/__init__.py
  • deepmd/jax/entrypoints/main.py
  • deepmd/jax/entrypoints/train.py
  • deepmd/jax/train/__init__.py
  • deepmd/jax/train/trainer.py
  • source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (2)
  • deepmd/jax/entrypoints/init.py
  • deepmd/jax/train/init.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • deepmd/backend/jax.py
  • deepmd/jax/entrypoints/main.py
  • source/tests/jax/test_training.py

Comment thread deepmd/jax/train/trainer.py Outdated
Comment thread deepmd/jax/train/trainer.py
Save the final checkpoint even when the last step is not on a save interval and normalize non-periodic coordinates to the expected 3D layout before ghost extension.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
@njzjz njzjz requested a review from wanghan-iapcm May 27, 2026 04:14
Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superseded by inline review comments below.

Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review: 2 issues, posted as inline comments below.

Comment thread deepmd/jax/entrypoints/main.py
Comment thread deepmd/jax/train/trainer.py Outdated
njzjz-bot added 2 commits May 27, 2026 12:03
Add the missing JAX freeze module without Hessian support and cover the CLI dispatch path so importing the JAX backend main entry point no longer fails. Also pass the true atom count to the energy loss instead of the flattened coordinate width.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Apply the import ordering change required by pre-commit.ci.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
@njzjz njzjz requested a review from wanghan-iapcm May 27, 2026 14:23
@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue May 27, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks May 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants