Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def __init__(self, hamiltonian):
self._jaxPotential = None
self.types = []
self.ethresh = 5e-4
self.step_pol = None
self.lpol = False
self.ref_dip = ""

Expand Down Expand Up @@ -1053,6 +1054,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):

if "ethresh" in args:
self.ethresh = args["ethresh"]
if "step_pol" in args:
self.step_pol = args["step_pol"]

pme_force = ADMPPmeForce(
box,
Expand All @@ -1063,7 +1066,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
self.ethresh,
self.lmax,
self.lpol,
lpme=self.lpme
self.lpme,
self.step_pol
)
self.pme_force = pme_force

Expand Down
48 changes: 48 additions & 0 deletions tests/test_admp/test_compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import openmm.app as app
import openmm.unit as unit
import numpy as np
import jax.numpy as jnp
import numpy.testing as npt
import pytest
from dmff import Hamiltonian, NeighborList
from jax import jit, value_and_grad

class TestADMPAPI:

""" Test ADMP related generators
"""

@pytest.fixture(scope='class', name='generators')
def test_init(self):
"""load generators from XML file

Yields:
Tuple: (
ADMPDispForce,
ADMPPmeForce, # polarized
)
"""
rc = 4.0
H = Hamiltonian('tests/data/admp.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)

yield H.getGenerators()

def test_ADMPPmeForce_jit(self, generators):

gen = generators[1]
rc = 4.0
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = jnp.array(pdb.positions._value) * 10
a, b, c = pdb.topology.getPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value]) * 10
# neighbor list
nblist = NeighborList(box, rc)
nblist.allocate(positions)
pairs = nblist.pairs

pot_pme = gen.getJaxPotential()
j_pot_pme = jit(value_and_grad(pot_pme))

E_pme, F_pme = j_pot_pme(positions, box, pairs, gen.params)