Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,7 @@ FodyWeavers.xsd
.vscode/**

# acpype cache
*.acpype/
*.acpype/

*/_date.py
*/_version.py
37 changes: 20 additions & 17 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
import xml.etree.ElementTree as ET
from copy import deepcopy
import warnings

import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -1189,10 +1190,11 @@ def __init__(self, hamiltonian):
def registerBondType(self, bond):
typetxt = findAtomTypeTexts(bond, 2)
types = self.ff._findAtomTypes(bond, 2)
self.types.append(types)
self.typetexts.append(typetxt)
self.params["k"].append(float(bond["k"]))
self.params["length"].append(float(bond["length"])) # length := r0
if None not in types:
self.types.append(types)
self.typetexts.append(typetxt)
self.params["k"].append(float(bond["k"]))
self.params["length"].append(float(bond["length"])) # length := r0

@staticmethod
def parseElement(element, hamiltonian):
Expand Down Expand Up @@ -1287,9 +1289,10 @@ def __init__(self, hamiltonian):

def registerAngleType(self, angle):
types = self.ff._findAtomTypes(angle, 3)
self.types.append(types)
self.params["k"].append(float(angle["k"]))
self.params["angle"].append(float(angle["angle"]))
if None not in types:
self.types.append(types)
self.params["k"].append(float(angle["k"]))
self.params["angle"].append(float(angle["angle"]))

@staticmethod
def parseElement(element, hamiltonian):
Expand All @@ -1302,8 +1305,12 @@ def parseElement(element, hamiltonian):
<\HarmonicAngleForce>

"""
generator = HarmonicAngleJaxGenerator(hamiltonian)
hamiltonian.registerGenerator(generator)
existing = [f for f in hamiltonian._forces if isinstance(f, HarmonicAngleJaxGenerator)]
if len(existing) == 0:
generator = HarmonicAngleJaxGenerator(hamiltonian)
hamiltonian.registerGenerator(generator)
else:
generator = existing[0]
for angletype in element.findall("Angle"):
generator.registerAngleType(angletype.attrib)

Expand Down Expand Up @@ -1342,7 +1349,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
n_angles += 1
break
if not ifFound:
print(
warnings.warn(
"No parameter for angle %i - %i - %i" % (idx1, idx2, idx3)
)

Expand Down Expand Up @@ -1994,7 +2001,7 @@ def renderXML(self):

class NonbondJaxGenerator:

SCALETOL = 1e-5
SCALETOL = 1e-3

def __init__(self, hamiltionian, coulomb14scale, lj14scale):

Expand All @@ -2019,7 +2026,8 @@ def __init__(self, hamiltionian, coulomb14scale, lj14scale):
def registerAtom(self, atom):
# use types in nb cards or resname+atomname in residue cards
types = self.ff._findAtomTypes(atom, 1)[0]
self.types.append(types)
if None not in types:
self.types.append(types)

for key in ["sigma", "epsilon", "charge"]:
if key not in self.useAttributeFromResidue:
Expand Down Expand Up @@ -2065,11 +2073,6 @@ def parseElement(element, ff):

generator.n_atoms = len(element.findall("Atom"))

# jax it!
for k in generator.params.keys():
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)

def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
methodMap = {
app.NoCutoff: "NoCutoff",
Expand Down
5 changes: 0 additions & 5 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ def get_LJ_energy(dr_vec, sig, eps, box):
if self.ifPBC:
dr_vec = v_pbc_shift(dr_vec, box, jnp.linalg.inv(box))
dr_norm = jnp.linalg.norm(dr_vec, axis=1)
if not self.ifNoCut:
msk = dr_norm <= self.r_cut
sig = sig[msk]
eps = eps[msk]
dr_norm = dr_norm[msk]

dr_inv = 1.0 / dr_norm
sig_dr = sig * dr_inv
Expand Down
221 changes: 221 additions & 0 deletions examples/classical/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Classical Force Field in DMFF"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DMFF implements classcial molecular mechanics force fields with the following forms:\n",
"\n",
"$$\\begin{align*}\n",
" V(\\mathbf{R}) &= V_{\\mathrm{bond}} + V_{\\mathrm{angle}} + V_\\mathrm{torsion} + V_\\mathrm{vdW} + V_\\mathrm{Coulomb} \\\\\n",
" &= \\sum_{\\mathrm{bonds}}\\frac{1}{2}k_b(r - r_0)^2 + \\sum_{\\mathrm{angles}}\\frac{1}{2}k_\\theta (\\theta -\\theta_0)^2 + \\sum_{\\mathrm{torsion}}\\sum_{n=1}^4 V_n[1+\\cos(n\\phi - \\phi_s)] \\\\\n",
" &\\quad+ \\sum_{ij}4\\varepsilon_{ij}\\left[\\left(\\frac{\\sigma_{ij}}{r_{ij}}\\right)^{12} - \\left(\\frac{\\sigma_{ij}}{r_{ij}}\\right)^6\\right] + \\sum_{ij}\\frac{q_iq_j}{4\\pi\\varepsilon_0r_{ij}}\n",
"\\end{align*}$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import necessary packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import openmm.app as app\n",
"import openmm.unit as unit\n",
"from dmff import Hamiltonian, NeighborList"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute energy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.\n",
"\n",
"- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.\n",
"- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params\n",
"- `lig-prm.xml`: Atomic charges"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"app.Topology.loadBondDefinitions(\"lig-top.xml\")\n",
"pdb = app.PDBFile(\"lig.pdb\")\n",
"ff = Hamiltonian(\"gaff-2.11.xml\", \"lig-prm.xml\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"potentials = ff.createPotential(\n",
" pdb.topology,\n",
" nonbondedMethod=app.NoCutoff\n",
")\n",
"for pot in potentials:\n",
" print(pot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The force field parameters are stored as a Python dict in the `param` attribute of force generators."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nbparam = ff.getGenerators()[3].params\n",
"nbparam[\"charge\"] # also \"epsilon\", \"sigma\" etc. keys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.\n",
"\n",
"The potential energy function will give energy (a scalar, in kJ/mol) as output:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))\n",
"box = jnp.array([\n",
" [10.0, 0.0, 0.0], \n",
" [0.0, 10.0, 0.0],\n",
" [0.0, 0.0, 10.0]\n",
"])\n",
"nbList = NeighborList(box, rc=4)\n",
"nbList.allocate(positions)\n",
"pairs = nbList.pairs\n",
"nbfunc = potentials[-1]\n",
"energy = nbfunc(positions, box, pairs, ff.getGenerators()[-1].params)\n",
"print(energy)\n",
"print(pairs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"efunc = ff.getPotentialFunc()\n",
"params = ff.getParameters()\n",
"totene = efunc(positions, box, pairs, params)\n",
"totene"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute forces and parametric gradients"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use `jax.grad` to compute forces and parametric gradients automatically"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pos_grad_func = jax.grad(efunc, argnums=0)\n",
"force = -pos_grad_func(positions, box, pairs, params)\n",
"force.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"param_grad_func = jax.grad(nbfunc, argnums=-1)\n",
"pgrad = param_grad_func(positions, box, pairs, nbparam)\n",
"pgrad[\"charge\"]"
]
}
],
"metadata": {
"interpreter": {
"hash": "44fe82502fda871be637af1aa98d2b3ddaac01204dd30f1519cbec4e95000815"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading