diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index 3987f23d8..f00d647ab 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -430,7 +430,7 @@ def spread_Q(positions, box, Q): return jnp.sum(E_k) * DIELECTRIC else: return jnp.sum(E_k) - + if DO_JIT: return jit(pme_recip, static_argnums=()) else: @@ -441,6 +441,8 @@ def Ck_1(ksq, kappa, V): return 2*jnp.pi/V/ksq * jnp.exp(-ksq/4/kappa**2) def Ck_6(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x3 = x2 * x @@ -449,6 +451,8 @@ def Ck_6(ksq, kappa, V): return sqrt_pi*jnp.pi/2/V*kappa**3 * f / 3 def Ck_8(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x4 = x2 * x2 @@ -458,6 +462,8 @@ def Ck_8(ksq, kappa, V): return sqrt_pi*jnp.pi/2/V*kappa**5 * f / 45 def Ck_10(ksq, kappa, V): + thresh = 1e-16 + ksq = jnp.piecewise(ksq, [ksq=thresh], [lambda x: jnp.array(thresh), lambda x: x]) x2 = ksq / 4 / kappa**2 x = jnp.sqrt(x2) x4 = x2 * x2 diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index ce7c4edec..8e9d5f216 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -599,26 +599,63 @@ def getMetaData(self): # Here are all the short range "charge penetration" terms -# They all have the exchange form +# They all have the exchange form with minus sign class SlaterSrEsGenerator(SlaterExGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrEsForce" + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.atomTypes == atype)[0][0] + self.map_atomtype = map_atomtype + # build covalent map + self.covalent_map = build_covalent_map(data, 6) + + self._meta["cov_map"] = self.covalent_map + self._meta[self.name+"_map_atomtype"] = self.map_atomtype + + pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, + static_args={}) + + def potential_fn(positions, box, pairs, params): + params = params[self.name] + mScales = params["mScales"] + a_list = params["A"][map_atomtype] + b_list = params["B"][map_atomtype] / 10 # nm^-1 to A^-1 + + # add minus sign + return - pot_fn_sr(positions, box, pairs, mScales, a_list, b_list) + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def getMetaData(self): + return self._meta + -class SlaterSrPolGenerator(SlaterExGenerator): +class SlaterSrPolGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrPolForce" -class SlaterSrDispGenerator(SlaterExGenerator): +class SlaterSrDispGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterSrDispForce" -class SlaterDhfGenerator(SlaterExGenerator): +class SlaterDhfGenerator(SlaterSrEsGenerator): def __init__(self, ff): super().__init__(ff) self.name = "SlaterDhfForce" diff --git a/docs/user_guide/installation.md b/docs/user_guide/installation.md index d4cd73550..19f0f31b7 100644 --- a/docs/user_guide/installation.md +++ b/docs/user_guide/installation.md @@ -13,6 +13,12 @@ pip install jax==0.3.17 ```bash pip install jax-md==0.2.0 ``` ++ Install [mdtraj](https://github.com/mdtraj/mdtraj), [optax](https://github.com/deepmind/optax) and [pymbar](https://github.com/choderalab/pymbar): +```bash +conda install -c conda-forge mdtraj==1.9.7 +pip install optax==0.1.3 +pip install pymbar==4.0.1 +``` + Install [OpenMM](https://openmm.org/): ```bash conda install -c conda-forge openmm==7.7.0 diff --git a/examples/mbar/ben-prm.xml b/examples/mbar/ben-prm.xml index bac99356d..51ca52004 100644 --- a/examples/mbar/ben-prm.xml +++ b/examples/mbar/ben-prm.xml @@ -18,11 +18,6 @@ - - - - - diff --git a/examples/peg_slater_isa/check_calc.py b/examples/peg_slater_isa/check_calc.py new file mode 100755 index 000000000..2b68838d8 --- /dev/null +++ b/examples/peg_slater_isa/check_calc.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +import sys +import numpy as np +import openmm +from openmm import * +from openmm.app import * +from openmm.unit import * +import jax +import jax_md +import jax.numpy as jnp +import dmff +from dmff.api import Hamiltonian +from dmff.common import nblist +import pickle +import time +from jax import value_and_grad, jit + +if __name__ == '__main__': + ff = 'peg.xml' + pdb_AB = PDBFile('peg2.pdb') + H_AB = Hamiltonian(ff) + rc = 15 + # get potential functions + pots_AB = H_AB.createPotential(pdb_AB.topology, nonbondedCutoff=rc*angstrom, nonbondedMethod=CutoffPeriodic, ethresh=1e-4) + pot_pme_AB = pots_AB.dmff_potentials['ADMPPmeForce'] + pot_disp_AB = pots_AB.dmff_potentials['ADMPDispPmeForce'] + pot_ex_AB = pots_AB.dmff_potentials['SlaterExForce'] + pot_sr_es_AB = pots_AB.dmff_potentials['SlaterSrEsForce'] + pot_sr_pol_AB = pots_AB.dmff_potentials['SlaterSrPolForce'] + pot_sr_disp_AB = pots_AB.dmff_potentials['SlaterSrDispForce'] + pot_dhf_AB = pots_AB.dmff_potentials['SlaterDhfForce'] + pot_dmp_es_AB = pots_AB.dmff_potentials['QqTtDampingForce'] + pot_dmp_disp_AB = pots_AB.dmff_potentials['SlaterDampingForce'] + + paramtree = H_AB.getParameters() + + # init positions used to set up neighbor list + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl_AB = nblist.NeighborList(box, rc, H_AB.getGenerators()[0].covalent_map) + nbl_AB.allocate(pos_AB0) + pairs_AB = nbl_AB.pairs + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + + pos_AB = jnp.array(pos_AB0) + E_es = pot_pme_AB(pos_AB, box, pairs_AB, paramtree) + E_disp = pot_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_ex_AB = pot_ex_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, paramtree) + print(E_es, E_disp, E_ex_AB, E_sr_es, E_sr_pol, E_sr_disp, E_dhf, E_dmp_es, E_dmp_disp) diff --git a/examples/peg_slater_isa/peg.xml b/examples/peg_slater_isa/peg.xml new file mode 100644 index 000000000..25d90db50 --- /dev/null +++ b/examples/peg_slater_isa/peg.xml @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/peg.xml b/tests/data/peg.xml new file mode 100644 index 000000000..25d90db50 --- /dev/null +++ b/tests/data/peg.xml @@ -0,0 +1,145 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/data/peg2.pdb b/tests/data/peg2.pdb new file mode 100644 index 000000000..e44d9f14b --- /dev/null +++ b/tests/data/peg2.pdb @@ -0,0 +1,35 @@ +TITLE MDANALYSIS FRAME 0: Created by PDBWriter +CRYST1 30.000 30.000 30.000 90.00 90.00 90.00 P 1 1 +ATOM 1 C00 TER M 1 14.058 13.500 16.329 0.00 1.00 SYST C +ATOM 2 H01 TER M 1 14.410 14.387 16.870 0.00 1.00 SYST H +ATOM 3 H02 TER M 1 14.412 12.614 16.871 0.00 1.00 SYST H +ATOM 4 O03 TER M 1 14.578 13.500 15.000 0.00 1.00 SYST O +ATOM 5 C04 TER M 1 16.000 13.500 15.000 0.00 1.00 SYST C +ATOM 6 H05 TER M 1 16.344 13.499 13.962 0.00 1.00 SYST H +ATOM 7 H06 TER M 1 16.382 12.602 15.496 0.00 1.00 SYST H +ATOM 8 H07 TER M 1 16.382 14.399 15.493 0.00 1.00 SYST H +ATOM 9 C00 TER M 2 12.535 13.498 16.276 0.00 1.00 SYST C +ATOM 10 H01 TER M 2 12.184 12.612 15.734 0.00 1.00 SYST H +ATOM 11 H02 TER M 2 12.182 14.385 15.735 0.00 1.00 SYST H +ATOM 12 O03 TER M 2 12.015 13.496 17.605 0.00 1.00 SYST O +ATOM 13 C04 TER M 2 10.593 13.493 17.605 0.00 1.00 SYST C +ATOM 14 H05 TER M 2 10.250 13.491 18.643 0.00 1.00 SYST H +ATOM 15 H06 TER M 2 10.213 12.595 17.109 0.00 1.00 SYST H +ATOM 16 H07 TER M 2 10.209 14.391 17.112 0.00 1.00 SYST H +CONECT 1 2 3 4 9 +CONECT 2 1 +CONECT 3 1 +CONECT 4 1 5 +CONECT 5 4 6 7 8 +CONECT 6 5 +CONECT 7 5 +CONECT 8 5 +CONECT 9 1 10 11 12 +CONECT 10 9 +CONECT 11 9 +CONECT 12 9 13 +CONECT 13 12 14 15 16 +CONECT 14 13 +CONECT 15 13 +CONECT 16 13 +END diff --git a/tests/test_admp/test_gradDisp.py b/tests/test_admp/test_gradDisp.py new file mode 100644 index 000000000..157b7a3db --- /dev/null +++ b/tests/test_admp/test_gradDisp.py @@ -0,0 +1,56 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad +from jax.config import config + +config.update("jax_debug_nans", True) + +class TestGradDispersion: + + @pytest.mark.parametrize( + "pdb, prm, values", + [ + ( + "tests/data/peg2.pdb", + "tests/data/peg.xml", + # "peg2.pdb", + # "peg.xml", + jnp.array( + [[-9.3132257e-10, 4.0017767e-11, 4.6566129e-10], + [ 3.7289283e-11, 1.4551915e-11, 9.0949470e-13], + [ 6.9849193e-10, -1.0913936e-11, -1.1641532e-10]] + ) + ), + ] + ) + def test_admp_slater(self, pdb, prm, values): + pdb = app.PDBFile(pdb) + H = Hamiltonian(prm) + rc = 15 + pots = H.createPotential( + pdb.topology, + nonbondedCutoff=rc*unit.angstrom, + nonbondedMethod=app.CutoffPeriodic, + ethresh=1e-4) + + pot_disp = pots.dmff_potentials['ADMPDispPmeForce'] + + params = H.getParameters() + + # init positions used to set up neighbor list + pos = jnp.array(pdb.positions._value) * 10 + n_atoms = len(pos) + box = jnp.array(pdb.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl = NeighborList(box, rc, H.getGenerators()[0].covalent_map) + nbl.allocate(pos) + pairs = nbl.pairs + + calc_disp = value_and_grad(pot_disp,argnums=(0,1)) + E, (F, V) = calc_disp(pos, box, pairs, params) diff --git a/tests/test_admp/test_slater.py b/tests/test_admp/test_slater.py new file mode 100644 index 000000000..a239534cf --- /dev/null +++ b/tests/test_admp/test_slater.py @@ -0,0 +1,79 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestADMPSlaterTypeFunction: + + @pytest.mark.parametrize( + "pdb, prm, values", + [ + ( + "tests/data/peg2.pdb", + "tests/data/peg.xml", + # "peg2.pdb", + # "peg.xml", + jnp.array([ + 5.731787, -0.053504415, 4.510135e-09, -4.510135e-09, -4.510135e-09, + -4.510135e-09, -4.510135e-09, -5.505022e-11, 1.0166141e-06 + ]) + ), + ] + ) + def test_admp_slater(self, pdb, prm, values): + pdb_AB = app.PDBFile(pdb) + H_AB = Hamiltonian(prm) + rc = 15 + pots_AB = H_AB.createPotential( + pdb_AB.topology, + nonbondedCutoff=rc*unit.angstrom, + nonbondedMethod=app.CutoffPeriodic, + ethresh=1e-4) + + pot_pme_AB = pots_AB.dmff_potentials['ADMPPmeForce'] + pot_disp_AB = pots_AB.dmff_potentials['ADMPDispPmeForce'] + pot_ex_AB = pots_AB.dmff_potentials['SlaterExForce'] + pot_sr_es_AB = pots_AB.dmff_potentials['SlaterSrEsForce'] + pot_sr_pol_AB = pots_AB.dmff_potentials['SlaterSrPolForce'] + pot_sr_disp_AB = pots_AB.dmff_potentials['SlaterSrDispForce'] + pot_dhf_AB = pots_AB.dmff_potentials['SlaterDhfForce'] + pot_dmp_es_AB = pots_AB.dmff_potentials['QqTtDampingForce'] + pot_dmp_disp_AB = pots_AB.dmff_potentials['SlaterDampingForce'] + + paramtree = H_AB.getParameters() + + # init positions used to set up neighbor list + pos_AB0 = jnp.array(pdb_AB.positions._value) * 10 + n_atoms = len(pos_AB0) + box = jnp.array(pdb_AB.topology.getPeriodicBoxVectors()._value) * 10 + + # nn list initial allocation + nbl_AB = NeighborList(box, rc, H_AB.getGenerators()[0].covalent_map) + nbl_AB.allocate(pos_AB0) + pairs_AB = nbl_AB.pairs + pairs_AB = pairs_AB[pairs_AB[:, 0] < pairs_AB[:, 1]] + + pos_AB = jnp.array(pos_AB0) + E_es = pot_pme_AB(pos_AB, box, pairs_AB, paramtree) + E_disp = pot_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_ex = pot_ex_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_es = pot_sr_es_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_pol = pot_sr_pol_AB(pos_AB, box, pairs_AB, paramtree) + E_sr_disp = pot_sr_disp_AB(pos_AB, box, pairs_AB, paramtree) + E_dhf = pot_dhf_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_es = pot_dmp_es_AB(pos_AB, box, pairs_AB, paramtree) + E_dmp_disp = pot_dmp_disp_AB(pos_AB, box, pairs_AB, paramtree) + + npt.assert_almost_equal(E_es, values[0]) + npt.assert_almost_equal(E_disp, values[1]) + npt.assert_almost_equal(E_ex, values[2]) + npt.assert_almost_equal(E_sr_es, values[3]) + npt.assert_almost_equal(E_sr_pol, values[4]) + npt.assert_almost_equal(E_sr_disp, values[5]) + npt.assert_almost_equal(E_dhf, values[6]) + npt.assert_almost_equal(E_dmp_es, values[7]) + npt.assert_almost_equal(E_dmp_disp, values[8]) \ No newline at end of file