diff --git a/dmff/api.py b/dmff/api.py index fc8788231..d8c7c24a5 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -7,21 +7,20 @@ import jax.numpy as jnp from collections import defaultdict import xml.etree.ElementTree as ET + +from dmff.utils import isinstance_jnp from .admp.disp_pme import ADMPDispPmeForce -from .admp.multipole import convert_cart2harm, rot_local2global +from .admp.multipole import convert_cart2harm from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel from .admp.pme import ADMPPmeForce -from .admp.spatial import generate_construct_local_frames -from .admp.recip import Ck_1, generate_pme_recip -from .utils import jit_condition from .classical.intra import ( HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, ) from jax_md import space, partition -from jax import grad, vmap +from jax import grad import linecache import itertools from .classical.inter import ( @@ -30,7 +29,6 @@ CoulNoCutoffForce, CoulReactionFieldForce, ) - import sys @@ -211,7 +209,7 @@ def renderXML(self): class ADMPDispPmeGenerator: - ''' + r''' This one computes the undamped C6/C8/C10 interactions u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 ''' @@ -316,7 +314,7 @@ def renderXML(self): app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement class QqTtDampingGenerator: - ''' + r''' This one calculates the tang-tonnies damping of charge-charge interaction E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r ''' @@ -392,7 +390,7 @@ def renderXML(self): class SlaterDampingGenerator: - ''' + r''' This one computes the slater-type damping function for c6/c8/c10 dispersion E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10 fn = f_tt(x, n) @@ -474,7 +472,7 @@ def renderXML(self): class SlaterExGenerator: - ''' + r''' This one computes the Slater-ISA type exchange interaction u = \sum_ij A * (1/3*(Br)^2 + Br + 1) ''' @@ -633,7 +631,7 @@ def registerAtomType(self, atom: dict): @staticmethod def parseElement(element, hamiltonian): - """ parse admp related parameters in XML file + r""" parse admp related parameters in XML file example: @@ -1068,7 +1066,7 @@ def registerBondType(self, bond): @staticmethod def parseElement(element, hamiltonian): - """parse section in XML file + r"""parse section in XML file example: @@ -1160,7 +1158,7 @@ def registerAngleType(self, angle): @staticmethod def parseElement(element, hamiltonian): - """ parse section in XML file + r""" parse section in XML file example: @@ -1988,7 +1986,12 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): coulenergy = coulforce.generate_get_energy() def potential_fn(positions, box, pairs, params): - + + # check whether args passed into potential_fn are jnp.array and differentiable + # note this check will be optimized away by jit + # it is jit-compatiable + isinstance_jnp(positions, box, params) + ljE = ljenergy( positions, box, diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 7051c3903..849b781df 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,4 +1,4 @@ -from dmff.admp.pairwise import distribute_scalar +from dmff.utils import pair_buffer_scales, regularize_pairs import jax.numpy as jnp from dmff.admp.pme import energy_pme, setup_ewald_parameters from dmff.admp.recip import generate_pme_recip @@ -28,7 +28,7 @@ def __init__( self.r_switch = r_switch self.r_cut = r_cut - self.map_prm = map_prm + self.map_prm = jnp.array(map_prm) self.map_nbfix = map_nbfix self.ifPBC = isPBC self.ifNoCut = isNoCut @@ -59,6 +59,10 @@ def get_LJ_energy(dr_vec, sig, eps, box): return E def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): + + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) + map_prm = self.map_prm eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1) eps_m2 = eps_m1.T @@ -77,8 +81,8 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): dr_vec = positions[pairs[:, 0]] - positions[pairs[:, 1]] - prm_pair0 = self.map_prm[pairs[:, 0]] - prm_pair1 = self.map_prm[pairs[:, 1]] + prm_pair0 = map_prm[pairs[:, 0]] + prm_pair1 = map_prm[pairs[:, 1]] eps = eps_mat[prm_pair0, prm_pair1] sig = sig_mat[prm_pair0, prm_pair1] @@ -86,7 +90,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): E_inter = get_LJ_energy(dr_vec, sig, eps_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy @@ -110,12 +114,16 @@ def get_coul_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): + + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) + map_prm = jnp.array(self.map_prm) colv_pair = self.colvmap[pairs[:,0],pairs[:,1]] mscale_pair = mscales[colv_pair-1] - chrg_map0 = self.map_prm[pairs[:, 0]] - chrg_map1 = self.map_prm[pairs[:, 1]] + chrg_map0 = map_prm[pairs[:, 0]] + chrg_map1 = map_prm[pairs[:, 1]] charge0 = charges[chrg_map0] charge1 = charges[chrg_map1] chrgprod = charge0 * charge1 @@ -124,7 +132,7 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_coul_energy(dr_vec, chrgprod_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy @@ -169,6 +177,9 @@ def get_rf_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): + + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) colv_pair = self.colvmap[pairs[:,0],pairs[:,1]] mscale_pair = mscales[colv_pair-1] @@ -183,7 +194,7 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_rf_energy(dr_vec, chrgprod_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy @@ -198,7 +209,8 @@ def __init__(self, box, rc, ethresh, covalent_map): def generate_get_energy(self): def get_energy(positions, box, pairs, Q, mScales): - + # Not required regularize_pairs + # already done in the pme code return energy_pme( positions, box, diff --git a/dmff/utils.py b/dmff/utils.py index 1d5d9ac0d..7aa4dae2e 100644 --- a/dmff/utils.py +++ b/dmff/utils.py @@ -1,5 +1,5 @@ from dmff.settings import DO_JIT -from jax import jit, vmap +from jax import jit, vmap, tree_util import jax.numpy as jnp def jit_condition(*args, **kwargs): @@ -28,3 +28,15 @@ def pair_buffer_scales(p): p[0] - p[1], (p[0] - p[1] < 0, p[0] - p[1] >= 0), (lambda x: jnp.array(1), lambda x: jnp.array(0))) + + +def isinstance_jnp(*args): + + def _check(arg): + if not isinstance(arg, jnp.ndarray): + raise TypeError('all arguments must be jnp.array, \ + otherwise they won\'t be able to take derivatives \ + on these variables from outside of potential_fn anyway') + + for arg in args: + tree_util.tree_map(lambda arg: _check(arg), args[0]) diff --git a/docs/dev_guide/arch.md b/docs/dev_guide/arch.md index 6d86e88dd..e7e4811d1 100644 --- a/docs/dev_guide/arch.md +++ b/docs/dev_guide/arch.md @@ -137,6 +137,7 @@ returned to users: ```python def potential_fn(positions, box, pairs, params): + isinstance_jnp(positions, box, params) return bforce.get_energy( positions, box, pairs, params["k"], params["length"] ) @@ -145,7 +146,7 @@ self._jaxPotential = potential_fn ``` The `potential_fn` function only takes `(positions, box, pairs, params)` as explicit input arguments. All these arguments except -`pairs` (neighbor list) should be differentiable. Non differentiable parameters are passed into it by closure (see code convention section). +`pairs` (neighbor list) should be differentiable. A helper function `isinstance_jnp` in `utils.py` can check take-in arguments whether they are `jnp.array`. Non differentiable parameters are passed into it by closure (see code convention section). Meanwhile, if the generator needs to initialize multiple calculators (e.g. `NonBondedJaxGenerator` will call both `LJ` and `PME` calculators), `potential_fn` should return the summation of the results of all calculators. diff --git a/tests/test_classical.py b/tests/test_classical/test_classical.py similarity index 59% rename from tests/test_classical.py rename to tests/test_classical/test_classical.py index 81eea0c00..162900bef 100644 --- a/tests/test_classical.py +++ b/tests/test_classical/test_classical.py @@ -5,18 +5,40 @@ import numpy as np import numpy.testing as npt from dmff.api import Hamiltonian -import dmff.api as api import pytest - -from dmff.classical.inter import LennardJonesForce +from jax import jit, make_jaxpr, grad +from dmff import NeighborList class TestClassical: + + @pytest.mark.parametrize( + "pdb, prm, value", + [("tests/data/lj2.pdb", "tests/data/lj2.xml", -1.85001802444458)]) + def test_lj_force(self, pdb, prm, value): + pdb = app.PDBFile(pdb) + h = Hamiltonian(prm) + system = h.createPotential(pdb.topology, + nonbondedMethod=app.NoCutoff, + constraints=None, + removeCMMotion=False) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + nblist = NeighborList(box, 4.0) + nblist.allocate(pos) + pairs = nblist.pairs + ljE = h._potentials[0] + energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + @pytest.mark.parametrize( "pdb, prm, value", [ - ("data/bond1.pdb", "data/bond1.xml", 1389.162109375), - #("data/bond2.pdb", "data/bond2.xml", 100.00), + ("tests/data/bond1.pdb", "tests/data/bond1.xml", 1389.162109375), + #("tests/data/bond2.pdb", "tests/data/bond2.xml", 100.00), ]) def test_harmonic_bond_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) @@ -25,22 +47,25 @@ def test_harmonic_bond_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[]], dtype=int) bondE = h._potentials[0] energy = bondE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(bondE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", [ - ("data/proper1.pdb", "data/proper1.xml", 8.368000030517578), - ("data/proper2.pdb", "data/proper2.xml", 20.931230545), - ("data/impr1.pdb", "data/impr1.xml", 2.9460556507110596), - ("data/proper1.pdb", "data/wild1.xml", 8.368000030517578), - ("data/impr1.pdb", "data/wild2.xml", 2.9460556507110596), - #("data/tor2.pdb", "data/tor2.xml", 100.00) + ("tests/data/proper1.pdb", "tests/data/proper1.xml", 8.368000030517578), + ("tests/data/proper2.pdb", "tests/data/proper2.xml", 20.931230545), + ("tests/data/impr1.pdb", "tests/data/impr1.xml", 2.9460556507110596), + ("tests/data/proper1.pdb", "tests/data/wild1.xml", 8.368000030517578), + ("tests/data/impr1.pdb", "tests/data/wild2.xml", 2.9460556507110596), + #("tests/data/tor2.pdb", "tests/data/tor2.xml", 100.00) ]) def test_periodic_torsion_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) @@ -49,37 +74,43 @@ def test_periodic_torsion_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[]], dtype=int) bondE = h._potentials[0] energy = bondE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + energy = jit(bondE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj2.pdb", "data/lj2.xml", -1.85001802444458)]) - def test_lj_force(self, pdb, prm, value): + [("tests/data/lj3.pdb", "tests/data/lj3.xml", -2.001220464706421)]) + def test_lj_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - pairs = np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]], - dtype=int) + pairs = [] + for ii in range(10): + for jj in range(ii + 1, 10): + pairs.append((ii, jj)) + pairs = np.array(pairs, dtype=int) ljE = h._potentials[0] energy = ljE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) - - @pytest.mark.parametrize( - "pdb, prm, value", - [("data/lj3.pdb", "data/lj3.xml", -2.001220464706421)]) - def test_lj_large_force(self, pdb, prm, value): - pdb = app.PDBFile(pdb) - h = Hamiltonian(prm) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + + def test_lj_params_check(self): + pdb = app.PDBFile("tests/data/lj3.pdb") + h = Hamiltonian("tests/data/lj3.xml") system = h.createPotential(pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, @@ -90,14 +121,18 @@ def test_lj_large_force(self, pdb, prm, value): for ii in range(10): for jj in range(ii + 1, 10): pairs.append((ii, jj)) - pairs = np.array(pairs, dtype=int) + pairs = np.array(pairs, dtype=int) ljE = h._potentials[0] - energy = ljE(pos, box, pairs, h.getGenerators()[0].params) - npt.assert_almost_equal(energy, value, decimal=3) - + with pytest.raises(TypeError): + energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) # jit will optimized away type check + force = grad(jit(ljE))(pos, box, pairs, h.getGenerators()[0].params) + + @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj2.pdb", "data/coul2.xml", 83.72177124023438)]) + [("tests/data/lj2.pdb", "tests/data/coul2.xml", 83.72177124023438)]) def test_coul_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -105,17 +140,20 @@ def test_coul_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]], dtype=int) coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj3.pdb", "data/coul3.xml", -446.82037353515625)]) + [("tests/data/lj3.pdb", "tests/data/coul3.xml", -446.82037353515625)]) def test_coul_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -123,7 +161,7 @@ def test_coul_large_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(10): @@ -133,10 +171,13 @@ def test_coul_large_force(self, pdb, prm, value): coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj3.pdb", "data/coul3-res.xml", -446.82037353515625)]) + [("tests/data/lj3.pdb", "tests/data/coul3-res.xml", -446.82037353515625)]) def test_coul_res_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -144,7 +185,7 @@ def test_coul_res_large_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(10): @@ -154,46 +195,51 @@ def test_coul_res_large_force(self, pdb, prm, value): coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lig.pdb", "data/lig-prm-single-lj.xml", 22.77804946899414)]) + [("tests/data/lig.pdb", "tests/data/lig-prm-single-lj.xml", 22.77804946899414)]) def test_gaff2_lj_force(self, pdb, prm, value): - app.Topology.loadBondDefinitions("data/lig-top.xml") + app.Topology.loadBondDefinitions("tests/data/lig-top.xml") pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(pos.shape[0]): for jj in range(ii + 1, pos.shape[0]): pairs.append((ii, jj)) - pairs = np.array(pairs, dtype=int) - pairs = np.array(pairs, dtype=int) + pairs = jnp.array(pairs, dtype=int) ljE = h._potentials[0] energy = ljE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize("pdb, prm, values", [ - ("data/lig.pdb", ["data/gaff-2.11.xml", "data/lig-prm-lj.xml"], [ + ("tests/data/lig.pdb", ["tests/data/gaff-2.11.xml", "tests/data/lig-prm-lj.xml"], [ 174.16702270507812, 99.81585693359375, 99.0631103515625, 22.778038024902344 ]), - #("data/lig.pdb", ["data/gaff-2.11.xml", "data/lig-prm.xml"], []), + #("tests/data/lig.pdb", ["tests/data/gaff-2.11.xml", "tests/data/lig-prm.xml"], []), ]) def test_gaff2_force(self, pdb, prm, values): - app.Topology.loadBondDefinitions("data/lig-top.xml") + app.Topology.loadBondDefinitions("tests/data/lig-top.xml") pdb = app.PDBFile(pdb) h = Hamiltonian(*prm) system = h.createPotential(pdb.topology, nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) pairs = [] for ii in range(pos.shape[0]): @@ -204,5 +250,8 @@ def test_gaff2_force(self, pdb, prm, values): generators = h.getGenerators() be_ref, ae_ref, tore_ref, lj_ref = values for ne, energy in enumerate(h._potentials): - energy = energy(pos, box, pairs, h.getGenerators()[ne].params) - npt.assert_almost_equal(energy, values[ne], decimal=3) \ No newline at end of file + E = energy(pos, box, pairs, h.getGenerators()[ne].params) + npt.assert_almost_equal(E, values[ne], decimal=3) + + E = jit(energy)(pos, box, pairs, h.getGenerators()[ne].params) + npt.assert_almost_equal(E, values[ne], decimal=3)