diff --git a/dmff/api.py b/dmff/api.py index c205422c2..3219c8f70 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -667,6 +667,7 @@ def __init__(self, hamiltonian): self._jaxPotential = None self.types = [] self.ethresh = 5e-4 + self.step_pol = None self.lpol = False self.ref_dip = "" @@ -1053,6 +1054,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): if "ethresh" in args: self.ethresh = args["ethresh"] + if "step_pol" in args: + self.step_pol = args["step_pol"] pme_force = ADMPPmeForce( box, @@ -1063,7 +1066,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): self.ethresh, self.lmax, self.lpol, - lpme=self.lpme + self.lpme, + self.step_pol ) self.pme_force = pme_force diff --git a/tests/test_admp/test_compute.py b/tests/test_admp/test_compute.py new file mode 100644 index 000000000..e6ea1f3bb --- /dev/null +++ b/tests/test_admp/test_compute.py @@ -0,0 +1,48 @@ +import openmm.app as app +import openmm.unit as unit +import numpy as np +import jax.numpy as jnp +import numpy.testing as npt +import pytest +from dmff import Hamiltonian, NeighborList +from jax import jit, value_and_grad + +class TestADMPAPI: + + """ Test ADMP related generators + """ + + @pytest.fixture(scope='class', name='generators') + def test_init(self): + """load generators from XML file + + Yields: + Tuple: ( + ADMPDispForce, + ADMPPmeForce, # polarized + ) + """ + rc = 4.0 + H = Hamiltonian('tests/data/admp.xml') + pdb = app.PDBFile('tests/data/water_dimer.pdb') + H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5) + + yield H.getGenerators() + + def test_ADMPPmeForce_jit(self, generators): + + gen = generators[1] + rc = 4.0 + pdb = app.PDBFile('tests/data/water_dimer.pdb') + positions = jnp.array(pdb.positions._value) * 10 + a, b, c = pdb.topology.getPeriodicBoxVectors() + box = jnp.array([a._value, b._value, c._value]) * 10 + # neighbor list + nblist = NeighborList(box, rc) + nblist.allocate(positions) + pairs = nblist.pairs + + pot_pme = gen.getJaxPotential() + j_pot_pme = jit(value_and_grad(pot_pme)) + + E_pme, F_pme = j_pot_pme(positions, box, pairs, gen.params) \ No newline at end of file