# Classical Force Field in DMFF

DMFF implements classcial molecular mechanics force fields with the following forms:

$$\begin{align*}
    V(\mathbf{R}) &= V_{\mathrm{bond}} + V_{\mathrm{angle}} + V_\mathrm{torsion} + V_\mathrm{vdW} + V_\mathrm{Coulomb} \\
    &=  \sum_{\mathrm{bonds}}\frac{1}{2}k_b(r - r_0)^2 + \sum_{\mathrm{angles}}\frac{1}{2}k_\theta (\theta -\theta_0)^2 + \sum_{\mathrm{torsion}}\sum_{n=1}^4 V_n[1+\cos(n\phi - \phi_s)] \\
    &\quad+ \sum_{ij}4\varepsilon_{ij}\left[\left(\frac{\sigma_{ij}}{r_{ij}}\right)^{12} - \left(\frac{\sigma_{ij}}{r_{ij}}\right)^6\right] + \sum_{ij}\frac{q_iq_j}{4\pi\varepsilon_0r_{ij}}
\end{align*}$$

## Import necessary packages

In [None]:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

## Compute energy

DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.

- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.
- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params
- `lig-prm.xml`: Atomic charges

In [None]:
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")

The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. 

In [None]:
potentials = ff.createPotential(
    pdb.topology,
    nonbondedMethod=app.NoCutoff
)
for k in potentials.dmff_potentials.keys():
    pot = potentials.dmff_potentials[k]
    print(pot)

The force field parameters are stored as a Python dict in the `param` attribute of force generators.

In [None]:
params = ff.getParameters()
nbparam = params['NonbondedForce']
nbparam["charge"] # also "epsilon", "sigma" etc. keys

Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.

The potential energy function will give energy (a scalar, in kJ/mol) as output:



In [None]:
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
    [10.0, 0.0, 0.0], 
    [0.0, 10.0, 0.0],
    [0.0, 0.0, 10.0]
])
nbList = NeighborList(box, r_cutoff=4, potentials.meta["cov_map"])
nbList.allocate(positions)
pairs = nbList.pairs
nbfunc = potentials.dmff_potentials['NonbondedForce']
energy = nbfunc(positions, box, pairs, params)
print(energy)
print(pairs)

You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.

In [None]:
efunc = potentials.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
totene

## Compute forces and parametric gradients

Use `jax.grad` to compute forces and parametric gradients automatically

In [None]:
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
force.shape

In [None]:
param_grad_func = jax.grad(nbfunc, argnums=-1, allow_int=True)
pgrad = param_grad_func(positions, box, pairs, params)
pgrad["NonbondedForce"]["charge"]