diff --git a/.gitignore b/.gitignore
index 7df232a6f..abe26b321 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,10 @@
+
+# temporary
+err
+out
+sub.sh
+*.npy
+
### C++ ###
# Prerequisites
*.d
diff --git a/dmff/api.py b/dmff/api.py
index 8d612ed51..9fb6074f3 100644
--- a/dmff/api.py
+++ b/dmff/api.py
@@ -15,9 +15,8 @@
import linecache
-
def get_line_context(file_path, line_number):
- return linecache.getline(file_path,line_number).strip()
+ return linecache.getline(file_path, line_number).strip()
def build_covalent_map(data, max_neighbor):
@@ -30,8 +29,8 @@ def build_covalent_map(data, max_neighbor):
for i in range(n_atoms):
# current neighbors
j_list = np.where(
- np.logical_and(covalent_map[i] <= n_curr, covalent_map[i] > 0)
- )[0]
+ np.logical_and(covalent_map[i] <= n_curr,
+ covalent_map[i] > 0))[0]
for j in j_list:
k_list = np.where(covalent_map[j] == 1)[0]
for k in k_list:
@@ -119,10 +118,17 @@ def set_axis_type(map_atomtypes, types, params):
class ADMPDispGenerator:
def __init__(self, hamiltonian):
self.ff = hamiltonian
- self.params = {"A": [], "B": [], "Q": [], "C6": [], "C8": [], "C10": []}
+ self.params = {
+ "A": [],
+ "B": [],
+ "Q": [],
+ "C6": [],
+ "C8": [],
+ "C10": []
+ }
self._jaxPotential = None
self.types = []
- self.ethresh = 1.0e-5
+ self.ethresh = 5e-4
self.pmax = 10
def registerAtomType(self, atom):
@@ -150,7 +156,8 @@ def parseElement(element, hamiltonian):
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
n_atoms = len(data.atoms)
# build index map
@@ -168,22 +175,22 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
rc = nonbondedCutoff.value_in_unit(unit.angstrom)
# get calculator
- Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, self.pmax)
+ Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh,
+ self.pmax)
# debugging
# Force_DispPME.update_env('kappa', 0.657065221219616)
# Force_DispPME.update_env('K1', 96)
# Force_DispPME.update_env('K2', 96)
# Force_DispPME.update_env('K3', 96)
pot_fn_lr = Force_DispPME.get_energy
- pot_fn_sr = generate_pairwise_interaction(
- TT_damping_qq_c6_kernel, covalent_map, static_args={}
- )
+ pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_c6_kernel,
+ covalent_map,
+ static_args={})
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
- a_list = (
- params["A"][map_atomtype] / 2625.5
- ) # kj/mol to au, as expected by TT_damping kernel
+ a_list = (params["A"][map_atomtype] / 2625.5
+ ) # kj/mol to au, as expected by TT_damping kernel
b_list = params["B"][map_atomtype] * 0.0529177249 # nm^-1 to au
q_list = params["Q"][map_atomtype]
c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6)
@@ -191,9 +198,8 @@ def potential_fn(positions, box, pairs, params):
c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10)
c_list = jnp.vstack((c6_list, c8_list, c10_list))
- E_sr = pot_fn_sr(
- positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0]
- )
+ E_sr = pot_fn_sr(positions, box, pairs, mScales, a_list, b_list,
+ q_list, c_list[0])
E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales)
return E_sr - E_lr
@@ -244,19 +250,24 @@ def __init__(self, hamiltonian):
"thole": [],
"polarizabilityXX": [],
"polarizabilityYY": [],
- "polarizabilityZZ": []
+ "polarizabilityZZ": [],
+
+ }
+ self.params = {
+ "mScales": [],
+ "pScales": [],
+ "dScales": [],
}
# if more or optional input params
- # self._input_params = defaultDict(list)
+ # self._input_params = defaultDict(list)
self._jaxPotential = None
self.types = []
- self.ethresh = 1.0e-5
- self.params = {}
+ self.ethresh = 5e-4
self.lpol = False
self.ref_dip = ''
- def registerAtomType(self, atom:dict):
-
+ def registerAtomType(self, atom: dict):
+
self.types.append(atom.pop("type"))
kStrings = ["kz", "kx", "ky"]
@@ -273,29 +284,29 @@ def registerAtomType(self, atom:dict):
def parseElement(element, hamiltonian):
generator = ADMPPmeGenerator(hamiltonian)
generator.lmax = int(element.attrib.get('lmax'))
- generator.pmax = int(element.attrib.get('pmax'))
-
+ generator.defaultTholeWidth = 5
+
hamiltonian.registerGenerator(generator)
- mScales = []
- pScales = []
- dScales = []
for i in range(2, 7):
- mScales.append(float(element.attrib["mScale1%d" % i]))
- pScales.append(float(element.attrib["pScale1%d" % i]))
- dScales.append(float(element.attrib["dScale1%d" % i]))
- generator.params["mScales"] = jnp.array(mScales)
- generator.params["pScales"] = jnp.array(pScales)
- generator.params["dScales"] = jnp.array(dScales)
-
+ generator.params["mScales"].append(
+ float(element.attrib["mScale1%d" % i]))
+ generator.params["pScales"].append(
+ float(element.attrib["pScale1%d" % i]))
+ generator.params["dScales"].append(
+ float(element.attrib["dScale1%d" % i]))
+
if element.findall('Polarize'):
generator.lpol = True
for atomType in element.findall("Atom"):
atomAttrib = atomType.attrib
+ # if not set
+ atomAttrib.update({'polarizabilityXX': 0, 'polarizabilityYY': 0, 'polarizabilityZZ': 0})
for polarInfo in element.findall("Polarize"):
polarAttrib = polarInfo.attrib
if polarInfo.attrib['type'] == atomAttrib['type']:
+ # cover default
atomAttrib.update(polarAttrib)
break
generator.registerAtomType(atomAttrib)
@@ -303,87 +314,50 @@ def parseElement(element, hamiltonian):
for k in generator._input_params.keys():
generator._input_params[k] = jnp.array(generator._input_params[k])
generator.types = np.array(generator.types)
-
- def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
-
- n_atoms = len(data.atoms)
- # build index map
- map_atomtype = np.zeros(n_atoms, dtype=int)
- for i in range(n_atoms):
- atype = data.atomType[data.atoms[i]]
- map_atomtype[i] = np.where(self.types == atype)[0][0]
+
+ n_atoms = len(element.findall('Atom'))
# map atom multipole moments
- p = self._input_params
Q = np.zeros((n_atoms, 10))
- Q[:, 0] = p["c0"][map_atomtype]
- Q[:, 1] = p["dX"][map_atomtype] * 10
- Q[:, 2] = p["dY"][map_atomtype] * 10
- Q[:, 3] = p["dZ"][map_atomtype] * 10
- Q[:, 4] = p["qXX"][map_atomtype] * 300
- Q[:, 5] = p["qYY"][map_atomtype] * 300
- Q[:, 6] = p["qZZ"][map_atomtype] * 300
- Q[:, 7] = p["qXY"][map_atomtype] * 300
- Q[:, 8] = p["qXZ"][map_atomtype] * 300
- Q[:, 9] = p["qYZ"][map_atomtype] * 300
+ Q[:, 0] = generator._input_params["c0"]
+ Q[:, 1] = generator._input_params["dX"] * 10
+ Q[:, 2] = generator._input_params["dY"] * 10
+ Q[:, 3] = generator._input_params["dZ"] * 10
+ Q[:, 4] = generator._input_params["qXX"] * 300
+ Q[:, 5] = generator._input_params["qYY"] * 300
+ Q[:, 6] = generator._input_params["qZZ"] * 300
+ Q[:, 7] = generator._input_params["qXY"] * 300
+ Q[:, 8] = generator._input_params["qXZ"] * 300
+ Q[:, 9] = generator._input_params["qYZ"] * 300
+
+ # add all differentiable params to self.params
+ Q_local = convert_cart2harm(Q, 2)
+ generator.params['Q_local'] = Q_local
- # map polarization-related params
- pol = jnp.vstack((p['polarizabilityXX'][map_atomtype], p['polarizabilityYY'][map_atomtype], p['polarizabilityZZ'][map_atomtype])).T.astype(jnp.float32)
+ pol = jnp.vstack((generator._input_params['polarizabilityXX'], generator._input_params['polarizabilityYY'], generator._input_params['polarizabilityZZ'])).T
pol = 1000*jnp.mean(pol,axis=1)
- self.params['pol'] = pol
+
+ tholes = jnp.array(generator._input_params['thole'])
+ # tholes = jnp.mean(jnp.atleast_2d(tholes), axis=1)
- tholes = jnp.array(p['thole'][map_atomtype]).astype(jnp.float32)
- tholes = jnp.mean(jnp.atleast_2d(tholes), axis=1)
- self.params['tholes'] = tholes
+ generator.params['pol'] = pol
+ generator.params['tholes'] = tholes
+ # generator.params['']
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
- # defaultTholeWidth = 8
- Uind_global = jnp.zeros([n_atoms,3])
- ref_dip = self.ref_dip
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_atoms = len(data.atoms)
+ # build index map
+ map_atomtype = np.zeros(n_atoms, dtype=int)
+
for i in range(n_atoms):
- a = get_line_context(ref_dip, i+1)
- b = a.split()
- t = np.array([10*float(b[0]),10*float(b[1]),10*float(b[2])])
- Uind_global = Uind_global.at[i].set(t)
-
- # construct the C list
- c_list = np.zeros((3, n_atoms))
- a_list = np.zeros(n_atoms)
- q_list = np.zeros(n_atoms)
- b_list = np.zeros(n_atoms)
-
-
- nmol=int(n_atoms/3) # WARNING: HARD CODE!
- for i in range(nmol):
- a = i*3
- b = i*3+1
- c = i*3+2
- # dispersion coeff
- c_list[0][a]=37.19677405
- c_list[0][b]=7.6111103
- c_list[0][c]=7.6111103
- c_list[1][a]=85.26810658
- c_list[1][b]=11.90220148
- c_list[1][c]=11.90220148
- c_list[2][a]=134.44874488
- c_list[2][b]=15.05074749
- c_list[2][c]=15.05074749
- # q
- q_list[a] = -0.741706
- q_list[b] = 0.370853
- q_list[c] = 0.370853
- # b, Bohr^-1
- b_list[a] = 2.00095977
- b_list[b] = 1.999519942
- b_list[c] = 1.999519942
- # a, Hartree
- a_list[a] = 458.3777
- a_list[b] = 0.0317
- a_list[c] = 0.0317
-
- # add all differentiable params to self.params
- Q = jnp.array(Q)
- Q_local = convert_cart2harm(Q, 2)
- self.params["Q_local"] = Q_local
+ atype = data.atomType[data.atoms[i]]
+ map_atomtype[i] = np.where(self.types == atype)[0][0]
+
# here box is only used to setup ewald parameters, no need to be differentiable
a, b, c = system.getDefaultPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value]) * 10
@@ -396,15 +370,15 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
# build intra-molecule axis
self.axis_types, self.axis_indices = set_axis_type(
- map_atomtype, self.types, self.kStrings
- )
+ map_atomtype, self.types, self.kStrings)
map_axis_indices = []
# map axis_indices
for i in range(n_atoms):
catom = data.atoms[i]
residue = catom.residue._atoms
atom_indices = [
- index if index != "" else -1 for index in self.axis_indices[i][1:]
+ index if index != "" else -1
+ for index in self.axis_indices[i][1:]
]
for atom in residue:
if atom == catom:
@@ -416,40 +390,21 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
map_axis_indices.append(atom_indices)
self.axis_indices = np.array(map_axis_indices)
-
- # Finish data preparation
- # -------------------------------------------------------------------------------------
- # parameters should be ready:
- # geometric variables: positions, box
- # atomic parameters: Q_local, c_list
- # topological parameters: covalent_map, mScales, pScales, dScales
- # general force field setting parameters: rc, ethresh, lmax, pmax
-
- pme_force = ADMPPmeForce(
- box,
- self.axis_types,
- self.axis_indices,
- covalent_map,
- rc,
- self.ethresh,
- self.lmax,
- self.lpol
- )
- self.params['U_ind'] = pme_force.U_ind
+ pme_force = ADMPPmeForce(box, self.axis_types, self.axis_indices,
+ covalent_map, rc, self.ethresh, self.lmax,
+ self.lpol)
def potential_fn(positions, box, pairs, params):
mScales = params["mScales"]
- Q_local = params["Q_local"]
-
-
- # positions, box, pairs, Q_local, mScales
+ Q_local = params["Q_local"][map_atomtype]
if self.lpol:
pScales = params["pScales"]
dScales = params["dScales"]
- U_ind = params["U_ind"]
- return pme_force.get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=U_ind)
+ pol = params["pol"][map_atomtype]
+ tholes = params["tholes"][map_atomtype]
+ return pme_force.get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, pme_force.U_ind)
else:
return pme_force.get_energy(positions, box, pairs, Q_local, mScales)
@@ -465,6 +420,313 @@ def renderXML(self):
app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement
+class HarmonicBondGenerator:
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {'k': [], 'length': []}
+ self._jaxPotential = None
+ self.types = []
+
+ def registerBondType(self, bond):
+ types = self.ff._findAtomTypes(bond, 2)
+ self.types.append(types)
+ self.params['k'].append(float(bond['k']))
+ self.params['length'].append(float(bond['length']))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = HarmonicBondGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ for bondtype in element.findall("Bond"):
+ generator.registerBondType(bondtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_bonds = len(data.bonds)
+ # build map
+ map_atom1 = np.zeros(n_bonds, dtype=int)
+ map_atom2 = np.zeros(n_bonds, dtype=int)
+ map_param = np.zeros(n_bonds, dtype=int)
+ for i in range(n_bonds):
+ idx1 = data.bonds[i].atom1
+ idx2 = data.bonds[i].atom2
+ type1 = data.atomType[data.atoms[idx1]]
+ type2 = data.atomType[data.atoms[idx2]]
+ ifFound = False
+ for ii in range(len(self.types)):
+ if (type1 in self.types[ii][0] and type2 in self.types[ii][1]
+ ) or (type1 in self.types[ii][1]
+ and type2 in self.types[ii][0]):
+ map_atom1[i] = idx1
+ map_atom2[i] = idx2
+ map_param[i] = ii
+ ifFound = True
+ break
+ if not ifFound:
+ raise BaseException("No parameter for bond %i - %i" %
+ (idx1, idx2))
+
+ bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param)
+
+ def potential_fn(positions, box, pairs, params):
+ return bforce.get_energy(positions, box, pairs, params["k"],
+ params["length"])
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+
+# register all parsers
+app.forcefield.parsers[
+ "HarmonicBondForce"] = HarmonicBondGenerator.parseElement
+
+
+class HarmonicAngleGenerator:
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.params = {'k': [], 'theta0': []}
+ self._jaxPotential = None
+ self.types = []
+
+ def registerAngleType(self, angle):
+ types = self.ff._findAtomTypes(angle, 3)
+ self.types.append(types)
+ self.params['k'].append(float(angle['k']))
+ self.params['theta0'].append(float(angle['theta0']))
+
+ @staticmethod
+ def parseElement(element, hamiltonian):
+ generator = HarmonicAngleGenerator(hamiltonian)
+ hamiltonian.registerGenerator(generator)
+ for bondtype in element.findall("Angle"):
+ generator.registerAngleType(bondtype.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
+ args):
+
+ n_angles = len(data.angles)
+ # build map
+ map_atom1 = np.zeros(n_angles, dtype=int)
+ map_atom2 = np.zeros(n_angles, dtype=int)
+ map_atom3 = np.zeros(n_angles, dtype=int)
+ map_param = np.zeros(n_angles, dtype=int)
+ for i in range(n_angles):
+ idx1 = data.angles[i].atom1
+ idx2 = data.angles[i].atom2
+ idx3 = data.angles[i].atom3
+ type1 = data.atomType[data.atoms[idx1]]
+ type2 = data.atomType[data.atoms[idx2]]
+ type3 = data.atomType[data.atoms[idx3]]
+ ifFound = False
+ for ii in range(len(self.types)):
+ if type2 in self.types[ii][1]:
+ if (type1 in self.types[ii][0]
+ and type3 in self.types[ii][2]) or (
+ type1 in self.types[ii][2]
+ and type3 in self.types[ii][0]):
+ map_atom1[i] = idx1
+ map_atom2[i] = idx2
+ map_atom3[i] = idx3
+ map_param[i] = ii
+ ifFound = True
+ break
+ if not ifFound:
+ raise BaseException("No parameter for angle %i - %i - %i" %
+ (idx1, idx2, idx3))
+
+ aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3,
+ map_param)
+
+ def potential_fn(positions, box, pairs, params):
+ return aforce.get_energy(positions, box, pairs, params["k"],
+ params["theta0"])
+
+ self._jaxPotential = potential_fn
+ # self._top_data = data
+
+ def getJaxPotential(self):
+ return self._jaxPotential
+
+ def renderXML(self):
+ # generate xml force field file
+ pass
+
+
+# register all parsers
+app.forcefield.parsers[
+ "HarmonicAngleForce"] = HarmonicAngleGenerator.parseElement
+
+
+class PeriodicTorsion(object):
+ """A PeriodicTorsion records the information for a periodic torsion definition."""
+ def __init__(self, types):
+ self.types1 = types[0]
+ self.types2 = types[1]
+ self.types3 = types[2]
+ self.types4 = types[3]
+ self.periodicity = []
+ self.phase = []
+ self.k = []
+ self.ordering = 'default'
+
+
+## @private
+class PeriodicTorsionGenerator(object):
+ """A PeriodicTorsionGenerator constructs a PeriodicTorsionForce."""
+ def __init__(self, hamiltonian):
+ self.ff = hamiltonian
+ self.proper = []
+ self.improper = []
+ self.params = {'k': [], 'theta0': []}
+ self.propersForAtomType = defaultdict(set)
+
+ def registerProperTorsion(self, parameters):
+ torsion = self.ff._parseTorsion(parameters)
+ if torsion is not None:
+ index = len(self.proper)
+ self.proper.append(torsion)
+ for t in torsion.types2:
+ self.propersForAtomType[t].add(index)
+ for t in torsion.types3:
+ self.propersForAtomType[t].add(index)
+
+ def registerImproperTorsion(self, parameters, ordering='default'):
+ torsion = self.ff._parseTorsion(parameters)
+ if torsion is not None:
+ if ordering in ['default', 'charmm', 'amber', 'smirnoff']:
+ torsion.ordering = ordering
+ else:
+ raise ValueError(
+ 'Illegal ordering type %s for improper torsion %s' %
+ (ordering, torsion))
+ self.improper.append(torsion)
+
+ @staticmethod
+ def parseElement(element, ff):
+ existing = [
+ f for f in ff._forces if isinstance(f, PeriodicTorsionGenerator)
+ ]
+ if len(existing) == 0:
+ generator = PeriodicTorsionGenerator(ff)
+ ff.registerGenerator(generator)
+ else:
+ generator = existing[0]
+ for torsion in element.findall('Proper'):
+ generator.registerProperTorsion(torsion.attrib)
+ for torsion in element.findall('Improper'):
+ if 'ordering' in element.attrib:
+ generator.registerImproperTorsion(torsion.attrib,
+ element.attrib['ordering'])
+ else:
+ generator.registerImproperTorsion(torsion.attrib)
+ # jax it!
+ for k in generator.params.keys():
+ generator.params[k] = jnp.array(generator.params[k])
+ generator.types = np.array(generator.types)
+
+ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
+ wildcard = self.ff._atomClasses['']
+ proper_cache = {}
+ for torsion in data.propers:
+ type1, type2, type3, type4 = [
+ data.atomType[data.atoms[torsion[i]]] for i in range(4)
+ ]
+ sig = (type1, type2, type3, type4)
+ sig = frozenset((sig, sig[::-1]))
+ match = proper_cache.get(sig, None)
+ if match == -1:
+ continue
+ if match is None:
+ for index in self.propersForAtomType[type2]:
+ tordef = self.proper[index]
+ types1 = tordef.types1
+ types2 = tordef.types2
+ types3 = tordef.types3
+ types4 = tordef.types4
+ if (type2 in types2 and type3 in types3 and type4 in types4
+ and type1 in types1) or (type2 in types3
+ and type3 in types2
+ and type4 in types1
+ and type1 in types4):
+ hasWildcard = (wildcard
+ in (types1, types2, types3, types4))
+ if match is None or not hasWildcard: # Prefer specific definitions over ones with wildcards
+ match = tordef
+ if not hasWildcard:
+ break
+ if match is None:
+ proper_cache[sig] = -1
+ else:
+ proper_cache[sig] = match
+ if match is not None:
+ for i in range(len(match.phase)):
+ if match.k[i] != 0:
+ force.addTorsion(torsion[0], torsion[1], torsion[2],
+ torsion[3], match.periodicity[i],
+ match.phase[i], match.k[i])
+ impr_cache = {}
+ for torsion in data.impropers:
+ t1, t2, t3, t4 = [
+ data.atomType[data.atoms[torsion[i]]] for i in range(4)
+ ]
+ sig = (t1, t2, t3, t4)
+ match = impr_cache.get(sig, None)
+ if match == -1:
+ # Previously checked, and doesn't appear in the database
+ continue
+ elif match:
+ i1, i2, i3, i4, tordef = match
+ a1, a2, a3, a4 = (torsion[i] for i in (i1, i2, i3, i4))
+ match = (a1, a2, a3, a4, tordef)
+ if match is None:
+ match = _matchImproper(data, torsion, self)
+ if match is not None:
+ order = match[:4]
+ i1, i2, i3, i4 = tuple(torsion.index(a) for a in order)
+ impr_cache[sig] = (i1, i2, i3, i4, match[-1])
+ else:
+ impr_cache[sig] = -1
+ if match is not None:
+ (a1, a2, a3, a4, tordef) = match
+ for i in range(len(tordef.phase)):
+ if tordef.k[i] != 0:
+ if tordef.ordering == 'smirnoff':
+ # Add all torsions in trefoil
+ force.addTorsion(a1, a2, a3, a4,
+ tordef.periodicity[i],
+ tordef.phase[i], tordef.k[i])
+ force.addTorsion(a1, a3, a4, a2,
+ tordef.periodicity[i],
+ tordef.phase[i], tordef.k[i])
+ force.addTorsion(a1, a4, a2, a3,
+ tordef.periodicity[i],
+ tordef.phase[i], tordef.k[i])
+ else:
+ force.addTorsion(a1, a2, a3, a4,
+ tordef.periodicity[i],
+ tordef.phase[i], tordef.k[i])
+
+
+app.forcefield.parsers[
+ "PeriodicTorsionForce"] = PeriodicTorsionGenerator.parseElement
+
+
class Hamiltonian(app.forcefield.ForceField):
def __init__(self, xmlname):
super().__init__(xmlname)
@@ -476,9 +738,9 @@ def createPotential(
nonbondedMethod=app.NoCutoff,
nonbondedCutoff=1.0 * unit.nanometer,
):
- system = self.createSystem(
- topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff
- )
+ system = self.createSystem(topology,
+ nonbondedMethod=nonbondedMethod,
+ nonbondedCutoff=nonbondedCutoff)
# load_constraints_from_system_if_needed
# create potentials
for generator in self._forces:
@@ -493,7 +755,8 @@ def createPotential(
app.Topology.loadBondDefinitions("residues.xml")
pdb = app.PDBFile("../water1024.pdb")
rc = 4.0
- potentials = H.createPotential(pdb.topology, nonbondedCutoff=rc * unit.angstrom)
+ potentials = H.createPotential(pdb.topology,
+ nonbondedCutoff=rc * unit.angstrom)
pot_disp = potentials[0]
positions = jnp.array(pdb.positions._value) * 10
@@ -502,13 +765,15 @@ def createPotential(
# neighbor list
displacement_fn, shift_fn = space.periodic_general(
- box, fractional_coordinates=False
- )
- neighbor_list_fn = partition.neighbor_list(
- displacement_fn, box, rc, 0, format=partition.OrderedSparse
- )
+ box, fractional_coordinates=False)
+ neighbor_list_fn = partition.neighbor_list(displacement_fn,
+ box,
+ rc,
+ 0,
+ format=partition.OrderedSparse)
nbr = neighbor_list_fn.allocate(positions)
pairs = nbr.idx.T
- param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator.params)
+ param_grad = grad(pot_disp, argnums=3)(positions, box, pairs,
+ generator.params)
print(param_grad)
diff --git a/docs/uder_guide/tutorial.md b/dmff/classical/__init__.py
similarity index 100%
rename from docs/uder_guide/tutorial.md
rename to dmff/classical/__init__.py
diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py
new file mode 100644
index 000000000..12abd91bd
--- /dev/null
+++ b/dmff/classical/inter.py
@@ -0,0 +1,5 @@
+class LennardJonesForce:
+ pass
+
+class CoulombForce:
+ pass
\ No newline at end of file
diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py
new file mode 100644
index 000000000..8c0d187b4
--- /dev/null
+++ b/dmff/classical/intra.py
@@ -0,0 +1,119 @@
+import sys
+import numpy as np
+import jax
+import jax.numpy as jnp
+from jax import grad, value_and_grad, vmap, jit
+from jax.scipy.special import erf
+
+def distance(p1v, p2v):
+ pass
+
+def angle(p1v, p2v, p3v):
+ pass
+
+def dihedral(p1v, p2v, p3v, p4v):
+ pass
+
+class HarmonicBondJaxForce:
+ def __init__(self, p1idx, p2idx, prmidx):
+ self.p1idx = p1idx
+ self.p2idx = p2idx
+ self.prmidx = prmidx
+ self.refresh_calculators()
+
+ def generate_get_energy(self):
+ def get_energy(positions, box, pairs, k, length):
+ p1 = positions[self.p1idx]
+ p2 = positions[self.p2idx]
+ kprm = k[self.prmidx][0]
+ b0prm = length[self.prmidx][1]
+ dist = distance(p1, p2)
+ return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2))
+
+ return get_energy
+
+ def update_env(self, attr, val):
+ '''
+ Update the environment of the calculator
+ '''
+ setattr(self, attr, val)
+ self.refresh_calculators()
+
+ def refresh_calculators(self):
+ '''
+ refresh the energy and force calculators according to the current environment
+ '''
+ self.get_energy = self.generate_get_energy()
+ self.get_forces = value_and_grad(self.get_energy)
+
+
+class HarmonicAngleJaxForce:
+ def __init__(self, p1idx, p2idx, p3idx, prmidx):
+ self.p1idx = p1idx
+ self.p2idx = p2idx
+ self.p3idx = p3idx
+ self.prmidx = prmidx
+ self.refresh_calculators()
+
+ def generate_get_energy(self):
+ def get_energy(positions, box, pairs, k, theta0):
+ p1 = positions[self.p1idx]
+ p2 = positions[self.p2idx]
+ p3 = positions[self.p3idx]
+ kprm = k[self.prmidx][0]
+ theta0prm = theta0[self.prmidx][1]
+ ang = angle(p1, p2, p3)
+ return jnp.sum(0.5 * kprm * jnp.power(ang - theta0prm, 2))
+
+ return get_energy
+
+ def update_env(self, attr, val):
+ '''
+ Update the environment of the calculator
+ '''
+ setattr(self, attr, val)
+ self.refresh_calculators()
+
+ def refresh_calculators(self):
+ '''
+ refresh the energy and force calculators according to the current environment
+ '''
+ self.get_energy = self.generate_get_energy()
+ self.get_forces = value_and_grad(self.get_energy)
+
+
+class PeriodicTorsionJaxForce:
+ def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx):
+ self.p1idx = p1idx
+ self.p2idx = p2idx
+ self.p3idx = p3idx
+ self.p4idx = p4idx
+ self.prmidx = prmidx
+ self.refresh_calculators()
+
+ def generate_get_energy(self):
+ def get_energy(positions, box, pairs, k, psi0):
+ p1 = positions[self.p1idx]
+ p2 = positions[self.p2idx]
+ p3 = positions[self.p3idx]
+ p4 = positions[self.p4idx]
+ kprm = k[self.prmidx][0]
+ psi0prm = psi0[self.prmidx][1]
+ dih = dihedral(p1, p2, p3, p4)
+ return jnp.sum(0.5 * k * jnp.power(dih - psi0, 2))
+
+ return get_energy
+
+ def update_env(self, attr, val):
+ '''
+ Update the environment of the calculator
+ '''
+ setattr(self, attr, val)
+ self.refresh_calculators()
+
+ def refresh_calculators(self):
+ '''
+ refresh the energy and force calculators according to the current environment
+ '''
+ self.get_energy = self.generate_get_energy()
+ self.get_forces = value_and_grad(self.get_energy)
diff --git a/dmff/settings.py b/dmff/settings.py
index fa2e2bd92..8ba7097c3 100644
--- a/dmff/settings.py
+++ b/dmff/settings.py
@@ -1,8 +1,9 @@
from jax.config import config
-PRECISION = 'double' # 'double'
+PRECISION = 'single' # 'double'
DO_JIT = True
if PRECISION == 'double':
config.update("jax_enable_x64", True)
+
\ No newline at end of file
diff --git a/docs/user_guide/tutorial.md b/docs/user_guide/tutorial.md
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/openmm_api/forcefield.xml b/examples/openmm_api/forcefield.xml
index f3e9c3f3d..0dc44c980 100644
--- a/examples/openmm_api/forcefield.xml
+++ b/examples/openmm_api/forcefield.xml
@@ -20,7 +20,7 @@
-
+
-
+
diff --git a/examples/openmm_api/run.py b/examples/openmm_api/run.py
index b8565fcf1..affb697df 100755
--- a/examples/openmm_api/run.py
+++ b/examples/openmm_api/run.py
@@ -9,18 +9,20 @@
from dmff.api import Hamiltonian
from jax_md import space, partition
from jax import grad
-
+from time import time
if __name__ == '__main__':
H = Hamiltonian('forcefield.xml')
app.Topology.loadBondDefinitions("residues.xml")
- pdb = app.PDBFile("water1024.pdb")
+ pdb = app.PDBFile("waterbox_31ang.pdb")
rc = 4.0
# generator stores all force field parameters
generator = H.getGenerators()
disp_generator = generator[0]
pme_generator = generator[1]
+
+ pme_generator.lpol = True # debug
pme_generator.ref_dip = 'dipole_1024'
potentials = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom)
# pot_fn is the actual energy calculator
@@ -35,14 +37,13 @@
displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False)
neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse)
nbr = neighbor_list_fn.allocate(positions)
- pairs = nbr.idx.T
+ pairs = nbr.idx.T
- print(pot_disp(positions, box, pairs, disp_generator.params))
- param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator[0].params)
- print(param_grad['mScales'])
+ # print(pot_disp(positions, box, pairs, disp_generator.params))
+ # param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator[0].params)
+ # print(param_grad['mScales'])
print(pot_pme(positions, box, pairs, pme_generator.params))
- param_grad = grad(pot_pme, argnums=3)(positions, box, pairs, generator[1].params)
- print(param_grad['mScales'])
-
+ param_grad = grad(pot_pme, argnums=(3))(positions, box, pairs, pme_generator.params)
+ print(param_grad)
diff --git a/examples/water_1024/run_admp.py b/examples/water_1024/run.py
similarity index 100%
rename from examples/water_1024/run_admp.py
rename to examples/water_1024/run.py
diff --git a/examples/water_pol_1024/forcefield.xml b/examples/water_pol_1024/forcefield.xml
new file mode 100644
index 000000000..0dc44c980
--- /dev/null
+++ b/examples/water_pol_1024/forcefield.xml
@@ -0,0 +1,44 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/water_pol_1024/run_admp.py b/examples/water_pol_1024/run.py
similarity index 86%
rename from examples/water_pol_1024/run_admp.py
rename to examples/water_pol_1024/run.py
index 89ebba585..8df0f19c2 100755
--- a/examples/water_pol_1024/run_admp.py
+++ b/examples/water_pol_1024/run.py
@@ -9,6 +9,7 @@
from dmff.admp.multipole import convert_cart2harm
from dmff.admp.pme import ADMPPmeForce
from dmff.admp.parser import *
+from jax import grad
import linecache
@@ -134,11 +135,20 @@ def get_line_context(file_path, line_number):
# electrostatic
pme_force = ADMPPmeForce(box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=True)
pme_force.update_env('kappa', 0.657065221219616)
- E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales)
- print('# Electrostatic Energy (kJ/mol)')
+ pot_pme = pme_force.get_energy
+ jnp.save('mScales', mScales)
+ jnp.save('Q_local', Q_local)
+ jnp.save('pol', pol)
+ jnp.save('tholes', tholes)
+ jnp.save('pScales', pScales)
+ jnp.save('dScales', dScales)
+ jnp.save('U_ind', pme_force.U_ind)
+ # E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales)
+ # print('# Electrostatic Energy (kJ/mol)')
# E = pme_force.get_energy(positions, box, pairs, Q_local, mScales, pScales, dScales)
- E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=pme_force.U_ind)
- print(E)
+ E = pot_pme(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=pme_force.U_ind)
+ grad_params = grad(pot_pme, argnums=(3,4,5,6,7,8,9))(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, pme_force.U_ind)
+ # print(E)
U_ind = pme_force.U_ind
# compare U_ind with reference
for i in range(1024):