From 2c688f2da5eac219180a82c41796a7ca4bddb5b1 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sat, 14 May 2022 11:03:49 +0800 Subject: [PATCH 1/4] fix: LJ bug; all classical code can be jitted; bug at test_gaff2_force --- dmff/api.py | 11 +- dmff/classical/inter.py | 28 +++-- tests/{ => test_classical}/test_classical.py | 103 ++++++++++++------- 3 files changed, 91 insertions(+), 51 deletions(-) rename tests/{ => test_classical}/test_classical.py (71%) diff --git a/dmff/api.py b/dmff/api.py index fc8788231..849bb1d88 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -8,20 +8,17 @@ from collections import defaultdict import xml.etree.ElementTree as ET from .admp.disp_pme import ADMPDispPmeForce -from .admp.multipole import convert_cart2harm, rot_local2global +from .admp.multipole import convert_cart2harm from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel from .admp.pme import ADMPPmeForce -from .admp.spatial import generate_construct_local_frames -from .admp.recip import Ck_1, generate_pme_recip -from .utils import jit_condition from .classical.intra import ( HarmonicBondJaxForce, HarmonicAngleJaxForce, PeriodicTorsionJaxForce, ) from jax_md import space, partition -from jax import grad, vmap +from jax import grad import linecache import itertools from .classical.inter import ( @@ -30,7 +27,6 @@ CoulNoCutoffForce, CoulReactionFieldForce, ) - import sys @@ -1988,6 +1984,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): coulenergy = coulforce.generate_get_energy() def potential_fn(positions, box, pairs, params): + + positions = jnp.array(positions) + box = jnp.array(box) ljE = ljenergy( positions, diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 7051c3903..e44724a16 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,4 +1,9 @@ -from dmff.admp.pairwise import distribute_scalar +import sys + +from dmff.utils import pair_buffer_scales, regularize_pairs +sys.path.append('/home/roy/work/DMFF') + + import jax.numpy as jnp from dmff.admp.pme import energy_pme, setup_ewald_parameters from dmff.admp.recip import generate_pme_recip @@ -59,6 +64,11 @@ def get_LJ_energy(dr_vec, sig, eps, box): return E def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): + + positions = jnp.array(positions) + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) + map_prm = jnp.array(self.map_prm) eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1) eps_m2 = eps_m1.T @@ -77,8 +87,8 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): dr_vec = positions[pairs[:, 0]] - positions[pairs[:, 1]] - prm_pair0 = self.map_prm[pairs[:, 0]] - prm_pair1 = self.map_prm[pairs[:, 1]] + prm_pair0 = map_prm[pairs[:, 0]] + prm_pair1 = map_prm[pairs[:, 1]] eps = eps_mat[prm_pair0, prm_pair1] sig = sig_mat[prm_pair0, prm_pair1] @@ -86,7 +96,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): E_inter = get_LJ_energy(dr_vec, sig, eps_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy @@ -110,12 +120,16 @@ def get_coul_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): + + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) + map_prm = jnp.array(self.map_prm) colv_pair = self.colvmap[pairs[:,0],pairs[:,1]] mscale_pair = mscales[colv_pair-1] - chrg_map0 = self.map_prm[pairs[:, 0]] - chrg_map1 = self.map_prm[pairs[:, 1]] + chrg_map0 = map_prm[pairs[:, 0]] + chrg_map1 = map_prm[pairs[:, 1]] charge0 = charges[chrg_map0] charge1 = charges[chrg_map1] chrgprod = charge0 * charge1 @@ -124,7 +138,7 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_coul_energy(dr_vec, chrgprod_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy diff --git a/tests/test_classical.py b/tests/test_classical/test_classical.py similarity index 71% rename from tests/test_classical.py rename to tests/test_classical/test_classical.py index 81eea0c00..e11cd04d4 100644 --- a/tests/test_classical.py +++ b/tests/test_classical/test_classical.py @@ -5,20 +5,17 @@ import numpy as np import numpy.testing as npt from dmff.api import Hamiltonian -import dmff.api as api import pytest - -from dmff.classical.inter import LennardJonesForce +from jax import jit +from dmff import NeighborList class TestClassical: + @pytest.mark.parametrize( "pdb, prm, value", - [ - ("data/bond1.pdb", "data/bond1.xml", 1389.162109375), - #("data/bond2.pdb", "data/bond2.xml", 100.00), - ]) - def test_harmonic_bond_force(self, pdb, prm, value): + [("tests/data/lj2.pdb", "tests/data/lj2.xml", -1.85001802444458)]) + def test_lj_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, @@ -27,22 +24,23 @@ def test_harmonic_bond_force(self, pdb, prm, value): removeCMMotion=False) pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - pairs = np.array([[]], dtype=int) - bondE = h._potentials[0] - energy = bondE(pos, box, pairs, h.getGenerators()[0].params) + nblist = NeighborList(box, 4.0) + nblist.allocate(pos) + pairs = nblist.pairs + ljE = h._potentials[0] + energy = ljE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) - + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + @pytest.mark.parametrize( "pdb, prm, value", [ - ("data/proper1.pdb", "data/proper1.xml", 8.368000030517578), - ("data/proper2.pdb", "data/proper2.xml", 20.931230545), - ("data/impr1.pdb", "data/impr1.xml", 2.9460556507110596), - ("data/proper1.pdb", "data/wild1.xml", 8.368000030517578), - ("data/impr1.pdb", "data/wild2.xml", 2.9460556507110596), - #("data/tor2.pdb", "data/tor2.xml", 100.00) + ("tests/data/bond1.pdb", "tests/data/bond1.xml", 1389.162109375), + #("tests/data/bond2.pdb", "tests/data/bond2.xml", 100.00), ]) - def test_periodic_torsion_force(self, pdb, prm, value): + def test_harmonic_bond_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, @@ -55,11 +53,21 @@ def test_periodic_torsion_force(self, pdb, prm, value): bondE = h._potentials[0] energy = bondE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(bondE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj2.pdb", "data/lj2.xml", -1.85001802444458)]) - def test_lj_force(self, pdb, prm, value): + [ + ("tests/data/proper1.pdb", "tests/data/proper1.xml", 8.368000030517578), + ("tests/data/proper2.pdb", "tests/data/proper2.xml", 20.931230545), + ("tests/data/impr1.pdb", "tests/data/impr1.xml", 2.9460556507110596), + ("tests/data/proper1.pdb", "tests/data/wild1.xml", 8.368000030517578), + ("tests/data/impr1.pdb", "tests/data/wild2.xml", 2.9460556507110596), + #("tests/data/tor2.pdb", "tests/data/tor2.xml", 100.00) + ]) + def test_periodic_torsion_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, @@ -68,15 +76,17 @@ def test_lj_force(self, pdb, prm, value): removeCMMotion=False) pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - pairs = np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]], - dtype=int) - ljE = h._potentials[0] - energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + pairs = np.array([[]], dtype=int) + bondE = h._potentials[0] + energy = bondE(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(bondE)(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj3.pdb", "data/lj3.xml", -2.001220464706421)]) + [("tests/data/lj3.pdb", "tests/data/lj3.xml", -2.001220464706421)]) def test_lj_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -94,10 +104,13 @@ def test_lj_large_force(self, pdb, prm, value): ljE = h._potentials[0] energy = ljE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj2.pdb", "data/coul2.xml", 83.72177124023438)]) + [("tests/data/lj2.pdb", "tests/data/coul2.xml", 83.72177124023438)]) def test_coul_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -112,10 +125,13 @@ def test_coul_force(self, pdb, prm, value): coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj3.pdb", "data/coul3.xml", -446.82037353515625)]) + [("tests/data/lj3.pdb", "tests/data/coul3.xml", -446.82037353515625)]) def test_coul_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -133,10 +149,13 @@ def test_coul_large_force(self, pdb, prm, value): coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lj3.pdb", "data/coul3-res.xml", -446.82037353515625)]) + [("tests/data/lj3.pdb", "tests/data/coul3-res.xml", -446.82037353515625)]) def test_coul_res_large_force(self, pdb, prm, value): pdb = app.PDBFile(pdb) h = Hamiltonian(prm) @@ -154,12 +173,15 @@ def test_coul_res_large_force(self, pdb, prm, value): coulE = h._potentials[0] energy = coulE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(coulE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize( "pdb, prm, value", - [("data/lig.pdb", "data/lig-prm-single-lj.xml", 22.77804946899414)]) + [("tests/data/lig.pdb", "tests/data/lig-prm-single-lj.xml", 22.77804946899414)]) def test_gaff2_lj_force(self, pdb, prm, value): - app.Topology.loadBondDefinitions("data/lig-top.xml") + app.Topology.loadBondDefinitions("tests/data/lig-top.xml") pdb = app.PDBFile(pdb) h = Hamiltonian(prm) system = h.createPotential(pdb.topology, @@ -172,21 +194,23 @@ def test_gaff2_lj_force(self, pdb, prm, value): for ii in range(pos.shape[0]): for jj in range(ii + 1, pos.shape[0]): pairs.append((ii, jj)) - pairs = np.array(pairs, dtype=int) - pairs = np.array(pairs, dtype=int) + pairs = jnp.array(pairs, dtype=int) ljE = h._potentials[0] energy = ljE(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(energy, value, decimal=3) @pytest.mark.parametrize("pdb, prm, values", [ - ("data/lig.pdb", ["data/gaff-2.11.xml", "data/lig-prm-lj.xml"], [ + ("tests/data/lig.pdb", ["tests/data/gaff-2.11.xml", "tests/data/lig-prm-lj.xml"], [ 174.16702270507812, 99.81585693359375, 99.0631103515625, 22.778038024902344 ]), - #("data/lig.pdb", ["data/gaff-2.11.xml", "data/lig-prm.xml"], []), + #("tests/data/lig.pdb", ["tests/data/gaff-2.11.xml", "tests/data/lig-prm.xml"], []), ]) def test_gaff2_force(self, pdb, prm, values): - app.Topology.loadBondDefinitions("data/lig-top.xml") + app.Topology.loadBondDefinitions("tests/data/lig-top.xml") pdb = app.PDBFile(pdb) h = Hamiltonian(*prm) system = h.createPotential(pdb.topology, @@ -204,5 +228,8 @@ def test_gaff2_force(self, pdb, prm, values): generators = h.getGenerators() be_ref, ae_ref, tore_ref, lj_ref = values for ne, energy in enumerate(h._potentials): - energy = energy(pos, box, pairs, h.getGenerators()[ne].params) - npt.assert_almost_equal(energy, values[ne], decimal=3) \ No newline at end of file + E = energy(pos, box, pairs, h.getGenerators()[ne].params) + npt.assert_almost_equal(E, values[ne], decimal=3) + + E = jit(energy)(pos, box, pairs, h.getGenerators()[0].params) + npt.assert_almost_equal(E, values[ne], decimal=3) \ No newline at end of file From ad46c6bc529b8e868b47951b17419c62e26e3a42 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Sat, 14 May 2022 11:14:51 +0800 Subject: [PATCH 2/4] fix test_gaff2_force bug --- tests/test_classical/test_classical.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_classical/test_classical.py b/tests/test_classical/test_classical.py index e11cd04d4..19cb5b724 100644 --- a/tests/test_classical/test_classical.py +++ b/tests/test_classical/test_classical.py @@ -231,5 +231,5 @@ def test_gaff2_force(self, pdb, prm, values): E = energy(pos, box, pairs, h.getGenerators()[ne].params) npt.assert_almost_equal(E, values[ne], decimal=3) - E = jit(energy)(pos, box, pairs, h.getGenerators()[0].params) - npt.assert_almost_equal(E, values[ne], decimal=3) \ No newline at end of file + E = jit(energy)(pos, box, pairs, h.getGenerators()[ne].params) + npt.assert_almost_equal(E, values[ne], decimal=3) From c2bfb1c67981f07a8b664839319646e0522acbf1 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Tue, 17 May 2022 18:31:39 +0800 Subject: [PATCH 3/4] fix: non-differentiable error, move args check in the api.py --- dmff/api.py | 19 +++++++----- dmff/classical/inter.py | 18 +++++------ dmff/utils.py | 14 ++++++++- docs/dev_guide/arch.md | 3 +- tests/test_classical/test_classical.py | 41 +++++++++++++++++++------- 5 files changed, 64 insertions(+), 31 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 849bb1d88..989b77e48 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -7,6 +7,8 @@ import jax.numpy as jnp from collections import defaultdict import xml.etree.ElementTree as ET + +from dmff.utils import isinstance_jnp from .admp.disp_pme import ADMPDispPmeForce from .admp.multipole import convert_cart2harm from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction @@ -207,7 +209,7 @@ def renderXML(self): class ADMPDispPmeGenerator: - ''' + r''' This one computes the undamped C6/C8/C10 interactions u = \sum_{ij} c6/r^6 + c8/r^8 + c10/r^10 ''' @@ -312,7 +314,7 @@ def renderXML(self): app.forcefield.parsers["ADMPDispPmeForce"] = ADMPDispPmeGenerator.parseElement class QqTtDampingGenerator: - ''' + r''' This one calculates the tang-tonnies damping of charge-charge interaction E = \sum_ij exp(-B*r)*(1+B*r)*q_i*q_j/r ''' @@ -388,7 +390,7 @@ def renderXML(self): class SlaterDampingGenerator: - ''' + r''' This one computes the slater-type damping function for c6/c8/c10 dispersion E = \sum_ij (f6-1)*c6/r6 + (f8-1)*c8/r8 + (f10-1)*c10/r10 fn = f_tt(x, n) @@ -470,7 +472,7 @@ def renderXML(self): class SlaterExGenerator: - ''' + r''' This one computes the Slater-ISA type exchange interaction u = \sum_ij A * (1/3*(Br)^2 + Br + 1) ''' @@ -629,7 +631,7 @@ def registerAtomType(self, atom: dict): @staticmethod def parseElement(element, hamiltonian): - """ parse admp related parameters in XML file + r""" parse admp related parameters in XML file example: @@ -1064,7 +1066,7 @@ def registerBondType(self, bond): @staticmethod def parseElement(element, hamiltonian): - """parse section in XML file + r"""parse section in XML file example: @@ -1156,7 +1158,7 @@ def registerAngleType(self, angle): @staticmethod def parseElement(element, hamiltonian): - """ parse section in XML file + r""" parse section in XML file example: @@ -1985,7 +1987,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): def potential_fn(positions, box, pairs, params): - positions = jnp.array(positions) + isinstance_jnp(positions, box, params) + box = jnp.array(box) ljE = ljenergy( diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index e44724a16..849b781df 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,9 +1,4 @@ -import sys - from dmff.utils import pair_buffer_scales, regularize_pairs -sys.path.append('/home/roy/work/DMFF') - - import jax.numpy as jnp from dmff.admp.pme import energy_pme, setup_ewald_parameters from dmff.admp.recip import generate_pme_recip @@ -33,7 +28,7 @@ def __init__( self.r_switch = r_switch self.r_cut = r_cut - self.map_prm = map_prm + self.map_prm = jnp.array(map_prm) self.map_nbfix = map_nbfix self.ifPBC = isPBC self.ifNoCut = isNoCut @@ -65,10 +60,9 @@ def get_LJ_energy(dr_vec, sig, eps, box): def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): - positions = jnp.array(positions) pairs = regularize_pairs(pairs) mask = pair_buffer_scales(pairs) - map_prm = jnp.array(self.map_prm) + map_prm = self.map_prm eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1) eps_m2 = eps_m1.T @@ -183,6 +177,9 @@ def get_rf_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): + + pairs = regularize_pairs(pairs) + mask = pair_buffer_scales(pairs) colv_pair = self.colvmap[pairs[:,0],pairs[:,1]] mscale_pair = mscales[colv_pair-1] @@ -197,7 +194,7 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_rf_energy(dr_vec, chrgprod_scale, box) - return jnp.sum(E_inter) + return jnp.sum(E_inter * mask) return get_energy @@ -212,7 +209,8 @@ def __init__(self, box, rc, ethresh, covalent_map): def generate_get_energy(self): def get_energy(positions, box, pairs, Q, mScales): - + # Not required regularize_pairs + # already done in the pme code return energy_pme( positions, box, diff --git a/dmff/utils.py b/dmff/utils.py index 1d5d9ac0d..7aa4dae2e 100644 --- a/dmff/utils.py +++ b/dmff/utils.py @@ -1,5 +1,5 @@ from dmff.settings import DO_JIT -from jax import jit, vmap +from jax import jit, vmap, tree_util import jax.numpy as jnp def jit_condition(*args, **kwargs): @@ -28,3 +28,15 @@ def pair_buffer_scales(p): p[0] - p[1], (p[0] - p[1] < 0, p[0] - p[1] >= 0), (lambda x: jnp.array(1), lambda x: jnp.array(0))) + + +def isinstance_jnp(*args): + + def _check(arg): + if not isinstance(arg, jnp.ndarray): + raise TypeError('all arguments must be jnp.array, \ + otherwise they won\'t be able to take derivatives \ + on these variables from outside of potential_fn anyway') + + for arg in args: + tree_util.tree_map(lambda arg: _check(arg), args[0]) diff --git a/docs/dev_guide/arch.md b/docs/dev_guide/arch.md index 6d86e88dd..e7e4811d1 100644 --- a/docs/dev_guide/arch.md +++ b/docs/dev_guide/arch.md @@ -137,6 +137,7 @@ returned to users: ```python def potential_fn(positions, box, pairs, params): + isinstance_jnp(positions, box, params) return bforce.get_energy( positions, box, pairs, params["k"], params["length"] ) @@ -145,7 +146,7 @@ self._jaxPotential = potential_fn ``` The `potential_fn` function only takes `(positions, box, pairs, params)` as explicit input arguments. All these arguments except -`pairs` (neighbor list) should be differentiable. Non differentiable parameters are passed into it by closure (see code convention section). +`pairs` (neighbor list) should be differentiable. A helper function `isinstance_jnp` in `utils.py` can check take-in arguments whether they are `jnp.array`. Non differentiable parameters are passed into it by closure (see code convention section). Meanwhile, if the generator needs to initialize multiple calculators (e.g. `NonBondedJaxGenerator` will call both `LJ` and `PME` calculators), `potential_fn` should return the summation of the results of all calculators. diff --git a/tests/test_classical/test_classical.py b/tests/test_classical/test_classical.py index 19cb5b724..71ec48d63 100644 --- a/tests/test_classical/test_classical.py +++ b/tests/test_classical/test_classical.py @@ -22,8 +22,8 @@ def test_lj_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) - box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) + box = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) nblist = NeighborList(box, 4.0) nblist.allocate(pos) pairs = nblist.pairs @@ -47,7 +47,7 @@ def test_harmonic_bond_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[]], dtype=int) bondE = h._potentials[0] @@ -74,7 +74,7 @@ def test_periodic_torsion_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[]], dtype=int) bondE = h._potentials[0] @@ -94,7 +94,7 @@ def test_lj_large_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(10): @@ -107,7 +107,26 @@ def test_lj_large_force(self, pdb, prm, value): energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) - + + def test_lj_force(self): + pdb = app.PDBFile("tests/data/lj3.pdb") + h = Hamiltonian("tests/data/lj3.xml") + system = h.createPotential(pdb.topology, + nonbondedMethod=app.NoCutoff, + constraints=None, + removeCMMotion=False) + pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) + pairs = [] + for ii in range(10): + for jj in range(ii + 1, 10): + pairs.append((ii, jj)) + pairs = np.array(pairs, dtype=int) + ljE = h._potentials[0] + with pytest.raises(TypeError): + energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + + @pytest.mark.parametrize( "pdb, prm, value", [("tests/data/lj2.pdb", "tests/data/coul2.xml", 83.72177124023438)]) @@ -118,7 +137,7 @@ def test_coul_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = np.array([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]], dtype=int) @@ -139,7 +158,7 @@ def test_coul_large_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(10): @@ -163,7 +182,7 @@ def test_coul_res_large_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(10): @@ -188,7 +207,7 @@ def test_gaff2_lj_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) pairs = [] for ii in range(pos.shape[0]): @@ -217,7 +236,7 @@ def test_gaff2_force(self, pdb, prm, values): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) + pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) pairs = [] for ii in range(pos.shape[0]): From 9afbba2f8ef7ca33cc86678efbd4a12037d2a4d5 Mon Sep 17 00:00:00 2001 From: Roy Kid Date: Wed, 18 May 2022 15:13:40 +0800 Subject: [PATCH 4/4] fix: remove redundancy `box=jnp.array(box)`; confirm isinstance_jnp is jit-compatiable --- dmff/api.py | 5 +++-- tests/test_classical/test_classical.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 989b77e48..d8c7c24a5 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1987,10 +1987,11 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): def potential_fn(positions, box, pairs, params): + # check whether args passed into potential_fn are jnp.array and differentiable + # note this check will be optimized away by jit + # it is jit-compatiable isinstance_jnp(positions, box, params) - box = jnp.array(box) - ljE = ljenergy( positions, box, diff --git a/tests/test_classical/test_classical.py b/tests/test_classical/test_classical.py index 71ec48d63..162900bef 100644 --- a/tests/test_classical/test_classical.py +++ b/tests/test_classical/test_classical.py @@ -6,7 +6,7 @@ import numpy.testing as npt from dmff.api import Hamiltonian import pytest -from jax import jit +from jax import jit, make_jaxpr, grad from dmff import NeighborList @@ -108,7 +108,7 @@ def test_lj_large_force(self, pdb, prm, value): energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) npt.assert_almost_equal(energy, value, decimal=3) - def test_lj_force(self): + def test_lj_params_check(self): pdb = app.PDBFile("tests/data/lj3.pdb") h = Hamiltonian("tests/data/lj3.xml") system = h.createPotential(pdb.topology, @@ -125,6 +125,9 @@ def test_lj_force(self): ljE = h._potentials[0] with pytest.raises(TypeError): energy = ljE(pos, box, pairs, h.getGenerators()[0].params) + + energy = jit(ljE)(pos, box, pairs, h.getGenerators()[0].params) # jit will optimized away type check + force = grad(jit(ljE))(pos, box, pairs, h.getGenerators()[0].params) @pytest.mark.parametrize(