feat(tf): support training stat_file#5551
Conversation
Allow TensorFlow training to accept training/stat_file and reuse saved energy statistics in the same type-map directory layout as PyTorch. This ports the useful part of PR deepmodeling#4926 onto current master and keeps TensorFlow's 1-D fitting bias shape internally. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds ChangesTF stat_file feature
Sequence Diagram(s)sequenceDiagram
participant Entrypoint as train.py _do_work
participant DPTrainer as DPTrainer.build
participant EnerModel as EnerModel._compute_output_stat
participant StatUtil as compute_output_stats
participant FS as DPPath / Filesystem
Entrypoint->>FS: Read training.stat_file from jdata
Entrypoint->>FS: Chief rank creates empty HDF5 or directory
Entrypoint->>Entrypoint: Wrap as DPPath(stat_file_raw, "a")
Entrypoint->>DPTrainer: model.build(..., stat_file_path=DPPath)
DPTrainer->>EnerModel: model.data_stat(data, stat_file_path=DPPath)
EnerModel->>EnerModel: _pack_stat_batches → samples
EnerModel->>FS: _save_observed_types_to_file (restore or compute+save)
EnerModel->>StatUtil: compute_output_stats(samples, stat_file_path, keys=["energy"])
StatUtil->>FS: Restore bias/std arrays if they exist
alt Files missing
StatUtil->>FS: Compute and save bias/std arrays
end
StatUtil-->>EnerModel: bias_out["energy"]
EnerModel->>EnerModel: fitting.bias_atom_e ← flattened bias
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/tf/entrypoints/train.py`:
- Around line 247-252: The code at lines 247-252 fails for nested paths because
parent directories are not created before attempting to initialize the stat file
target. Before the conditional block that checks if stat_file_raw exists, ensure
all parent directories are created using
Path(stat_file_raw).parent.mkdir(parents=True, exist_ok=True). This will create
the full directory hierarchy needed before attempting either the h5py.File()
call for HDF5 files or the Path(stat_file_raw).mkdir() call for directory
creation. The parents=True parameter ensures intermediate directories are
created, and exist_ok=True prevents errors if directories already exist.
In `@deepmd/tf/model/linear.py`:
- Around line 98-103: The data_stat method passes the same stat_file_path to all
child models in the loop, which causes later submodels to incorrectly reuse
previously saved statistics from earlier submodels. Modify the data_stat method
to namespace the stat_file_path for each model by appending a unique identifier
(such as the model index or model name) to the original path before passing it
to model.data_stat(). This ensures each model in self.models writes to and reads
from its own separate stat file.
In `@deepmd/tf/model/pairwise_dprc.py`:
- Around line 325-327: The data_stat method in the pairwise_dprc.py file passes
the same stat_file_path to both qm_model.data_stat() and qmmm_model.data_stat()
calls, which can cause one model to load the other's statistics instead of
computing its own. Modify the stat_file_path parameter for each call to use
separate namespaces or identifiers that distinguish between QM and QMMM models,
ensuring each model loads or computes its own statistics independently.
In `@deepmd/tf/utils/stat.py`:
- Around line 16-19: The functions `_restore_from_file` and
`compute_output_stats` use a mutable default argument `keys=["energy"]` which
triggers Ruff B006 violations. Replace the mutable default list with `keys=None`
in both function signatures, then at the beginning of each function body, add a
check to initialize `keys` to `["energy"]` if it is `None` (e.g., using `if keys
is None: keys = ["energy"]`). Since the parameter is only read and never
mutated, this change is safe and will resolve the Ruff violations.
In `@source/tests/consistent/test_stat_file.py`:
- Around line 115-117: The subprocess.run() call in the test does not include a
timeout parameter, which can cause the test to hang indefinitely if the CLI
command stalls. Add a timeout parameter to the subprocess.run() call to ensure
the test completes within a reasonable timeframe and fails gracefully if the
subprocess takes too long to complete.
- Around line 212-220: The tearDown method is deleting files from the current
working directory without proper scoping, which can remove artifacts created by
other tests or processes. Create a test-specific temporary directory in the
setUp method that is unique to this test instance, configure the test to write
its outputs to this directory, and modify the tearDown method to only clean up
files within this test-specific directory. This ensures that the cleanup in
tearDown is isolated and does not interfere with shared cwd artifacts or other
concurrent tests.
In `@source/tests/tf/test_stat_file_integration.py`:
- Around line 99-103: The current assertion in the stat_path validation block is
conditional on stat_path.exists(), which means if the stat_file is never
created, the assertion is skipped entirely and the test passes without detecting
the failure. Remove the if stat_path.exists() condition and instead assert
unconditionally that the stat_path both exists and is a directory. This ensures
that when the training.stat_file creation is broken, the test will properly fail
rather than silently passing.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 23d2b3aa-2fce-44a0-8c17-e0bda3504fd0
📒 Files selected for processing (14)
deepmd/tf/entrypoints/train.pydeepmd/tf/model/dos.pydeepmd/tf/model/ener.pydeepmd/tf/model/frozen.pydeepmd/tf/model/linear.pydeepmd/tf/model/model.pydeepmd/tf/model/pairwise_dprc.pydeepmd/tf/model/tensor.pydeepmd/tf/train/trainer.pydeepmd/tf/utils/stat.pydeepmd/utils/argcheck.pysource/tests/consistent/test_stat_file.pysource/tests/tf/test_stat_file.pysource/tests/tf/test_stat_file_integration.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5551 +/- ##
==========================================
- Coverage 82.23% 82.22% -0.01%
==========================================
Files 894 902 +8
Lines 102002 103753 +1751
Branches 4276 4433 +157
==========================================
+ Hits 83877 85316 +1439
- Misses 16823 17043 +220
- Partials 1302 1394 +92 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
Persist observed_type for TensorFlow stat files and normalize the stat-file test input before calling the lower-level training helper. Also broadcast the global output std to match the shared statistic logic. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/tf/test_stat_file_integration.py (1)
30-31: ⚡ Quick winAlign test intent with assertions.
test_stat_file_save_and_loadcurrently validates config acceptance and directory creation, but not a load path. Either rename the test/docstring to match current scope or add an assertion that exercises restore semantics.Suggested minimal rename
- def test_stat_file_save_and_load(self) -> None: - """Test that stat_file can be saved and loaded in TF training.""" + def test_stat_file_path_is_accepted_and_created(self) -> None: + """Test that TF training accepts training.stat_file and creates its directory."""Also applies to: 107-110
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/tf/test_stat_file_integration.py` around lines 30 - 31, The test method `test_stat_file_save_and_load` has a docstring claiming it validates both save and load functionality, but the current implementation only tests config acceptance and directory creation without actually exercising the load/restore path. Either rename the test method and its docstring to accurately reflect the current scope (e.g., something like test_stat_file_config_acceptance or test_stat_file_save), or add assertions that verify the restore semantics by actually loading and validating the saved stat_file data after the save operation.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@source/tests/tf/test_stat_file_integration.py`:
- Around line 30-31: The test method `test_stat_file_save_and_load` has a
docstring claiming it validates both save and load functionality, but the
current implementation only tests config acceptance and directory creation
without actually exercising the load/restore path. Either rename the test method
and its docstring to accurately reflect the current scope (e.g., something like
test_stat_file_config_acceptance or test_stat_file_save), or add assertions that
verify the restore semantics by actually loading and validating the saved
stat_file data after the save operation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 3d3940c2-03e6-4c88-b0b0-4cd8fded6d59
📒 Files selected for processing (3)
source/tests/consistent/test_stat_file.pysource/tests/tf/test_stat_file.pysource/tests/tf/test_stat_file_integration.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/tf/test_stat_file.py
- source/tests/consistent/test_stat_file.py
Remove the TensorFlow-specific copy of stat-file helpers and call the dpmodel backend-agnostic implementation after packing TF batches into normalized samples. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/tf/model/ener.py`:
- Around line 236-250: The call to _save_observed_types_to_file at line 239
passes self.type_map which can be None, causing a TypeError in
collect_observed_types when it attempts to index the type_map parameter. Add a
guard condition to check that self.type_map is not None before calling
_save_observed_types_to_file, similar to the existing guard for
self.fitting.atom_ener. Only execute the _save_observed_types_to_file call when
type_map is available to prevent the indexing error in downstream code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 75848502-2d2d-4977-81bd-6239a3c2e2c6
📒 Files selected for processing (1)
deepmd/tf/model/ener.py
| # Reuse the backend-agnostic dpmodel stat implementation instead of | ||
| # maintaining a TensorFlow copy of the same save/load/stat logic. | ||
| sampled = _pack_stat_batches(all_stat) | ||
| _save_observed_types_to_file(stat_file_path, sampled, self.type_map) | ||
| preset_bias = None | ||
| if len(self.fitting.atom_ener) > 0: | ||
| preset_bias = {"energy": self.fitting.atom_ener_v} | ||
| bias_out, _ = compute_output_stats( | ||
| sampled, | ||
| self.ntypes, | ||
| keys=["energy"], | ||
| stat_file_path=stat_file_path, | ||
| rcond=getattr(self.fitting, "rcond", None), | ||
| preset_bias=preset_bias, | ||
| ) |
There was a problem hiding this comment.
Routing through compute_output_stats here changes how the TF energy bias is computed, unconditionally (it runs even when stat_file_path is None). The historical TF path (EnerFitting._compute_output_stats, deepmd/tf/fit/ener.py) reduced one averaged energy per system and regressed nsys points; _pack_stat_batches concatenates every frame, so compute_output_stats regresses per frame — the same objective weighted by each system's frame count. The two coincide only when all systems have equal frame counts (and the mixed_type per-system vs per-frame real_natoms_vec averaging differs too). So bias_atom_e initialization now differs from old TF for datasets with unequal per-system frame counts.
This unification is actually required for the cross-backend stat file to be meaningful — a PT-written bias must match what TF computes — so I think switching to the dpmodel path is the right call (and it's only an initialization that training refines). But it is a silent change on the no-stat-file path for existing TF users. Suggest: (1) call it out in the PR description / changelog, and (2) add an unequal-frame (ideally also mixed_type) case to source/tests/consistent/test_stat_file.py — the current consistency test uses a single / equal-frame system, where per-frame and per-system regression are trivially identical, so it can't actually prove the TF/PT bias equality where it could diverge.
There was a problem hiding this comment.
Thanks, this is a good catch. Addressed in 6110c89, with a follow-up test-scope fix in 0419763 after CI exposed that std_atom_energy is backend-dependent auxiliary data:
- Added an explicit PR-description note that TF energy-bias initialization now intentionally uses the shared dpmodel/PyTorch per-frame regression path even when
training.stat_fileis unset. This may differ from legacy TF's per-system weighting for unequal-frame systems, but it keeps freshly computed TF stats consistent with restored cross-backend stat files. - Added a code comment at the TF energy-model call site documenting the same intentional behavior.
- Added
test_stat_file_consistency_unequal_frame_systems, usingwater/data_0+water/data_1, and now assert TF/PT equality forbias_atom_energyspecifically. That is the stat-file value consumed for the restored energy bias;std_atom_energycan differ slightly between backends and is not used by the TF fitting initialization/reload path.
Local validation: python3 -m py_compile source/tests/consistent/test_stat_file.py, uvx ruff check source/tests/consistent/test_stat_file.py, uvx ruff format --check source/tests/consistent/test_stat_file.py, and git diff --check passed. The full unittest target is still not runnable in this unbuilt checkout because generated package metadata (deepmd._version / deepmd.__about__) is missing.
— OpenClaw 2026.6.8 (844f405), model: custom-chat-jinzhezeng-group/gpt-5.5
| stat_dict = _load_se_input_stats(stat_path, descrpt.get_ntypes(), angular) | ||
| if stat_dict is not None: | ||
| return stat_dict | ||
|
|
||
| stat_dict = compute() | ||
| _save_se_input_stats(stat_path, stat_dict, descrpt.get_ntypes(), angular) | ||
| return stat_dict |
There was a problem hiding this comment.
This load branch (if stat_dict is not None: return stat_dict) is never exercised by a test. Every TF stat-file test only asserts that files were created; none reloads descriptor stats from an existing file and checks they are used (and bit-match a fresh compute). The fitting side does test its load path via the raise_error compute callback in test_fitting_stat.py — the descriptor side should have the equivalent. Also untested: the se_r (angular=False) and se_atten (mixed_types=True) save/load paths — only se_a is covered, so the angular=False round-trip and the mixed_types hash branch never run. A reload-and-compare test (per descriptor type) would close the gap.
There was a problem hiding this comment.
Thanks, addressed in 6110c89.
I added descriptor-stat reload coverage in TestDescriptorStatFile:
test_se_a_descriptor_stats_reload_from_filecovers the angular save/load path.test_se_r_descriptor_stats_reload_from_filecovers the radial-only (angular=False) path.test_se_atten_descriptor_stats_reload_from_mixed_type_hashcovers the mixed-type hash branch used byse_atten.
Each test first saves stats, then calls load_or_compute_se_input_stats again with a compute callback that fails the test if invoked, so the load branch (if stat_dict is not None: return stat_dict) is now exercised directly.
Local validation: python3 -m py_compile on touched files, uvx ruff check, and uvx ruff format --check passed. I could not run the unittest target in this unbuilt checkout because generated package metadata (deepmd._version / deepmd.__about__) is missing.
— OpenClaw 2026.6.8 (844f405), model: custom-chat-jinzhezeng-group/gpt-5.5
Add unequal-frame TF/PT stat-file consistency coverage and descriptor stat-file reload tests for angular, radial-only, and mixed-type hash paths. Document the intentional shared per-frame bias initialization behavior in the TF energy model.\n\nAuthored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Compare only shared output-stat files in the unequal-frame TF/PT regression test. Descriptor input stats are produced by backend-specific pipelines and can differ for multi-system data, so the regression should target the energy-bias stat files that the fix shares across backends.\n\nAuthored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
The unequal-frame consistency test is meant to cover the cross-backend energy-bias stat file. TF and PT can emit different auxiliary std_atom_energy values, so restrict the assertion to the bias consumed by stat-file reloads. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Problem
training/stat_fileis accepted by the shared input schema but was effectively only wired for non-TF backends.master.Change
stat_fileplumbing onto currentmaster: create/openDPPath, pass it throughDPTrainer.build()andModel.data_stat(), and save/load energy statistics under the PyTorch-compatible type-map subdirectory.bias_atom_eas the historical 1-D vector while storing stat files in the cross-backend(ntypes, 1)format.training.stat_fileis not set. This can differ from legacy TF's per-system weighting for unequal-frame systems, but keeps freshly computed TF stats consistent with restored cross-backend stat files.Notes
stat_filein TF #4017.python3 -m py_compileon touched files,uvx ruff checkon touched files, anduvx ruff format --checkon touched files passed.deepmd._version/deepmd.__about__).Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
Summary by CodeRabbit
New Features
training.stat_fileconfiguration option to save training statistics during training. Statistics can be saved to an HDF5 file or directory structure.Tests