Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ class Potential:
def __init__(self):
self.dmff_potentials = {}
self.omm_system = None
self.meta = {}

def addDmffPotential(self, name, potential):
def addDmffPotential(self, name, potential, meta={}):
self.dmff_potentials[name] = potential
if len(meta):
for key in meta.keys():
self.meta[key] = meta[key]

def addOmmSystem(self, system):
self.omm_system = system
Expand Down Expand Up @@ -183,7 +187,8 @@ def createPotential(self,
continue
try:
potentialImpl = generator.getJaxPotential()
potObj.addDmffPotential(generator.name, potentialImpl)
meta = generator.getMetaData()
potObj.addDmffPotential(generator.name, potentialImpl, meta=meta)
except Exception as e:
print(e)
pass
Expand Down
43 changes: 41 additions & 2 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, ff):
self.types = []
self.ethresh = 5e-4
self.pmax = 10
self._meta = {}

def extract(self):

Expand Down Expand Up @@ -92,6 +93,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
self.map_atomtype = map_atomtype
# build covalent map
self.covalent_map = build_covalent_map(data, 6)
self._meta["cov_map"] = self.covalent_map
self._meta["ADMPDispForce_map_atomtype"] = map_atomtype
# here box is only used to setup ewald parameters, no need to be differentiable
a, b, c = system.getDefaultPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value]) * 10
Expand Down Expand Up @@ -160,6 +163,9 @@ def overwrite(self):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['ADMPDispForce'] = ADMPDispGenerator
Expand All @@ -181,6 +187,7 @@ def __init__(self, ff):
self.ethresh = 5e-4
self.pmax = 10
self.name = "ADMPDispPmeForce"
self._meta = {}

def extract(self):

Expand Down Expand Up @@ -252,6 +259,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
self.covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
self._meta["ADMPDispPmeForce_map_atomtype"] = self.map_atomtype

# here box is only used to setup ewald parameters, no need to be differentiable
a, b, c = system.getDefaultPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value]) * 10
Expand Down Expand Up @@ -285,6 +295,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator
Expand All @@ -302,6 +315,7 @@ def __init__(self, ff):
self.paramtree = ff.paramtree
self._jaxPotnetial = None
self.name = "QqTtDampingForce"
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -357,6 +371,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
self.covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
self._meta["QqTtDampingForce_map_atomtype"] = self.map_atomtype

pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel,
static_args={})

Expand All @@ -373,6 +390,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


# register all parsers
Expand All @@ -392,6 +412,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._jaxPotential = None
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -453,6 +474,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
self.covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
self._meta[self.name+"_map_atomtype"] = self.map_atomtype

# WORKING
pot_fn_sr = generate_pairwise_interaction(slater_disp_damping_kernel,
static_args={})
Expand All @@ -474,6 +498,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['SlaterDampingForce'] = SlaterDampingGenerator
Expand All @@ -490,6 +517,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._jaxPotential = None
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -543,6 +571,9 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
# build covalent map
self.covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
self._meta[self.name+"_map_atomtype"] = self.map_atomtype

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
static_args={})

Expand All @@ -559,6 +590,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["SlaterExForce"] = SlaterExGenerator
Expand Down Expand Up @@ -613,6 +647,8 @@ def __init__(self, ff):
self.lpol = False
self.ref_dip = ""

self._meta = {}

def extract(self):

self.lmax = self.fftree.get_attribs(f'{self.name}',
Expand Down Expand Up @@ -850,7 +886,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,

# build covalent map
self.covalent_map = covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
# build intra-molecule axis
# the following code is the direct transplant of forcefield.py in openmm 7.4.0

Expand Down Expand Up @@ -1098,5 +1134,8 @@ def potential_fn(positions, box, pairs, params):
def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator
dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator
57 changes: 51 additions & 6 deletions dmff/generators/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
CoulNoCutoffForce,
CoulombPMEForce,
CoulReactionFieldForce,
LennardJonesForce,
)
from dmff.classical.fep import (
LennardJonesFreeEnergyForce,
Expand All @@ -40,6 +39,7 @@ def __init__(self, ff: Hamiltonian):
self.ff: Hamiltonian = ff
self.fftree: ForcefieldTree = ff.fftree
self.paramtree: Dict = ff.paramtree
self._meta = {}

def extract(self):
"""
Expand Down Expand Up @@ -74,6 +74,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
Args:
Those args are the same as those in createSystem.
"""
self._meta = {}

# initialize typemap
matcher = TypeMatcher(self.fftree, "HarmonicBondForce/Bond")
Expand Down Expand Up @@ -114,7 +115,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):

