Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
$CONDA/bin/conda update -n base -c defaults conda
conda install pip
conda update pip
conda install numpy openmm pytest -c conda-forge
conda install numpy openmm pytest rdkit biopandas openbabel -c conda-forge
pip install jax jax_md
pip install mdtraj==1.9.7 pymbar==4.0.1
- name: Install DMFF
Expand Down
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# DMFF

[![doi:10.26434/chemrxiv-2022-2c7gv](https://img.shields.io/badge/DOI-10.26434%2Fchemrxiv--2022--2c7gv-blue)](https://doi.org/10.26434/chemrxiv-2022-2c7gv)

## About DMFF

**DMFF** (**D**ifferentiable **M**olecular **F**orce **F**ield) is a Jax-based python package that provides a full differentiable implementation of molecular force field models. This project aims to establish an extensible codebase to minimize the efforts in force field parameterization, and to ease the force and virial tensor evaluations for advanced complicated potentials (e.g., polarizable models with geometry-dependent atomic parameters). Currently, this project mainly focuses on the molecular systems such as: water, biological macromolecules (peptides, proteins, nucleic acids), organic polymers, and small organic molecules (organic electrolyte, drug-like molecules) etc. We support both the conventional point charge models (OPLS and AMBER like) and multipolar polarizable models (AMOEBA and MPID like). The entire project is backed by the XLA technique in JAX, thus can be "jitted" and run in GPU devices much more efficiently compared to normal python codes.

The behavior of organic molecular systems (e.g., protein folding, polymer structure, etc.) is often determined by a complex effect of many different types of interactions. The existing organic molecular force fields are mainly empirically fitted and their performance relies heavily on error cancellation. Therefore, the transferability and the prediction power of these force fields are insufficient. For new molecules, the parameter fitting process requires essential manual intervention and can be quite cumbersome. In order to automate the parametrization process and increase the robustness of the model, it is necessary to apply modern AI techniques in conventional force field development. This project serves for this purpose by utilizing the automatic differentiable programming technique to develop a codebase, which allows a more convenient incorporation of modern AI optimization techniques. It also helps the realization of many exciting functions including (but not limited to): hybrid machine learning/force field models and parameter optimization based on trajectory.

### License and credits

The project DMFF is licensed under [GNU LGPL v3.0](LICENSE). If you use this code in any future publications, please cite this using `Wang X, Li J, Yang L, Chen F, Wang Y, Chang J, et al. DMFF: An Open-Source Automatic
Differentiable Platform for Molecular Force Field
Development and Molecular Dynamics
Simulation. ChemRxiv. Cambridge: Cambridge Open Engage; 2022; This content is a preprint and has not been peer-reviewed.`

## User Guide

+ [1. Introduction](docs/user_guide/introduction.md)
Expand All @@ -18,9 +29,20 @@ The behavior of organic molecular systems (e.g., protein folding, polymer struct
+ [3. Coding conventions](docs/dev_guide/convention.md)
+ [4. Document writing](docs/dev_guide/write_docs.md)

## Modules
+ [1. ADMP](docs/modules/admp.md)
## Code Structure

The code is organized as follows:

+ `examples`: demos presented in Jupyter Notebook.
+ `docs`: documentation.
+ `package`: files for constructing packages or images, such as conda recipe and docker files.
+ `tests`: unit tests.
+ `dmff`: DMFF python codes
+ `dmff/admp`: source code of automatic differentiable multipolar polarizable (ADMP) force field module.
+ `dmff/classical`: source code of classical force field module.
+ `dmff/common`: source code of common functions, such as neighbor list.
+ `dmff/generators`: source code of force generators.
+ `dmff/sgnn`: source of subgragh neural network force field model.

## Support and Contribution

Expand Down
178 changes: 176 additions & 2 deletions dmff/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import linecache
from typing import Callable, Dict, Any

import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -82,7 +83,7 @@ def totalPE(positions, box, pairs, params):


class Hamiltonian(app.forcefield.ForceField):
def __init__(self, *xmlnames):
def __init__(self, *xmlnames, **kwargs):
super().__init__(*xmlnames)
self._pseudo_ff = app.ForceField(*xmlnames)
# parse XML forcefields
Expand All @@ -104,6 +105,9 @@ def __init__(self, *xmlnames):
self.extractParameterTree()

# hook generators to self._forces
# use noOmmSys to disable all traditional openmm system
if kwargs.get("noOmmSys", False):
self._forces = []
for jaxGen in self._jaxGenerators:
self._forces.append(jaxGen)

Expand Down Expand Up @@ -184,6 +188,29 @@ def createPotential(self,
print(e)
pass

# virtual site
try:
addVsiteFunc = generator.getAddVsiteFunc()
self.setAddVirtualSiteFunc(addVsiteFunc)
vsiteObj = generator.getVsiteObj()
self.setVirtualSiteObj(vsiteObj)
except AttributeError as e:
pass

# covalent map
try:
cov_map = generator.covalent_map
self.setCovalentMap(cov_map)
except AttributeError as e:
pass

# topology matrix (for BCC usage)
try:
top_mat = generator.getTopologyMatrix()
self.setTopologyMatrix(top_mat)
except AttributeError as e:
pass

return potObj

def render(self, filename):
Expand All @@ -201,4 +228,151 @@ def update_iter(node, ref):
else:
node[key] = ref[key]

update_iter(self.paramtree, paramtree)
update_iter(self.paramtree, paramtree)

def setCovalentMap(self, cov_map: jnp.ndarray):
self._cov_map = cov_map

def getCovalentMap(self) -> jnp.ndarray:
"""
Get covalent map
"""
if hasattr(self, "_cov_map"):
return self._cov_map
else:
raise DMFFException("Covalent map is not set.")

def getAddVirtualSiteFunc(self) -> Callable:
return self._add_vsite_coords

def setAddVirtualSiteFunc(self, func: Callable):
self._add_vsite_coords = func

def setVirtualSiteObj(self, vsite):
self._vsite = vsite

def getVirtualSiteObj(self):
return self._vsite

def setTopologyMatrix(self, top_mat):
self._top_mat = top_mat

def getTopologyMatrix(self):
return self._top_mat

def addVirtualSiteCoords(self, pos: jnp.ndarray, params: Dict[str, Any]) -> jnp.ndarray:
"""
Add coordinates for virtual sites

Parameters
----------
pos: jnp.ndarray
Coordinates without virtual sites
params: dict
Paramtree of hamiltonian, i.e. `dmff.Hamiltonian.paramtree`

Return
------
newpos: jnp.ndarray

Examples
--------
>>> import jax.numpy as jnp
>>> import openmm.app as app
>>> from rdkit import Chem
>>> from dmff import Hamiltonian
>>> pdb = app.PDBFile("tests/data/chlorobenzene.pdb")
>>> pos = jnp.array(pdb.getPositions(asNumpy=True)._value)
>>> mol = Chem.MolFromMolFile("tests/data/chlorobenzene.mol", removeHs=False)
>>> h = Hamiltonian("tests/data/cholorobenzene_vsite.xml")
>>> potObj = h.createPotential(pdb.topology, rdmol=mol)
>>> newpos = h.addVirtualSiteCoords(pos, h.paramtree)

"""
func = self.getAddVirtualSiteFunc()
newpos = func(pos, params)
return newpos

def addVirtualSiteToMol(self, rdmol, params):
"""
Add coordinates for rdkit.Chem.Mol object

Parameters
----------
rdmol: rdkit.Chem.Mol
Mol object to which virtual sites are added
params: dict
Paramtree of hamiltonian, i.e. `dmff.Hamiltonian.paramtree`

Return
------
newmol: rdkit.Chem.Mol
Mol object with virtual sites added

Examples
--------
>>> import jax.numpy as jnp
>>> import openmm.app as app
>>> from rdkit import Chem
>>> from dmff import Hamiltonian
>>> pdb = app.PDBFile("tests/data/chlorobenzene.pdb")
>>> mol = Chem.MolFromMolFile("tests/data/chlorobenzene.mol", removeHs=False)
>>> h = Hamiltonian("tests/data/cholorobenzene_vsite.xml")
>>> potObj = h.createPotential(pdb.topology, rdmol=mol)
>>> newmol = h.addVirtualSiteToMol(mol, h.paramtree)
"""
vsiteObj = self.getVirtualSiteObj()
newmol = vsiteObj.addVirtualSiteToMol(
rdmol,
params['NonbondedForce']['vsite_types'],
params['NonbondedForce']['vsite_distances']
)
return newmol

@staticmethod
def buildTopologyFromMol(rdmol, resname: str = "MOL") -> app.Topology:
"""
Build openmm.app.Topology from rdkit.Chem.Mol Object

Parameters
----------
rdmol: rdkit.Chem.Mol
Mol object
resname: str
Name of the added residue, default "MOL"

Return
------
top: `openmm.app.Topology`
Topology built based on the input rdkit Mol object
"""
from rdkit import Chem

top = app.Topology()
chain = top.addChain(0)
res = top.addResidue(resname, chain, "1", "")

atCount = {}
addedAtoms = []
for idx, atom in enumerate(rdmol.GetAtoms()):
symb = atom.GetSymbol().upper()
atCount.update({symb: atCount.get(symb, 0) + 1})
ele = app.Element.getBySymbol(symb)
atName = f'{symb}{atCount[symb]}'

addedAtom = top.addAtom(atName, ele, res, str(idx+1))
addedAtoms.append(addedAtom)

bondTypeMap = {
Chem.rdchem.BondType.SINGLE: app.Single,
Chem.rdchem.BondType.DOUBLE: app.Double,
Chem.rdchem.BondType.TRIPLE: app.Triple,
Chem.rdchem.BondType.AROMATIC: app.Aromatic
}
for bond in rdmol.GetBonds():
top.addBond(
addedAtoms[bond.GetBeginAtomIdx()],
addedAtoms[bond.GetEndAtomIdx()],
type=bondTypeMap.get(bond.GetBondType(), None)
)
return top
42 changes: 33 additions & 9 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Tuple
from typing import Iterable, Tuple, Optional

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -130,10 +130,11 @@ def get_energy(box, epsilon, sigma, epsfix, sigfix):
class CoulNoCutoffForce:
# E=\frac{{q}_{1}{q}_{2}}{4\pi\epsilon_0\epsilon_1 r}

def __init__(self, map_prm, epsilon_1=1.0) -> None:
def __init__(self, map_prm, epsilon_1=1.0, topology_matrix=None) -> None:

self.eps_1 = epsilon_1
self.map_prm = map_prm
self.top_mat = topology_matrix

def generate_get_energy(self):
def get_coul_energy(dr_vec, chrgprod, box):
Expand All @@ -145,7 +146,6 @@ def get_coul_energy(dr_vec, chrgprod, box):
return E

def get_energy(positions, box, pairs, charges, mscales):

pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
mask = pair_buffer_scales(pairs[:, :2])
map_prm = jnp.array(self.map_prm)
Expand All @@ -163,9 +163,16 @@ def get_energy(positions, box, pairs, charges, mscales):

E_inter = get_coul_energy(dr_vec, chrgprod_scale, box)

return jnp.sum(E_inter * mask)

return get_energy
return jnp.sum(E_inter * mask)

def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)

if self.top_mat is None:
return get_energy
else:
return get_energy_bcc


class CoulReactionFieldForce:
Expand All @@ -177,6 +184,7 @@ def __init__(
epsilon_1=1.0,
epsilon_solv=78.5,
isPBC=True,
topology_matrix=None
) -> None:

self.r_cut = r_cut
Expand All @@ -186,6 +194,7 @@ def __init__(
self.eps_1 = epsilon_1
self.map_prm = map_prm
self.ifPBC = isPBC
self.top_mat = topology_matrix

def generate_get_energy(self):
def get_rf_energy(dr_vec, chrgprod, box):
Expand All @@ -204,7 +213,6 @@ def get_rf_energy(dr_vec, chrgprod, box):
return E

def get_energy(positions, box, pairs, charges, mscales):

pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
mask = pair_buffer_scales(pairs[:, :2])

Expand All @@ -223,7 +231,14 @@ def get_energy(positions, box, pairs, charges, mscales):

return jnp.sum(E_inter * mask)

return get_energy
def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)

if self.top_mat is None:
return get_energy
else:
return get_energy_bcc


class CoulombPMEForce:
Expand All @@ -235,13 +250,15 @@ def __init__(
kappa: float,
K: Tuple[int, int, int],
pme_order: int = 6,
topology_matrix: Optional[jnp.array] = None,
):
self.r_cut = r_cut
self.map_prm = map_prm
self.lmax = 0
self.kappa = kappa
self.K1, self.K2, self.K3 = K[0], K[1], K[2]
self.pme_order = pme_order
self.top_mat = topology_matrix
assert pme_order == 6, "PME order other than 6 is not supported"

def generate_get_energy(self):
Expand Down Expand Up @@ -283,4 +300,11 @@ def get_energy(positions, box, pairs, charges, mscales):
False,
)

return get_energy
def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)

if self.top_mat is None:
return get_energy
else:
return get_energy_bcc
Loading