diff --git a/.gitignore b/.gitignore index 7df232a6f..abe26b321 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ + +# temporary +err +out +sub.sh +*.npy + ### C++ ### # Prerequisites *.d diff --git a/dmff/api.py b/dmff/api.py index 8d612ed51..9fb6074f3 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -15,9 +15,8 @@ import linecache - def get_line_context(file_path, line_number): - return linecache.getline(file_path,line_number).strip() + return linecache.getline(file_path, line_number).strip() def build_covalent_map(data, max_neighbor): @@ -30,8 +29,8 @@ def build_covalent_map(data, max_neighbor): for i in range(n_atoms): # current neighbors j_list = np.where( - np.logical_and(covalent_map[i] <= n_curr, covalent_map[i] > 0) - )[0] + np.logical_and(covalent_map[i] <= n_curr, + covalent_map[i] > 0))[0] for j in j_list: k_list = np.where(covalent_map[j] == 1)[0] for k in k_list: @@ -119,10 +118,17 @@ def set_axis_type(map_atomtypes, types, params): class ADMPDispGenerator: def __init__(self, hamiltonian): self.ff = hamiltonian - self.params = {"A": [], "B": [], "Q": [], "C6": [], "C8": [], "C10": []} + self.params = { + "A": [], + "B": [], + "Q": [], + "C6": [], + "C8": [], + "C10": [] + } self._jaxPotential = None self.types = [] - self.ethresh = 1.0e-5 + self.ethresh = 5e-4 self.pmax = 10 def registerAtomType(self, atom): @@ -150,7 +156,8 @@ def parseElement(element, hamiltonian): generator.params[k] = jnp.array(generator.params[k]) generator.types = np.array(generator.types) - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): n_atoms = len(data.atoms) # build index map @@ -168,22 +175,22 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): rc = nonbondedCutoff.value_in_unit(unit.angstrom) # get calculator - Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, self.pmax) + Force_DispPME = ADMPDispPmeForce(box, covalent_map, rc, self.ethresh, + self.pmax) # debugging # Force_DispPME.update_env('kappa', 0.657065221219616) # Force_DispPME.update_env('K1', 96) # Force_DispPME.update_env('K2', 96) # Force_DispPME.update_env('K3', 96) pot_fn_lr = Force_DispPME.get_energy - pot_fn_sr = generate_pairwise_interaction( - TT_damping_qq_c6_kernel, covalent_map, static_args={} - ) + pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_c6_kernel, + covalent_map, + static_args={}) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - a_list = ( - params["A"][map_atomtype] / 2625.5 - ) # kj/mol to au, as expected by TT_damping kernel + a_list = (params["A"][map_atomtype] / 2625.5 + ) # kj/mol to au, as expected by TT_damping kernel b_list = params["B"][map_atomtype] * 0.0529177249 # nm^-1 to au q_list = params["Q"][map_atomtype] c6_list = jnp.sqrt(params["C6"][map_atomtype] * 1e6) @@ -191,9 +198,8 @@ def potential_fn(positions, box, pairs, params): c10_list = jnp.sqrt(params["C10"][map_atomtype] * 1e10) c_list = jnp.vstack((c6_list, c8_list, c10_list)) - E_sr = pot_fn_sr( - positions, box, pairs, mScales, a_list, b_list, q_list, c_list[0] - ) + E_sr = pot_fn_sr(positions, box, pairs, mScales, a_list, b_list, + q_list, c_list[0]) E_lr = pot_fn_lr(positions, box, pairs, c_list.T, mScales) return E_sr - E_lr @@ -244,19 +250,24 @@ def __init__(self, hamiltonian): "thole": [], "polarizabilityXX": [], "polarizabilityYY": [], - "polarizabilityZZ": [] + "polarizabilityZZ": [], + + } + self.params = { + "mScales": [], + "pScales": [], + "dScales": [], } # if more or optional input params - # self._input_params = defaultDict(list) + # self._input_params = defaultDict(list) self._jaxPotential = None self.types = [] - self.ethresh = 1.0e-5 - self.params = {} + self.ethresh = 5e-4 self.lpol = False self.ref_dip = '' - def registerAtomType(self, atom:dict): - + def registerAtomType(self, atom: dict): + self.types.append(atom.pop("type")) kStrings = ["kz", "kx", "ky"] @@ -273,29 +284,29 @@ def registerAtomType(self, atom:dict): def parseElement(element, hamiltonian): generator = ADMPPmeGenerator(hamiltonian) generator.lmax = int(element.attrib.get('lmax')) - generator.pmax = int(element.attrib.get('pmax')) - + generator.defaultTholeWidth = 5 + hamiltonian.registerGenerator(generator) - mScales = [] - pScales = [] - dScales = [] for i in range(2, 7): - mScales.append(float(element.attrib["mScale1%d" % i])) - pScales.append(float(element.attrib["pScale1%d" % i])) - dScales.append(float(element.attrib["dScale1%d" % i])) - generator.params["mScales"] = jnp.array(mScales) - generator.params["pScales"] = jnp.array(pScales) - generator.params["dScales"] = jnp.array(dScales) - + generator.params["mScales"].append( + float(element.attrib["mScale1%d" % i])) + generator.params["pScales"].append( + float(element.attrib["pScale1%d" % i])) + generator.params["dScales"].append( + float(element.attrib["dScale1%d" % i])) + if element.findall('Polarize'): generator.lpol = True for atomType in element.findall("Atom"): atomAttrib = atomType.attrib + # if not set + atomAttrib.update({'polarizabilityXX': 0, 'polarizabilityYY': 0, 'polarizabilityZZ': 0}) for polarInfo in element.findall("Polarize"): polarAttrib = polarInfo.attrib if polarInfo.attrib['type'] == atomAttrib['type']: + # cover default atomAttrib.update(polarAttrib) break generator.registerAtomType(atomAttrib) @@ -303,87 +314,50 @@ def parseElement(element, hamiltonian): for k in generator._input_params.keys(): generator._input_params[k] = jnp.array(generator._input_params[k]) generator.types = np.array(generator.types) - - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): - - n_atoms = len(data.atoms) - # build index map - map_atomtype = np.zeros(n_atoms, dtype=int) - for i in range(n_atoms): - atype = data.atomType[data.atoms[i]] - map_atomtype[i] = np.where(self.types == atype)[0][0] + + n_atoms = len(element.findall('Atom')) # map atom multipole moments - p = self._input_params Q = np.zeros((n_atoms, 10)) - Q[:, 0] = p["c0"][map_atomtype] - Q[:, 1] = p["dX"][map_atomtype] * 10 - Q[:, 2] = p["dY"][map_atomtype] * 10 - Q[:, 3] = p["dZ"][map_atomtype] * 10 - Q[:, 4] = p["qXX"][map_atomtype] * 300 - Q[:, 5] = p["qYY"][map_atomtype] * 300 - Q[:, 6] = p["qZZ"][map_atomtype] * 300 - Q[:, 7] = p["qXY"][map_atomtype] * 300 - Q[:, 8] = p["qXZ"][map_atomtype] * 300 - Q[:, 9] = p["qYZ"][map_atomtype] * 300 + Q[:, 0] = generator._input_params["c0"] + Q[:, 1] = generator._input_params["dX"] * 10 + Q[:, 2] = generator._input_params["dY"] * 10 + Q[:, 3] = generator._input_params["dZ"] * 10 + Q[:, 4] = generator._input_params["qXX"] * 300 + Q[:, 5] = generator._input_params["qYY"] * 300 + Q[:, 6] = generator._input_params["qZZ"] * 300 + Q[:, 7] = generator._input_params["qXY"] * 300 + Q[:, 8] = generator._input_params["qXZ"] * 300 + Q[:, 9] = generator._input_params["qYZ"] * 300 + + # add all differentiable params to self.params + Q_local = convert_cart2harm(Q, 2) + generator.params['Q_local'] = Q_local - # map polarization-related params - pol = jnp.vstack((p['polarizabilityXX'][map_atomtype], p['polarizabilityYY'][map_atomtype], p['polarizabilityZZ'][map_atomtype])).T.astype(jnp.float32) + pol = jnp.vstack((generator._input_params['polarizabilityXX'], generator._input_params['polarizabilityYY'], generator._input_params['polarizabilityZZ'])).T pol = 1000*jnp.mean(pol,axis=1) - self.params['pol'] = pol + + tholes = jnp.array(generator._input_params['thole']) + # tholes = jnp.mean(jnp.atleast_2d(tholes), axis=1) - tholes = jnp.array(p['thole'][map_atomtype]).astype(jnp.float32) - tholes = jnp.mean(jnp.atleast_2d(tholes), axis=1) - self.params['tholes'] = tholes + generator.params['pol'] = pol + generator.params['tholes'] = tholes + # generator.params[''] + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) - # defaultTholeWidth = 8 - Uind_global = jnp.zeros([n_atoms,3]) - ref_dip = self.ref_dip + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_atoms = len(data.atoms) + # build index map + map_atomtype = np.zeros(n_atoms, dtype=int) + for i in range(n_atoms): - a = get_line_context(ref_dip, i+1) - b = a.split() - t = np.array([10*float(b[0]),10*float(b[1]),10*float(b[2])]) - Uind_global = Uind_global.at[i].set(t) - - # construct the C list - c_list = np.zeros((3, n_atoms)) - a_list = np.zeros(n_atoms) - q_list = np.zeros(n_atoms) - b_list = np.zeros(n_atoms) - - - nmol=int(n_atoms/3) # WARNING: HARD CODE! - for i in range(nmol): - a = i*3 - b = i*3+1 - c = i*3+2 - # dispersion coeff - c_list[0][a]=37.19677405 - c_list[0][b]=7.6111103 - c_list[0][c]=7.6111103 - c_list[1][a]=85.26810658 - c_list[1][b]=11.90220148 - c_list[1][c]=11.90220148 - c_list[2][a]=134.44874488 - c_list[2][b]=15.05074749 - c_list[2][c]=15.05074749 - # q - q_list[a] = -0.741706 - q_list[b] = 0.370853 - q_list[c] = 0.370853 - # b, Bohr^-1 - b_list[a] = 2.00095977 - b_list[b] = 1.999519942 - b_list[c] = 1.999519942 - # a, Hartree - a_list[a] = 458.3777 - a_list[b] = 0.0317 - a_list[c] = 0.0317 - - # add all differentiable params to self.params - Q = jnp.array(Q) - Q_local = convert_cart2harm(Q, 2) - self.params["Q_local"] = Q_local + atype = data.atomType[data.atoms[i]] + map_atomtype[i] = np.where(self.types == atype)[0][0] + # here box is only used to setup ewald parameters, no need to be differentiable a, b, c = system.getDefaultPeriodicBoxVectors() box = jnp.array([a._value, b._value, c._value]) * 10 @@ -396,15 +370,15 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): # build intra-molecule axis self.axis_types, self.axis_indices = set_axis_type( - map_atomtype, self.types, self.kStrings - ) + map_atomtype, self.types, self.kStrings) map_axis_indices = [] # map axis_indices for i in range(n_atoms): catom = data.atoms[i] residue = catom.residue._atoms atom_indices = [ - index if index != "" else -1 for index in self.axis_indices[i][1:] + index if index != "" else -1 + for index in self.axis_indices[i][1:] ] for atom in residue: if atom == catom: @@ -416,40 +390,21 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): map_axis_indices.append(atom_indices) self.axis_indices = np.array(map_axis_indices) - - # Finish data preparation - # ------------------------------------------------------------------------------------- - # parameters should be ready: - # geometric variables: positions, box - # atomic parameters: Q_local, c_list - # topological parameters: covalent_map, mScales, pScales, dScales - # general force field setting parameters: rc, ethresh, lmax, pmax - - pme_force = ADMPPmeForce( - box, - self.axis_types, - self.axis_indices, - covalent_map, - rc, - self.ethresh, - self.lmax, - self.lpol - ) - self.params['U_ind'] = pme_force.U_ind + pme_force = ADMPPmeForce(box, self.axis_types, self.axis_indices, + covalent_map, rc, self.ethresh, self.lmax, + self.lpol) def potential_fn(positions, box, pairs, params): mScales = params["mScales"] - Q_local = params["Q_local"] - - - # positions, box, pairs, Q_local, mScales + Q_local = params["Q_local"][map_atomtype] if self.lpol: pScales = params["pScales"] dScales = params["dScales"] - U_ind = params["U_ind"] - return pme_force.get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=U_ind) + pol = params["pol"][map_atomtype] + tholes = params["tholes"][map_atomtype] + return pme_force.get_energy(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, pme_force.U_ind) else: return pme_force.get_energy(positions, box, pairs, Q_local, mScales) @@ -465,6 +420,313 @@ def renderXML(self): app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement +class HarmonicBondGenerator: + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = {'k': [], 'length': []} + self._jaxPotential = None + self.types = [] + + def registerBondType(self, bond): + types = self.ff._findAtomTypes(bond, 2) + self.types.append(types) + self.params['k'].append(float(bond['k'])) + self.params['length'].append(float(bond['length'])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = HarmonicBondGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + for bondtype in element.findall("Bond"): + generator.registerBondType(bondtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_bonds = len(data.bonds) + # build map + map_atom1 = np.zeros(n_bonds, dtype=int) + map_atom2 = np.zeros(n_bonds, dtype=int) + map_param = np.zeros(n_bonds, dtype=int) + for i in range(n_bonds): + idx1 = data.bonds[i].atom1 + idx2 = data.bonds[i].atom2 + type1 = data.atomType[data.atoms[idx1]] + type2 = data.atomType[data.atoms[idx2]] + ifFound = False + for ii in range(len(self.types)): + if (type1 in self.types[ii][0] and type2 in self.types[ii][1] + ) or (type1 in self.types[ii][1] + and type2 in self.types[ii][0]): + map_atom1[i] = idx1 + map_atom2[i] = idx2 + map_param[i] = ii + ifFound = True + break + if not ifFound: + raise BaseException("No parameter for bond %i - %i" % + (idx1, idx2)) + + bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param) + + def potential_fn(positions, box, pairs, params): + return bforce.get_energy(positions, box, pairs, params["k"], + params["length"]) + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + + +# register all parsers +app.forcefield.parsers[ + "HarmonicBondForce"] = HarmonicBondGenerator.parseElement + + +class HarmonicAngleGenerator: + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.params = {'k': [], 'theta0': []} + self._jaxPotential = None + self.types = [] + + def registerAngleType(self, angle): + types = self.ff._findAtomTypes(angle, 3) + self.types.append(types) + self.params['k'].append(float(angle['k'])) + self.params['theta0'].append(float(angle['theta0'])) + + @staticmethod + def parseElement(element, hamiltonian): + generator = HarmonicAngleGenerator(hamiltonian) + hamiltonian.registerGenerator(generator) + for bondtype in element.findall("Angle"): + generator.registerAngleType(bondtype.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, + args): + + n_angles = len(data.angles) + # build map + map_atom1 = np.zeros(n_angles, dtype=int) + map_atom2 = np.zeros(n_angles, dtype=int) + map_atom3 = np.zeros(n_angles, dtype=int) + map_param = np.zeros(n_angles, dtype=int) + for i in range(n_angles): + idx1 = data.angles[i].atom1 + idx2 = data.angles[i].atom2 + idx3 = data.angles[i].atom3 + type1 = data.atomType[data.atoms[idx1]] + type2 = data.atomType[data.atoms[idx2]] + type3 = data.atomType[data.atoms[idx3]] + ifFound = False + for ii in range(len(self.types)): + if type2 in self.types[ii][1]: + if (type1 in self.types[ii][0] + and type3 in self.types[ii][2]) or ( + type1 in self.types[ii][2] + and type3 in self.types[ii][0]): + map_atom1[i] = idx1 + map_atom2[i] = idx2 + map_atom3[i] = idx3 + map_param[i] = ii + ifFound = True + break + if not ifFound: + raise BaseException("No parameter for angle %i - %i - %i" % + (idx1, idx2, idx3)) + + aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3, + map_param) + + def potential_fn(positions, box, pairs, params): + return aforce.get_energy(positions, box, pairs, params["k"], + params["theta0"]) + + self._jaxPotential = potential_fn + # self._top_data = data + + def getJaxPotential(self): + return self._jaxPotential + + def renderXML(self): + # generate xml force field file + pass + + +# register all parsers +app.forcefield.parsers[ + "HarmonicAngleForce"] = HarmonicAngleGenerator.parseElement + + +class PeriodicTorsion(object): + """A PeriodicTorsion records the information for a periodic torsion definition.""" + def __init__(self, types): + self.types1 = types[0] + self.types2 = types[1] + self.types3 = types[2] + self.types4 = types[3] + self.periodicity = [] + self.phase = [] + self.k = [] + self.ordering = 'default' + + +## @private +class PeriodicTorsionGenerator(object): + """A PeriodicTorsionGenerator constructs a PeriodicTorsionForce.""" + def __init__(self, hamiltonian): + self.ff = hamiltonian + self.proper = [] + self.improper = [] + self.params = {'k': [], 'theta0': []} + self.propersForAtomType = defaultdict(set) + + def registerProperTorsion(self, parameters): + torsion = self.ff._parseTorsion(parameters) + if torsion is not None: + index = len(self.proper) + self.proper.append(torsion) + for t in torsion.types2: + self.propersForAtomType[t].add(index) + for t in torsion.types3: + self.propersForAtomType[t].add(index) + + def registerImproperTorsion(self, parameters, ordering='default'): + torsion = self.ff._parseTorsion(parameters) + if torsion is not None: + if ordering in ['default', 'charmm', 'amber', 'smirnoff']: + torsion.ordering = ordering + else: + raise ValueError( + 'Illegal ordering type %s for improper torsion %s' % + (ordering, torsion)) + self.improper.append(torsion) + + @staticmethod + def parseElement(element, ff): + existing = [ + f for f in ff._forces if isinstance(f, PeriodicTorsionGenerator) + ] + if len(existing) == 0: + generator = PeriodicTorsionGenerator(ff) + ff.registerGenerator(generator) + else: + generator = existing[0] + for torsion in element.findall('Proper'): + generator.registerProperTorsion(torsion.attrib) + for torsion in element.findall('Improper'): + if 'ordering' in element.attrib: + generator.registerImproperTorsion(torsion.attrib, + element.attrib['ordering']) + else: + generator.registerImproperTorsion(torsion.attrib) + # jax it! + for k in generator.params.keys(): + generator.params[k] = jnp.array(generator.params[k]) + generator.types = np.array(generator.types) + + def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): + wildcard = self.ff._atomClasses[''] + proper_cache = {} + for torsion in data.propers: + type1, type2, type3, type4 = [ + data.atomType[data.atoms[torsion[i]]] for i in range(4) + ] + sig = (type1, type2, type3, type4) + sig = frozenset((sig, sig[::-1])) + match = proper_cache.get(sig, None) + if match == -1: + continue + if match is None: + for index in self.propersForAtomType[type2]: + tordef = self.proper[index] + types1 = tordef.types1 + types2 = tordef.types2 + types3 = tordef.types3 + types4 = tordef.types4 + if (type2 in types2 and type3 in types3 and type4 in types4 + and type1 in types1) or (type2 in types3 + and type3 in types2 + and type4 in types1 + and type1 in types4): + hasWildcard = (wildcard + in (types1, types2, types3, types4)) + if match is None or not hasWildcard: # Prefer specific definitions over ones with wildcards + match = tordef + if not hasWildcard: + break + if match is None: + proper_cache[sig] = -1 + else: + proper_cache[sig] = match + if match is not None: + for i in range(len(match.phase)): + if match.k[i] != 0: + force.addTorsion(torsion[0], torsion[1], torsion[2], + torsion[3], match.periodicity[i], + match.phase[i], match.k[i]) + impr_cache = {} + for torsion in data.impropers: + t1, t2, t3, t4 = [ + data.atomType[data.atoms[torsion[i]]] for i in range(4) + ] + sig = (t1, t2, t3, t4) + match = impr_cache.get(sig, None) + if match == -1: + # Previously checked, and doesn't appear in the database + continue + elif match: + i1, i2, i3, i4, tordef = match + a1, a2, a3, a4 = (torsion[i] for i in (i1, i2, i3, i4)) + match = (a1, a2, a3, a4, tordef) + if match is None: + match = _matchImproper(data, torsion, self) + if match is not None: + order = match[:4] + i1, i2, i3, i4 = tuple(torsion.index(a) for a in order) + impr_cache[sig] = (i1, i2, i3, i4, match[-1]) + else: + impr_cache[sig] = -1 + if match is not None: + (a1, a2, a3, a4, tordef) = match + for i in range(len(tordef.phase)): + if tordef.k[i] != 0: + if tordef.ordering == 'smirnoff': + # Add all torsions in trefoil + force.addTorsion(a1, a2, a3, a4, + tordef.periodicity[i], + tordef.phase[i], tordef.k[i]) + force.addTorsion(a1, a3, a4, a2, + tordef.periodicity[i], + tordef.phase[i], tordef.k[i]) + force.addTorsion(a1, a4, a2, a3, + tordef.periodicity[i], + tordef.phase[i], tordef.k[i]) + else: + force.addTorsion(a1, a2, a3, a4, + tordef.periodicity[i], + tordef.phase[i], tordef.k[i]) + + +app.forcefield.parsers[ + "PeriodicTorsionForce"] = PeriodicTorsionGenerator.parseElement + + class Hamiltonian(app.forcefield.ForceField): def __init__(self, xmlname): super().__init__(xmlname) @@ -476,9 +738,9 @@ def createPotential( nonbondedMethod=app.NoCutoff, nonbondedCutoff=1.0 * unit.nanometer, ): - system = self.createSystem( - topology, nonbondedMethod=nonbondedMethod, nonbondedCutoff=nonbondedCutoff - ) + system = self.createSystem(topology, + nonbondedMethod=nonbondedMethod, + nonbondedCutoff=nonbondedCutoff) # load_constraints_from_system_if_needed # create potentials for generator in self._forces: @@ -493,7 +755,8 @@ def createPotential( app.Topology.loadBondDefinitions("residues.xml") pdb = app.PDBFile("../water1024.pdb") rc = 4.0 - potentials = H.createPotential(pdb.topology, nonbondedCutoff=rc * unit.angstrom) + potentials = H.createPotential(pdb.topology, + nonbondedCutoff=rc * unit.angstrom) pot_disp = potentials[0] positions = jnp.array(pdb.positions._value) * 10 @@ -502,13 +765,15 @@ def createPotential( # neighbor list displacement_fn, shift_fn = space.periodic_general( - box, fractional_coordinates=False - ) - neighbor_list_fn = partition.neighbor_list( - displacement_fn, box, rc, 0, format=partition.OrderedSparse - ) + box, fractional_coordinates=False) + neighbor_list_fn = partition.neighbor_list(displacement_fn, + box, + rc, + 0, + format=partition.OrderedSparse) nbr = neighbor_list_fn.allocate(positions) pairs = nbr.idx.T - param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator.params) + param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, + generator.params) print(param_grad) diff --git a/docs/uder_guide/tutorial.md b/dmff/classical/__init__.py similarity index 100% rename from docs/uder_guide/tutorial.md rename to dmff/classical/__init__.py diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py new file mode 100644 index 000000000..12abd91bd --- /dev/null +++ b/dmff/classical/inter.py @@ -0,0 +1,5 @@ +class LennardJonesForce: + pass + +class CoulombForce: + pass \ No newline at end of file diff --git a/dmff/classical/intra.py b/dmff/classical/intra.py new file mode 100644 index 000000000..8c0d187b4 --- /dev/null +++ b/dmff/classical/intra.py @@ -0,0 +1,119 @@ +import sys +import numpy as np +import jax +import jax.numpy as jnp +from jax import grad, value_and_grad, vmap, jit +from jax.scipy.special import erf + +def distance(p1v, p2v): + pass + +def angle(p1v, p2v, p3v): + pass + +def dihedral(p1v, p2v, p3v, p4v): + pass + +class HarmonicBondJaxForce: + def __init__(self, p1idx, p2idx, prmidx): + self.p1idx = p1idx + self.p2idx = p2idx + self.prmidx = prmidx + self.refresh_calculators() + + def generate_get_energy(self): + def get_energy(positions, box, pairs, k, length): + p1 = positions[self.p1idx] + p2 = positions[self.p2idx] + kprm = k[self.prmidx][0] + b0prm = length[self.prmidx][1] + dist = distance(p1, p2) + return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2)) + + return get_energy + + def update_env(self, attr, val): + ''' + Update the environment of the calculator + ''' + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + ''' + refresh the energy and force calculators according to the current environment + ''' + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) + + +class HarmonicAngleJaxForce: + def __init__(self, p1idx, p2idx, p3idx, prmidx): + self.p1idx = p1idx + self.p2idx = p2idx + self.p3idx = p3idx + self.prmidx = prmidx + self.refresh_calculators() + + def generate_get_energy(self): + def get_energy(positions, box, pairs, k, theta0): + p1 = positions[self.p1idx] + p2 = positions[self.p2idx] + p3 = positions[self.p3idx] + kprm = k[self.prmidx][0] + theta0prm = theta0[self.prmidx][1] + ang = angle(p1, p2, p3) + return jnp.sum(0.5 * kprm * jnp.power(ang - theta0prm, 2)) + + return get_energy + + def update_env(self, attr, val): + ''' + Update the environment of the calculator + ''' + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + ''' + refresh the energy and force calculators according to the current environment + ''' + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) + + +class PeriodicTorsionJaxForce: + def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx): + self.p1idx = p1idx + self.p2idx = p2idx + self.p3idx = p3idx + self.p4idx = p4idx + self.prmidx = prmidx + self.refresh_calculators() + + def generate_get_energy(self): + def get_energy(positions, box, pairs, k, psi0): + p1 = positions[self.p1idx] + p2 = positions[self.p2idx] + p3 = positions[self.p3idx] + p4 = positions[self.p4idx] + kprm = k[self.prmidx][0] + psi0prm = psi0[self.prmidx][1] + dih = dihedral(p1, p2, p3, p4) + return jnp.sum(0.5 * k * jnp.power(dih - psi0, 2)) + + return get_energy + + def update_env(self, attr, val): + ''' + Update the environment of the calculator + ''' + setattr(self, attr, val) + self.refresh_calculators() + + def refresh_calculators(self): + ''' + refresh the energy and force calculators according to the current environment + ''' + self.get_energy = self.generate_get_energy() + self.get_forces = value_and_grad(self.get_energy) diff --git a/dmff/settings.py b/dmff/settings.py index fa2e2bd92..8ba7097c3 100644 --- a/dmff/settings.py +++ b/dmff/settings.py @@ -1,8 +1,9 @@ from jax.config import config -PRECISION = 'double' # 'double' +PRECISION = 'single' # 'double' DO_JIT = True if PRECISION == 'double': config.update("jax_enable_x64", True) + \ No newline at end of file diff --git a/docs/user_guide/tutorial.md b/docs/user_guide/tutorial.md new file mode 100644 index 000000000..e69de29bb diff --git a/examples/openmm_api/forcefield.xml b/examples/openmm_api/forcefield.xml index f3e9c3f3d..0dc44c980 100644 --- a/examples/openmm_api/forcefield.xml +++ b/examples/openmm_api/forcefield.xml @@ -20,7 +20,7 @@ - + - + diff --git a/examples/openmm_api/run.py b/examples/openmm_api/run.py index b8565fcf1..affb697df 100755 --- a/examples/openmm_api/run.py +++ b/examples/openmm_api/run.py @@ -9,18 +9,20 @@ from dmff.api import Hamiltonian from jax_md import space, partition from jax import grad - +from time import time if __name__ == '__main__': H = Hamiltonian('forcefield.xml') app.Topology.loadBondDefinitions("residues.xml") - pdb = app.PDBFile("water1024.pdb") + pdb = app.PDBFile("waterbox_31ang.pdb") rc = 4.0 # generator stores all force field parameters generator = H.getGenerators() disp_generator = generator[0] pme_generator = generator[1] + + pme_generator.lpol = True # debug pme_generator.ref_dip = 'dipole_1024' potentials = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.angstrom) # pot_fn is the actual energy calculator @@ -35,14 +37,13 @@ displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=False) neighbor_list_fn = partition.neighbor_list(displacement_fn, box, rc, 0, format=partition.OrderedSparse) nbr = neighbor_list_fn.allocate(positions) - pairs = nbr.idx.T + pairs = nbr.idx.T - print(pot_disp(positions, box, pairs, disp_generator.params)) - param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator[0].params) - print(param_grad['mScales']) + # print(pot_disp(positions, box, pairs, disp_generator.params)) + # param_grad = grad(pot_disp, argnums=3)(positions, box, pairs, generator[0].params) + # print(param_grad['mScales']) print(pot_pme(positions, box, pairs, pme_generator.params)) - param_grad = grad(pot_pme, argnums=3)(positions, box, pairs, generator[1].params) - print(param_grad['mScales']) - + param_grad = grad(pot_pme, argnums=(3))(positions, box, pairs, pme_generator.params) + print(param_grad) diff --git a/examples/water_1024/run_admp.py b/examples/water_1024/run.py similarity index 100% rename from examples/water_1024/run_admp.py rename to examples/water_1024/run.py diff --git a/examples/water_pol_1024/forcefield.xml b/examples/water_pol_1024/forcefield.xml new file mode 100644 index 000000000..0dc44c980 --- /dev/null +++ b/examples/water_pol_1024/forcefield.xml @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/examples/water_pol_1024/run_admp.py b/examples/water_pol_1024/run.py similarity index 86% rename from examples/water_pol_1024/run_admp.py rename to examples/water_pol_1024/run.py index 89ebba585..8df0f19c2 100755 --- a/examples/water_pol_1024/run_admp.py +++ b/examples/water_pol_1024/run.py @@ -9,6 +9,7 @@ from dmff.admp.multipole import convert_cart2harm from dmff.admp.pme import ADMPPmeForce from dmff.admp.parser import * +from jax import grad import linecache @@ -134,11 +135,20 @@ def get_line_context(file_path, line_number): # electrostatic pme_force = ADMPPmeForce(box, axis_type, axis_indices, covalent_map, rc, ethresh, lmax, lpol=True) pme_force.update_env('kappa', 0.657065221219616) - E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales) - print('# Electrostatic Energy (kJ/mol)') + pot_pme = pme_force.get_energy + jnp.save('mScales', mScales) + jnp.save('Q_local', Q_local) + jnp.save('pol', pol) + jnp.save('tholes', tholes) + jnp.save('pScales', pScales) + jnp.save('dScales', dScales) + jnp.save('U_ind', pme_force.U_ind) + # E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales) + # print('# Electrostatic Energy (kJ/mol)') # E = pme_force.get_energy(positions, box, pairs, Q_local, mScales, pScales, dScales) - E, F = pme_force.get_forces(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=pme_force.U_ind) - print(E) + E = pot_pme(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=pme_force.U_ind) + grad_params = grad(pot_pme, argnums=(3,4,5,6,7,8,9))(positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, pme_force.U_ind) + # print(E) U_ind = pme_force.U_ind # compare U_ind with reference for i in range(1024):