diff --git a/.github/workflows/ut.yml b/.github/workflows/ut.yml index caa0c77c4..0371fe8db 100644 --- a/.github/workflows/ut.yml +++ b/.github/workflows/ut.yml @@ -22,7 +22,7 @@ jobs: $CONDA/bin/conda update -n base -c defaults conda conda install pip conda update pip - conda install numpy openmm pytest -c conda-forge + conda install numpy openmm pytest rdkit biopandas openbabel -c conda-forge pip install jax jax_md pip install mdtraj==1.9.7 pymbar==4.0.1 - name: Install DMFF diff --git a/README.md b/README.md index 4ad11df2c..1b530e68b 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,20 @@ # DMFF +[](https://doi.org/10.26434/chemrxiv-2022-2c7gv) + +## About DMFF + **DMFF** (**D**ifferentiable **M**olecular **F**orce **F**ield) is a Jax-based python package that provides a full differentiable implementation of molecular force field models. This project aims to establish an extensible codebase to minimize the efforts in force field parameterization, and to ease the force and virial tensor evaluations for advanced complicated potentials (e.g., polarizable models with geometry-dependent atomic parameters). Currently, this project mainly focuses on the molecular systems such as: water, biological macromolecules (peptides, proteins, nucleic acids), organic polymers, and small organic molecules (organic electrolyte, drug-like molecules) etc. We support both the conventional point charge models (OPLS and AMBER like) and multipolar polarizable models (AMOEBA and MPID like). The entire project is backed by the XLA technique in JAX, thus can be "jitted" and run in GPU devices much more efficiently compared to normal python codes. The behavior of organic molecular systems (e.g., protein folding, polymer structure, etc.) is often determined by a complex effect of many different types of interactions. The existing organic molecular force fields are mainly empirically fitted and their performance relies heavily on error cancellation. Therefore, the transferability and the prediction power of these force fields are insufficient. For new molecules, the parameter fitting process requires essential manual intervention and can be quite cumbersome. In order to automate the parametrization process and increase the robustness of the model, it is necessary to apply modern AI techniques in conventional force field development. This project serves for this purpose by utilizing the automatic differentiable programming technique to develop a codebase, which allows a more convenient incorporation of modern AI optimization techniques. It also helps the realization of many exciting functions including (but not limited to): hybrid machine learning/force field models and parameter optimization based on trajectory. +### License and credits + +The project DMFF is licensed under [GNU LGPL v3.0](LICENSE). If you use this code in any future publications, please cite this using `Wang X, Li J, Yang L, Chen F, Wang Y, Chang J, et al. DMFF: An Open-Source Automatic +Differentiable Platform for Molecular Force Field +Development and Molecular Dynamics +Simulation. ChemRxiv. Cambridge: Cambridge Open Engage; 2022; This content is a preprint and has not been peer-reviewed.` + ## User Guide + [1. Introduction](docs/user_guide/introduction.md) @@ -18,9 +29,20 @@ The behavior of organic molecular systems (e.g., protein folding, polymer struct + [3. Coding conventions](docs/dev_guide/convention.md) + [4. Document writing](docs/dev_guide/write_docs.md) -## Modules -+ [1. ADMP](docs/modules/admp.md) +## Code Structure + +The code is organized as follows: ++ `examples`: demos presented in Jupyter Notebook. ++ `docs`: documentation. ++ `package`: files for constructing packages or images, such as conda recipe and docker files. ++ `tests`: unit tests. ++ `dmff`: DMFF python codes ++ `dmff/admp`: source code of automatic differentiable multipolar polarizable (ADMP) force field module. ++ `dmff/classical`: source code of classical force field module. ++ `dmff/common`: source code of common functions, such as neighbor list. ++ `dmff/generators`: source code of force generators. ++ `dmff/sgnn`: source of subgragh neural network force field model. ## Support and Contribution diff --git a/dmff/api.py b/dmff/api.py index d2af46683..ddc0533fc 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1,4 +1,5 @@ import linecache +from typing import Callable, Dict, Any import numpy as np import jax.numpy as jnp @@ -82,7 +83,7 @@ def totalPE(positions, box, pairs, params): class Hamiltonian(app.forcefield.ForceField): - def __init__(self, *xmlnames): + def __init__(self, *xmlnames, **kwargs): super().__init__(*xmlnames) self._pseudo_ff = app.ForceField(*xmlnames) # parse XML forcefields @@ -104,6 +105,9 @@ def __init__(self, *xmlnames): self.extractParameterTree() # hook generators to self._forces + # use noOmmSys to disable all traditional openmm system + if kwargs.get("noOmmSys", False): + self._forces = [] for jaxGen in self._jaxGenerators: self._forces.append(jaxGen) @@ -184,6 +188,29 @@ def createPotential(self, print(e) pass + # virtual site + try: + addVsiteFunc = generator.getAddVsiteFunc() + self.setAddVirtualSiteFunc(addVsiteFunc) + vsiteObj = generator.getVsiteObj() + self.setVirtualSiteObj(vsiteObj) + except AttributeError as e: + pass + + # covalent map + try: + cov_map = generator.covalent_map + self.setCovalentMap(cov_map) + except AttributeError as e: + pass + + # topology matrix (for BCC usage) + try: + top_mat = generator.getTopologyMatrix() + self.setTopologyMatrix(top_mat) + except AttributeError as e: + pass + return potObj def render(self, filename): @@ -201,4 +228,151 @@ def update_iter(node, ref): else: node[key] = ref[key] - update_iter(self.paramtree, paramtree) \ No newline at end of file + update_iter(self.paramtree, paramtree) + + def setCovalentMap(self, cov_map: jnp.ndarray): + self._cov_map = cov_map + + def getCovalentMap(self) -> jnp.ndarray: + """ + Get covalent map + """ + if hasattr(self, "_cov_map"): + return self._cov_map + else: + raise DMFFException("Covalent map is not set.") + + def getAddVirtualSiteFunc(self) -> Callable: + return self._add_vsite_coords + + def setAddVirtualSiteFunc(self, func: Callable): + self._add_vsite_coords = func + + def setVirtualSiteObj(self, vsite): + self._vsite = vsite + + def getVirtualSiteObj(self): + return self._vsite + + def setTopologyMatrix(self, top_mat): + self._top_mat = top_mat + + def getTopologyMatrix(self): + return self._top_mat + + def addVirtualSiteCoords(self, pos: jnp.ndarray, params: Dict[str, Any]) -> jnp.ndarray: + """ + Add coordinates for virtual sites + + Parameters + ---------- + pos: jnp.ndarray + Coordinates without virtual sites + params: dict + Paramtree of hamiltonian, i.e. `dmff.Hamiltonian.paramtree` + + Return + ------ + newpos: jnp.ndarray + + Examples + -------- + >>> import jax.numpy as jnp + >>> import openmm.app as app + >>> from rdkit import Chem + >>> from dmff import Hamiltonian + >>> pdb = app.PDBFile("tests/data/chlorobenzene.pdb") + >>> pos = jnp.array(pdb.getPositions(asNumpy=True)._value) + >>> mol = Chem.MolFromMolFile("tests/data/chlorobenzene.mol", removeHs=False) + >>> h = Hamiltonian("tests/data/cholorobenzene_vsite.xml") + >>> potObj = h.createPotential(pdb.topology, rdmol=mol) + >>> newpos = h.addVirtualSiteCoords(pos, h.paramtree) + + """ + func = self.getAddVirtualSiteFunc() + newpos = func(pos, params) + return newpos + + def addVirtualSiteToMol(self, rdmol, params): + """ + Add coordinates for rdkit.Chem.Mol object + + Parameters + ---------- + rdmol: rdkit.Chem.Mol + Mol object to which virtual sites are added + params: dict + Paramtree of hamiltonian, i.e. `dmff.Hamiltonian.paramtree` + + Return + ------ + newmol: rdkit.Chem.Mol + Mol object with virtual sites added + + Examples + -------- + >>> import jax.numpy as jnp + >>> import openmm.app as app + >>> from rdkit import Chem + >>> from dmff import Hamiltonian + >>> pdb = app.PDBFile("tests/data/chlorobenzene.pdb") + >>> mol = Chem.MolFromMolFile("tests/data/chlorobenzene.mol", removeHs=False) + >>> h = Hamiltonian("tests/data/cholorobenzene_vsite.xml") + >>> potObj = h.createPotential(pdb.topology, rdmol=mol) + >>> newmol = h.addVirtualSiteToMol(mol, h.paramtree) + """ + vsiteObj = self.getVirtualSiteObj() + newmol = vsiteObj.addVirtualSiteToMol( + rdmol, + params['NonbondedForce']['vsite_types'], + params['NonbondedForce']['vsite_distances'] + ) + return newmol + + @staticmethod + def buildTopologyFromMol(rdmol, resname: str = "MOL") -> app.Topology: + """ + Build openmm.app.Topology from rdkit.Chem.Mol Object + + Parameters + ---------- + rdmol: rdkit.Chem.Mol + Mol object + resname: str + Name of the added residue, default "MOL" + + Return + ------ + top: `openmm.app.Topology` + Topology built based on the input rdkit Mol object + """ + from rdkit import Chem + + top = app.Topology() + chain = top.addChain(0) + res = top.addResidue(resname, chain, "1", "") + + atCount = {} + addedAtoms = [] + for idx, atom in enumerate(rdmol.GetAtoms()): + symb = atom.GetSymbol().upper() + atCount.update({symb: atCount.get(symb, 0) + 1}) + ele = app.Element.getBySymbol(symb) + atName = f'{symb}{atCount[symb]}' + + addedAtom = top.addAtom(atName, ele, res, str(idx+1)) + addedAtoms.append(addedAtom) + + bondTypeMap = { + Chem.rdchem.BondType.SINGLE: app.Single, + Chem.rdchem.BondType.DOUBLE: app.Double, + Chem.rdchem.BondType.TRIPLE: app.Triple, + Chem.rdchem.BondType.AROMATIC: app.Aromatic + } + for bond in rdmol.GetBonds(): + top.addBond( + addedAtoms[bond.GetBeginAtomIdx()], + addedAtoms[bond.GetEndAtomIdx()], + type=bondTypeMap.get(bond.GetBondType(), None) + ) + return top \ No newline at end of file diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index b33f2f1b8..8438b9ee5 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional import jax.numpy as jnp import numpy as np @@ -130,10 +130,11 @@ def get_energy(box, epsilon, sigma, epsfix, sigfix): class CoulNoCutoffForce: # E=\frac{{q}_{1}{q}_{2}}{4\pi\epsilon_0\epsilon_1 r} - def __init__(self, map_prm, epsilon_1=1.0) -> None: + def __init__(self, map_prm, epsilon_1=1.0, topology_matrix=None) -> None: self.eps_1 = epsilon_1 self.map_prm = map_prm + self.top_mat = topology_matrix def generate_get_energy(self): def get_coul_energy(dr_vec, chrgprod, box): @@ -145,7 +146,6 @@ def get_coul_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): - pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) mask = pair_buffer_scales(pairs[:, :2]) map_prm = jnp.array(self.map_prm) @@ -163,9 +163,16 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_coul_energy(dr_vec, chrgprod_scale, box) - return jnp.sum(E_inter * mask) - - return get_energy + return jnp.sum(E_inter * mask) + + def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): + charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy(positions, box, pairs, charges, mscales) + + if self.top_mat is None: + return get_energy + else: + return get_energy_bcc class CoulReactionFieldForce: @@ -177,6 +184,7 @@ def __init__( epsilon_1=1.0, epsilon_solv=78.5, isPBC=True, + topology_matrix=None ) -> None: self.r_cut = r_cut @@ -186,6 +194,7 @@ def __init__( self.eps_1 = epsilon_1 self.map_prm = map_prm self.ifPBC = isPBC + self.top_mat = topology_matrix def generate_get_energy(self): def get_rf_energy(dr_vec, chrgprod, box): @@ -204,7 +213,6 @@ def get_rf_energy(dr_vec, chrgprod, box): return E def get_energy(positions, box, pairs, charges, mscales): - pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) mask = pair_buffer_scales(pairs[:, :2]) @@ -223,7 +231,14 @@ def get_energy(positions, box, pairs, charges, mscales): return jnp.sum(E_inter * mask) - return get_energy + def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): + charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy(positions, box, pairs, charges, mscales) + + if self.top_mat is None: + return get_energy + else: + return get_energy_bcc class CoulombPMEForce: @@ -235,6 +250,7 @@ def __init__( kappa: float, K: Tuple[int, int, int], pme_order: int = 6, + topology_matrix: Optional[jnp.array] = None, ): self.r_cut = r_cut self.map_prm = map_prm @@ -242,6 +258,7 @@ def __init__( self.kappa = kappa self.K1, self.K2, self.K3 = K[0], K[1], K[2] self.pme_order = pme_order + self.top_mat = topology_matrix assert pme_order == 6, "PME order other than 6 is not supported" def generate_get_energy(self): @@ -283,4 +300,11 @@ def get_energy(positions, box, pairs, charges, mscales): False, ) - return get_energy + def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): + charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy(positions, box, pairs, charges, mscales) + + if self.top_mat is None: + return get_energy + else: + return get_energy_bcc diff --git a/dmff/classical/vsite.py b/dmff/classical/vsite.py new file mode 100644 index 000000000..9da7e960b --- /dev/null +++ b/dmff/classical/vsite.py @@ -0,0 +1,109 @@ +from typing import Tuple, List, Dict, Callable, Optional +import numpy as np +import jax.numpy as jnp + + +class VirtualSite: + """ + Class for manipulation of virtual sites + """ + def __init__(self, matches_dict: Dict[Tuple[int], int]): + """ + Initialize a virtual site object + """ + self.num_vsites = len(matches_dict) + self.matches, self.indices = [], [] + for key, value in matches_dict.items(): + self.matches.append(key) + self.indices.append(value) + + def getAddVirtualSiteFunc(self) -> Callable: + """ + Get fuction to compute virtual site coordinates + """ + + def add_vsite_position(pos: jnp.ndarray, vtypes: jnp.ndarray, vdist: jnp.ndarray) -> jnp.ndarray: + newpos = jnp.zeros((pos.shape[0] + self.num_vsites, pos.shape[1])) + newpos = newpos.at[:pos.shape[0]].set(pos) + for i in range(self.num_vsites): + match = self.matches[i] + idx = self.indices[i] + if vtypes[idx] == 1: + vec = newpos[match[0]] - newpos[match[1]] + nvec = vec / jnp.linalg.norm(vec, ord=2) + newpos = newpos.at[pos.shape[0] + i].set(newpos[match[0]] + vdist[idx] * nvec) + elif vtypes[idx] == 2: + vec1 = newpos[match[0]] - newpos[match[1]] + vec2 = newpos[match[0]] - newpos[match[2]] + nvec1 = vec1 / jnp.linalg.norm(vec1, ord=2) + nvec2 = vec2 / jnp.linalg.norm(vec2, ord=2) + nvec = (nvec1 + nvec2) / jnp.linalg.norm(nvec1 + nvec2, ord=2) + newpos = newpos.at[pos.shape[0] + i].set(newpos[match[0]] + vdist[idx] * nvec) + return newpos + + return add_vsite_position + + def addVirtualSiteToMol(self, rdmol, vtypes=None, vdist=None): + """ + Add virtual site to rdkit.Chem.Mol object + + Parameters + ---------- + rdmol: rdkit.Chem.Mol + Mol object to which virtual sites are added + vtypes: jnp.ndarray or None + Virtual site types, can be obtained from `dmff.Hamiltonian.paramtree['vsite_types']` + vdist: jnp.ndarray or None + Virtual site distances params, can be obtained from `dmff.Hamiltonian.paramtree['vsite_distances']` + + Return + ------ + newmol: rdkit.Chem.Mol + Mol with virtual sites added + """ + if isinstance(vtypes, jnp.ndarray) and isinstance(vdist, jnp.ndarray): + func = self.getAddVirtualSiteFunc() + # convert between angstrom and nm + pos = jnp.array(rdmol.GetConformer(0).GetPositions()) / 10 + addCoords = func(pos, vtypes, vdist) * 10 + else: + addCoords = None + + newmol = self.add_dummy(rdmol, [m[0] for m in self.matches], addCoords) + return newmol + + @staticmethod + def add_dummy(mol, parentAtomIdx: List[int], addCoords: Optional[np.ndarray]): + """ + Add dummy atom to rdkit.Chem.Mol object and make a dummy bond between + the dummy atom and its parent atom + + Parameters + ---------- + mol: rdkit.Chem.Mol + Molecule to add dummy atom + parentAtomIdx: int + Index of the dummy atom's parent atom + addCoords: numpy.ndarray or None + Coordinates of the virtual sites. In unit of Angstrom + """ + from rdkit import Chem + ori_num_atoms = mol.GetNumAtoms() + rwmol = Chem.RWMol(mol) + + duIdxs = [] + for pidx in parentAtomIdx: + dummy = Chem.Atom(0) + duIdx = rwmol.AddAtom(dummy) + rwmol.AddBond(duIdx, pidx) + newmol = rwmol.GetMol() + duIdxs.append(duIdx) + + if addCoords is not None: + assert len(addCoords) == len(parentAtomIdx) + ori_num_atoms, f"Number of atoms in coordinates doesn't match" + conf = newmol.GetConformer() + for i, duIdx in enumerate(duIdxs): + conf.SetAtomPosition(duIdx, [float(x) for x in addCoords[duIdx]]) + + return newmol + diff --git a/dmff/fftree.py b/dmff/fftree.py index b68b9a882..e4b954f53 100644 --- a/dmff/fftree.py +++ b/dmff/fftree.py @@ -257,11 +257,13 @@ def write(self, path): class TypeMatcher: - def __init__(self, fftree: ForcefieldTree, parser): + def __init__(self, fftree: ForcefieldTree, parser: str): """ Freeze type matching list. """ # not convert to float for atom types + self.useSmirks = False + atypes = fftree.get_attribs("AtomTypes/Type", "name", convert_to_float=False) aclasses = fftree.get_attribs("AtomTypes/Type", "class", convert_to_float=False) self.class2type = {} @@ -288,8 +290,13 @@ def __init__(self, fftree: ForcefieldTree, parser): tmp.append((nit, self.class2type.get(node.attrs[key], [None]))) elif key == "class": tmp.append((1, self.class2type.get(node.attrs[key], [None]))) - tmp = sorted(tmp, key=lambda x: x[0]) - self.functions.append([i[1] for i in tmp]) + elif key == "smirks": + self.useSmirks = True + self.functions.append(node.attrs[key]) + + if not self.useSmirks: + tmp = sorted(tmp, key=lambda x: x[0]) + self.functions.append([i[1] for i in tmp]) def matchGeneral(self, types): matches = [] @@ -300,6 +307,82 @@ def matchGeneral(self, types): if len(matches) == 0: return False, False, -1 return matches[-1] + + def matchSmirks(self, rdmol): + """ + Match smirks + """ + from rdkit import Chem + + if rdmol is None: + raise DMFFException("No rdkit.Chem.Mol object is provided") + + matches_dict = {} + for idx, smk in enumerate(self.functions): + patt = Chem.MolFromSmarts(smk) + matches = rdmol.GetSubstructMatches(patt) + for match in matches: + if len(match) == 2: + canonical_match = (min(match), max(match)) + elif len(match) == 3: + canonical_match = (min([match[0], match[2]]), match[1], max([match[0], match[2]])) + elif len(match) == 4: + canonical_match = (match[3], match[2], match[1], match[0]) if match[2] < match[1] else match + elif len(match) == 1: + canonical_match = match + else: + raise DMFFException(f"Invalid SMIRKS: {smk}") + matches_dict.update({canonical_match: idx}) + + return matches_dict + + def matchSmirksNoSort(self, rdmol): + """ + Match smirks, but no sorting the matched atom indices + """ + from rdkit import Chem + + if rdmol is None: + raise DMFFException("No rdkit.Chem.Mol object is provided") + + matches_dict = {} + for idx, smk in enumerate(self.functions): + patt = Chem.MolFromSmarts(smk) + matches = rdmol.GetSubstructMatches(patt) + for match in matches: + matches_dict.update({match: idx}) + + return matches_dict + + def matchSmirksImproper(self, rdmol): + """ + Match smirks for improper torsions + """ + from rdkit import Chem + + if rdmol is None: + raise DMFFException("No rdkit.Chem.Mol object is provided") + + matches_dict = {} + for idx, smk in enumerate(self.functions): + patt = Chem.MolFromSmarts(smk) + matches = rdmol.GetSubstructMatches(patt) + hasWildcard = "*" in smk + for match in matches: + # Be the most consistent with AMBER, in which ordering is determined in this way + atnums = [rdmol.GetAtomWithIdx(i).GetAtomicNum() for i in match] + if hasWildcard: + if atnums[1] == atnums[2] and match[1] > match[2]: + canonical_match = (match[2], match[1], match[0], match[3]) + elif atnums[1] != 6 and (atnums[2] == 6 or atnums[1] < atnums[2]): + canonical_match = (match[2], match[1], match[0], match[3]) + else: + canonical_match = (match[1], match[2], match[0], match[3]) + else: + canonical_match = match + matches_dict.update({canonical_match: idx}) + + return matches_dict def _match(self, types, term): if len(types) != len(term): diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index b4584ba9a..cda71da45 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -1,5 +1,6 @@ from collections import defaultdict from typing import Dict +from copy import deepcopy import numpy as np import jax.numpy as jnp @@ -26,7 +27,8 @@ LennardJonesLongRangeFreeEnergyForce, CoulombPMEFreeEnergyForce ) -from dmff.admp.pme import setup_ewald_parameters +from dmff.classical.vsite import VirtualSite +from dmff.admp.pme import setup_ewald_parameters from dmff.utils import jit_condition, isinstance_jnp, DMFFException, findItemInList from dmff.fftree import ForcefieldTree, TypeMatcher from dmff.api import Hamiltonian, build_covalent_map @@ -77,24 +79,42 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): matcher = TypeMatcher(self.fftree, "HarmonicBondForce/Bond") map_atom1, map_atom2, map_param = [], [], [] - n_bonds = len(data.bonds) - # build map - for i in range(n_bonds): - idx1 = data.bonds[i].atom1 - idx2 = data.bonds[i].atom2 - type1 = data.atomType[data.atoms[idx1]] - type2 = data.atomType[data.atoms[idx2]] - ifFound, ifForward, nfunc = matcher.matchGeneral([type1, type2]) - if not ifFound: - raise BaseException( - f"No parameter for bond ({idx1},{type1}) - ({idx2},{type2})" - ) - map_atom1.append(idx1) - map_atom2.append(idx2) - map_param.append(nfunc) + + if not matcher.useSmirks: + n_bonds = len(data.bonds) + # build map + for i in range(n_bonds): + idx1 = data.bonds[i].atom1 + idx2 = data.bonds[i].atom2 + type1 = data.atomType[data.atoms[idx1]] + type2 = data.atomType[data.atoms[idx2]] + ifFound, ifForward, nfunc = matcher.matchGeneral([type1, type2]) + if not ifFound: + raise DMFFException( + f"No parameter for bond ({idx1},{type1}) - ({idx2},{type2})" + ) + map_atom1.append(idx1) + map_atom2.append(idx2) + map_param.append(nfunc) + else: + rdmol = args.get("rdmol", None) + matches_dict = matcher.matchSmirks(rdmol) + for bond in rdmol.GetBonds(): + beginAtomIdx = bond.GetBeginAtomIdx() + endAtomIdx = bond.GetEndAtomIdx() + query = (beginAtomIdx, endAtomIdx) if beginAtomIdx < endAtomIdx else (endAtomIdx, beginAtomIdx) + map_atom1.append(query[0]) + map_atom2.append(query[1]) + try: + map_param.append(matches_dict[query]) + except KeyError as e: + raise DMFFException( + f"No parameter for bond between Atom{beginAtomIdx} and Atom{endAtomIdx}" + ) + map_atom1 = np.array(map_atom1, dtype=int) map_atom2 = np.array(map_atom2, dtype=int) - map_param = np.array(map_param, dtype=int) + map_param = np.array(map_param, dtype=int) bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param) @@ -137,25 +157,46 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): matcher = TypeMatcher(self.fftree, "HarmonicAngleForce/Angle") map_atom1, map_atom2, map_atom3, map_param = [], [], [], [] - n_angles = len(data.angles) - for nangle in range(n_angles): - idx1 = data.angles[nangle][0] - idx2 = data.angles[nangle][1] - idx3 = data.angles[nangle][2] - type1 = data.atomType[data.atoms[idx1]] - type2 = data.atomType[data.atoms[idx2]] - type3 = data.atomType[data.atoms[idx3]] - ifFound, ifForward, nfunc = matcher.matchGeneral( - [type1, type2, type3]) - if not ifFound: - print( - f"No parameter for angle ({idx1},{type1}) - ({idx2},{type2}) - ({idx3},{type3})" - ) - else: - map_atom1.append(idx1) - map_atom2.append(idx2) - map_atom3.append(idx3) - map_param.append(nfunc) + + if not matcher.useSmirks: + n_angles = len(data.angles) + for nangle in range(n_angles): + idx1 = data.angles[nangle][0] + idx2 = data.angles[nangle][1] + idx3 = data.angles[nangle][2] + type1 = data.atomType[data.atoms[idx1]] + type2 = data.atomType[data.atoms[idx2]] + type3 = data.atomType[data.atoms[idx3]] + ifFound, ifForward, nfunc = matcher.matchGeneral( + [type1, type2, type3]) + if not ifFound: + print( + f"No parameter for angle ({idx1},{type1}) - ({idx2},{type2}) - ({idx3},{type3})" + ) + else: + map_atom1.append(idx1) + map_atom2.append(idx2) + map_atom3.append(idx3) + map_param.append(nfunc) + else: + from rdkit import Chem + + rdmol = args.get("rdmol", None) + matches_dict = matcher.matchSmirks(rdmol) + angle_patt = Chem.MolFromSmarts("[*:1]~[*:2]~[*:3]") + angles = rdmol.GetSubstructMatches(angle_patt) + for angle in angles: + canonical_angle = (min([angle[0], angle[2]]), angle[1], max([angle[0], angle[2]])) + map_atom1.append(canonical_angle[0]) + map_atom2.append(canonical_angle[1]) + map_atom3.append(canonical_angle[2]) + try: + map_param.append(matches_dict[canonical_angle]) + except KeyError as e: + raise DMFFException( + f"No parameter for angle Atom{canonical_angle[0]}-Atom{canonical_angle[1]}-Atom{canonical_angle[2]}" + ) + map_atom1 = np.array(map_atom1, dtype=int) map_atom2 = np.array(map_atom2, dtype=int) map_atom3 = np.array(map_atom3, dtype=int) @@ -294,6 +335,10 @@ def overwrite(self): self.fftree.set_node("PeriodicTorsionForce/Improper", impr_data) def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): + """ + Create force for torsions + """ + # Proper Torsions proper_matcher = TypeMatcher(self.fftree, "PeriodicTorsionForce/Proper") map_prop_atom1 = {i: [] for i in range(1, self.max_pred_prop + 1)} @@ -302,24 +347,55 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): map_prop_atom4 = {i: [] for i in range(1, self.max_pred_prop + 1)} map_prop_param = {i: [] for i in range(1, self.max_pred_prop + 1)} n_matched_props = 0 - for torsion in data.propers: - types = [data.atomType[data.atoms[torsion[i]]] for i in range(4)] - ifFound, ifForward, nnode = proper_matcher.matchGeneral(types) - if not ifFound: - continue - # find terms for node - for periodicity in range(1, self.max_pred_prop + 1): - idx = findItemInList( - nnode, self.meta[f"prop_nodeidx"][f"{periodicity}"]) - if idx < 0: - continue - n_matched_props += 1 - map_prop_atom1[periodicity].append(torsion[0]) - map_prop_atom2[periodicity].append(torsion[1]) - map_prop_atom3[periodicity].append(torsion[2]) - map_prop_atom4[periodicity].append(torsion[3]) - map_prop_param[periodicity].append(idx) + if not proper_matcher.useSmirks: + for torsion in data.propers: + types = [data.atomType[data.atoms[torsion[i]]] for i in range(4)] + ifFound, ifForward, nnode = proper_matcher.matchGeneral(types) + if not ifFound: + continue + # find terms for node + for periodicity in range(1, self.max_pred_prop + 1): + idx = findItemInList( + nnode, self.meta[f"prop_nodeidx"][f"{periodicity}"]) + if idx < 0: + continue + n_matched_props += 1 + map_prop_atom1[periodicity].append(torsion[0]) + map_prop_atom2[periodicity].append(torsion[1]) + map_prop_atom3[periodicity].append(torsion[2]) + map_prop_atom4[periodicity].append(torsion[3]) + map_prop_param[periodicity].append(idx) + else: + from rdkit import Chem + + rdmol = args.get("rdmol", None) + proper_patt = Chem.MolFromSmarts("[*:1]~[*:2]-[*:3]~[*:4]") + propers = rdmol.GetSubstructMatches(proper_patt) + matches_dict = proper_matcher.matchSmirks(rdmol) + for match in propers: + torsion = (match[3], match[2], match[1], match[0]) if match[2] < match[1] else match + try: + nnode = matches_dict[torsion] + ifFound = True + n_matched_props += 1 + except KeyError: + ifFound = False + + if not ifFound: + continue + + for periodicity in range(1, self.max_pred_prop + 1): + idx = findItemInList(nnode, self.meta['prop_nodeidx'][f"{periodicity}"]) + if idx < 0: + continue + map_prop_atom1[periodicity].append(torsion[0]) + map_prop_atom2[periodicity].append(torsion[1]) + map_prop_atom3[periodicity].append(torsion[2]) + map_prop_atom4[periodicity].append(torsion[3]) + map_prop_param[periodicity].append(idx) + + # Improper Torsions impr_matcher = TypeMatcher(self.fftree, "PeriodicTorsionForce/Improper") try: @@ -327,47 +403,69 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): "ordering")[0] except KeyError as e: ordering = "default" + map_impr_atom1 = {i: [] for i in range(1, self.max_pred_impr + 1)} map_impr_atom2 = {i: [] for i in range(1, self.max_pred_impr + 1)} map_impr_atom3 = {i: [] for i in range(1, self.max_pred_impr + 1)} map_impr_atom4 = {i: [] for i in range(1, self.max_pred_impr + 1)} map_impr_param = {i: [] for i in range(1, self.max_pred_impr + 1)} n_matched_imprs = 0 - for impr in data.impropers: - match = impr_matcher.matchImproper(impr, data, ordering=ordering) - if match is not None: - (a1, a2, a3, a4, nnode) = match + + if not impr_matcher.useSmirks: + for impr in data.impropers: + match = impr_matcher.matchImproper(impr, data, ordering=ordering) + if match is not None: + (a1, a2, a3, a4, nnode) = match + n_matched_imprs += 1 + # find terms for node + for periodicity in range(1, self.max_pred_impr + 1): + idx = findItemInList( + nnode, self.meta[f"impr_nodeidx"][f"{periodicity}"]) + if idx < 0: + continue + if ordering == 'smirnoff': + # Add all torsions in trefoil + map_impr_atom1[periodicity].append(a1) + map_impr_atom2[periodicity].append(a2) + map_impr_atom3[periodicity].append(a3) + map_impr_atom4[periodicity].append(a4) + map_impr_param[periodicity].append(idx) + map_impr_atom1[periodicity].append(a1) + map_impr_atom2[periodicity].append(a3) + map_impr_atom3[periodicity].append(a4) + map_impr_atom4[periodicity].append(a2) + map_impr_param[periodicity].append(idx) + map_impr_atom1[periodicity].append(a1) + map_impr_atom2[periodicity].append(a4) + map_impr_atom3[periodicity].append(a2) + map_impr_atom4[periodicity].append(a3) + map_impr_param[periodicity].append(idx) + else: + map_impr_atom1[periodicity].append(a1) + map_impr_atom2[periodicity].append(a2) + map_impr_atom3[periodicity].append(a3) + map_impr_atom4[periodicity].append(a4) + map_impr_param[periodicity].append(idx) + else: + rdmol = args.get("rdmol", None) + + if rdmol is None: + raise DMFFException("No rdkit.Chem.Mol object is provided") + + matches_dict = impr_matcher.matchSmirksImproper(rdmol) + for torsion, nnode in matches_dict.items(): n_matched_imprs += 1 - # find terms for node - for periodicity in range(1, self.max_pred_impr + 1): - idx = findItemInList( - nnode, self.meta[f"impr_nodeidx"][f"{periodicity}"]) + for periodicity in range(1, self.max_pred_impr+ 1): + idx = findItemInList(nnode, self.meta['impr_nodeidx'][f"{periodicity}"]) if idx < 0: continue - if ordering == 'smirnoff': - # Add all torsions in trefoil - map_impr_atom1[periodicity].append(a1) - map_impr_atom2[periodicity].append(a2) - map_impr_atom3[periodicity].append(a3) - map_impr_atom4[periodicity].append(a4) - map_impr_param[periodicity].append(idx) - map_impr_atom1[periodicity].append(a1) - map_impr_atom2[periodicity].append(a3) - map_impr_atom3[periodicity].append(a4) - map_impr_atom4[periodicity].append(a2) - map_impr_param[periodicity].append(idx) - map_impr_atom1[periodicity].append(a1) - map_impr_atom2[periodicity].append(a4) - map_impr_atom3[periodicity].append(a2) - map_impr_atom4[periodicity].append(a3) - map_impr_param[periodicity].append(idx) - else: - map_impr_atom1[periodicity].append(a1) - map_impr_atom2[periodicity].append(a2) - map_impr_atom3[periodicity].append(a3) - map_impr_atom4[periodicity].append(a4) - map_impr_param[periodicity].append(idx) - + map_impr_atom1[periodicity].append(torsion[0]) + map_impr_atom2[periodicity].append(torsion[1]) + map_impr_atom3[periodicity].append(torsion[2]) + map_impr_atom4[periodicity].append(torsion[3]) + map_impr_param[periodicity].append(idx) + + # Sum proper and improper torsions props = [ PeriodicTorsionJaxForce(jnp.array(map_prop_atom1[p], dtype=int), jnp.array(map_prop_atom2[p], dtype=int), @@ -413,7 +511,7 @@ def getJaxPotential(self): class NonbondedJaxGenerator: - def __init__(self, ff): + def __init__(self, ff: Hamiltonian): self.name = "NonbondedForce" self.ff = ff self.fftree = ff.fftree @@ -427,6 +525,9 @@ def __init__(self, ff): self.ra2idx = {} self.idx2rai = {} + self.useBCC = False + self.useVsite = False + def extract(self): self.from_residue = self.fftree.get_attribs( "NonbondedForce/UseAttributeFromResidue", "name") @@ -438,6 +539,7 @@ def extract(self): for prm in self.from_force: vals = self.fftree.get_attribs("NonbondedForce/Atom", prm) self.paramtree[self.name][prm] = jnp.array(vals) + # Build per-atom array for from_residue residues = self.fftree.get_nodes("Residues/Residue") resvals = {k: [] for k in self.from_residue} @@ -454,6 +556,7 @@ def extract(self): resvals[prm].extend(atomval) for prm in self.from_residue: self.paramtree[self.name][prm] = jnp.array(resvals[prm]) + # Build coulomb14scale and lj14scale coulomb14scale, lj14scale = self.fftree.get_attribs( "NonbondedForce", ["coulomb14scale", "lj14scale"])[0] @@ -461,6 +564,18 @@ def extract(self): [coulomb14scale]) self.paramtree[self.name]["lj14scale"] = jnp.array([lj14scale]) + # Build BondChargeCorrection + bccs = self.fftree.get_attribs("NonbondedForce/BondChargeCorrection", "bcc") + self.paramtree[self.name]['bcc'] = jnp.array(bccs).reshape(-1, 1) + self.useBCC = len(bccs) > 0 + + # Build VirtualSite + vsite_types = self.fftree.get_attribs("NonbondedForce/VirtualSite", "vtype") + self.paramtree[self.name]['vsite_types'] = jnp.array(vsite_types, dtype=int) + vsite_distance = self.fftree.get_attribs("NonbondedForce/VirtualSite", "distance") + self.paramtree[self.name]['vsite_distances'] = jnp.array(vsite_distance) + self.useVsite = len(vsite_types) > 0 + def overwrite(self): # write coulomb14scale self.fftree.set_attrib("NonbondedForce", "coulomb14scale", @@ -486,9 +601,18 @@ def overwrite(self): [d for d in data if d[0] == resnode.attrs["name"]], key=lambda x: x[1]) resnode.set_attrib("Atom", prm, [t[2] for t in tmp]) - - def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, - args): + + # write BCC + if self.useBCC: + self.fftree.set_attrib( + "NonbondedForce/BondChargeCorrection", "bcc", + self.paramtree[self.name]['bcc'] + ) + + def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args): + # Build Covalent Map + self.covalent_map = build_covalent_map(data, 6) + methodMap = { app.NoCutoff: "NoCutoff", app.CutoffPeriodic: "CutoffPeriodic", @@ -509,7 +633,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, mscales_lj = mscales_lj.at[2].set( self.paramtree[self.name]["lj14scale"][0]) - # Coulomb: only support PME for now # set PBC if nonbondedMethod not in [app.NoCutoff, app.CutoffNonPeriodic]: ifPBC = True @@ -517,20 +640,73 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, ifPBC = False nbmatcher = TypeMatcher(self.fftree, "NonbondedForce/Atom") - # load LJ from types - maps = {} - for prm in self.from_force: - maps[prm] = [] - for atom in data.atoms: - atype = data.atomType[atom] - ifFound, _, nnode = nbmatcher.matchGeneral([atype]) - if not ifFound: - raise DMFFException( - "AtomType of %s mismatched in NonbondedForce" % - (str(atom))) - maps[prm].append(nnode) - maps[prm] = jnp.array(maps[prm], dtype=int) + + rdmol = args.get("rdmol", None) + + if self.useVsite: + vsitematcher = TypeMatcher(self.fftree, "NonbondedForce/VirtualSite") + vsite_matches_dict = vsitematcher.matchSmirksNoSort(rdmol) + vsiteObj = VirtualSite(vsite_matches_dict) + + def addVsiteFunc(pos, params): + func = vsiteObj.getAddVirtualSiteFunc() + newpos = func(pos, params[self.name]['vsite_types'], params[self.name]['vsite_distances']) + return newpos + + self._addVsiteFunc = addVsiteFunc + rdmol = vsiteObj.addVirtualSiteToMol(rdmol) + self.vsiteObj = vsiteObj + + # expand covalent map + ori_dim = self.covalent_map.shape[0] + new_dim = ori_dim + len(vsite_matches_dict) + cov_map = np.zeros((new_dim, new_dim), dtype=int) + cov_map[:ori_dim, :ori_dim] += np.array(self.covalent_map, dtype=int) + + map_to_parents = np.arange(new_dim) + for i, match in enumerate(vsite_matches_dict.keys()): + map_to_parents[ori_dim + i] = match[0] + for i in range(len(vsite_matches_dict)): + parent_i = map_to_parents[ori_dim + i] + for j in range(new_dim): + parent_j = map_to_parents[j] + cov_map[ori_dim + i, j] = cov_map[parent_i, parent_j] + cov_map[j, ori_dim + i] = cov_map[parent_j, parent_i] + # keep diagonal 0 + cov_map[ori_dim + i, ori_dim + i] = 0 + # keep vsite and its parent atom 1 + cov_map[parent_i, ori_dim + i] = 1 + cov_map[ori_dim + i, parent_i] = 1 + self.covalent_map = jnp.array(cov_map) + + # Load Lennard-Jones parameters + maps = {} + if not nbmatcher.useSmirks: + for prm in self.from_force: + maps[prm] = [] + for atom in data.atoms: + atype = data.atomType[atom] + ifFound, _, nnode = nbmatcher.matchGeneral([atype]) + if not ifFound: + raise DMFFException( + "AtomType of %s mismatched in NonbondedForce" % + (str(atom))) + maps[prm].append(nnode) + maps[prm] = jnp.array(maps[prm], dtype=int) + else: + lj_matches_dict = nbmatcher.matchSmirks(rdmol) + for prm in self.from_force: + maps[prm] = [] + for i in range(rdmol.GetNumAtoms()): + try: + maps[prm].append(lj_matches_dict[(i,)]) + except KeyError as e: + raise DMFFException( + f"No parameter for atom {i}" + ) + maps[prm] = jnp.array(maps[prm], dtype=int) + for prm in self.from_residue: maps[prm] = [] for atom in data.atoms: @@ -538,10 +714,56 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, aidx = data.atomTemplateIndexes[atom] resname, aname = templateName, atom.name maps[prm].append(self.ra2idx[(resname, aidx)]) + + # Virtual Site + if self.useVsite: + # expand charges + chg = jnp.zeros( + (len(self.paramtree[self.name]['charge']) + len(vsite_matches_dict),), + dtype=self.paramtree[self.name]['charge'].dtype + ) + self.paramtree[self.name]['charge'] = chg.at[:len(self.paramtree[self.name]['charge'])].set( + self.paramtree[self.name]['charge'] + ) + maps_chg = [int(x) for x in maps['charge']] + for i in range(len(vsite_matches_dict)): + maps_chg.append(len(maps['charge']) + i) + maps['charge'] = jnp.array(maps_chg, dtype=int) + + # BCC parameters + if self.useBCC: + bccmatcher = TypeMatcher(self.fftree, "NonbondedForce/BondChargeCorrection") + + if bccmatcher.useSmirks: + bcc_matches_dict = bccmatcher.matchSmirksNoSort(rdmol) + self.top_mat = np.zeros((rdmol.GetNumAtoms(), self.paramtree[self.name]['bcc'].shape[0])) + + for bond in rdmol.GetBonds(): + beginAtomIdx = bond.GetBeginAtomIdx() + endAtomIdx = bond.GetEndAtomIdx() + query1, query2 = (beginAtomIdx, endAtomIdx), (endAtomIdx, beginAtomIdx) + if query1 in bcc_matches_dict: + nnode = bcc_matches_dict[query1] + self.top_mat[query1[0], nnode] += 1 + self.top_mat[query1[1], nnode] -= 1 + elif query2 in bcc_matches_dict: + nnode = bcc_matches_dict[query2] + self.top_mat[query2[0], nnode] += 1 + self.top_mat[query2[1], nnode] -= 1 + else: + raise DMFFException( + f"No BCC parameter for bond between Atom{beginAtomIdx} and Atom{endAtomIdx}" + ) + else: + raise DMFFException( + "Only SMIRKS-based parametrization is supported for BCC" + ) + else: + self.top_mat = None + + # NBFIX map_nbfix = [] - map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2)) - - self.covalent_map = build_covalent_map(data, 6) + map_nbfix = jnp.array(map_nbfix, dtype=jnp.int32).reshape(-1, 2) if unit.is_quantity(nonbondedCutoff): r_cut = nonbondedCutoff.value_in_unit(unit.nanometer) @@ -569,7 +791,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, self.coeff_method) map_lj = jnp.array(maps["sigma"]) - map_nbfix = jnp.array(map_nbfix) map_charge = jnp.array(maps["charge"]) # Free Energy Settings # @@ -664,15 +885,17 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, # use Reaction Field coulforce = CoulReactionFieldForce(r_cut, map_charge, - isPBC=ifPBC) + isPBC=ifPBC, + topology_matrix=self.top_mat) if nonbondedMethod is app.NoCutoff: # use NoCutoff - coulforce = CoulNoCutoffForce(map_charge) + coulforce = CoulNoCutoffForce(map_charge, topology_matrix=self.top_mat) else: coulforce = CoulombPMEForce(r_cut, map_charge, kappa, - (K1, K2, K3)) + (K1, K2, K3), topology_matrix=self.top_mat) else: assert nonbondedMethod is app.PME, "Only PME is supported in free energy calculations" + assert not self.useBCC, "BCC usage in free energy calculations is not supported yet" coulforce = CoulombPMEFreeEnergyForce(r_cut, map_charge, kappa, (K1, K2, K3), @@ -699,8 +922,13 @@ def potential_fn(positions, box, pairs, params): params[self.name]["sigma"], params[self.name]["epsfix"], params[self.name]["sigfix"], mscales_lj) - coulE = coulenergy(positions, box, pairs, - params[self.name]["charge"], mscales_coul) + + if not self.useBCC: + coulE = coulenergy(positions, box, pairs, + params[self.name]["charge"], mscales_coul) + else: + coulE = coulenergy(positions, box, pairs, + params[self.name]["charge"], params[self.name]["bcc"], mscales_coul) if useDispersionCorrection: ljDispEnergy = ljDispEnergyFn(box, @@ -745,6 +973,26 @@ def potential_fn(positions, box, pairs, params, vdwLambda, def getJaxPotential(self): return self._jaxPotential + def getAddVsiteFunc(self): + """ + Get function to add coordinates for virtual sites + """ + return self._addVsiteFunc + + def getVsiteObj(self): + """ + Get `dmff.classical.vsite.VirtualSite` object + """ + if self.useVsite: + return self.vsiteObj + else: + return None + + def getTopologyMatrix(self): + """ + Get topology Matrix + """ + return self.top_mat dmff.api.jaxGenerators["NonbondedForce"] = NonbondedJaxGenerator diff --git a/docs/assets/arch.svg b/docs/assets/arch.svg new file mode 100644 index 000000000..f033f9175 --- /dev/null +++ b/docs/assets/arch.svg @@ -0,0 +1,121 @@ + + + diff --git a/docs/assets/clpy.png b/docs/assets/clpy.png new file mode 100644 index 000000000..4a2abebc5 Binary files /dev/null and b/docs/assets/clpy.png differ diff --git a/docs/assets/smirks.png b/docs/assets/smirks.png new file mode 100644 index 000000000..c79969bff Binary files /dev/null and b/docs/assets/smirks.png differ diff --git a/docs/assets/vsite.png b/docs/assets/vsite.png new file mode 100644 index 000000000..ca4874217 Binary files /dev/null and b/docs/assets/vsite.png differ diff --git a/docs/dev_guide/arch.md b/docs/dev_guide/arch.md index 4b1aa8c62..d31543c90 100644 --- a/docs/dev_guide/arch.md +++ b/docs/dev_guide/arch.md @@ -1,6 +1,6 @@ # 2. Software architecture - + The overall architechture of DMFF can be divided into two parts: 1. parser & typing and 2. calculators. We usually refer to the former as the *frontend* and the latter as the *backend* for ease of description. diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 77fbb0a3c..ff00dc02c 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -7,6 +7,9 @@ nav = mkdocs_gen_files.Nav() for path in sorted(Path("dmff").rglob("*.py")): # + + if path.name.startswith("_"): + continue module_path = path.relative_to('dmff').with_suffix("") # diff --git a/docs/index.md b/docs/index.md index 5406747fb..d66fa2f09 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,25 +1,47 @@ # DMFF +[](https://doi.org/10.26434/chemrxiv-2022-2c7gv) + +## About DMFF + **DMFF** (**D**ifferentiable **M**olecular **F**orce **F**ield) is a Jax-based python package that provides a full differentiable implementation of molecular force field models. This project aims to establish an extensible codebase to minimize the efforts in force field parameterization, and to ease the force and virial tensor evaluations for advanced complicated potentials (e.g., polarizable models with geometry-dependent atomic parameters). Currently, this project mainly focuses on the molecular systems such as: water, biological macromolecules (peptides, proteins, nucleic acids), organic polymers, and small organic molecules (organic electrolyte, drug-like molecules) etc. We support both the conventional point charge models (OPLS and AMBER like) and multipolar polarizable models (AMOEBA and MPID like). The entire project is backed by the XLA technique in JAX, thus can be "jitted" and run in GPU devices much more efficiently compared to normal python codes. The behavior of organic molecular systems (e.g., protein folding, polymer structure, etc.) is often determined by a complex effect of many different types of interactions. The existing organic molecular force fields are mainly empirically fitted and their performance relies heavily on error cancellation. Therefore, the transferability and the prediction power of these force fields are insufficient. For new molecules, the parameter fitting process requires essential manual intervention and can be quite cumbersome. In order to automate the parametrization process and increase the robustness of the model, it is necessary to apply modern AI techniques in conventional force field development. This project serves for this purpose by utilizing the automatic differentiable programming technique to develop a codebase, which allows a more convenient incorporation of modern AI optimization techniques. It also helps the realization of many exciting functions including (but not limited to): hybrid machine learning/force field models and parameter optimization based on trajectory. +### License and credits + +The project DMFF is licensed under [GNU LGPL v3.0](LICENSE). If you use this code in any future publications, please cite this using `Wang X, Li J, Yang L, Chen F, Wang Y, Chang J, et al. DMFF: An Open-Source Automatic +Differentiable Platform for Molecular Force Field +Development and Molecular Dynamics +Simulation. ChemRxiv. Cambridge: Cambridge Open Engage; 2022; This content is a preprint and has not been peer-reviewed. + ## User Guide + [1. Introduction](user_guide/introduction.md) + [2. Installation](user_guide/installation.md) -+ [3. Compute energy and forces](user_guide/compute.md) -+ [4. Compute gradients with auto differentiable framework](user_guide/auto_diff.md) -+ [5. Theories](user_guide/theory.md) -+ [6. Introduction to force field xml files](user_guide/xml_spec.md) ++ [3. Basic usage](user_guide/usage.md) ++ [4. XML format force field](user_guide/xml_spec.md) ++ [5. Theory](user_guide/theory.md) ## Developer Guide + [1. Introduction](dev_guide/introduction.md) -+ [2. Architecture](dev_guide/arch.md) -+ [3. Convention](dev_guide/convention.md) - -## Modules -+ [1. ADMP](modules/admp.md) ++ [2. Software architecture](dev_guide/arch.md) ++ [3. Coding conventions](dev_guide/convention.md) ++ [4. Document writing](dev_guide/write_docs.md) + +## Code Structure +The code is organized as follows: + ++ `examples`: demos presented in Jupyter Notebook. ++ `docs`: documentation. ++ `package`: files for constructing packages or images, such as conda recipe and docker files. ++ `tests`: unit tests. ++ `dmff`: DMFF python codes ++ `dmff/admp`: source code of automatic differentiable multipolar polarizable (ADMP) force field module. ++ `dmff/classical`: source code of classical force field module. ++ `dmff/common`: source code of common functions, such as neighbor list. ++ `dmff/generators`: source code of force generators. ++ `dmff/sgnn`: source of subgragh neural network force field model. ## Support and Contribution diff --git a/docs/user_guide/installation.md b/docs/user_guide/installation.md index d4b3c4687..d4cd73550 100644 --- a/docs/user_guide/installation.md +++ b/docs/user_guide/installation.md @@ -17,6 +17,10 @@ pip install jax-md==0.2.0 ```bash conda install -c conda-forge openmm==7.7.0 ``` ++ Install [RDKit](https://www.rdkit.org/) (required for SMIRKS-based parametrization): +```bash +conda install -c conda-forge rdkit +``` ## 2.2 Install DMFF from Source Code One can download the DMFF source code from github: ```bash diff --git a/docs/user_guide/theory.md b/docs/user_guide/theory.md index 720476852..34e4928b7 100644 --- a/docs/user_guide/theory.md +++ b/docs/user_guide/theory.md @@ -10,7 +10,7 @@ DMFF project aims to implement organic molecular force fields using a differenti All interations involved in DMFF are briefly introduced below and the users are encouraged to read the references for more mathematical details: -## Electrostatic Interaction +## 5.1 Electrostatic Interaction The electrostatic interaction between two atoms can be described using multipole expansion, in which the electron cloud of an atom can be expanded as a series of multipole moments including charges, dipoles, quadrupoles, and octupoles etc. If only the charges (zero-moment) are considered, it is reduced to the point charge model in classical force fields: @@ -33,7 +33,7 @@ $$0, 10, 1c, 1s, 20, 21c, 21s, 22c, 22s, ...$$ The $T_{tu}^{AB}$ represents the interaction tensor between multipoles. The mathematical expression of these tensors can be found in the appendix F of Ref 1. The user can also find the conversion rule between different representations in Ref 1 & 5. -## Coordinate System for Multipoles +## 5.2 Coordinate System for Multipoles Different to charges, the definition of multipole moments depends on the coordinate system. The exact value of the moment tensor will be rotated in accord to different coordinate systems. There are three types of frames involved in DMFF, each used in a different scenario: @@ -44,7 +44,7 @@ Different to charges, the definition of multipole moments depends on the coordin - Quasi internal frame, aka. QI frame: this frame is defined for each pair of interaction sites, in which the z-axis is pointing from one site to another. In this frame, the real-space interaction tensor ($T_{tu}^{AB}$) can be greatly simplified due to symmetry. We thus use this frame in the real space calculation of PME. -## Polarization Interaction +## 5.3 Polarization Interaction DMFF supports polarizable force fields, in which the dipole moment of the atom can respond to the change of the external electric field. In practice, each atom has not only permanent multipoles $Q_t$, but also induced dipoles $U_{ind}$. The induced dipole-induced dipole and induced dipole-permanent multipole interactions needs to be damped at short-range to avoid polarization catastrophe. In DMFF, we use the Thole damping scheme identical to MPID (ref 6), which introduces a damping width ($a_i$) for each atom $i$. The damping function is then computed and applied to the corresponding interaction tensor. Taking $U_{ind}$-permanent charge interaction as an example, the definition of damping function is: @@ -76,7 +76,7 @@ where the off-diagonal term of $K$ matrix is induced-induced dipole interaction, In the current version, we temporarily assume that the polarizability is spherically symmetric, thus the polarizability $\alpha_i$ is a scalar, not a tensor. **Thus the inputs (`polarizabilityXX, polarizabilityYY, polarizabilityZZ`) in the xml API is averaged internally**. In future, it is relatively simple to relax this restriction: simply change the reciprocal of the polarizability to the inverse of the matrix when calculating the diagonal terms of the $K$ matrix. -## Dispersion Interaction +## 5.4 Dispersion Interaction In ADMP, we assume that the following expansion is used for the long-range dispersion interaction: @@ -96,7 +96,7 @@ In ADMP, this long-range dispersion is computed using PME (*vida infra*), just a In the classical module, dispersions are treated as short-range interactions using standard cutoff scheme. -## Long-Range Interaction with PME +## 5.5 Long-Range Interaction with PME The long-range potential includes electrostatic, polarization, and dispersion (in ADMP) interactions. Taking charge-charge interaction as example, the interaction decays in the form of $O(\frac{1}{r})$, and its energy does not converge with the increase of cutoff distance. The multipole electrostatics and dispersion interactions also converge slow with respect to cutoff distance. We therefore use Particle Meshed Ewald(PME) method to calculate these interactions. @@ -141,7 +141,7 @@ where the user needs to specify the cutoff distance $r_c$ when building the neig In the current version, the dispersion PME calculator uses the same parameters as in electrostatic PME. -## Short-Range Interaction +## 5.6 Short-Range Interaction Short-range pair interaction refers to all interactions with the following form: @@ -161,13 +161,14 @@ v(r)=\frac{C^{12}}{r^{12}} $$ - Tang-Tonnies Damping: damping function for short-range electrostatic and dispersion energies. + $$ - f_n(r, \beta) = 1 - e^{-\beta r}\sum_{k=0}^n {\frac{(\beta r)^k}{k!}} +f_n(r,\beta)=1-e^{-\beta r} \sum_{k=0}^{n}\frac{(\beta r)^k}{k!} $$ In ADMP, the user can define a pairwise kernel function $f(dr)=f(dr, m, a_i,a_j,b_i,b_j,\dots)$ ($a_i, b_i$ are atomic parameters), then use `generate_pairwise_interaction` to raise the kernel function into an energy calculator (see details in ADMP manual). -## Combination Rule +## 5.7 Combination Rule For most traditional force fields, pairwise parameters between interacting particles are determined by atomic parameters. This mathematical relationship is called the combination rule. For example, in the calculation of LJ potential, the following combination rule may be used: @@ -181,7 +182,7 @@ $$ In ADMP module, we do not make any assumptions about the specific mathematical forms of the combination rule and $v(r)$. Users need to write them in the definition of the pairwise kernel function. -## Neighbor List +## 5.8 Neighbor List All DMFF real space calculations depends on neighbor list (or "pair list" as we sometimes call in DMFF). Its purpose is to keep a record of all the "neighbors" within a certain distance of the central atom, thus avoiding to go over all pairs explicitly. @@ -189,7 +190,7 @@ In DMFF, we use external code ([jax-md](https://github.com/google/jax-md)) to bu Since the pair list only provides atom **id** information, it does not take part in the differentiation process, so it can be fed in as a normal numpy array (instead of a jax numpy array). -## Topological scaling +## 5.9 Topological scaling In order to avoid double-counting with the bonding term, we often need to scale the non-bonding interactions between two atoms that are topologically connected. The scaling factor depends on the topological distance between the two atoms. We define two atoms separated by one bond as "1-2" interaction, and those separated by two bonds as "1-3" interaction, and so on. For example, in the OPLS-AA force field, all "1-2" nonbonding interactions are turned off completely, while all "1-3" non-bonding interactions are scaled by 50%. DMFF supports such feature, and important variables related to topological scaling include: @@ -200,34 +201,36 @@ In order to avoid double-counting with the bonding term, we often need to scale - `pScales`/`dScales`: similar to `mScales`, but only related to polarizable calculations. They are scaling factors for induced-perm and induced-induced interactions, respectively. -## General Many-Body Interactions (such as ML force field) +## 5.10 General Many-Body Interactions - TODO: +(such as ML force field) TBA -## Bonding Interaction +## 5.11 Bonded Interaction Intramolecular bonding interactions refer to all interactions that depend on internal coordinates (IC), such as bonds, angles, and dihedrals, etc. - * Harmonic Bonding Terms ++ Harmonic Bonding Terms + The definition of the bonding term in DMFF is the same as in OpenMM. For each bond, we have: + $$ E=\frac{1}{2}k(x-x_0)^2 $$ - Note prefactor $1/2$ before the force constant. - * Harmonic Angle Terms - we have: ++ Harmonic Angle Terms + $$ E=\frac{1}{2} k\left(\theta-\theta_{0}\right)^{2} $$ - * Dihedral Terms ++ Dihedral Terms + 1. Proper dihedral 2. Improper dihedral - * Multi IC coupling term ++ Multi IC coupling term -## Typification +## 5.12 Typification Before energy calculation, atomic and IC parameters (such as charge, multipole moment, dispersion coefficient, polarizability, force constant of each bond and angle, etc.) need to be assigned first. @@ -237,7 +240,7 @@ In DMFF, the input parameters that need to be optimized are called **force field The design of the high-level DMFF API is based on the existing framework of OpenMM. DMFF needs to keep the derivation chain uninterrupted when dispatching the force field params into atomic params. Therefore, maintaining the basic design logic of OpenMM, we rewrite the typification part of OpenMM using Jax. Briefly speaking, OpenMM/DMFF requires the users to clearly define the type of each atom in each residue and the connection mode between atoms in residue templates. Then the residue templates are used to match the PDB file to typify the whole system. See the following [documents](../dev_guide/arch.MD) for details. -## References +## 5.13 References 1. [Anthony's book](https://oxford.universitypressscholarship.com/view/10.1093/acprof:oso/9780199672394.001.0001/acprof-9780199672394) 2. [The Multipolar Ewald paper in JCTC: J. Chem. Theory Comput. 2015, 11, 2, 436–450](https://pubs.acs.org/doi/abs/10.1021/ct5007983) diff --git a/docs/user_guide/usage.md b/docs/user_guide/usage.md index 8f6b00a82..4eeefbf34 100644 --- a/docs/user_guide/usage.md +++ b/docs/user_guide/usage.md @@ -1,4 +1,5 @@ # 3. Basic usage +This chapter will introduce some basic usage of DMFF. All scripts can be found in `examples/` directory in which Jupyter notebook-based demos are provided. ## 3.1 Compute energy DMFF uses OpenMM to parse input files, including coordinates file, topology specification file and force field parameter file. Then, the core class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and the method `createPotential` will be called to create differentiable potential energy functions for different energy terms. Take parametrzing an organic moleclue with GAFF2 force field as an example: ```python @@ -81,7 +82,7 @@ force = -pos_grad_func(positions, box, pairs, params) ## 3.3 Compute parametric gradients Similarly, the derivative of energy with regard to force field parameters can also be computed easily. -``` +```python param_grad_func = jax.grad(nbfunc, argnums=-1) pgrad = param_grad_func(positions, box, pairs, params) print(pgrad["NonbondedForce"]["charge"]) @@ -103,3 +104,102 @@ print(pgrad["NonbondedForce"]["charge"]) 485.1427 512.1267 558.55896 560.4667 562.812 333.74194 ] ``` + +## 3.4 Parametrize molecules with SMIRKS-based force field +### 3.4.1 Background +Besides atom-typing based methods, DMFF also supports assigning force field parameters with [SMIRKS](https://www.daylight.com/dayhtml/doc/theory/theory.smirks.html). SMIRKS is an extenstion of [SMARTS](https://www.daylight.com/dayhtml/doc/theory/theory.smarts.html) language which allows users not only to specify chemical substructures with certain patterns, but also to numerically tag the matched atoms for assigning parameters. This approach avoid the duplicate atom-typing definition process, which enables new parameters to be easily introduced to existing force field parameters sets. [OpenFF](https://github.com/openforcefield/openff-toolkit)[[1-2]](#sminorff) series are examples of SMIRKS-based force fields for organic molecules. + + +