From 1ee405137c936cddc7da1a42e5caa2feae4491ee Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Tue, 29 Nov 2022 23:34:53 +0800 Subject: [PATCH 1/9] Update python version requirement to 3.8 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 3461b0e97..a0d782e4f 100644 --- a/setup.py +++ b/setup.py @@ -39,12 +39,12 @@ def setup(scm=None): long_description=readme, long_description_content_type="text/markdown", url="https://github.com/deepmodeling/DMFF", - python_requires="~=3.6", + python_requires="~=3.8", packages=packages, data_files=[], package_data={}, classifiers=[ - "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", ], keywords='DMFF', From 88c1626b7689d16e55b1334589d99ad115c22a1d Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Tue, 29 Nov 2022 23:38:18 +0800 Subject: [PATCH 2/9] Save meta data in Potential object --- dmff/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dmff/api.py b/dmff/api.py index ddc0533fc..d1f20324c 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -53,9 +53,13 @@ class Potential: def __init__(self): self.dmff_potentials = {} self.omm_system = None + self.meta = {} - def addDmffPotential(self, name, potential): + def addDmffPotential(self, name, potential, meta={}): self.dmff_potentials[name] = potential + if len(meta): + for key in meta.keys(): + self.meta[key] = meta[key] def addOmmSystem(self, system): self.omm_system = system From 88fa833baea73e0d1d0a27578cb78b640a96c0c2 Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Tue, 29 Nov 2022 23:59:19 +0800 Subject: [PATCH 3/9] Add support to save meta data in Hamiltonian --- dmff/api.py | 4 +-- dmff/generators/admp.py | 21 ++++++++----- dmff/generators/classical.py | 59 ++++++++++++++++++++++++++++++------ 3 files changed, 66 insertions(+), 18 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index d1f20324c..eb53ad9f1 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -186,8 +186,8 @@ def createPotential(self, if len(jaxForces) > 0 and generator.name not in jaxForces: continue try: - potentialImpl = generator.getJaxPotential() - potObj.addDmffPotential(generator.name, potentialImpl) + potentialImpl, meta = generator.getJaxPotential() + potObj.addDmffPotential(generator.name, potentialImpl, meta=meta) except Exception as e: print(e) pass diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index c388cde7e..5dad78065 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -33,6 +33,7 @@ def __init__(self, ff): self.types = [] self.ethresh = 5e-4 self.pmax = 10 + self._meta = {} def extract(self): @@ -159,7 +160,7 @@ def overwrite(self): [self.paramtree[self.name]['C10']]) def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators['ADMPDispForce'] = ADMPDispGenerator @@ -181,6 +182,7 @@ def __init__(self, ff): self.ethresh = 5e-4 self.pmax = 10 self.name = "ADMPDispPmeForce" + self._meta = {} def extract(self): @@ -284,7 +286,7 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator @@ -302,6 +304,7 @@ def __init__(self, ff): self.paramtree = ff.paramtree self._jaxPotnetial = None self.name = "QqTtDampingForce" + self._meta = {} def extract(self): # get mscales @@ -372,7 +375,7 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta # register all parsers @@ -392,6 +395,7 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self._jaxPotential = None + self._meta = {} def extract(self): # get mscales @@ -473,7 +477,7 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators['SlaterDampingForce'] = SlaterDampingGenerator @@ -490,6 +494,7 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self._jaxPotential = None + self._meta = {} def extract(self): # get mscales @@ -558,7 +563,7 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["SlaterExForce"] = SlaterExGenerator @@ -613,6 +618,8 @@ def __init__(self, ff): self.lpol = False self.ref_dip = "" + self._meta = {} + def extract(self): self.lmax = self.fftree.get_attribs(f'{self.name}', @@ -850,7 +857,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map self.covalent_map = covalent_map = build_covalent_map(data, 6) - + self._meta["cov_map"] = self.covalent_map # build intra-molecule axis # the following code is the direct transplant of forcefield.py in openmm 7.4.0 @@ -1096,7 +1103,7 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator \ No newline at end of file diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index 2ae0c3eb9..96502038a 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -40,6 +40,7 @@ def __init__(self, ff: Hamiltonian): self.ff: Hamiltonian = ff self.fftree: ForcefieldTree = ff.fftree self.paramtree: Dict = ff.paramtree + self._meta = {} def extract(self): """ @@ -74,6 +75,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): Args: Those args are the same as those in createSystem. """ + self._meta = {} # initialize typemap matcher = TypeMatcher(self.fftree, "HarmonicBondForce/Bond") @@ -114,7 +116,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): map_atom1 = np.array(map_atom1, dtype=int) map_atom2 = np.array(map_atom2, dtype=int) - map_param = np.array(map_param, dtype=int) + map_param = np.array(map_param, dtype=int) + self._meta["HarmonicBondForce_atom1"] = map_atom1 + self._meta["HarmonicBondForce_atom2"] = map_atom2 + self._meta["HarmonicBondForce_param"] = map_param bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param) self._force_latest = bforce @@ -128,7 +133,7 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["HarmonicBondForce"] = HarmonicBondJaxGenerator @@ -140,6 +145,7 @@ def __init__(self, ff): self.ff = ff self.fftree = ff.fftree self.paramtree = ff.paramtree + self._meta = {} def extract(self): angles = self.fftree.get_attribs(f"{self.name}/Angle", "angle") @@ -155,6 +161,8 @@ def overwrite(self): self.paramtree[self.name]["k"]) def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): + self._meta = {} + matcher = TypeMatcher(self.fftree, "HarmonicAngleForce/Angle") map_atom1, map_atom2, map_atom3, map_param = [], [], [], [] @@ -202,6 +210,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): map_atom2 = np.array(map_atom2, dtype=int) map_atom3 = np.array(map_atom3, dtype=int) map_param = np.array(map_param, dtype=int) + self._meta["HarmonicAngleForce_atom1"] = map_atom1 + self._meta["HarmonicAngleForce_atom2"] = map_atom2 + self._meta["HarmonicAngleForce_atom3"] = map_atom3 + self._meta["HarmonicAngleForce_param"] = map_param aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3, map_param) @@ -216,7 +228,7 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["HarmonicAngleForce"] = HarmonicAngleJaxGenerator @@ -229,7 +241,7 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self.meta = {} - + self._meta self.meta["prop_order"] = defaultdict(list) self.meta["prop_nodeidx"] = defaultdict(list) @@ -340,6 +352,17 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): """ Create force for torsions """ + self.meta = {} + self._meta + self.meta["prop_order"] = defaultdict(list) + self.meta["prop_nodeidx"] = defaultdict(list) + + self.meta["impr_order"] = defaultdict(list) + self.meta["impr_nodeidx"] = defaultdict(list) + + self.max_pred_prop = 0 + self.max_pred_impr = 0 + # Proper Torsions proper_matcher = TypeMatcher(self.fftree, "PeriodicTorsionForce/Proper") @@ -487,6 +510,19 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): self._props_latest = props self._imprs_latest = imprs + self._meta["PeriodicTorsionForce_prop_atom1"] = map_prop_atom1 + self._meta["PeriodicTorsionForce_prop_atom2"] = map_prop_atom2 + self._meta["PeriodicTorsionForce_prop_atom3"] = map_prop_atom3 + self._meta["PeriodicTorsionForce_prop_atom4"] = map_prop_atom4 + self._meta["PeriodicTorsionForce_prop_param"] = map_prop_param + + self._meta["PeriodicTorsionForce_impr_atom1"] = map_impr_atom1 + self._meta["PeriodicTorsionForce_impr_atom2"] = map_impr_atom2 + self._meta["PeriodicTorsionForce_impr_atom3"] = map_impr_atom3 + self._meta["PeriodicTorsionForce_impr_atom4"] = map_impr_atom4 + self._meta["PeriodicTorsionForce_impr_param"] = map_impr_param + + def potential_fn(positions, box, pairs, params): prop_sum = sum([ props[i].get_energy( @@ -508,7 +544,7 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["PeriodicTorsionForce"] = PeriodicTorsionJaxGenerator @@ -532,6 +568,8 @@ def __init__(self, ff: Hamiltonian): self.useBCC = False self.useVsite = False + self._meta = {} + def extract(self): self.from_residue = self.fftree.get_attribs( "NonbondedForce/UseAttributeFromResidue", "name") @@ -684,6 +722,8 @@ def addVsiteFunc(pos, params): cov_map[ori_dim + i, parent_i] = 1 self.covalent_map = jnp.array(cov_map) + self._meta["cov_map"] = self.covalent_map + # Load Lennard-Jones parameters maps = {} if not nbmatcher.useSmirks: @@ -975,7 +1015,7 @@ def potential_fn(positions, box, pairs, params, vdwLambda, self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta def getAddVsiteFunc(self): """ @@ -1008,8 +1048,8 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self.paramtree[self.name] = {} - self.paramtree[self.name] - self.paramtree[self.name] + self._meta + def extract(self): for prm in ["sigma", "epsilon"]: @@ -1109,6 +1149,7 @@ def findIdx(labels, label): map_nbfix = jnp.array(map_nbfix) colv_map = build_covalent_map(data, 6) + self._meta["cov_map"] = colv_map if unit.is_quantity(nonbondedCutoff): r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) @@ -1188,7 +1229,7 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential + return self._jaxPotential, self._meta dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator \ No newline at end of file From 8b1df2974be331f1432bb2f2289e8b462d7e63fb Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Wed, 30 Nov 2022 12:53:05 +0800 Subject: [PATCH 4/9] Create an attribute to save meta data --- dmff/generators/classical.py | 2 +- requirements.txt | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index 96502038a..3c070b02e 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -241,7 +241,7 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self.meta = {} - self._meta + self._meta = {} self.meta["prop_order"] = defaultdict(list) self.meta["prop_nodeidx"] = defaultdict(list) diff --git a/requirements.txt b/requirements.txt index 9dd3170c3..9d9f862b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ mkdocstrings-python>=0.7.0 pygments>=2.12 jax>=0.3.7,<0.3.16 jaxlib>=0.3.7,<0.3.16 -pymbar==4.0.1 \ No newline at end of file +pymbar==4.0.1 +tqdm==4.64.1 +rdkit==2022.03.2 From 7d079ac683f687b8cab3040e9acc0f18d7f67c79 Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Wed, 30 Nov 2022 13:48:24 +0800 Subject: [PATCH 5/9] Use potential.meta["cov_map"] in unittests --- dmff/api.py | 3 ++- dmff/generators/admp.py | 30 ++++++++++++++++++++----- dmff/generators/classical.py | 35 ++++++++++++++++------------- tests/test_classical/test_coul.py | 10 +++------ tests/test_classical/test_fep.py | 6 ++--- tests/test_classical/test_gaff2.py | 9 +++----- tests/test_classical/test_lj.py | 9 +++----- tests/test_classical/test_smirks.py | 4 ++-- tests/test_mbar/test_mbar.py | 12 ++-------- 9 files changed, 61 insertions(+), 57 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index eb53ad9f1..00a14af2f 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -186,7 +186,8 @@ def createPotential(self, if len(jaxForces) > 0 and generator.name not in jaxForces: continue try: - potentialImpl, meta = generator.getJaxPotential() + potentialImpl = generator.getJaxPotential() + meta = generator.getMetaData() potObj.addDmffPotential(generator.name, potentialImpl, meta=meta) except Exception as e: print(e) diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index 5dad78065..dd0ebe9ec 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -160,7 +160,10 @@ def overwrite(self): [self.paramtree[self.name]['C10']]) def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators['ADMPDispForce'] = ADMPDispGenerator @@ -286,7 +289,10 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator @@ -375,7 +381,10 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta # register all parsers @@ -477,7 +486,10 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators['SlaterDampingForce'] = SlaterDampingGenerator @@ -563,7 +575,10 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["SlaterExForce"] = SlaterExGenerator @@ -1103,7 +1118,10 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator \ No newline at end of file diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index 3c070b02e..fbe8f1d96 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -133,7 +133,10 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["HarmonicBondForce"] = HarmonicBondJaxGenerator @@ -228,7 +231,10 @@ def potential_fn(positions, box, pairs, params): # self._top_data = data def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["HarmonicAngleForce"] = HarmonicAngleJaxGenerator @@ -352,16 +358,6 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): """ Create force for torsions """ - self.meta = {} - self._meta - self.meta["prop_order"] = defaultdict(list) - self.meta["prop_nodeidx"] = defaultdict(list) - - self.meta["impr_order"] = defaultdict(list) - self.meta["impr_nodeidx"] = defaultdict(list) - - self.max_pred_prop = 0 - self.max_pred_impr = 0 # Proper Torsions proper_matcher = TypeMatcher(self.fftree, @@ -544,7 +540,10 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["PeriodicTorsionForce"] = PeriodicTorsionJaxGenerator @@ -1015,7 +1014,10 @@ def potential_fn(positions, box, pairs, params, vdwLambda, self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta def getAddVsiteFunc(self): """ @@ -1229,7 +1231,10 @@ def potential_fn(positions, box, pairs, params): self._jaxPotential = potential_fn def getJaxPotential(self): - return self._jaxPotential, self._meta + return self._jaxPotential + + def getMetaData(self): + return self._meta dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator \ No newline at end of file diff --git a/tests/test_classical/test_coul.py b/tests/test_classical/test_coul.py index afbce7d01..32adad2f1 100644 --- a/tests/test_classical/test_coul.py +++ b/tests/test_classical/test_coul.py @@ -23,8 +23,7 @@ def test_coul_force(self, pdb, prm, value): pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) rc = 4 - gen = h.getGenerators()[-1] - nblist = NeighborList(box, rc, gen.covalent_map) + nblist = NeighborList(box, rc, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs coulE = potential.getPotentialFunc(names="NonbondedForce") @@ -71,8 +70,7 @@ def test_coul_res_large_force(self, pdb, prm, value): pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) rc = 4 - gen = h.getGenerators()[-1] - nblist = NeighborList(box, rc, gen.covalent_map) + nblist = NeighborList(box, rc, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs coulE = potential.getPotentialFunc() @@ -116,9 +114,7 @@ def test_coul_pme(self, pdb, prm, value): [ 0.00, 0.00, 1.20] ]) - gen = h.getGenerators()[-1] - - nbList = NeighborList(box, rcut, gen.covalent_map) + nbList = NeighborList(box, rcut, potential.meta["cov_map"]) nbList.allocate(positions) pairs = nbList.pairs func = potential.getPotentialFunc(names=["NonbondedForce"]) diff --git a/tests/test_classical/test_fep.py b/tests/test_classical/test_fep.py index e9cbf2ab9..dad7b1563 100644 --- a/tests/test_classical/test_fep.py +++ b/tests/test_classical/test_fep.py @@ -47,8 +47,7 @@ def test_coul(self, pdb, prm, lambdas, energies, dvdls): [ 0.00, 1.20, 0.00], [ 0.00, 0.00, 1.20] ]) - gen = h.getGenerators()[-1] - nbList = NeighborList(box, rcut, gen.covalent_map) + nbList = NeighborList(box, rcut, potential.meta["cov_map"]) nbList.allocate(positions) pairs = nbList.pairs func = jax.value_and_grad(potential.dmff_potentials["NonbondedForce"], argnums=-1) @@ -103,8 +102,7 @@ def test_vdw(self, pdb, prm, lambdas, energies, dvdls): [ 0.00, 1.20, 0.00], [ 0.00, 0.00, 1.20] ]) - gen = h.getGenerators()[-1] - nbList = NeighborList(box, rcut, gen.covalent_map) + nbList = NeighborList(box, rcut, potential.meta["cov_map"]) nbList.allocate(positions) pairs = nbList.pairs func = jax.value_and_grad(potential.dmff_potentials["NonbondedForce"], argnums=-2) diff --git a/tests/test_classical/test_gaff2.py b/tests/test_classical/test_gaff2.py index f8b8aca90..9b1a7fc65 100644 --- a/tests/test_classical/test_gaff2.py +++ b/tests/test_classical/test_gaff2.py @@ -24,8 +24,7 @@ def test_gaff2_lj_force(self, pdb, prm, value): pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) rc = 4 - gen = h.getGenerators()[-1] - nblist = NeighborList(box, rc, gen.covalent_map) + nblist = NeighborList(box, rc, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs ljE = potential.getPotentialFunc() @@ -61,8 +60,7 @@ def test_gaff2_force(self, pdb, prm, values): pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) rc = 4 - gen = h.getGenerators()[-1] - nblist = NeighborList(box, rc, gen.covalent_map) + nblist = NeighborList(box, rc, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs for ne, energy in enumerate(potential.dmff_potentials.values()): @@ -98,8 +96,7 @@ def test_gaff2_total(self, pdb, prm, values): pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[20.0, 0.0, 0.0], [0.0, 20.0, 0.0], [0.0, 0.0, 20.0]]) rc = 4 - gen = h.getGenerators()[-1] - nblist = NeighborList(box, rc, gen.covalent_map) + nblist = NeighborList(box, rc, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs efunc = potential.getPotentialFunc() diff --git a/tests/test_classical/test_lj.py b/tests/test_classical/test_lj.py index fc6b5766f..711b4dd6a 100644 --- a/tests/test_classical/test_lj.py +++ b/tests/test_classical/test_lj.py @@ -21,10 +21,9 @@ def test_lj_force(self, pdb, prm, value): nonbondedMethod=app.NoCutoff, constraints=None, removeCMMotion=False) - gen = h.getGenerators()[0] pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = jnp.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - nblist = NeighborList(box, 4.0, gen.covalent_map) + nblist = NeighborList(box, 4.0, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs ljE = potential.getPotentialFunc() @@ -47,8 +46,7 @@ def test_lj_large_force(self, pdb, prm, value): removeCMMotion=False) pos = jnp.asarray(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - gen = h.getGenerators()[0] - nblist = NeighborList(box, 4.0, gen.covalent_map) + nblist = NeighborList(box, 4.0, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs ljE = potential.getPotentialFunc() @@ -67,8 +65,7 @@ def test_lj_params_check(self): removeCMMotion=False) pos = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer) box = np.array([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) - gen = h.getGenerators()[0] - nblist = NeighborList(box, 4.0, gen.covalent_map) + nblist = NeighborList(box, 4.0, potential.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs ljE = potential.getPotentialFunc() diff --git a/tests/test_classical/test_smirks.py b/tests/test_classical/test_smirks.py index 3f200b71a..5e23905e6 100644 --- a/tests/test_classical/test_smirks.py +++ b/tests/test_classical/test_smirks.py @@ -56,7 +56,7 @@ def test_vsite(name: str): pos_vsite = jnp.array(newmol.GetConformer().GetPositions()) / 10 box = jnp.eye(3, dtype=jnp.float32) - nblist = NeighborList(box, 1.0, h_smirks.getCovalentMap()) + nblist = NeighborList(box, 1.0, pot_vsite.meta["cov_map"]) nblist.allocate(pos_vsite) pairs_vsite = nblist.pairs @@ -80,7 +80,7 @@ def test_vsite(name: str): pot_typing = h_typing.createPotential(top) pos = jnp.array(rdmol.GetConformer().GetPositions()) / 10 box = jnp.eye(3, dtype=jnp.float32) - nblist = NeighborList(box, 1.0, h_typing.getCovalentMap()) + nblist = NeighborList(box, 1.0, pot_typing.meta["cov_map"]) nblist.allocate(pos) pairs = nblist.pairs nbfunc = jax.value_and_grad(pot_typing.dmff_potentials['NonbondedForce'], argnums=-1, allow_int=True) diff --git a/tests/test_mbar/test_mbar.py b/tests/test_mbar/test_mbar.py index 9482ec695..ba3f57dac 100644 --- a/tests/test_mbar/test_mbar.py +++ b/tests/test_mbar/test_mbar.py @@ -28,10 +28,6 @@ def test_mbar_free_energy_diff(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedMethod=app.PME, nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - nbgen = None - for gen in h.getGenerators(): - if isinstance(gen, dmff.generators.NonbondedJaxGenerator): - nbgen = gen def target_energy_function(traj, parameters): pos_list, box_list, pairs_list, vol_list = [], [], [], [] @@ -42,7 +38,7 @@ def target_energy_function(traj, parameters): [cc[0], cc[1], cc[2]]]) vol = aa[0] * bb[1] * cc[2] positions = jnp.array(frame.xyz[0, :, :]) - nbobj = NeighborListFreud(box, 0.9, nbgen.covalent_map) + nbobj = NeighborListFreud(box, 0.9, pot.meta["cov_map"]) nbobj.capacity_multiplier = 1 pairs = nbobj.allocate(positions) box_list.append(box) @@ -234,10 +230,6 @@ def test_mbar_weight(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedMethod=app.PME, nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - nbgen = None - for gen in h.getGenerators(): - if isinstance(gen, dmff.generators.NonbondedJaxGenerator): - nbgen = gen def target_energy_function(traj, parameters): pos_list, box_list, pairs_list, vol_list = [], [], [], [] @@ -248,7 +240,7 @@ def target_energy_function(traj, parameters): [cc[0], cc[1], cc[2]]]) vol = aa[0] * bb[1] * cc[2] positions = jnp.array(frame.xyz[0, :, :]) - nbobj = NeighborListFreud(box, 0.9, nbgen.covalent_map) + nbobj = NeighborListFreud(box, 0.9, pot.meta["cov_map"]) nbobj.capacity_multiplier = 1 pairs = nbobj.allocate(positions) box_list.append(box) From 15eb0e56ae30e1f5577a6fe8dba17d1e86aed2bd Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Thu, 1 Dec 2022 00:40:36 +0800 Subject: [PATCH 6/9] Build a function generator to calculate energies of a trajectory --- dmff/mbar.py | 62 ++++++++++++++++++++++++++++- docs/user_guide/tutorial.md | 12 ++++-- examples/classical/demo.ipynb | 2 +- examples/mbar/demo.ipynb | 35 ++++------------- examples/smirks/demo.ipynb | 2 +- tests/test_mbar/test_mbar.py | 74 +++++------------------------------ 6 files changed, 88 insertions(+), 99 deletions(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index a92486ab6..2de1b6ca3 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -2,14 +2,72 @@ import mdtraj as md from pymbar import MBAR import dmff + dmff.update_jax_precision(dmff.PRECISION) import jax import jax.numpy as jnp from jax import grad from tqdm import tqdm, trange import openmm as mm -import openmm.app as app +import openmm.app as app import openmm.unit as unit +from dmff import NeighborList, NeighborListFreud + + +def buildEnergyFunction(potential_func, + cov_map, + cutoff, + usePBC=True, + useFreud=True, + ensemble="nvt", + pressure=1.0): + def energy_function(traj, parameters): + pos_list, box_list, pairs_list, vol_list = [], [], [], [] + pair_full = [] + for na in range(traj.topology.n_atoms): + for nb in range(na + 1, traj.topology.n_atoms): + pair_full.append([na, nb, 0]) + pair_full = np.array(pair_full, dtype=int) + pair_full[:, 2] = cov_map[pair_full[:, 0], pair_full[:, 1]] + for frame in tqdm(traj): + aa, bb, cc = frame.openmm_boxes(0).value_in_unit(unit.nanometer) + box = jnp.array([[aa[0], aa[1], aa[2]], [bb[0], bb[1], bb[2]], + [cc[0], cc[1], cc[2]]]) + vol = aa[0] * bb[1] * cc[2] + positions = jnp.array(frame.xyz[0, :, :]) + if usePBC: + if useFreud: + nbobj = NeighborListFreud(box, cutoff, cov_map) + else: + nbobj = NeighborList(box, cutoff, cov_map) + nbobj.capacity_multiplier = 1 + pairs = nbobj.allocate(positions) + pairs_list.append(pairs) + else: + pairs_list.append(pair_full) + box_list.append(box) + vol_list.append(vol) + pos_list.append(positions) + + pmax = max([p.shape[0] for p in pairs_list]) + pairs_jax = np.zeros( + (traj.n_frames, pmax, 3), dtype=int) + traj.n_atoms + for nframe in range(traj.n_frames): + pair = pairs_list[nframe] + pairs_jax[nframe, :pair.shape[0], :] = pair[:, :] + pairs_jax = jax.numpy.array(pairs_jax) + if ensemble.upper() == "NVT": + ensemble_cns = 0.0 + elif ensemble.upper() == "NPT": + ensemble_cns = 1.0 + eners = [ + potential_func(pos_list[i], box_list[i], pairs_jax[i], parameters) + + ensemble_cns * pressure * 0.06023 * vol_list[i] + for i in trange(traj.n_frames) + ] + return eners + + return energy_function class TargetState: @@ -247,4 +305,4 @@ def estimate_free_energy_difference(self, f_target = self._estimate_free_energy(u_target) if return_energy: return f_target - f_ref, u_target, u_ref - return f_target - f_ref \ No newline at end of file + return f_target - f_ref diff --git a/docs/user_guide/tutorial.md b/docs/user_guide/tutorial.md index d402ef0f0..21c0f6315 100644 --- a/docs/user_guide/tutorial.md +++ b/docs/user_guide/tutorial.md @@ -133,7 +133,7 @@ rc = 4.0 # cutoff pot = H.createPotential(pdb.topology, nonbondedCutoff=rc) ``` -The `Hamiltonian` class will parse tags in XML file and invoke corresponding potential functions. We can access those potentials in the `Potential` object (`pot`) by the name of the corresponding force: +The `Hamiltonian` class will parse tags in XML file and invoke corresponding potential functions. We can access those potentials in the `Potential` object (`pot`) by the name of the corresponding force ``` bondE = pot.dmff_potentials['HarmonicBondForce'] @@ -141,6 +141,12 @@ angleE = pot.dmff_potentials['HarmonicAngleForce'] nonBondE = pot.dmff_potentials['NonbondedForce'] ``` +and access the covalent map from `Potential.meta["cov_map"]` + +``` +cov_map = pot.meta["cov_map"] +``` + > Note: only when the `createPotential` method is called can potentials be obtained Next, we need to construct neighbor list. Here we use the code from `jax_md`: @@ -157,7 +163,7 @@ Also, we provide a wrapper to simplify neighborList construction: ``` from dmff import NeighborList -nblist = NeighborList(box, rc) +nblist = NeighborList(box, rc, cov_map) nblist.allocate(positions) pairs = nblist.pairs # equivalent to nbr.idx.T distance = nblist.distance # distance between pairs @@ -165,7 +171,7 @@ dr = nblist.dr # distance vector ``` -`pairs` is a `(N, 2)` shape array, which indicates the index of atom i and atom j. ATTENTION: pairs array contains many **invalid** index. For example, in this case, we only have 6 atoms and pairs' shape maybe `(18, 2)`. And even there are three `[6, 6]` pairs which are obviously out of range. Because `jax-md` takes advantage of the feature of Jax.numpy, which will not throw an error when the index out of range, and return the [last element](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). +`pairs` is a `(N, 3)` shape array, which indicates the index of atom i, atom j, and the covalent distance between i and j. ATTENTION: pairs array contains many **invalid** index. For example, in this case, we only have 6 atoms and pairs' shape maybe `(18, 3)`. And even there are three `[6, 6]` pairs which are obviously out of range. Because `jax-md` takes advantage of the feature of Jax.numpy, which will not throw an error when the index out of range, and return the [last element](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing). THe force field parameters are stored as a dictionary and can be accessed from Hamiltonian. Using it, we can calculate energy and force using the aforementioned potential: diff --git a/examples/classical/demo.ipynb b/examples/classical/demo.ipynb index 6671e310f..39777ad0a 100644 --- a/examples/classical/demo.ipynb +++ b/examples/classical/demo.ipynb @@ -131,7 +131,7 @@ " [0.0, 10.0, 0.0],\n", " [0.0, 0.0, 10.0]\n", "])\n", - "nbList = NeighborList(box, rc=4)\n", + "nbList = NeighborList(box, rc=4, potentials.meta[\"cov_map\"])\n", "nbList.allocate(positions)\n", "pairs = nbList.pairs\n", "nbfunc = potentials.dmff_potentials['NonbondedForce']\n", diff --git a/examples/mbar/demo.ipynb b/examples/mbar/demo.ipynb index 01da8e00d..584b9607d 100644 --- a/examples/mbar/demo.ipynb +++ b/examples/mbar/demo.ipynb @@ -44,7 +44,7 @@ "import mdtraj as md\n", "from tqdm import tqdm, trange\n", "import matplotlib.pyplot as plt\n", - "from dmff.mbar import MBAREstimator, SampleState, TargetState, Sample, OpenMMSampleState\n", + "from dmff.mbar import MBAREstimator, SampleState, TargetState, Sample, OpenMMSampleState, buildEnergyFunction\n", "from dmff.optimize import MultiTransform, genOptimizer\n", "from dmff import Hamiltonian, NeighborListFreud\n", "import optax\n", @@ -221,7 +221,9 @@ "id": "162885fc", "metadata": {}, "source": [ - "## define a function to calculate DMFF energy using mdtraj.Trajectory as input" + "## define a function to calculate DMFF energy using mdtraj.Trajectory as input\n", + "\n", + "Here we use \"buildEnergyFunction\" function generator to build a function which can calculate energies of a trajectory with the MDTraj trajectory itself and a parameter set." ] }, { @@ -235,32 +237,11 @@ "top_pdb = app.PDBFile(\"box_relaxed.pdb\")\n", "pot = hamilt.createPotential(top_pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=1.1*unit.nanometer, ethresh=1e-4)\n", "efunc = pot.getPotentialFunc()\n", - "nbgen = [g for g in hamilt.getGenerators() if g.name == \"NonbondedForce\"][0]\n", - "\n", - "def target_energy_function(traj, parameters):\n", - " pos_list, box_list, pairs_list, vol_list = [], [], [], []\n", - " for frame in tqdm(traj):\n", - " aa, bb, cc = frame.openmm_boxes(0).value_in_unit(unit.nanometer)\n", - " box = jnp.array([[aa[0], aa[1], aa[2]], [bb[0], bb[1], bb[2]],\n", - " [cc[0], cc[1], cc[2]]])\n", - " vol = aa[0] * bb[1] * cc[2]\n", - " positions = jnp.array(frame.xyz[0, :, :])\n", - " nbobj = NeighborListFreud(box, 0.9, nbgen.covalent_map)\n", - " nbobj.capacity_multiplier = 1\n", - " pairs = nbobj.allocate(positions)\n", - " box_list.append(box)\n", - " pairs_list.append(pairs)\n", - " vol_list.append(vol)\n", - " pos_list.append(positions)\n", "\n", - " pmax = max([p.shape[0] for p in pairs_list])\n", - " pairs_jax = np.zeros((traj.n_frames, pmax, 3), dtype=int) + traj.n_atoms\n", - " for nframe in range(traj.n_frames):\n", - " pair = pairs_list[nframe]\n", - " pairs_jax[nframe,:pair.shape[0],:] = pair[:,:]\n", - " pairs_jax = jax.numpy.array(pairs_jax)\n", - " eners = [efunc(pos_list[i], box_list[i], pairs_jax[i], parameters) + 0.06023 * vol_list[i] for i in trange(traj.n_frames)]\n", - " return eners" + "target_energy_function = buildEnergyFunction(efunc,\n", + " pot.meta[\"cov_map\"],\n", + " 1.1,\n", + " ensemble=\"npt\")" ] }, { diff --git a/examples/smirks/demo.ipynb b/examples/smirks/demo.ipynb index e7e98caeb..9c518f03e 100644 --- a/examples/smirks/demo.ipynb +++ b/examples/smirks/demo.ipynb @@ -99,7 +99,7 @@ "source": [ "pos = jnp.array(mol.GetConformer().GetPositions()) / 10 # angstrom -> nm\n", "box = jnp.eye(3, dtype=jnp.float32)\n", - "nblist = NeighborList(box, 1.0, h_smk.getCovalentMap())\n", + "nblist = NeighborList(box, 1.0, potObj.meta[\"cov_map\"])\n", "nblist.allocate(pos)\n", "pairs = nblist.pairs\n", "energy = func(pos, box, pairs, h_smk.getParameters())\n", diff --git a/tests/test_mbar/test_mbar.py b/tests/test_mbar/test_mbar.py index ba3f57dac..bd5505d02 100644 --- a/tests/test_mbar/test_mbar.py +++ b/tests/test_mbar/test_mbar.py @@ -1,4 +1,4 @@ -from dmff.mbar import MBAREstimator, Sample, SampleState, TargetState, OpenMMSampleState +from dmff.mbar import MBAREstimator, Sample, SampleState, TargetState, OpenMMSampleState, buildEnergyFunction import dmff import pytest import jax @@ -29,38 +29,10 @@ def test_mbar_free_energy_diff(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - def target_energy_function(traj, parameters): - pos_list, box_list, pairs_list, vol_list = [], [], [], [] - for frame in tqdm(traj): - aa, bb, cc = frame.openmm_boxes(0).value_in_unit( - unit.nanometer) - box = jnp.array([[aa[0], aa[1], aa[2]], [bb[0], bb[1], bb[2]], - [cc[0], cc[1], cc[2]]]) - vol = aa[0] * bb[1] * cc[2] - positions = jnp.array(frame.xyz[0, :, :]) - nbobj = NeighborListFreud(box, 0.9, pot.meta["cov_map"]) - nbobj.capacity_multiplier = 1 - pairs = nbobj.allocate(positions) - box_list.append(box) - pairs_list.append(pairs) - vol_list.append(vol) - pos_list.append(positions) - - pmax = max([p.shape[0] for p in pairs_list]) - pairs_jax = np.zeros( - (traj.n_frames, pmax, 3), dtype=int) + traj.n_atoms - for nframe in range(traj.n_frames): - pair = pairs_list[nframe] - pairs_jax[nframe, :pair.shape[0], :] = pair[:, :] - pairs_jax = jax.numpy.array(pairs_jax) - pos_list = jnp.array(pos_list) - box_list = jnp.array(box_list) - vol_list = jnp.array(vol_list) - eners = [ - efunc(pos_list[i], box_list[i], pairs_jax[i], parameters) + - 0.06023 * vol_list[i] for i in range(traj.n_frames) - ] - return eners + target_energy_function = buildEnergyFunction(efunc, + pot.meta["cov_map"], + 0.9, + ensemble="npt") target_state = TargetState(300.0, target_energy_function) @@ -231,38 +203,10 @@ def test_mbar_weight(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - def target_energy_function(traj, parameters): - pos_list, box_list, pairs_list, vol_list = [], [], [], [] - for frame in tqdm(traj): - aa, bb, cc = frame.openmm_boxes(0).value_in_unit( - unit.nanometer) - box = jnp.array([[aa[0], aa[1], aa[2]], [bb[0], bb[1], bb[2]], - [cc[0], cc[1], cc[2]]]) - vol = aa[0] * bb[1] * cc[2] - positions = jnp.array(frame.xyz[0, :, :]) - nbobj = NeighborListFreud(box, 0.9, pot.meta["cov_map"]) - nbobj.capacity_multiplier = 1 - pairs = nbobj.allocate(positions) - box_list.append(box) - pairs_list.append(pairs) - vol_list.append(vol) - pos_list.append(positions) - - pmax = max([p.shape[0] for p in pairs_list]) - pairs_jax = np.zeros( - (traj.n_frames, pmax, 3), dtype=int) + traj.n_atoms - for nframe in range(traj.n_frames): - pair = pairs_list[nframe] - pairs_jax[nframe, :pair.shape[0], :] = pair[:, :] - pairs_jax = jax.numpy.array(pairs_jax) - pos_list = jnp.array(pos_list) - box_list = jnp.array(box_list) - vol_list = jnp.array(vol_list) - eners = [ - efunc(pos_list[i], box_list[i], pairs_jax[i], parameters) + - 0.06023 * vol_list[i] for i in range(traj.n_frames) - ] - return eners + target_energy_function = buildEnergyFunction(efunc, + pot.meta["cov_map"], + 0.9, + ensemble="npt") target_state = TargetState(300.0, target_energy_function) From 593a0992b77811b13985ab110af459d2da13ace8 Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Thu, 1 Dec 2022 00:44:08 +0800 Subject: [PATCH 7/9] Change API function name --- dmff/mbar.py | 2 +- examples/mbar/demo.ipynb | 10 +++++----- tests/test_mbar/test_mbar.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index 2de1b6ca3..aafc4b986 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -14,7 +14,7 @@ from dmff import NeighborList, NeighborListFreud -def buildEnergyFunction(potential_func, +def buildTrajEnergyFunction(potential_func, cov_map, cutoff, usePBC=True, diff --git a/examples/mbar/demo.ipynb b/examples/mbar/demo.ipynb index 584b9607d..cbd7ea21d 100644 --- a/examples/mbar/demo.ipynb +++ b/examples/mbar/demo.ipynb @@ -44,7 +44,7 @@ "import mdtraj as md\n", "from tqdm import tqdm, trange\n", "import matplotlib.pyplot as plt\n", - "from dmff.mbar import MBAREstimator, SampleState, TargetState, Sample, OpenMMSampleState, buildEnergyFunction\n", + "from dmff.mbar import MBAREstimator, SampleState, TargetState, Sample, OpenMMSampleState, buildTrajEnergyFunction\n", "from dmff.optimize import MultiTransform, genOptimizer\n", "from dmff import Hamiltonian, NeighborListFreud\n", "import optax\n", @@ -238,10 +238,10 @@ "pot = hamilt.createPotential(top_pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=1.1*unit.nanometer, ethresh=1e-4)\n", "efunc = pot.getPotentialFunc()\n", "\n", - "target_energy_function = buildEnergyFunction(efunc,\n", - " pot.meta[\"cov_map\"],\n", - " 1.1,\n", - " ensemble=\"npt\")" + "target_energy_function = buildTrajEnergyFunction(efunc,\n", + " pot.meta[\"cov_map\"],\n", + " 1.1,\n", + " ensemble=\"npt\")" ] }, { diff --git a/tests/test_mbar/test_mbar.py b/tests/test_mbar/test_mbar.py index bd5505d02..3c9468d21 100644 --- a/tests/test_mbar/test_mbar.py +++ b/tests/test_mbar/test_mbar.py @@ -1,4 +1,4 @@ -from dmff.mbar import MBAREstimator, Sample, SampleState, TargetState, OpenMMSampleState, buildEnergyFunction +from dmff.mbar import MBAREstimator, Sample, SampleState, TargetState, OpenMMSampleState, buildTrajEnergyFunction import dmff import pytest import jax @@ -29,10 +29,10 @@ def test_mbar_free_energy_diff(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - target_energy_function = buildEnergyFunction(efunc, - pot.meta["cov_map"], - 0.9, - ensemble="npt") + target_energy_function = buildTrajEnergyFunction(efunc, + pot.meta["cov_map"], + 0.9, + ensemble="npt") target_state = TargetState(300.0, target_energy_function) @@ -203,10 +203,10 @@ def test_mbar_weight(self, pdb, prm1, traj1, prm2, traj2, prm3): nonbondedCutoff=0.9 * unit.nanometer) efunc = pot.getPotentialFunc() - target_energy_function = buildEnergyFunction(efunc, - pot.meta["cov_map"], - 0.9, - ensemble="npt") + target_energy_function = buildTrajEnergyFunction(efunc, + pot.meta["cov_map"], + 0.9, + ensemble="npt") target_state = TargetState(300.0, target_energy_function) From d5c7f7d3efac4795d8daa671e3ca17c9e3556304 Mon Sep 17 00:00:00 2001 From: KuangYu Date: Thu, 1 Dec 2022 10:26:27 +0800 Subject: [PATCH 8/9] fix meta data for admp generators --- dmff/generators/admp.py | 16 +++++++++++++++- dmff/generators/classical.py | 5 ++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index dd0ebe9ec..ce7c4edec 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -93,6 +93,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, self.map_atomtype = map_atomtype # build covalent map self.covalent_map = build_covalent_map(data, 6) + self._meta["cov_map"] = self.covalent_map + self._meta["ADMPDispForce_map_atomtype"] = map_atomtype # here box is only used to setup ewald parameters, no need to be differentiable a, b, c = system.getDefaultPeriodicBoxVectors() box = jnp.array([a._value, b._value, c._value]) * 10 @@ -257,6 +259,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map self.covalent_map = build_covalent_map(data, 6) + self._meta["cov_map"] = self.covalent_map + self._meta["ADMPDispPmeForce_map_atomtype"] = self.map_atomtype + # here box is only used to setup ewald parameters, no need to be differentiable a, b, c = system.getDefaultPeriodicBoxVectors() box = jnp.array([a._value, b._value, c._value]) * 10 @@ -366,6 +371,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map self.covalent_map = build_covalent_map(data, 6) + self._meta["cov_map"] = self.covalent_map + self._meta["QqTtDampingForce_map_atomtype"] = self.map_atomtype + pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel, static_args={}) @@ -466,6 +474,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map self.covalent_map = build_covalent_map(data, 6) + self._meta["cov_map"] = self.covalent_map + self._meta[self.name+"_map_atomtype"] = self.map_atomtype + # WORKING pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel, static_args={}) @@ -560,6 +571,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # build covalent map self.covalent_map = build_covalent_map(data, 6) + self._meta["cov_map"] = self.covalent_map + self._meta[self.name+"_map_atomtype"] = self.map_atomtype + pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={}) @@ -1124,4 +1138,4 @@ def getMetaData(self): return self._meta -dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator \ No newline at end of file +dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index fbe8f1d96..2e5cc6306 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -20,7 +20,6 @@ CoulNoCutoffForce, CoulombPMEForce, CoulReactionFieldForce, - LennardJonesForce, ) from dmff.classical.fep import ( LennardJonesFreeEnergyForce, @@ -1050,7 +1049,7 @@ def __init__(self, ff): self.fftree = ff.fftree self.paramtree = ff.paramtree self.paramtree[self.name] = {} - self._meta + self._meta = {} def extract(self): @@ -1237,4 +1236,4 @@ def getMetaData(self): return self._meta -dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator \ No newline at end of file +dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator From 66906283796fc02ce6edb3fa5a82a740c9a16399 Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Thu, 1 Dec 2022 13:08:02 +0800 Subject: [PATCH 9/9] Make buildEnergyFunction cleaner --- dmff/mbar.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dmff/mbar.py b/dmff/mbar.py index aafc4b986..b0182d236 100644 --- a/dmff/mbar.py +++ b/dmff/mbar.py @@ -33,7 +33,6 @@ def energy_function(traj, parameters): aa, bb, cc = frame.openmm_boxes(0).value_in_unit(unit.nanometer) box = jnp.array([[aa[0], aa[1], aa[2]], [bb[0], bb[1], bb[2]], [cc[0], cc[1], cc[2]]]) - vol = aa[0] * bb[1] * cc[2] positions = jnp.array(frame.xyz[0, :, :]) if usePBC: if useFreud: @@ -45,9 +44,9 @@ def energy_function(traj, parameters): pairs_list.append(pairs) else: pairs_list.append(pair_full) - box_list.append(box) - vol_list.append(vol) - pos_list.append(positions) + pos_list = jnp.array(traj.xyz) + vol_list = jnp.array(traj.unitcell_volumes) + box_list = jnp.array(traj.unitcell_vectors) pmax = max([p.shape[0] for p in pairs_list]) pairs_jax = np.zeros(