# Classical Force Field in DMFF

DMFF implements classcial molecular mechanics force fields with the following forms:

$$\begin{align*}
    V(\mathbf{R}) &= V_{\mathrm{bond}} + V_{\mathrm{angle}} + V_\mathrm{torsion} + V_\mathrm{vdW} + V_\mathrm{Coulomb} \\
    &=  \sum_{\mathrm{bonds}}\frac{1}{2}k_b(r - r_0)^2 + \sum_{\mathrm{angles}}\frac{1}{2}k_\theta (\theta -\theta_0)^2 + \sum_{\mathrm{torsion}}\sum_{n=1}^4 V_n[1+\cos(n\phi - \phi_s)] \\
    &\quad+ \sum_{ij}4\varepsilon_{ij}\left[\left(\frac{\sigma_{ij}}{r_{ij}}\right)^{12} - \left(\frac{\sigma_{ij}}{r_{ij}}\right)^6\right] + \sum_{ij}\frac{q_iq_j}{4\pi\varepsilon_0r_{ij}}
\end{align*}$$

## Import necessary packages

In [1]:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

## Compute energy

DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.

- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.
- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params
- `lig-prm.xml`: Atomic charges

In [2]:
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")

Generator for HarmonicAngleForce is not implemented.
Generator for PeriodicTorsionForce is not implemented.
Generator for NonbondedForce is not implemented.


The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. 

In [3]:
potentials = ff.createJaxPotential(
    pdb.topology,
    nonbondedMethod=app.NoCutoff
)
for k in potentials.dmff_potentials.keys():
    pot = potentials.dmff_potentials[k]
    print(pot)

<function HarmonicBondGenerator.createPotential.<locals>.potential_fn at 0x151d319d0>


The force field parameters are stored as a Python dict in the `param` attribute of force generators.

In [4]:
params = ff.getParameters()
nbparam = params['HarmonicBondForce']
nbparam["length"]

Array([0.09572, 0.15136, 0.2542 , 0.1787 , 0.1893 , 0.1946 , 0.1978 ,
       0.1908 , 0.1885 , 0.1937 , 0.186  , 0.2038 , 0.1873 , 0.1952 ,
       0.1926 , 0.2002 , 0.1944 , 0.2101 , 0.18   , 0.1866 , 0.1887 ,
       0.221  , 0.2231 , 0.2171 , 0.2196 , 0.222  , 0.2341 , 0.2214 ,
       0.2209 , 0.2203 , 0.1198 , 0.1307 , 0.1467 , 0.144  , 0.1315 ,
       0.1216 , 0.1219 , 0.1631 , 0.1445 , 0.127  , 0.1067 , 0.106  ,
       0.1989 , 0.1153 , 0.1197 , 0.1347 , 0.1417 , 0.133  , 0.1362 ,
       0.1202 , 0.1202 , 0.1342 , 0.1405 , 0.1172 , 0.1326 , 0.1318 ,
       0.177  , 0.179  , 0.179  , 0.1753 , 0.1595 , 0.1603 , 0.1746 ,
       0.1722 , 0.168  , 0.169  , 0.1334 , 0.151  , 0.1385 , 0.1359 ,
       0.1359 , 0.1346 , 0.1346 , 0.1731 , 0.1324 , 0.1485 , 0.1511 ,
       0.1339 , 0.1087 , 0.1091 , 0.1088 , 0.1087 , 0.1083 , 0.217  ,
       0.1306 , 0.1282 , 0.134  , 0.1399 , 0.1512 , 0.1401 , 0.1313 ,
       0.1313 , 0.1292 , 0.1292 , 0.1387 , 0.1448 , 0.1225 , 0.1339 ,
       0.136  , 0.16

Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.

The potential energy function will give energy (a scalar, in kJ/mol) as output:



In [5]:
potentials.meta["cov_map"]

Array([[0, 1, 2, ..., 0, 0, 2],
       [1, 0, 1, ..., 0, 0, 3],
       [2, 1, 0, ..., 0, 0, 4],
       ...,
       [0, 0, 0, ..., 0, 2, 0],
       [0, 0, 0, ..., 2, 0, 0],
       [2, 3, 4, ..., 0, 0, 0]], dtype=int32)

In [6]:
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
    [10.0, 0.0, 0.0], 
    [0.0, 10.0, 0.0],
    [0.0, 0.0, 10.0]
])
nbList = NeighborList(box, rcut=4, cov_map=potentials.meta["cov_map"])
nbList.allocate(positions)
pairs = nbList.pairs
nbfunc = potentials.dmff_potentials['HarmonicBondForce']
energy = nbfunc(positions, box, pairs, params)
print(energy)
print(pairs)

174.16698
[[ 0  1  1]
 [ 0  2  2]
 [ 0  3  1]
 ...
 [66 66  0]
 [66 66  0]
 [66 66  0]]


You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.

In [7]:
efunc = potentials.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
totene

Array(174.16698, dtype=float32)

## Compute forces and parametric gradients

Use `jax.grad` to compute forces and parametric gradients automatically

In [8]:
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
force.shape

(66, 3)

In [10]:
param_grad_func = jax.grad(nbfunc, argnums=-1, allow_int=True)
pgrad = param_grad_func(positions, box, pairs, params)
pgrad["HarmonicBondForce"]["length"]

Array([    0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0.     ,
           0.     ,     0.     ,     0.     ,     0.     ,     0