map_atom1 = np.array(map_atom1, dtype=int)
map_atom2 = np.array(map_atom2, dtype=int)
map_param = np.array(map_param, dtype=int)
map_param = np.array(map_param, dtype=int)
self._meta["HarmonicBondForce_atom1"] = map_atom1
self._meta["HarmonicBondForce_atom2"] = map_atom2
self._meta["HarmonicBondForce_param"] = map_param

bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param)
self._force_latest = bforce
Expand All @@ -129,6 +133,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["HarmonicBondForce"] = HarmonicBondJaxGenerator
Expand All @@ -140,6 +147,7 @@ def __init__(self, ff):
self.ff = ff
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._meta = {}

def extract(self):
angles = self.fftree.get_attribs(f"{self.name}/Angle", "angle")
Expand All @@ -155,6 +163,8 @@ def overwrite(self):
self.paramtree[self.name]["k"])

def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
self._meta = {}

matcher = TypeMatcher(self.fftree, "HarmonicAngleForce/Angle")

map_atom1, map_atom2, map_atom3, map_param = [], [], [], []
Expand Down Expand Up @@ -202,6 +212,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
map_atom2 = np.array(map_atom2, dtype=int)
map_atom3 = np.array(map_atom3, dtype=int)
map_param = np.array(map_param, dtype=int)
self._meta["HarmonicAngleForce_atom1"] = map_atom1
self._meta["HarmonicAngleForce_atom2"] = map_atom2
self._meta["HarmonicAngleForce_atom3"] = map_atom3
self._meta["HarmonicAngleForce_param"] = map_param

aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3,
map_param)
Expand All @@ -217,6 +231,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["HarmonicAngleForce"] = HarmonicAngleJaxGenerator
Expand All @@ -229,7 +246,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self.meta = {}

self._meta = {}
self.meta["prop_order"] = defaultdict(list)
self.meta["prop_nodeidx"] = defaultdict(list)

Expand Down Expand Up @@ -340,6 +357,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
"""
Create force for torsions
"""

# Proper Torsions
proper_matcher = TypeMatcher(self.fftree,
"PeriodicTorsionForce/Proper")
Expand Down Expand Up @@ -487,6 +505,19 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
self._props_latest = props
self._imprs_latest = imprs

self._meta["PeriodicTorsionForce_prop_atom1"] = map_prop_atom1
self._meta["PeriodicTorsionForce_prop_atom2"] = map_prop_atom2
self._meta["PeriodicTorsionForce_prop_atom3"] = map_prop_atom3
self._meta["PeriodicTorsionForce_prop_atom4"] = map_prop_atom4
self._meta["PeriodicTorsionForce_prop_param"] = map_prop_param

self._meta["PeriodicTorsionForce_impr_atom1"] = map_impr_atom1
self._meta["PeriodicTorsionForce_impr_atom2"] = map_impr_atom2
self._meta["PeriodicTorsionForce_impr_atom3"] = map_impr_atom3
self._meta["PeriodicTorsionForce_impr_atom4"] = map_impr_atom4
self._meta["PeriodicTorsionForce_impr_param"] = map_impr_param


def potential_fn(positions, box, pairs, params):
prop_sum = sum([
props[i].get_energy(
Expand All @@ -509,6 +540,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["PeriodicTorsionForce"] = PeriodicTorsionJaxGenerator
Expand All @@ -532,6 +566,8 @@ def __init__(self, ff: Hamiltonian):
self.useBCC = False
self.useVsite = False

self._meta = {}

def extract(self):
self.from_residue = self.fftree.get_attribs(
"NonbondedForce/UseAttributeFromResidue", "name")
Expand Down Expand Up @@ -684,6 +720,8 @@ def addVsiteFunc(pos, params):
cov_map[ori_dim + i, parent_i] = 1
self.covalent_map = jnp.array(cov_map)

self._meta["cov_map"] = self.covalent_map

# Load Lennard-Jones parameters
maps = {}
if not nbmatcher.useSmirks:
Expand Down Expand Up @@ -976,6 +1014,9 @@ def potential_fn(positions, box, pairs, params, vdwLambda,

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta

def getAddVsiteFunc(self):
"""
Expand Down Expand Up @@ -1008,8 +1049,8 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self.paramtree[self.name] = {}
self.paramtree[self.name]
self.paramtree[self.name]
self._meta = {}


def extract(self):
for prm in ["sigma", "epsilon"]:
Expand Down Expand Up @@ -1109,6 +1150,7 @@ def findIdx(labels, label):
map_nbfix = jnp.array(map_nbfix)

colv_map = build_covalent_map(data, 6)
self._meta["cov_map"] = colv_map

if unit.is_quantity(nonbondedCutoff):
r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
Expand Down Expand Up @@ -1189,6 +1231,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator
dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator
Loading