Skip to content

perf(pt_expt): use inductor+dynamic for torch.compile training#5393

Draft
wanghan-iapcm wants to merge 11 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-compile-dynamic
Draft

perf(pt_expt): use inductor+dynamic for torch.compile training#5393
wanghan-iapcm wants to merge 11 commits intodeepmodeling:masterfrom
wanghan-iapcm:feat-pt-expt-compile-dynamic

Conversation

@wanghan-iapcm
Copy link
Copy Markdown
Collaborator

@wanghan-iapcm wanghan-iapcm commented Apr 11, 2026

Summary

  • Replace aot_eager backend + manual nall padding with inductor backend + dynamic=True for training compilation
  • Use make_fx(tracing_mode="symbolic") instead of tracing_mode="real" to capture shape-polymorphic ops
  • Inductor options: shape_padding=True, max_autotune=False, epilogue_fusion=False, triton.cudagraphs=False, max_fusion_size=8
  • Removes ~120 lines of manual padding/recompilation infrastructure (_CompiledModel._recompile, max_nall estimation from 20 sampled batches, etc.)
  • Add silut/custom_silu activation support to pt_expt _torch_activation (needed for DPA3)
  • Fix: duplicate trace sample to nframes=2 when nframes=1, so symbolic tracer creates a dynamic batch dimension (supports batch_size: "auto" with mixed-size systems)

Speed Benchmark (V100 GPU)

DPA1 (se_atten_compressible: rcut=6, sel=120, fitting=[240,240,240], float64)

Mode bs=1 bs=4
Uncompiled 21.8 ms 42.9 ms
Old compiled (aot_eager) 18.1 ms (1.20x) 38.3 ms (1.12x)
New compiled (inductor) 9.8 ms (2.22x) 20.4 ms (2.10x)

DPA2 (input_torch_small, float32, bs=1)

Mode time/step
Uncompiled 63.1 ms
Compiled (inductor) 22.1 ms (2.85x)

DPA3 (silut:10.0, static sel, float32, bs=1)

Mode time/step
Uncompiled 164.5 ms
Compiled (inductor) 45.7 ms (3.60x)

Convergence Benchmark (1000 steps, V100)

DPA1 (se_atten_compressible) — rmse_f_val

step Uncompiled Compiled
1 1.37 1.36
500 0.412 0.497
1000 0.291 0.316

DPA2 — val loss / e_rmse / f_rmse

step Uncompiled Compiled
1 24.5 / 1.96e-01 / 0.776 27.2 / 1.98e-01 / 0.861
500 20.8 / 2.45e-02 / 0.659 17.2 / 1.87e-01 / 0.544
1000 7.32 / 3.63e-02 / 0.411 10.3 / 3.65e-02 / 0.576

DPA3 (silut:10.0) — val loss / e_rmse / f_rmse

step Uncompiled Compiled
1 25.3 / 9.17e-02 / 0.800 22.4 / 9.45e-02 / 0.707
500 24.5 / 2.14e-02 / 0.776 24.0 / 1.66e-02 / 0.759
1000 13.1 / 1.13e-02 / 0.692 7.27 / 4.14e-03 / 0.384

All models converge comparably. Variation is within normal run-to-run noise from random seeds/batch ordering.

Varying natoms + batch size

Compiled training with batch_size: "auto" across systems of different atom counts (192-atom + 6-atom) works correctly. Both nframes and natoms vary across steps. Tested locally (se_e2_a) and on V100 (DPA3 with 192+64 atoms).

Known limitations

  • DPA3 with use_dynamic_sel: true cannot be compiled (data-dependent int() in get_graph_index). Only static sel configs are supported.
  • fparam/aparam with varying nframes is untested (the code handles it but no test exercises it).

Test plan

  • All 13 training tests pass locally (including 2 new varying-natoms tests)
  • silut unit tests (make_fx, torch.export, gradient, branch coverage)
  • Cross-backend consistency tests (dpmodel, pt, pt_expt) for all activations including silut variants
  • Varying natoms compiled training (192-atom + 6-atom, batch_size: "auto", mocked system selection)
  • CI passes

Comment thread deepmd/pt_expt/train/training.py Fixed
Add max_autotune, epilogue_fusion, triton.cudagraphs, max_fusion_size
options to match the reference implementation.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 11, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Training compilation now uses shape‑polymorphic symbolic FX tracing and fully dynamic torch.compile(backend="inductor", dynamic=True). Manual max_nall padding, re-tracing/re-compilation logic, and padded inputs were removed; the compiled region accepts varying nall at runtime.

Changes

Cohort / File(s) Summary
Compilation Strategy Overhaul
deepmd/pt_expt/train/training.py
Replaced concrete-shape FX tracing and manual max_nall padding/recompilation with symbolic FX tracing via make_fx(..., tracing_mode="symbolic", _allow_non_fake_inputs=True) and torch.compile(..., backend="inductor", dynamic=True). Removed padded nlist/tracing and simplified compiled-wrapper and force scatter-add logic.
Tests — dynamic shape behavior
source/tests/pt_expt/test_training.py
Renamed TestCompiledRecompileTestCompiledDynamicShapes and test_nall_growth_triggers_recompile()test_compiled_handles_varying_nall(). Test now runs multiple training steps asserting finite losses for varying nall; removed explicit recompilation/state assertions and changed temp dir prefix.
Descriptor accessors added
deepmd/dpmodel/descriptor/repformers.py
Added get_rcut_smth(self) -> float and get_env_protection(self) -> float accessors to DescrptBlockRepformers.

Sequence Diagram(s)

sequenceDiagram
    participant Loader as DataLoader
    participant Trainer as TrainingLoop
    participant Tracer as make_fx (symbolic)
    participant Compiler as torch.compile (Inductor, dynamic)
    participant Runtime as CompiledModel

    Loader->>Trainer: provide sample batch (varying nall)
    Trainer->>Tracer: trace forward_lower (symbolic)
    Tracer-->>Trainer: symbolic FX graph
    Trainer->>Compiler: torch.compile(graph, backend=inductor, dynamic=True)
    Compiler-->>Trainer: compiled callable
    Trainer->>Runtime: wrap compiled callable
    Loader->>Trainer: subsequent batches (varying nall)
    Trainer->>Runtime: execute compiled forward per batch
    Runtime-->>Trainer: outputs (loss, gradients handled by dynamic runtime)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately describes the main change: replacing aot_eager with inductor+dynamic for torch.compile training in the pt_expt module.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
source/tests/pt_expt/test_training.py (1)

167-177: ⚠️ Potential issue | 🟡 Minor

Add the required 60s timeout to this training test.

This new training-path test does not set the repository’s required timeout, so a compile regression can hang CI instead of failing fast.

As per coding guidelines, **/tests/**/*training*.py: Set training test timeouts to 60 seconds maximum for validation purposes, as real training takes hours or days.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt_expt/test_training.py` around lines 167 - 177, This training
test lacks the repository-required 60s timeout; add a 60-second timeout to
test_compiled_handles_varying_nall (in class TestCompiledDynamicShapes) by
decorating the test method with a timeout decorator (e.g.,
`@pytest.mark.timeout`(60)) and import pytest at the top if not present, so the
test will fail fast instead of hanging CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 199-211: The loop needs to record and assert that the
dynamic-shape path ran by capturing distinct nall values returned by the model
and use the trainer's stepping helper instead of calling optimizer.step()
directly: inside the loop keep a set (e.g. observed_nall) and extract nall from
the wrapper's extra output (the _more_loss/aux dict returned by
trainer.wrapper(**inp, cur_lr=lr, label=lab)), add it to the set, call the
trainer's stepping helper (use Trainer._optimizer_step or the public
trainer.step helper instead of trainer.optimizer.step()), and after the loop
assert that len(observed_nall) >= 2 to prove at least two different nall values
were seen; also retain the existing finite-loss assertions.

---

Outside diff comments:
In `@source/tests/pt_expt/test_training.py`:
- Around line 167-177: This training test lacks the repository-required 60s
timeout; add a 60-second timeout to test_compiled_handles_varying_nall (in class
TestCompiledDynamicShapes) by decorating the test method with a timeout
decorator (e.g., `@pytest.mark.timeout`(60)) and import pytest at the top if not
present, so the test will fail fast instead of hanging CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 2f7348ce-2f4d-4c7f-b068-a1d65cf52d0b

📥 Commits

Reviewing files that changed from the base of the PR and between baab3e8 and a183f95.

📒 Files selected for processing (2)
  • deepmd/pt_expt/train/training.py
  • source/tests/pt_expt/test_training.py

Comment thread source/tests/pt_expt/test_training.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
deepmd/pt_expt/train/training.py (1)

312-312: ⚠️ Potential issue | 🟠 Major

Remove or use the unused local at Line 312

actual_nall is assigned but never read. This has already been flagged by prior scanning and may fail lint/static-analysis gates.

Suggested change
-        actual_nall = ext_coord.shape[1]
         out: dict[str, torch.Tensor] = {}

As per coding guidelines: **/*.py: Install linter and run ruff check . before committing changes or the CI will fail.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` at line 312, Remove the unused local
variable assignment actual_nall = ext_coord.shape[1] (it is assigned but never
read); either delete this line or replace its usage where intended—search for
actual_nall or ext_coord in the surrounding function in training.py and, if a
column count is needed, use ext_coord.shape[1] directly or assign to a used
variable name so the value is consumed.
🧹 Nitpick comments (1)
deepmd/pt_expt/train/training.py (1)

221-225: Avoid mutating caller-owned compile_opts in place

pop()/setdefault() currently mutate the dict from training config. A local copy is safer and avoids side effects if the same config is reused.

Suggested change
 def _trace_and_compile(
@@
-    # Override backend and dynamic — the inductor backend with
+    # Work on a local copy to avoid mutating caller-owned config.
+    compile_opts = deepcopy(compile_opts)
+
+    # Override backend and dynamic — the inductor backend with
     # dynamic=True handles varying shapes automatically.
     compile_opts.pop("dynamic", None)
     compile_opts.pop("backend", None)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt_expt/train/training.py` around lines 221 - 225, The code mutates
the caller-owned dict compile_opts using pop() and by adding "options"; instead
create a shallow copy (e.g., local_compile_opts = compile_opts.copy()) and
operate on that copy, then use local_compile_opts.pop("dynamic", None),
local_compile_opts.pop("backend", None) and ensure "options" exists on
local_compile_opts before setting opts = local_compile_opts["options"]; leave
the original compile_opts untouched so callers can reuse it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 312: Remove the unused local variable assignment actual_nall =
ext_coord.shape[1] (it is assigned but never read); either delete this line or
replace its usage where intended—search for actual_nall or ext_coord in the
surrounding function in training.py and, if a column count is needed, use
ext_coord.shape[1] directly or assign to a used variable name so the value is
consumed.

---

Nitpick comments:
In `@deepmd/pt_expt/train/training.py`:
- Around line 221-225: The code mutates the caller-owned dict compile_opts using
pop() and by adding "options"; instead create a shallow copy (e.g.,
local_compile_opts = compile_opts.copy()) and operate on that copy, then use
local_compile_opts.pop("dynamic", None), local_compile_opts.pop("backend", None)
and ensure "options" exists on local_compile_opts before setting opts =
local_compile_opts["options"]; leave the original compile_opts untouched so
callers can reuse it.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 547d2b4a-1e06-47cc-9b70-259db8391305

📥 Commits

Reviewing files that changed from the base of the PR and between a183f95 and 4ebce58.

📒 Files selected for processing (1)
  • deepmd/pt_expt/train/training.py

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 11, 2026

Codecov Report

❌ Patch coverage is 94.54545% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.35%. Comparing base (baab3e8) to head (7670e6e).
⚠️ Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt_expt/train/training.py 92.85% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5393      +/-   ##
==========================================
+ Coverage   80.33%   80.35%   +0.01%     
==========================================
  Files         819      819              
  Lines       85356    85425      +69     
  Branches     4139     4139              
==========================================
+ Hits        68571    68643      +72     
+ Misses      15509    15506       -3     
  Partials     1276     1276              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@wanghan-iapcm wanghan-iapcm force-pushed the feat-pt-expt-compile-dynamic branch from e332640 to d5fcd61 Compare April 12, 2026 07:21
The env_mat_stat code calls descriptor.get_rcut_smth() and
descriptor.get_env_protection() during compute_input_stats, but
DescrptBlockRepformers only had the attributes without accessor
methods, causing AttributeError when training DPA2 models.
@wanghan-iapcm wanghan-iapcm force-pushed the feat-pt-expt-compile-dynamic branch from d5fcd61 to e48497c Compare April 12, 2026 07:46
Han Wang added 3 commits April 12, 2026 15:49
Test get_rcut_smth() and get_env_protection() (default and custom
values) on DescrptBlockRepformers.
Add silut/custom_silu activation to _torch_activation using native torch
ops (torch.sigmoid, torch.tanh, torch.where) for make_fx compatibility.
Add pt_expt unit tests and cross-backend consistency tests covering
multiple thresholds with inputs spanning both silu and tanh branches.
Comment thread source/tests/pt_expt/utils/test_activation.py Fixed
Han Wang added 3 commits April 12, 2026 23:29
tests/pt/__init__.py sets torch.set_default_device("cuda:9999999"),
which causes bare torch.tensor() calls to attempt CUDA init on
CPU-only CI. Use the pt_expt DEVICE (same pattern as the pt tests
use their own DEVICE via to_torch_tensor).
Remove unused `actual_nall` in training.py and unused `threshold` in
test_activation.py.
Remove unused variables (actual_nall, threshold) flagged by CodeQL.
Use trainer._optimizer_step() instead of trainer.optimizer.step() in
the dynamic-shapes test to match the real training path.
@wanghan-iapcm wanghan-iapcm requested review from OutisLi and iProzd April 12, 2026 16:16
…s test

When make_fx traces with nframes=1, the symbolic tracer specialises the
batch dimension to the concrete value 1, causing failures when later
batches have nframes>1 (e.g. with batch_size: "auto" across systems of
different atom counts). Fix by duplicating the trace sample to nframes=2
before tracing, forcing a symbolic int for the batch dimension.

Add TestCompiledVaryingNatoms: compiled and uncompiled training with a
192-atom and a synthetic 6-atom system using batch_size: "auto", with
dp_random.choice mocked to deterministically alternate between systems.
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test for common silu activation training perhaps

fparam: torch.Tensor | None,
aparam: torch.Tensor | None,
) -> dict[str, torch.Tensor]:
extended_coord = extended_coord.detach().requires_grad_(True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not detach the grad inside the function which will be traced, detach to a leaf node before tracing and used as compiled function

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed by #5397

Avoid mutating caller-owned compile_opts dict in _trace_and_compile.
Add compiled training test with silu activation.
@wanghan-iapcm
Copy link
Copy Markdown
Collaborator Author

please add a test for common silu activation training perhaps

solved by 7670e6e

@wanghan-iapcm wanghan-iapcm requested a review from OutisLi April 15, 2026 17:09
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few issues with the current implementation:

1. Indiscriminate detach removal is unsafe

_remove_detach_nodes removes every aten.detach.default node in the graph. Not all of them come from the autograd saved-tensor mechanism — user-explicit .detach() calls (e.g. gradient boundary management, virial computation) also appear as aten.detach.default and are semantically meaningful.

Although empirically removing all of them may not cause incorrect results in the current graph (because make_fx has already materialized the first-order backward, and the user-explicit detach only affects the already-frozen forward ops), this is fragile and does not preserve the original semantics. If the model structure changes in the future, blind removal could silently break user-intended gradient boundaries.

The two categories can be distinguished by graph topology alone, without hard-coding any op names. Autograd-inserted detach always forms a double-detach chain (forward_op → detach_A → detach_B → backward_use), while user-explicit detach is structurally isolated. Three rules suffice:

  • Chain inner: input is another detach node.
  • Dead node: no downstream users.
  • Chain head: all users are detach nodes.

Anything else is user-explicit and should be preserved. Classification must happen on the original graph before any mutation (two-pass: collect, then remove), because removing chain-head nodes changes the input of chain-inner nodes, causing misclassification in a single pass.

Suggested implementation:

def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
    _DETACH = torch.ops.aten.detach.default

    def _is_detach(n: torch.fx.Node) -> bool:
        return n.op == "call_function" and n.target == _DETACH

    # Pass 1: classify on the unmodified graph
    to_remove = []
    for node in gm.graph.nodes:
        if not _is_detach(node):
            continue
        inp = node.args[0]
        users = list(node.users.keys())
        is_chain_inner = _is_detach(inp)
        is_dead = len(users) == 0
        is_chain_head = len(users) > 0 and all(_is_detach(u) for u in users)
        if is_chain_inner or is_dead or is_chain_head:
            to_remove.append(node)

    # Pass 2: remove after classification is complete
    for node in to_remove:
        node.replace_all_uses_with(node.args[0])
        gm.graph.erase_node(node)

    gm.graph.lint()
    gm.recompile()

2. Must guard with training mode — eval-time stripping triggers silu_backward derivative error

The stripping should only be applied when tracing in training mode (create_graph=True). In eval mode (create_graph=False), the saved-tensor detach nodes are correct and harmless. Removing them in eval mode causes inductor to attempt higher-order differentiation of ops like aten.silu_backward, which do not have registered higher-order derivatives in PyTorch. This is a pre-existing PyTorch limitation, not something the PR can fix, so the correct approach is:

traced = make_fx(compute_fn, ...)(*trace_args)

if self.training:
    _strip_saved_tensor_detach(traced)

compiled = torch.compile(traced, ...)

3. A unit test is needed, and it must use random parameter initialization

This bug is nearly invisible under default (zero or near-zero) initialization. At zero init, tanh(0) = 0 and sigmoid(0) = 0.5 produce saved tensors at special points where the second-order gradient contribution vanishes or is masked by numerical noise. In my testing, zero-init showed 0 parameter gradient mismatches; switching to random init immediately exposed ~80% relative error in 5 out of 54 parameters.

The test should:

  • Build paired models (eager vs compiled) with shared random weights
  • Compare force-loss ((force**2).sum()) backward gradients (second-order) across all parameters
  • Use tight tolerance — the differences from this bug are O(1) relative, not floating-point noise

Minimal structure:

def test_compile_second_order_grad_matches_eager(self):
    model_eager = build_model(use_compile=False)
    # Random init is critical — zero init hides the bug
    torch.manual_seed(42)
    with torch.no_grad():
        for p in model_eager.parameters():
            p.copy_(torch.randn_like(p) * 0.1)

    model_compiled = build_model(use_compile=True)
    model_compiled.load_state_dict(model_eager.state_dict())
    model_eager.train(); model_compiled.train()

    # Forward match
    out_e = model_eager(coord, atype, box=box)
    out_c = model_compiled(coord, atype, box=box)
    torch.testing.assert_close(out_e["force"], out_c["force"], ...)

    # Second-order: force-loss backward
    model_eager.zero_grad(); model_compiled.zero_grad()
    (out_e["force"] ** 2).sum().backward()
    (out_c["force"] ** 2).sum().backward()

    for (ne, pe), (_, pc) in zip(
        model_eager.named_parameters(),
        model_compiled.named_parameters(),
    ):
        ge = pe.grad if pe.grad is not None else torch.zeros_like(pe)
        gc = pc.grad if pc.grad is not None else torch.zeros_like(pc)
        torch.testing.assert_close(ge, gc, atol=1e-5, rtol=1e-5,
                                   msg=f"Grad mismatch: {ne}")

@OutisLi
Copy link
Copy Markdown
Collaborator

OutisLi commented Apr 16, 2026

Solution for silu_backward derivative not implemented:

# Decompose silu_backward into primitive ops (sigmoid + mul + ...)
# so that inductor can compile the eval graph without requiring a
# higher-order derivative that PyTorch does not register for silu.
from torch._decomp import get_decompositions

decomp_table = get_decompositions([torch.ops.aten.silu_backward.default])

traced = make_fx(
    compute_fn,
    tracing_mode="symbolic",
    _allow_non_fake_inputs=True,
    decomposition_table=decomp_table,
)(
    trace_coord,
    trace_atype,
    trace_edge_index,
    trace_edge_vec,
    trace_edge_mask,
    trace_fp,
    trace_ap,
)

@wanghan-iapcm wanghan-iapcm marked this pull request as draft April 16, 2026 05:16
wanghan-iapcm pushed a commit to wanghan-iapcm/deepmd-kit that referenced this pull request Apr 16, 2026
…modeling#5393

Add silut/custom_silu support to _torch_activation using native torch
ops (sigmoid, tanh, where) so the custom silu stays traceable by
make_fx / torch.export. Cross-backend consistency tests cover multiple
thresholds across the silu/tanh branches, and a pt_expt unit file
exercises default/custom threshold, gradient flow, make_fx, and
torch.export.

Also port DescrptBlockRepformers accessor tests (get_rcut_smth,
get_env_protection). The underlying accessor methods already exist on
this branch; these tests guard against regressions.
wanghan-iapcm pushed a commit to wanghan-iapcm/deepmd-kit that referenced this pull request Apr 16, 2026
…deling#5393

Adds the remaining tests from PR deepmodeling#5393 that were not yet on this branch:
``test_training_loop_compiled_silu`` (silu activation under torch.compile)
and ``TestCompiledVaryingNatoms`` (compiled training across systems with
different atom counts). Also drops a stray unused ``threshold`` variable
in ``test_silut_below_threshold_is_silu`` to match the upstream PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants