perf(pt_expt): use inductor+dynamic for torch.compile training#5393
perf(pt_expt): use inductor+dynamic for torch.compile training#5393wanghan-iapcm wants to merge 11 commits intodeepmodeling:masterfrom
Conversation
Add max_autotune, epilogue_fusion, triton.cudagraphs, max_fusion_size options to match the reference implementation.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughTraining compilation now uses shape‑polymorphic symbolic FX tracing and fully dynamic Changes
Sequence Diagram(s)sequenceDiagram
participant Loader as DataLoader
participant Trainer as TrainingLoop
participant Tracer as make_fx (symbolic)
participant Compiler as torch.compile (Inductor, dynamic)
participant Runtime as CompiledModel
Loader->>Trainer: provide sample batch (varying nall)
Trainer->>Tracer: trace forward_lower (symbolic)
Tracer-->>Trainer: symbolic FX graph
Trainer->>Compiler: torch.compile(graph, backend=inductor, dynamic=True)
Compiler-->>Trainer: compiled callable
Trainer->>Runtime: wrap compiled callable
Loader->>Trainer: subsequent batches (varying nall)
Trainer->>Runtime: execute compiled forward per batch
Runtime-->>Trainer: outputs (loss, gradients handled by dynamic runtime)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
source/tests/pt_expt/test_training.py (1)
167-177:⚠️ Potential issue | 🟡 MinorAdd the required 60s timeout to this training test.
This new training-path test does not set the repository’s required timeout, so a compile regression can hang CI instead of failing fast.
As per coding guidelines,
**/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_training.py` around lines 167 - 177, This training test lacks the repository-required 60s timeout; add a 60-second timeout to test_compiled_handles_varying_nall (in class TestCompiledDynamicShapes) by decorating the test method with a timeout decorator (e.g., `@pytest.mark.timeout`(60)) and import pytest at the top if not present, so the test will fail fast instead of hanging CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 199-211: The loop needs to record and assert that the
dynamic-shape path ran by capturing distinct nall values returned by the model
and use the trainer's stepping helper instead of calling optimizer.step()
directly: inside the loop keep a set (e.g. observed_nall) and extract nall from
the wrapper's extra output (the _more_loss/aux dict returned by
trainer.wrapper(**inp, cur_lr=lr, label=lab)), add it to the set, call the
trainer's stepping helper (use Trainer._optimizer_step or the public
trainer.step helper instead of trainer.optimizer.step()), and after the loop
assert that len(observed_nall) >= 2 to prove at least two different nall values
were seen; also retain the existing finite-loss assertions.
---
Outside diff comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 167-177: This training test lacks the repository-required 60s
timeout; add a 60-second timeout to test_compiled_handles_varying_nall (in class
TestCompiledDynamicShapes) by decorating the test method with a timeout
decorator (e.g., `@pytest.mark.timeout`(60)) and import pytest at the top if not
present, so the test will fail fast instead of hanging CI.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2f7348ce-2f4d-4c7f-b068-a1d65cf52d0b
📒 Files selected for processing (2)
deepmd/pt_expt/train/training.pysource/tests/pt_expt/test_training.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
deepmd/pt_expt/train/training.py (1)
312-312:⚠️ Potential issue | 🟠 MajorRemove or use the unused local at Line 312
actual_nallis assigned but never read. This has already been flagged by prior scanning and may fail lint/static-analysis gates.Suggested change
- actual_nall = ext_coord.shape[1] out: dict[str, torch.Tensor] = {}As per coding guidelines:
**/*.py: Install linter and runruff check .before committing changes or the CI will fail.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` at line 312, Remove the unused local variable assignment actual_nall = ext_coord.shape[1] (it is assigned but never read); either delete this line or replace its usage where intended—search for actual_nall or ext_coord in the surrounding function in training.py and, if a column count is needed, use ext_coord.shape[1] directly or assign to a used variable name so the value is consumed.
🧹 Nitpick comments (1)
deepmd/pt_expt/train/training.py (1)
221-225: Avoid mutating caller-ownedcompile_optsin place
pop()/setdefault()currently mutate the dict from training config. A local copy is safer and avoids side effects if the same config is reused.Suggested change
def _trace_and_compile( @@ - # Override backend and dynamic — the inductor backend with + # Work on a local copy to avoid mutating caller-owned config. + compile_opts = deepcopy(compile_opts) + + # Override backend and dynamic — the inductor backend with # dynamic=True handles varying shapes automatically. compile_opts.pop("dynamic", None) compile_opts.pop("backend", None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/train/training.py` around lines 221 - 225, The code mutates the caller-owned dict compile_opts using pop() and by adding "options"; instead create a shallow copy (e.g., local_compile_opts = compile_opts.copy()) and operate on that copy, then use local_compile_opts.pop("dynamic", None), local_compile_opts.pop("backend", None) and ensure "options" exists on local_compile_opts before setting opts = local_compile_opts["options"]; leave the original compile_opts untouched so callers can reuse it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 312: Remove the unused local variable assignment actual_nall =
ext_coord.shape[1] (it is assigned but never read); either delete this line or
replace its usage where intended—search for actual_nall or ext_coord in the
surrounding function in training.py and, if a column count is needed, use
ext_coord.shape[1] directly or assign to a used variable name so the value is
consumed.
---
Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 221-225: The code mutates the caller-owned dict compile_opts using
pop() and by adding "options"; instead create a shallow copy (e.g.,
local_compile_opts = compile_opts.copy()) and operate on that copy, then use
local_compile_opts.pop("dynamic", None), local_compile_opts.pop("backend", None)
and ensure "options" exists on local_compile_opts before setting opts =
local_compile_opts["options"]; leave the original compile_opts untouched so
callers can reuse it.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 547d2b4a-1e06-47cc-9b70-259db8391305
📒 Files selected for processing (1)
deepmd/pt_expt/train/training.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5393 +/- ##
==========================================
+ Coverage 80.33% 80.35% +0.01%
==========================================
Files 819 819
Lines 85356 85425 +69
Branches 4139 4139
==========================================
+ Hits 68571 68643 +72
+ Misses 15509 15506 -3
Partials 1276 1276 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
e332640 to
d5fcd61
Compare
The env_mat_stat code calls descriptor.get_rcut_smth() and descriptor.get_env_protection() during compute_input_stats, but DescrptBlockRepformers only had the attributes without accessor methods, causing AttributeError when training DPA2 models.
d5fcd61 to
e48497c
Compare
Test get_rcut_smth() and get_env_protection() (default and custom values) on DescrptBlockRepformers.
Add silut/custom_silu activation to _torch_activation using native torch ops (torch.sigmoid, torch.tanh, torch.where) for make_fx compatibility. Add pt_expt unit tests and cross-backend consistency tests covering multiple thresholds with inputs spanning both silu and tanh branches.
tests/pt/__init__.py sets torch.set_default_device("cuda:9999999"),
which causes bare torch.tensor() calls to attempt CUDA init on
CPU-only CI. Use the pt_expt DEVICE (same pattern as the pt tests
use their own DEVICE via to_torch_tensor).
Remove unused `actual_nall` in training.py and unused `threshold` in test_activation.py.
Remove unused variables (actual_nall, threshold) flagged by CodeQL. Use trainer._optimizer_step() instead of trainer.optimizer.step() in the dynamic-shapes test to match the real training path.
…s test When make_fx traces with nframes=1, the symbolic tracer specialises the batch dimension to the concrete value 1, causing failures when later batches have nframes>1 (e.g. with batch_size: "auto" across systems of different atom counts). Fix by duplicating the trace sample to nframes=2 before tracing, forcing a symbolic int for the batch dimension. Add TestCompiledVaryingNatoms: compiled and uncompiled training with a 192-atom and a synthetic 6-atom system using batch_size: "auto", with dp_random.choice mocked to deterministically alternate between systems.
OutisLi
left a comment
There was a problem hiding this comment.
please add a test for common silu activation training perhaps
| fparam: torch.Tensor | None, | ||
| aparam: torch.Tensor | None, | ||
| ) -> dict[str, torch.Tensor]: | ||
| extended_coord = extended_coord.detach().requires_grad_(True) |
There was a problem hiding this comment.
should not detach the grad inside the function which will be traced, detach to a leaf node before tracing and used as compiled function
Avoid mutating caller-owned compile_opts dict in _trace_and_compile. Add compiled training test with silu activation.
solved by 7670e6e |
OutisLi
left a comment
There was a problem hiding this comment.
A few issues with the current implementation:
1. Indiscriminate detach removal is unsafe
_remove_detach_nodes removes every aten.detach.default node in the graph. Not all of them come from the autograd saved-tensor mechanism — user-explicit .detach() calls (e.g. gradient boundary management, virial computation) also appear as aten.detach.default and are semantically meaningful.
Although empirically removing all of them may not cause incorrect results in the current graph (because make_fx has already materialized the first-order backward, and the user-explicit detach only affects the already-frozen forward ops), this is fragile and does not preserve the original semantics. If the model structure changes in the future, blind removal could silently break user-intended gradient boundaries.
The two categories can be distinguished by graph topology alone, without hard-coding any op names. Autograd-inserted detach always forms a double-detach chain (forward_op → detach_A → detach_B → backward_use), while user-explicit detach is structurally isolated. Three rules suffice:
- Chain inner: input is another detach node.
- Dead node: no downstream users.
- Chain head: all users are detach nodes.
Anything else is user-explicit and should be preserved. Classification must happen on the original graph before any mutation (two-pass: collect, then remove), because removing chain-head nodes changes the input of chain-inner nodes, causing misclassification in a single pass.
Suggested implementation:
def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
_DETACH = torch.ops.aten.detach.default
def _is_detach(n: torch.fx.Node) -> bool:
return n.op == "call_function" and n.target == _DETACH
# Pass 1: classify on the unmodified graph
to_remove = []
for node in gm.graph.nodes:
if not _is_detach(node):
continue
inp = node.args[0]
users = list(node.users.keys())
is_chain_inner = _is_detach(inp)
is_dead = len(users) == 0
is_chain_head = len(users) > 0 and all(_is_detach(u) for u in users)
if is_chain_inner or is_dead or is_chain_head:
to_remove.append(node)
# Pass 2: remove after classification is complete
for node in to_remove:
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
gm.graph.lint()
gm.recompile()2. Must guard with training mode — eval-time stripping triggers silu_backward derivative error
The stripping should only be applied when tracing in training mode (create_graph=True). In eval mode (create_graph=False), the saved-tensor detach nodes are correct and harmless. Removing them in eval mode causes inductor to attempt higher-order differentiation of ops like aten.silu_backward, which do not have registered higher-order derivatives in PyTorch. This is a pre-existing PyTorch limitation, not something the PR can fix, so the correct approach is:
traced = make_fx(compute_fn, ...)(*trace_args)
if self.training:
_strip_saved_tensor_detach(traced)
compiled = torch.compile(traced, ...)3. A unit test is needed, and it must use random parameter initialization
This bug is nearly invisible under default (zero or near-zero) initialization. At zero init, tanh(0) = 0 and sigmoid(0) = 0.5 produce saved tensors at special points where the second-order gradient contribution vanishes or is masked by numerical noise. In my testing, zero-init showed 0 parameter gradient mismatches; switching to random init immediately exposed ~80% relative error in 5 out of 54 parameters.
The test should:
- Build paired models (eager vs compiled) with shared random weights
- Compare force-loss (
(force**2).sum()) backward gradients (second-order) across all parameters - Use tight tolerance — the differences from this bug are O(1) relative, not floating-point noise
Minimal structure:
def test_compile_second_order_grad_matches_eager(self):
model_eager = build_model(use_compile=False)
# Random init is critical — zero init hides the bug
torch.manual_seed(42)
with torch.no_grad():
for p in model_eager.parameters():
p.copy_(torch.randn_like(p) * 0.1)
model_compiled = build_model(use_compile=True)
model_compiled.load_state_dict(model_eager.state_dict())
model_eager.train(); model_compiled.train()
# Forward match
out_e = model_eager(coord, atype, box=box)
out_c = model_compiled(coord, atype, box=box)
torch.testing.assert_close(out_e["force"], out_c["force"], ...)
# Second-order: force-loss backward
model_eager.zero_grad(); model_compiled.zero_grad()
(out_e["force"] ** 2).sum().backward()
(out_c["force"] ** 2).sum().backward()
for (ne, pe), (_, pc) in zip(
model_eager.named_parameters(),
model_compiled.named_parameters(),
):
ge = pe.grad if pe.grad is not None else torch.zeros_like(pe)
gc = pc.grad if pc.grad is not None else torch.zeros_like(pc)
torch.testing.assert_close(ge, gc, atol=1e-5, rtol=1e-5,
msg=f"Grad mismatch: {ne}")|
Solution for # Decompose silu_backward into primitive ops (sigmoid + mul + ...)
# so that inductor can compile the eval graph without requiring a
# higher-order derivative that PyTorch does not register for silu.
from torch._decomp import get_decompositions
decomp_table = get_decompositions([torch.ops.aten.silu_backward.default])
traced = make_fx(
compute_fn,
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
decomposition_table=decomp_table,
)(
trace_coord,
trace_atype,
trace_edge_index,
trace_edge_vec,
trace_edge_mask,
trace_fp,
trace_ap,
) |
…modeling#5393 Add silut/custom_silu support to _torch_activation using native torch ops (sigmoid, tanh, where) so the custom silu stays traceable by make_fx / torch.export. Cross-backend consistency tests cover multiple thresholds across the silu/tanh branches, and a pt_expt unit file exercises default/custom threshold, gradient flow, make_fx, and torch.export. Also port DescrptBlockRepformers accessor tests (get_rcut_smth, get_env_protection). The underlying accessor methods already exist on this branch; these tests guard against regressions.
…deling#5393 Adds the remaining tests from PR deepmodeling#5393 that were not yet on this branch: ``test_training_loop_compiled_silu`` (silu activation under torch.compile) and ``TestCompiledVaryingNatoms`` (compiled training across systems with different atom counts). Also drops a stray unused ``threshold`` variable in ``test_silut_below_threshold_is_silu`` to match the upstream PR.
Summary
aot_eagerbackend + manual nall padding withinductorbackend +dynamic=Truefor training compilationmake_fx(tracing_mode="symbolic")instead oftracing_mode="real"to capture shape-polymorphic opsshape_padding=True,max_autotune=False,epilogue_fusion=False,triton.cudagraphs=False,max_fusion_size=8_CompiledModel._recompile, max_nall estimation from 20 sampled batches, etc.)silut/custom_siluactivation support to pt_expt_torch_activation(needed for DPA3)batch_size: "auto"with mixed-size systems)Speed Benchmark (V100 GPU)
DPA1 (se_atten_compressible: rcut=6, sel=120, fitting=[240,240,240], float64)
DPA2 (input_torch_small, float32, bs=1)
DPA3 (silut:10.0, static sel, float32, bs=1)
Convergence Benchmark (1000 steps, V100)
DPA1 (se_atten_compressible) — rmse_f_val
DPA2 — val loss / e_rmse / f_rmse
DPA3 (silut:10.0) — val loss / e_rmse / f_rmse
All models converge comparably. Variation is within normal run-to-run noise from random seeds/batch ordering.
Varying natoms + batch size
Compiled training with
batch_size: "auto"across systems of different atom counts (192-atom + 6-atom) works correctly. Bothnframesandnatomsvary across steps. Tested locally (se_e2_a) and on V100 (DPA3 with 192+64 atoms).Known limitations
use_dynamic_sel: truecannot be compiled (data-dependentint()inget_graph_index). Only static sel configs are supported.Test plan