fix(jax_engine): correct gradients for MM3 angle term at near-collinear geometries#288
Merged
Merged
Conversation
…ar geometries Closes q2mm#284 §1 (the wrong-gradient bug; the noise side is separately mitigated by q2mm#286's median+CI reporting). Root cause The MM3 angle energy uses arccos(clip(cos θ, −1+ε, 1−ε)) to compute the bond angle from atomic positions. When a geometry minimization drives an angle to collinear (cos θ → ±1), the clip's gradient becomes zero in the boundary region — but cos θ remains formally inside the clip range during typical perturbations, so the autodiff chain `∂E/∂atom = ∂E/∂θ · ∂θ/∂cos · ∂cos/∂atom` underestimates the real gradient. The optimizer sees a spurious stationary point and gets stuck with angles wedged at ≈ 180°. For rh-conjugate at the BFGS-converged geometry, the harmonic-only angle term produces JAX gradient norm = 982 vs FD norm = 575 (factor 1.7×); for the full MM3 angle term, 1782 vs 871 (factor 2.0×). The discrepancy was localized to four angles where cos θ saturated at the clip value −1+1e-7. Bond and torsion terms agreed with FD to ~1e-7. Fix Replace `_safe_arccos(cos θ)` with `_angle_from_vectors(a, b)` that uses `atan2(|a×b|, a·b)` (well-conditioned at collinear in the forward pass) wrapped in a `jax.custom_vjp` that returns the analytic gradient `∂θ/∂a = (cos θ · a/|a|² − b/(|a||b|)) / sin θ` (and symmetric for `b`). The VJP floors `sin θ` at 1e-12 so the gradient direction stays finite at exact collinearity (pointing perpendicular to the bond axis, which is the physically correct restoring direction). The decorator is applied inside an `lru_cache(maxsize=1)` factory so we can apply `@jax.custom_vjp` after `_ensure_jax()` populates the module-level `jax` symbol (it's `None` at import time). The fix touches three call sites in q2mm/backends/mm/jax_engine.py: - `_harmonic_angle_energy` (OPLSAA-style harmonic angle) - `_mm3_angle_energy` (sextic anharmonic angle) - `_mm3_stretch_bend_energy` (stretch-bend cross term; kept its near-linear smoothstep suppression, but uses atan2 for θ now) Dead code cleanup (no shims, per AGENTS.md §2): - Deleted `_safe_arccos` (no remaining callers). - Deleted `_safe_norm_keepdims` (no remaining callers after the three angle call sites switched to bare-vector form). Validation New regression test: - `test/test_mm3_jax.py::TestMM3AngleEnergy::test_near_collinear_gradient_matches_fd` Places three atoms at near-antiparallel geometry, checks that ∂E/∂atom_k matches FD on the perpendicular axis to 1e-3 relative. Fails on the pre-fix implementation (JAX returns 0), passes on this fix. End-to-end on rh-conjugate: - Pre-fix at BFGS-converged geom: JAX grad norm = 1668, FD = 713 (ratio 2.34); minimizer cannot escape spurious stationary point. - Post-fix: optimizer reaches a real minimum (e drops from 166.67 → 166.12), and the real ObjectiveFunction improves by 18.00 % ± 4.17 % CI₉₅ (SIGNIFICANT) — previously "−0.080 % ± 1.18 %, NOT SIGNIFICANT" in q2mm#287. End-to-end on heck-relay (--ratio-tol none): - Pre-fix: ratio 1.378 (gate fails), surrogate breakdown, no real-OF improvement (−0.59 % ± 3.26 %). - Post-fix: ratio drops to 1.085 (would pass default gate), surrogate reduction 53 %, real-OF improvement **52.82 % ± 1.54 % CI₉₅** (SIGNIFICANT). Regression checks on previously-verified systems: - ch3f: 99.83 % (unchanged, deterministic). - rh-enamide: 44.73 % (was 44.66 %; difference within CI; no regression). - pd-allyl: −0.01 % ± 0.40 % (still NOT SIGNIFICANT — the FF really is at a JaxLoss local minimum here, distinct from rh-conjugate and heck-relay where the clip bug was preventing the optimizer from finding the real descent direction). Run validations: - ruff check + format clean - 680 unit tests pass - 115 JAX-marked tests pass (incl. the new regression) Companion data PR (regenerated artifacts for all 5 systems): ericchansen/q2mm-data#TBD. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Fixes a JAX gradient-correctness bug in MM3/harmonic angle energies where arccos(clip(cos θ, …)) zeroed gradients at near-collinear geometries, causing the optimizer to land on spurious stationary points. Replaces the clip-based angle with an atan2-based formulation backed by a custom VJP using analytic ∂θ/∂a expressions, with sin θ floored at 1e-12 to keep direction stable at collinearity. Unlocks real >18% (rh-conjugate) and >52% (heck-relay) ObjectiveFunction reductions previously hidden behind the bug.
Changes:
- New
_angle_from_vectors(a, b)withjax.custom_vjpanalytic gradients; applied to harmonic angle, MM3 angle, and (via inlinedatan2) MM3 stretch-bend. - Deletes superseded
_safe_arccosand_safe_norm_keepdimsper alpha-discipline. - Adds near-collinear FD-vs-JAX regression test; updates rh-conjugate / heck-relay / pd-allyl docs with post-fix verdicts.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
q2mm/backends/mm/jax_engine.py |
New atan2-based angle helper with custom VJP; replaces clip-arccos at three call sites; removes dead helpers. |
test/test_mm3_jax.py |
Adds test_near_collinear_gradient_matches_fd regression test for the gradient correctness fix. |
docs/systems/rh-conjugate.md |
Flips verdict to SIGNIFICANT (18.00 % ± 4.17 %) with explanation linking to #284. |
docs/systems/heck-relay.md |
Flips verdict to SIGNIFICANT (52.82 % ± 1.54 %); ratio improves 1.378 → 1.085. |
docs/systems/pd-allyl.md |
Refreshes numbers; verdict unchanged (NOT SIGNIFICANT, true local min). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the MM3 angle gradient correctness bug from #284 §1. The previously "no improvement available" verdicts for rh-conjugate and heck-relay were wrong — they were spurious stationary points caused by this bug. Post-fix, both systems show real improvements that the optimizer can now find.
Companion data PR: ericchansen/q2mm-data#10
Headline results
Root cause
The MM3 bond-angle energy used
arccos(clip(cos θ, −1+ε, 1−ε))to compute the angle from atomic positions. When a geometry minimization drives an angle to collinear (cos θ → ±1), the clip's gradient becomes zero in the boundary region — and the autodiff chain∂E/∂atom = ∂E/∂θ · ∂θ/∂cos · ∂cos/∂atomunderestimates the real gradient. The optimizer sees a spurious stationary point with angles wedged at ≈ 180°.For rh-conjugate at the BFGS-converged geometry, the harmonic-only angle term produced JAX gradient norm 982 vs FD norm 575 (1.7× discrepancy). The discrepancy localized to four angles where cos θ saturated at the clip value −1+1e-7. Bond and torsion terms agreed with FD to ~1e-7.
Fix
Replace
_safe_arccos(cos θ)with_angle_from_vectors(a, b):atan2(|a×b|, a·b)— well-conditioned at collinear∂θ/∂a = (cos θ · a/|a|² − b/(|a||b|)) / sin θ, withsin θfloored at 1e-12 to keep direction stable at exact collinearityThe decorator is applied inside an
lru_cache(maxsize=1)factory becausejaxisNoneat module import time (populated lazily by_ensure_jax()).Touches three call sites in
q2mm/backends/mm/jax_engine.py:_harmonic_angle_energy(OPLSAA harmonic angle)_mm3_angle_energy(MM3 sextic anharmonic angle)_mm3_stretch_bend_energy(stretch-bend cross-term)Deletes
_safe_arccosand_safe_norm_keepdims(no remaining callers; AGENTS.md §2 alpha discipline — no shims).Tests
test/test_mm3_jax.py::TestMM3AngleEnergy::test_near_collinear_gradient_matches_fdPlaces three atoms at near-antiparallel geometry, checks ∂E/∂atom_k matches FD on the perpendicular axis to 1e-3 relative. Fails on pre-fix (JAX returns 0), passes on this fix.
Doc updates (in same commit)
docs/systems/rh-conjugate.md— verdict flipped from "NOT SIGNIFICANT, JaxLoss local min" to "SIGNIFICANT, 18.00 %", explanation of how the gradient bug had been hiding the real descent directiondocs/systems/heck-relay.md— verdict flipped from "NOT SIGNIFICANT, --ratio-tol none doesn't help" to "SIGNIFICANT, 52.82 %"; the system no longer needs --ratio-tol none (ratio now 1.085, within default gate)docs/systems/pd-allyl.md— numbers refreshed (still NOT SIGNIFICANT, the FF truly is at a local min here)Closes