From 2690192df36d393346be219e74b02bc4aa385d01 Mon Sep 17 00:00:00 2001 From: WangXinyan940 <44082181+WangXinyan940@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:51:48 +0800 Subject: [PATCH 1/9] Fix mbar example (#76) * Add requirement of mdtraj, optax and pymbar to the doc * Fix benzen demo (#75) * Add Gitee_mirror * Fix mirror CI/CD * Update ben-prm.xml Co-authored-by: Yingze Wang Co-authored-by: Roy-Kid Co-authored-by: KuangYu Co-authored-by: Jichen Li <42854324+Roy-Kid@users.noreply.github.com> Co-authored-by: crone <2223469329@qq.com> Co-authored-by: Yuzhi Zhang <529133328@qq.com> --- docs/user_guide/installation.md | 6 ++++++ examples/mbar/ben-prm.xml | 5 ----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/user_guide/installation.md b/docs/user_guide/installation.md index d4cd73550..19f0f31b7 100644 --- a/docs/user_guide/installation.md +++ b/docs/user_guide/installation.md @@ -13,6 +13,12 @@ pip install jax==0.3.17 ```bash pip install jax-md==0.2.0 ``` ++ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax) and [pymbar](https://github.com/choderalab/pymbar): +```bash +conda install -c conda-forge mdtraj==1.9.7 +pip install optax==0.1.3 +pip install pymbar==4.0.1 +``` + Install [OpenMM](https://openmm.org/): ```bash conda install -c conda-forge openmm==7.7.0 diff --git a/examples/mbar/ben-prm.xml b/examples/mbar/ben-prm.xml index bac99356d..51ca52004 100644 --- a/examples/mbar/ben-prm.xml +++ b/examples/mbar/ben-prm.xml @@ -18,11 +18,6 @@ - - - - - From db60c61562519b784b44801aa12d2511ad36b2b1 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 7 Dec 2022 04:49:22 +0800 Subject: [PATCH 2/9] add minus sign to all "charge penetration" terms --- dmff/generators/admp.py | 45 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index ce7c4edec..8e9d5f216 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -599,26 +599,63 @@ def getMetaData(self): # Here are all the short range "charge penetration" terms -# They all have the exchange form +# They all have the exchange form with minus sign class SlaterSrEsGenerator(SlaterExGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrEsForce" + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + self.covalent_map = build_covalent_map(data, 6) + + self._meta["cov_map"] = self.covalent_map + self._meta[self.name+"_map_atomtype"] = self.map_atomtype + + pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, + static_args={}) + + def potential_fn(positions, box, pairs, params): + params = params[self.name] + mScales = params["mScales"] + a_list = params["A"][map_atomtype] + b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 + + # add minus sign + return - pot_fn_sr(positions, box, pairs, mScales, a_list, b_list) + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def getMetaData(self): + return self._meta + -class SlaterSrPolGenerator(SlaterExGenerator): +class SlaterSrPolGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrPolForce" -class SlaterSrDispGenerator(SlaterExGenerator): +class SlaterSrDispGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrDispForce" -class SlaterDhfGenerator(SlaterExGenerator): +class SlaterDhfGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterDhfForce" From 2e180d1306785117d376557ca65d469425e92399 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Fri, 30 Dec 2022 02:14:38 +0800 Subject: [PATCH 3/9] Adding ADMP Slater-type Decompositions Test --- examples/peg_slater_isa/check_calc.py | 58 +++++++++++ examples/peg_slater_isa/peg.xml | 145 ++++++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100755 examples/peg_slater_isa/check_calc.py create mode 100644 examples/peg_slater_isa/peg.xml diff --git a/examples/peg_slater_isa/check_calc.py b/examples/peg_slater_isa/check_calc.py new file mode 100755 index 000000000..2b68838d8 --- /dev/null +++ b/examples/peg_slater_isa/check_calc.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +import sys +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +from dmff.common import nblist +import pickle +import time +from jax import value_and_grad, jit + +if __name__ == '__main__': + ff = 'peg.xml' + pdb_AB = PDBFile('peg2.pdb') + H_AB = Hamiltonian(ff) + rc = 15 + # get potential functions + pots_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB = pots_AB.dmff_potentials['ADMPPmeForce'] + pot_disp_AB = pots_AB.dmff_potentials['ADMPDispPmeForce'] + pot_ex_AB = pots_AB.dmff_potentials['SlaterExForce'] + pot_sr_es_AB = pots_AB.dmff_potentials['SlaterSrEsForce'] + pot_sr_pol_AB = pots_AB.dmff_potentials['SlaterSrPolForce'] + pot_sr_disp_AB = pots_AB.dmff_potentials['SlaterSrDispForce'] + pot_dhf_AB = pots_AB.dmff_potentials['SlaterDhfForce'] + pot_dmp_es_AB = pots_AB.dmff_potentials['QqTtDampingForce'] + pot_dmp_disp_AB = pots_AB.dmff_potentials['SlaterDampingForce'] + + paramtree = H_AB.getParameters() + + # init positions used to set up neighbor list + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl_AB = nblist.NeighborList(box, rc, H_AB.getGenerators()[0].covalent_map) + nbl_AB.allocate(pos_AB0) + pairs_AB = nbl_AB.pairs + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + + pos_AB = jnp.array(pos_AB0) + E_es = pot_pme_AB(pos_AB, box, pairs_AB, paramtree) + E_disp = pot_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, paramtree) + print(E_es, E_disp, E_ex_AB, E_sr_es, E_sr_pol, E_sr_disp, E_dhf, E_dmp_es, E_dmp_disp) diff --git a/examples/peg_slater_isa/peg.xml b/examples/peg_slater_isa/peg.xml new file mode 100644 index 000000000..25d90db50 --- /dev/null +++ b/examples/peg_slater_isa/peg.xml @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 29e01cc432728ef15395f0bba65625cfc0f50988 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Fri, 30 Dec 2022 02:21:34 +0800 Subject: [PATCH 4/9] Adding slater type function decomposition test --- tests/data/peg.xml | 145 +++++++++++++++++++++++++++++++++ tests/data/peg2.pdb | 35 ++++++++ tests/test_admp/test_slater.py | 79 ++++++++++++++++++ 3 files changed, 259 insertions(+) create mode 100644 tests/data/peg.xml create mode 100644 tests/data/peg2.pdb create mode 100644 tests/test_admp/test_slater.py diff --git a/tests/data/peg.xml b/tests/data/peg.xml new file mode 100644 index 000000000..25d90db50 --- /dev/null +++ b/tests/data/peg.xml @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/peg2.pdb b/tests/data/peg2.pdb new file mode 100644 index 000000000..e44d9f14b --- /dev/null +++ b/tests/data/peg2.pdb @@ -0,0 +1,35 @@ +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C +ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H +ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H +ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O +ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C +ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H +ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H +ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H +ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C +ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H +ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H +ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O +ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C +ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H +ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H +ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H +CONECT 1 2 3 4 9 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 1 10 11 12 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 12 14 15 16 +CONECT 14 13 +CONECT 15 13 +CONECT 16 13 +END diff --git a/tests/test_admp/test_slater.py b/tests/test_admp/test_slater.py new file mode 100644 index 000000000..a239534cf --- /dev/null +++ b/tests/test_admp/test_slater.py @@ -0,0 +1,79 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestADMPSlaterTypeFunction: + + @pytest.mark.parametrize( + "pdb, prm, values", + [ + ( + "tests/data/peg2.pdb", + "tests/data/peg.xml", + # "peg2.pdb", + # "peg.xml", + jnp.array([ + 5.731787, -0.053504415, 4.510135e-09, -4.510135e-09, -4.510135e-09, + -4.510135e-09, -4.510135e-09, -5.505022e-11, 1.0166141e-06 + ]) + ), + ] + ) + def test_admp_slater(self, pdb, prm, values): + pdb_AB = app.PDBFile(pdb) + H_AB = Hamiltonian(prm) + rc = 15 + pots_AB = H_AB.createPotential( + pdb_AB.topology, + nonbondedCutoff=rc*unit.angstrom, + nonbondedMethod=app.CutoffPeriodic, + ethresh=1e-4) + + pot_pme_AB = pots_AB.dmff_potentials['ADMPPmeForce'] + pot_disp_AB = pots_AB.dmff_potentials['ADMPDispPmeForce'] + pot_ex_AB = pots_AB.dmff_potentials['SlaterExForce'] + pot_sr_es_AB = pots_AB.dmff_potentials['SlaterSrEsForce'] + pot_sr_pol_AB = pots_AB.dmff_potentials['SlaterSrPolForce'] + pot_sr_disp_AB = pots_AB.dmff_potentials['SlaterSrDispForce'] + pot_dhf_AB = pots_AB.dmff_potentials['SlaterDhfForce'] + pot_dmp_es_AB = pots_AB.dmff_potentials['QqTtDampingForce'] + pot_dmp_disp_AB = pots_AB.dmff_potentials['SlaterDampingForce'] + + paramtree = H_AB.getParameters() + + # init positions used to set up neighbor list + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl_AB = NeighborList(box, rc, H_AB.getGenerators()[0].covalent_map) + nbl_AB.allocate(pos_AB0) + pairs_AB = nbl_AB.pairs + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + + pos_AB = jnp.array(pos_AB0) + E_es = pot_pme_AB(pos_AB, box, pairs_AB, paramtree) + E_disp = pot_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_ex = pot_ex_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, paramtree) + + npt.assert_almost_equal(E_es, values[0]) + npt.assert_almost_equal(E_disp, values[1]) + npt.assert_almost_equal(E_ex, values[2]) + npt.assert_almost_equal(E_sr_es, values[3]) + npt.assert_almost_equal(E_sr_pol, values[4]) + npt.assert_almost_equal(E_sr_disp, values[5]) + npt.assert_almost_equal(E_dhf, values[6]) + npt.assert_almost_equal(E_dmp_es, values[7]) + npt.assert_almost_equal(E_dmp_disp, values[8]) \ No newline at end of file From e92d18ffb7ec8c73b4838987d9e5eec562412231 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 1 Mar 2023 17:30:49 +0800 Subject: [PATCH 5/9] avoid nan when calculate the potential gradient of box vector --- dmff/admp/recip.py | 8 ++++- tests/test_admp/test_gradDisp.py | 55 ++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/test_admp/test_gradDisp.py diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index 3987f23d8..f00d647ab 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -430,7 +430,7 @@ def spread_Q(positions, box, Q): return jnp.sum(E_k) * DIELECTRIC else: return jnp.sum(E_k) - + if DO_JIT: return jit(pme_recip, static_argnums=()) else: @@ -441,6 +441,8 @@ def Ck_1(ksq, kappa, V): return 2*jnp.pi/V/ksq * jnp.exp(-ksq/4/kappa**2) def Ck_6(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x3 = x2 * x @@ -449,6 +451,8 @@ def Ck_6(ksq, kappa, V): return sqrt_pi*jnp.pi/2/V*kappa**3 * f / 3 def Ck_8(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x4 = x2 * x2 @@ -458,6 +462,8 @@ def Ck_8(ksq, kappa, V): return sqrt_pi*jnp.pi/2/V*kappa**5 * f / 45 def Ck_10(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x4 = x2 * x2 diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py new file mode 100644 index 000000000..22051e4d4 --- /dev/null +++ b/tests/test_admp/test_gradDisp.py @@ -0,0 +1,55 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestGradDispersion: + + @pytest.mark.parametrize( + "pdb, prm, values", + [ + ( + "tests/data/peg2.pdb", + "tests/data/peg.xml", + # "peg2.pdb", + # "peg.xml", + jnp.array( + [[-9.3132257e-10, 4.0017767e-11, 4.6566129e-10], + [ 3.7289283e-11, 1.4551915e-11, 9.0949470e-13], + [ 6.9849193e-10, -1.0913936e-11, -1.1641532e-10]] + ) + ), + ] + ) + def test_admp_slater(self, pdb, prm, values): + pdb = app.PDBFile(pdb) + H = Hamiltonian(prm) + rc = 15 + pots = H.createPotential( + pdb.topology, + nonbondedCutoff=rc*unit.angstrom, + nonbondedMethod=app.CutoffPeriodic, + ethresh=1e-4) + + pot_disp = pots.dmff_potentials['ADMPDispPmeForce'] + + params = H.getParameters() + + # init positions used to set up neighbor list + pos = jnp.array(pdb.positions._value) * 10 + n_atoms = len(pos) + box = jnp.array(pdb.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl = NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl.allocate(pos0) + pairs = nbl.pairs + + calc_disp = value_and_grad(pot_disp,argnums=(0,1)) + E, (F, V) = calc_disp(pos, box, pairs, params) + + npt.assert_almost_equal(V, values) \ No newline at end of file From 360c720aea5187fc9738ca9ed3fc07316c70f17a Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 19 Apr 2023 22:25:58 +0800 Subject: [PATCH 6/9] address the nan problem of gradient of dispersion interaction to box vector --- tests/test_admp/test_gradDisp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py index 22051e4d4..83315ab1b 100644 --- a/tests/test_admp/test_gradDisp.py +++ b/tests/test_admp/test_gradDisp.py @@ -27,7 +27,7 @@ class TestGradDispersion: ) def test_admp_slater(self, pdb, prm, values): pdb = app.PDBFile(pdb) - H = Hamiltonian(prm) + H = Hamiltonian(prm) rc = 15 pots = H.createPotential( pdb.topology, @@ -52,4 +52,4 @@ def test_admp_slater(self, pdb, prm, values): calc_disp = value_and_grad(pot_disp,argnums=(0,1)) E, (F, V) = calc_disp(pos, box, pairs, params) - npt.assert_almost_equal(V, values) \ No newline at end of file + npt.assert_almost_equal(V, values) From 40ac3ecc0ffd8d826f6a43a6b0ed861f162793ab Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 19 Apr 2023 22:39:03 +0800 Subject: [PATCH 7/9] modification of test_gradDisp.py --- tests/test_admp/test_gradDisp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py index 83315ab1b..41b484422 100644 --- a/tests/test_admp/test_gradDisp.py +++ b/tests/test_admp/test_gradDisp.py @@ -46,7 +46,7 @@ def test_admp_slater(self, pdb, prm, values): # nn list initial allocation nbl = NeighborList(box, rc, H.getGenerators()[0].covalent_map) - nbl.allocate(pos0) + nbl.allocate(pos) pairs = nbl.pairs calc_disp = value_and_grad(pot_disp,argnums=(0,1)) From 88261c66480e4fba3d2f37092d809029aeaf3d01 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 19 Apr 2023 22:50:32 +0800 Subject: [PATCH 8/9] test file error correction --- tests/test_admp/test_gradDisp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py index 41b484422..2f921a1ae 100644 --- a/tests/test_admp/test_gradDisp.py +++ b/tests/test_admp/test_gradDisp.py @@ -52,4 +52,4 @@ def test_admp_slater(self, pdb, prm, values): calc_disp = value_and_grad(pot_disp,argnums=(0,1)) E, (F, V) = calc_disp(pos, box, pairs, params) - npt.assert_almost_equal(V, values) + npt.assert_almost_equal(V.all(), values.all()) From b7e1b599c47c1281614eaa550afa07f3714c44a4 Mon Sep 17 00:00:00 2001 From: JeremyDream Date: Wed, 19 Apr 2023 23:00:34 +0800 Subject: [PATCH 9/9] test nan --- tests/test_admp/test_gradDisp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py index 2f921a1ae..157b7a3db 100644 --- a/tests/test_admp/test_gradDisp.py +++ b/tests/test_admp/test_gradDisp.py @@ -6,6 +6,9 @@ import pytest from dmff import Hamiltonian, NeighborList from jax import jit, value_and_grad +from jax.config import config + +config.update("jax_debug_nans", True) class TestGradDispersion: @@ -51,5 +54,3 @@ def test_admp_slater(self, pdb, prm, values): calc_disp = value_and_grad(pot_disp,argnums=(0,1)) E, (F, V) = calc_disp(pos, box, pairs, params) - - npt.assert_almost_equal(V.all(), values.all())