Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
LennardJonesLongRangeFreeEnergyForce,
CoulombPMEFreeEnergyForce
)
from dmff.utils import jit_condition, isinstance_jnp
from dmff.utils import jit_condition, isinstance_jnp, DMFFException


class XMLNodeInfo:
Expand Down Expand Up @@ -152,6 +152,7 @@ def __init__(self, hamiltonian):
self.types = []
self.ethresh = 5e-4
self.pmax = 10
self.name = "ADMPDisp"

def registerAtomType(self, atom):
self.types.append(atom["type"])
Expand Down Expand Up @@ -281,6 +282,7 @@ def __init__(self, hamiltonian):
self.types = []
self.ethresh = 5e-4
self.pmax = 10
self.name = "ADMPDispPme"

def registerAtomType(self, atom):
self.types.append(atom["type"])
Expand Down Expand Up @@ -382,6 +384,7 @@ def __init__(self, hamiltonian):
}
self._jaxPotential = None
self.types = []
self.name = "QqTtDamping"

def registerAtomType(self, atom):
self.types.append(atom["type"])
Expand Down Expand Up @@ -462,6 +465,7 @@ def __init__(self, hamiltonian):
}
self._jaxPotential = None
self.types = []
self.name = "SlaterDamping"

def registerAtomType(self, atom):
self.types.append(atom["type"])
Expand Down Expand Up @@ -541,6 +545,7 @@ def __init__(self, hamiltonian):
}
self._jaxPotential = None
self.types = []
self.name = "SlaterEx"

def registerAtomType(self, atom):
self.types.append(atom["type"])
Expand Down Expand Up @@ -606,15 +611,19 @@ def renderXML(self):
class SlaterSrEsGenerator(SlaterExGenerator):
def __init__(self):
super().__init__(self)
self.name = "SlaterSrEs"
class SlaterSrPolGenerator(SlaterExGenerator):
def __init__(self):
super().__init__(self)
self.name = "SlaterSrPol"
class SlaterSrDispGenerator(SlaterExGenerator):
def __init__(self):
super().__init__(self)
self.name = "SlaterSrDisp"
class SlaterDhfGenerator(SlaterExGenerator):
def __init__(self):
super().__init__(self)
self.name = "SlaterDhf"

# register all parsers
app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement
Expand Down Expand Up @@ -670,6 +679,7 @@ def __init__(self, hamiltonian):
self.step_pol = None
self.lpol = False
self.ref_dip = ""
self.name = "ADMPPme"

def registerAtomType(self, atom: dict):

Expand Down Expand Up @@ -1149,6 +1159,7 @@ def __init__(self, hamiltonian):
self._jaxPotential = None
self.types = []
self.typetexts = []
self.name = "HarmonicBond"

def registerBondType(self, bond):
typetxt = findAtomTypeTexts(bond, 2)
Expand Down Expand Up @@ -1247,6 +1258,7 @@ def __init__(self, hamiltonian):
self.params = {"k": [], "angle": []}
self._jaxPotential = None
self.types = []
self.name = "HarmonicAngle"

def registerAngleType(self, angle):
types = self.ff._findAtomTypes(angle, 3)
Expand Down Expand Up @@ -1277,13 +1289,14 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
self.params[k] = jnp.array(self.params[k])
self.types = np.array(self.types)

n_angles = len(data.angles)
max_angles = len(data.angles)
n_angles = 0
# build map
map_atom1 = np.zeros(n_angles, dtype=int)
map_atom2 = np.zeros(n_angles, dtype=int)
map_atom3 = np.zeros(n_angles, dtype=int)
map_param = np.zeros(n_angles, dtype=int)
for i in range(n_angles):
map_atom1 = np.zeros(max_angles, dtype=int)
map_atom2 = np.zeros(max_angles, dtype=int)
map_atom3 = np.zeros(max_angles, dtype=int)
map_param = np.zeros(max_angles, dtype=int)
for i in range(max_angles):
idx1 = data.angles[i][0]
idx2 = data.angles[i][1]
idx3 = data.angles[i][2]
Expand All @@ -1296,17 +1309,23 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
if (type1 in self.types[ii][0] and type3 in self.types[ii][2]) or (
type1 in self.types[ii][2] and type3 in self.types[ii][0]
):
map_atom1[i] = idx1
map_atom2[i] = idx2
map_atom3[i] = idx3
map_param[i] = ii
map_atom1[n_angles] = idx1
map_atom2[n_angles] = idx2
map_atom3[n_angles] = idx3
map_param[n_angles] = ii
ifFound = True
n_angles += 1
break
if not ifFound:
raise BaseException(
print(
"No parameter for angle %i - %i - %i" % (idx1, idx2, idx3)
)

map_atom1 = map_atom1[:n_angles]
map_atom2 = map_atom2[:n_angles]
map_atom3 = map_atom3[:n_angles]
map_param = map_param[:n_angles]

aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3, map_param)

def potential_fn(positions, box, pairs, params):
Expand Down Expand Up @@ -1513,6 +1532,7 @@ def __init__(self, hamiltonian):
self.propersForAtomType = defaultdict(set)
self.n_proper = 0
self.n_improper = 0
self.name = "PeriodicTorsion"

def registerProperTorsion(self, parameters):
torsion = _parseTorsion(self.ff, parameters)
Expand Down Expand Up @@ -1955,6 +1975,7 @@ def __init__(self, hamiltionian, coulomb14scale, lj14scale):
}
self.types = []
self.useAttributeFromResidue = []
self.name = "Nonbond"


def registerAtom(self, atom):
Expand Down Expand Up @@ -2397,3 +2418,25 @@ def render(self, filename):

tree = ET.ElementTree(root)
tree.write(filename)

def getPotentialFunc(self):
if len(self._potentials) == 0:
raise DMFFException("Hamiltonian need to be initialized.")
efuncs = {}
for gen in self.getGenerators():
efuncs[gen.name] = gen._jaxPotential

def totalPE(positions, box, pairs, params):
totale = sum([
efuncs[k](positions, box, pairs, params[k])
for k in efuncs.keys()
])
return totale

return totalPE

def getParameters(self):
params = {}
for gen in self.getGenerators():
params[gen.name] = gen.params
return params
3 changes: 3 additions & 0 deletions dmff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from dmff.settings import DO_JIT


class DMFFException(BaseException):
pass

def jit_condition(*args, **kwargs):
def jit_deco(func):
if DO_JIT:
Expand Down
39 changes: 37 additions & 2 deletions tests/test_classical/test_gaff2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,45 @@ def test_gaff2_force(self, pdb, prm, values):
for jj in range(ii + 1, pos.shape[0]):
pairs.append((ii, jj))
pairs = np.array(pairs, dtype=int)

for ne, energy in enumerate(h._potentials):
E = energy(pos, box, pairs, h.getGenerators()[ne].params)
npt.assert_almost_equal(E, values[ne], decimal=3)

E = jax.jit(energy)(pos, box, pairs, h.getGenerators()[ne].params)
npt.assert_almost_equal(E, values[ne], decimal=3)
npt.assert_almost_equal(E, values[ne], decimal=3)

@pytest.mark.parametrize(
"pdb, prm, values",
[
(
"tests/data/lig.pdb",
["tests/data/gaff-2.11.xml", "tests/data/lig-prm-lj.xml"],
[
174.16702270507812, 99.81585693359375,
99.0631103515625, 22.778038024902344
]
),
]
)
def test_gaff2_total(self, pdb, prm, values):
app.Topology.loadBondDefinitions("tests/data/lig-top.xml")
pdb = app.PDBFile(pdb)
h = Hamiltonian(*prm)
system = h.createPotential(
pdb.topology,
nonbondedMethod=app.NoCutoff,
constraints=None,
removeCMMotion=False
)
pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]])
pairs = []
for ii in range(pos.shape[0]):
for jj in range(ii + 1, pos.shape[0]):
pairs.append((ii, jj))
pairs = np.array(pairs, dtype=int)
efunc = h.getPotentialFunc()
params = h.getParameters()
Eref = sum(values)
Ecalc = efunc(pos, box, pairs, params)
npt.assert_almost_equal(Ecalc, Eref, decimal=3)