diff --git a/dmff/api.py b/dmff/api.py index 3219c8f70..0ef52393e 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -40,7 +40,7 @@ LennardJonesLongRangeFreeEnergyForce, CoulombPMEFreeEnergyForce ) -from dmff.utils import jit_condition, isinstance_jnp +from dmff.utils import jit_condition, isinstance_jnp, DMFFException class XMLNodeInfo: @@ -152,6 +152,7 @@ def __init__(self, hamiltonian): self.types = [] self.ethresh = 5e-4 self.pmax = 10 + self.name = "ADMPDisp" def registerAtomType(self, atom): self.types.append(atom["type"]) @@ -281,6 +282,7 @@ def __init__(self, hamiltonian): self.types = [] self.ethresh = 5e-4 self.pmax = 10 + self.name = "ADMPDispPme" def registerAtomType(self, atom): self.types.append(atom["type"]) @@ -382,6 +384,7 @@ def __init__(self, hamiltonian): } self._jaxPotential = None self.types = [] + self.name = "QqTtDamping" def registerAtomType(self, atom): self.types.append(atom["type"]) @@ -462,6 +465,7 @@ def __init__(self, hamiltonian): } self._jaxPotential = None self.types = [] + self.name = "SlaterDamping" def registerAtomType(self, atom): self.types.append(atom["type"]) @@ -541,6 +545,7 @@ def __init__(self, hamiltonian): } self._jaxPotential = None self.types = [] + self.name = "SlaterEx" def registerAtomType(self, atom): self.types.append(atom["type"]) @@ -606,15 +611,19 @@ def renderXML(self): class SlaterSrEsGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + self.name = "SlaterSrEs" class SlaterSrPolGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + self.name = "SlaterSrPol" class SlaterSrDispGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + self.name = "SlaterSrDisp" class SlaterDhfGenerator(SlaterExGenerator): def __init__(self): super().__init__(self) + self.name = "SlaterDhf" # register all parsers app.forcefield.parsers["SlaterSrEsForce"] = SlaterSrEsGenerator.parseElement @@ -670,6 +679,7 @@ def __init__(self, hamiltonian): self.step_pol = None self.lpol = False self.ref_dip = "" + self.name = "ADMPPme" def registerAtomType(self, atom: dict): @@ -1149,6 +1159,7 @@ def __init__(self, hamiltonian): self._jaxPotential = None self.types = [] self.typetexts = [] + self.name = "HarmonicBond" def registerBondType(self, bond): typetxt = findAtomTypeTexts(bond, 2) @@ -1247,6 +1258,7 @@ def __init__(self, hamiltonian): self.params = {"k": [], "angle": []} self._jaxPotential = None self.types = [] + self.name = "HarmonicAngle" def registerAngleType(self, angle): types = self.ff._findAtomTypes(angle, 3) @@ -1277,13 +1289,14 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): self.params[k] = jnp.array(self.params[k]) self.types = np.array(self.types) - n_angles = len(data.angles) + max_angles = len(data.angles) + n_angles = 0 # build map - map_atom1 = np.zeros(n_angles, dtype=int) - map_atom2 = np.zeros(n_angles, dtype=int) - map_atom3 = np.zeros(n_angles, dtype=int) - map_param = np.zeros(n_angles, dtype=int) - for i in range(n_angles): + map_atom1 = np.zeros(max_angles, dtype=int) + map_atom2 = np.zeros(max_angles, dtype=int) + map_atom3 = np.zeros(max_angles, dtype=int) + map_param = np.zeros(max_angles, dtype=int) + for i in range(max_angles): idx1 = data.angles[i][0] idx2 = data.angles[i][1] idx3 = data.angles[i][2] @@ -1296,17 +1309,23 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): if (type1 in self.types[ii][0] and type3 in self.types[ii][2]) or ( type1 in self.types[ii][2] and type3 in self.types[ii][0] ): - map_atom1[i] = idx1 - map_atom2[i] = idx2 - map_atom3[i] = idx3 - map_param[i] = ii + map_atom1[n_angles] = idx1 + map_atom2[n_angles] = idx2 + map_atom3[n_angles] = idx3 + map_param[n_angles] = ii ifFound = True + n_angles += 1 break if not ifFound: - raise BaseException( + print( "No parameter for angle %i - %i - %i" % (idx1, idx2, idx3) ) + map_atom1 = map_atom1[:n_angles] + map_atom2 = map_atom2[:n_angles] + map_atom3 = map_atom3[:n_angles] + map_param = map_param[:n_angles] + aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3, map_param) def potential_fn(positions, box, pairs, params): @@ -1513,6 +1532,7 @@ def __init__(self, hamiltonian): self.propersForAtomType = defaultdict(set) self.n_proper = 0 self.n_improper = 0 + self.name = "PeriodicTorsion" def registerProperTorsion(self, parameters): torsion = _parseTorsion(self.ff, parameters) @@ -1955,6 +1975,7 @@ def __init__(self, hamiltionian, coulomb14scale, lj14scale): } self.types = [] self.useAttributeFromResidue = [] + self.name = "Nonbond" def registerAtom(self, atom): @@ -2397,3 +2418,25 @@ def render(self, filename): tree = ET.ElementTree(root) tree.write(filename) + + def getPotentialFunc(self): + if len(self._potentials) == 0: + raise DMFFException("Hamiltonian need to be initialized.") + efuncs = {} + for gen in self.getGenerators(): + efuncs[gen.name] = gen._jaxPotential + + def totalPE(positions, box, pairs, params): + totale = sum([ + efuncs[k](positions, box, pairs, params[k]) + for k in efuncs.keys() + ]) + return totale + + return totalPE + + def getParameters(self): + params = {} + for gen in self.getGenerators(): + params[gen.name] = gen.params + return params \ No newline at end of file diff --git a/dmff/utils.py b/dmff/utils.py index 9a5761589..be00d9c64 100644 --- a/dmff/utils.py +++ b/dmff/utils.py @@ -3,6 +3,9 @@ from dmff.settings import DO_JIT +class DMFFException(BaseException): + pass + def jit_condition(*args, **kwargs): def jit_deco(func): if DO_JIT: diff --git a/tests/test_classical/test_gaff2.py b/tests/test_classical/test_gaff2.py index 4062eadd9..8916f5488 100644 --- a/tests/test_classical/test_gaff2.py +++ b/tests/test_classical/test_gaff2.py @@ -65,10 +65,45 @@ def test_gaff2_force(self, pdb, prm, values): for jj in range(ii + 1, pos.shape[0]): pairs.append((ii, jj)) pairs = np.array(pairs, dtype=int) - for ne, energy in enumerate(h._potentials): E = energy(pos, box, pairs, h.getGenerators()[ne].params) npt.assert_almost_equal(E, values[ne], decimal=3) E = jax.jit(energy)(pos, box, pairs, h.getGenerators()[ne].params) - npt.assert_almost_equal(E, values[ne], decimal=3) \ No newline at end of file + npt.assert_almost_equal(E, values[ne], decimal=3) + + @pytest.mark.parametrize( + "pdb, prm, values", + [ + ( + "tests/data/lig.pdb", + ["tests/data/gaff-2.11.xml", "tests/data/lig-prm-lj.xml"], + [ + 174.16702270507812, 99.81585693359375, + 99.0631103515625, 22.778038024902344 + ] + ), + ] + ) + def test_gaff2_total(self, pdb, prm, values): + app.Topology.loadBondDefinitions("tests/data/lig-top.xml") + pdb = app.PDBFile(pdb) + h = Hamiltonian(*prm) + system = h.createPotential( + pdb.topology, + nonbondedMethod=app.NoCutoff, + constraints=None, + removeCMMotion=False + ) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) + pairs = [] + for ii in range(pos.shape[0]): + for jj in range(ii + 1, pos.shape[0]): + pairs.append((ii, jj)) + pairs = np.array(pairs, dtype=int) + efunc = h.getPotentialFunc() + params = h.getParameters() + Eref = sum(values) + Ecalc = efunc(pos, box, pairs, params) + npt.assert_almost_equal(Ecalc, Eref, decimal=3) \ No newline at end of file