Skip to content

fix(jax_engine): correct gradients for MM3 angle term at near-collinear geometries#288

Merged
ericchansen merged 1 commit into
masterfrom
fix/mm3-non-smooth-gradient
May 28, 2026
Merged

fix(jax_engine): correct gradients for MM3 angle term at near-collinear geometries#288
ericchansen merged 1 commit into
masterfrom
fix/mm3-non-smooth-gradient

Conversation

@ericchansen
Copy link
Copy Markdown
Owner

Summary

Fixes the MM3 angle gradient correctness bug from #284 §1. The previously "no improvement available" verdicts for rh-conjugate and heck-relay were wrong — they were spurious stationary points caused by this bug. Post-fix, both systems show real improvements that the optimizer can now find.

Companion data PR: ericchansen/q2mm-data#10

Headline results

System Pre-fix Δ% (q2mm-data#9) Post-fix Δ% Status
ch3f 99.83 % (det.) 99.83 % (det.) unchanged ✅
rh-enamide 44.66 % ± 0.29 % 44.73 % ± 0.29 % unchanged ✅
pd-allyl −0.029 % ± 0.34 % −0.010 % ± 0.40 % still NOT SIG ❌
rh-conjugate −0.080 % ± 1.18 % 18.00 % ± 4.17 % 🚀 NEWLY UNLOCKED
heck-relay −0.59 % ± 3.26 % 52.82 % ± 1.54 % 🚀 NEWLY UNLOCKED

Root cause

The MM3 bond-angle energy used arccos(clip(cos θ, −1+ε, 1−ε)) to compute the angle from atomic positions. When a geometry minimization drives an angle to collinear (cos θ → ±1), the clip's gradient becomes zero in the boundary region — and the autodiff chain ∂E/∂atom = ∂E/∂θ · ∂θ/∂cos · ∂cos/∂atom underestimates the real gradient. The optimizer sees a spurious stationary point with angles wedged at ≈ 180°.

For rh-conjugate at the BFGS-converged geometry, the harmonic-only angle term produced JAX gradient norm 982 vs FD norm 575 (1.7× discrepancy). The discrepancy localized to four angles where cos θ saturated at the clip value −1+1e-7. Bond and torsion terms agreed with FD to ~1e-7.

Fix

Replace _safe_arccos(cos θ) with _angle_from_vectors(a, b):

  • Forward: atan2(|a×b|, a·b) — well-conditioned at collinear
  • Custom VJP: analytic gradient ∂θ/∂a = (cos θ · a/|a|² − b/(|a||b|)) / sin θ, with sin θ floored at 1e-12 to keep direction stable at exact collinearity

The decorator is applied inside an lru_cache(maxsize=1) factory because jax is None at module import time (populated lazily by _ensure_jax()).

Touches three call sites in q2mm/backends/mm/jax_engine.py:

  • _harmonic_angle_energy (OPLSAA harmonic angle)
  • _mm3_angle_energy (MM3 sextic anharmonic angle)
  • _mm3_stretch_bend_energy (stretch-bend cross-term)

Deletes _safe_arccos and _safe_norm_keepdims (no remaining callers; AGENTS.md §2 alpha discipline — no shims).

Tests

  • New regression: test/test_mm3_jax.py::TestMM3AngleEnergy::test_near_collinear_gradient_matches_fd
    Places three atoms at near-antiparallel geometry, checks ∂E/∂atom_k matches FD on the perpendicular axis to 1e-3 relative. Fails on pre-fix (JAX returns 0), passes on this fix.
  • All 41 MM3-jax tests pass
  • All 115 JAX-marked tests pass
  • 680 unit tests pass
  • ruff check + format clean

Doc updates (in same commit)

  • docs/systems/rh-conjugate.md — verdict flipped from "NOT SIGNIFICANT, JaxLoss local min" to "SIGNIFICANT, 18.00 %", explanation of how the gradient bug had been hiding the real descent direction
  • docs/systems/heck-relay.md — verdict flipped from "NOT SIGNIFICANT, --ratio-tol none doesn't help" to "SIGNIFICANT, 52.82 %"; the system no longer needs --ratio-tol none (ratio now 1.085, within default gate)
  • docs/systems/pd-allyl.md — numbers refreshed (still NOT SIGNIFICANT, the FF truly is at a local min here)

Closes

…ar geometries

Closes q2mm#284 §1 (the wrong-gradient bug; the noise side is separately
mitigated by q2mm#286's median+CI reporting).

Root cause

The MM3 angle energy uses arccos(clip(cos θ, −1+ε, 1−ε)) to compute
the bond angle from atomic positions.  When a geometry minimization
drives an angle to collinear (cos θ → ±1), the clip's gradient
becomes zero in the boundary region — but cos θ remains formally
inside the clip range during typical perturbations, so the autodiff
chain `∂E/∂atom = ∂E/∂θ · ∂θ/∂cos · ∂cos/∂atom` underestimates the
real gradient.  The optimizer sees a spurious stationary point and
gets stuck with angles wedged at ≈ 180°.

For rh-conjugate at the BFGS-converged geometry, the harmonic-only
angle term produces JAX gradient norm = 982 vs FD norm = 575 (factor
1.7×); for the full MM3 angle term, 1782 vs 871 (factor 2.0×).  The
discrepancy was localized to four angles where cos θ saturated at
the clip value −1+1e-7.  Bond and torsion terms agreed with FD to
~1e-7.

Fix

Replace `_safe_arccos(cos θ)` with `_angle_from_vectors(a, b)` that
uses `atan2(|a×b|, a·b)` (well-conditioned at collinear in the
forward pass) wrapped in a `jax.custom_vjp` that returns the
analytic gradient `∂θ/∂a = (cos θ · a/|a|² − b/(|a||b|)) / sin θ`
(and symmetric for `b`).  The VJP floors `sin θ` at 1e-12 so the
gradient direction stays finite at exact collinearity (pointing
perpendicular to the bond axis, which is the physically correct
restoring direction).

The decorator is applied inside an `lru_cache(maxsize=1)` factory
so we can apply `@jax.custom_vjp` after `_ensure_jax()` populates
the module-level `jax` symbol (it's `None` at import time).

The fix touches three call sites in q2mm/backends/mm/jax_engine.py:
- `_harmonic_angle_energy` (OPLSAA-style harmonic angle)
- `_mm3_angle_energy` (sextic anharmonic angle)
- `_mm3_stretch_bend_energy` (stretch-bend cross term; kept its
  near-linear smoothstep suppression, but uses atan2 for θ now)

Dead code cleanup (no shims, per AGENTS.md §2):
- Deleted `_safe_arccos` (no remaining callers).
- Deleted `_safe_norm_keepdims` (no remaining callers after the
  three angle call sites switched to bare-vector form).

Validation

New regression test:
- `test/test_mm3_jax.py::TestMM3AngleEnergy::test_near_collinear_gradient_matches_fd`
  Places three atoms at near-antiparallel geometry, checks that
  ∂E/∂atom_k matches FD on the perpendicular axis to 1e-3 relative.
  Fails on the pre-fix implementation (JAX returns 0), passes on
  this fix.

End-to-end on rh-conjugate:
- Pre-fix at BFGS-converged geom: JAX grad norm = 1668, FD = 713
  (ratio 2.34); minimizer cannot escape spurious stationary point.
- Post-fix: optimizer reaches a real minimum (e drops from 166.67
  → 166.12), and the real ObjectiveFunction improves by 18.00 % ±
  4.17 % CI₉₅ (SIGNIFICANT) — previously "−0.080 % ± 1.18 %, NOT
  SIGNIFICANT" in q2mm#287.

End-to-end on heck-relay (--ratio-tol none):
- Pre-fix: ratio 1.378 (gate fails), surrogate breakdown, no real-OF
  improvement (−0.59 % ± 3.26 %).
- Post-fix: ratio drops to 1.085 (would pass default gate), surrogate
  reduction 53 %, real-OF improvement **52.82 % ± 1.54 % CI₉₅**
  (SIGNIFICANT).

Regression checks on previously-verified systems:
- ch3f: 99.83 % (unchanged, deterministic).
- rh-enamide: 44.73 % (was 44.66 %; difference within CI; no
  regression).
- pd-allyl: −0.01 % ± 0.40 % (still NOT SIGNIFICANT — the FF really
  is at a JaxLoss local minimum here, distinct from rh-conjugate
  and heck-relay where the clip bug was preventing the optimizer
  from finding the real descent direction).

Run validations:
  - ruff check + format clean
  - 680 unit tests pass
  - 115 JAX-marked tests pass (incl. the new regression)

Companion data PR (regenerated artifacts for all 5 systems):
ericchansen/q2mm-data#TBD.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a JAX gradient-correctness bug in MM3/harmonic angle energies where arccos(clip(cos θ, …)) zeroed gradients at near-collinear geometries, causing the optimizer to land on spurious stationary points. Replaces the clip-based angle with an atan2-based formulation backed by a custom VJP using analytic ∂θ/∂a expressions, with sin θ floored at 1e-12 to keep direction stable at collinearity. Unlocks real >18% (rh-conjugate) and >52% (heck-relay) ObjectiveFunction reductions previously hidden behind the bug.

Changes:

  • New _angle_from_vectors(a, b) with jax.custom_vjp analytic gradients; applied to harmonic angle, MM3 angle, and (via inlined atan2) MM3 stretch-bend.
  • Deletes superseded _safe_arccos and _safe_norm_keepdims per alpha-discipline.
  • Adds near-collinear FD-vs-JAX regression test; updates rh-conjugate / heck-relay / pd-allyl docs with post-fix verdicts.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
q2mm/backends/mm/jax_engine.py New atan2-based angle helper with custom VJP; replaces clip-arccos at three call sites; removes dead helpers.
test/test_mm3_jax.py Adds test_near_collinear_gradient_matches_fd regression test for the gradient correctness fix.
docs/systems/rh-conjugate.md Flips verdict to SIGNIFICANT (18.00 % ± 4.17 %) with explanation linking to #284.
docs/systems/heck-relay.md Flips verdict to SIGNIFICANT (52.82 % ± 1.54 %); ratio improves 1.378 → 1.085.
docs/systems/pd-allyl.md Refreshes numbers; verdict unchanged (NOT SIGNIFICANT, true local min).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ericchansen ericchansen merged commit dc8983e into master May 28, 2026
12 checks passed
@ericchansen ericchansen deleted the fix/mm3-non-smooth-gradient branch May 28, 2026 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants