# Classical Force Field in DMFF

DMFF implements classcial molecular mechanics force fields with the following forms:

$$\begin{align*}
    V(\mathbf{R}) &= V_{\mathrm{bond}} + V_{\mathrm{angle}} + V_\mathrm{torsion} + V_\mathrm{vdW} + V_\mathrm{Coulomb} \\
    &=  \sum_{\mathrm{bonds}}\frac{1}{2}k_b(r - r_0)^2 + \sum_{\mathrm{angles}}\frac{1}{2}k_\theta (\theta -\theta_0)^2 + \sum_{\mathrm{torsion}}\sum_{n=1}^4 V_n[1+\cos(n\phi - \phi_s)] \\
    &\quad+ \sum_{ij}4\varepsilon_{ij}\left[\left(\frac{\sigma_{ij}}{r_{ij}}\right)^{12} - \left(\frac{\sigma_{ij}}{r_{ij}}\right)^6\right] + \sum_{ij}\frac{q_iq_j}{4\pi\varepsilon_0r_{ij}}
\end{align*}$$

## Import necessary packages

In [2]:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

## Compute energy

DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.

- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.
- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params
- `lig-prm.xml`: Atomic charges

In [3]:
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")

Generator for HarmonicAngleForce is not implemented.
Generator for PeriodicTorsionForce is not implemented.
Generator for NonbondedForce is not implemented.


The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. 

In [4]:
potentials = ff.createPotential(
    pdb.topology,
    nonbondedMethod=app.NoCutoff
)
for k in potentials.dmff_potentials.keys():
    pot = potentials.dmff_potentials[k]
    print(pot)

<function HarmonicBondJaxGenerator.createForce.<locals>.potential_fn at 0x7ff8d0411940>
<function HarmonicAngleJaxGenerator.createForce.<locals>.potential_fn at 0x7ff743081d30>
<function PeriodicTorsionJaxGenerator.createForce.<locals>.potential_fn at 0x7ff743081ee0>
<function NonbondedJaxGenerator.createForce.<locals>.potential_fn at 0x7ff7431ee3a0>


In [9]:
potentials.meta["cov_map"]

Array([[0, 1, 2, ..., 0, 0, 2],
       [1, 0, 1, ..., 0, 0, 3],
       [2, 1, 0, ..., 0, 0, 4],
       ...,
       [0, 0, 0, ..., 0, 2, 0],
       [0, 0, 0, ..., 2, 0, 0],
       [2, 3, 4, ..., 0, 0, 0]], dtype=int32)

The force field parameters are stored as a Python dict in the `param` attribute of force generators.

In [5]:
params = ff.getParameters()
nbparam = params['HarmonicBondForce']
nbparam["length"]

Array([-0.75401515,  0.8628848 , -0.74901515,  0.44918483, -0.30991516,
        0.58678484, -0.01711515, -0.02311515, -0.13561516, -0.21381515,
        0.16038485, -0.47781515,  0.25678486, -0.7003152 ,  0.60158485,
       -0.6191152 ,  0.18658485, -0.15001515,  0.10558484, -0.13701515,
       -0.08301515,  0.03808485, -0.37091514,  0.81668484, -0.20798215,
       -0.20798215, -0.20798215, -0.65701514,  0.00528485,  0.11038485,
       -0.58381516, -0.6450152 ,  0.15478484,  0.10428485, -0.67541516,
        0.10428485,  0.15478484,  0.08308485,  0.08518485,  0.08518485,
        0.06118485,  0.06118485,  0.44978485,  0.16498485,  0.14798485,
        0.16998485,  0.32248485,  0.32248485,  0.08718485,  0.08718485,
        0.06668485,  0.06668485,  0.39998484,  0.08593485,  0.08593485,
        0.10518485,  0.10518485,  0.45978484,  0.10518485,  0.10518485,
        0.08593485,  0.08593485,  0.10535185,  0.10535185,  0.10535185,
        0.02608485], dtype=float32)

Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.

The potential energy function will give energy (a scalar, in kJ/mol) as output:



In [10]:
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
    [10.0, 0.0, 0.0], 
    [0.0, 10.0, 0.0],
    [0.0, 0.0, 10.0]
])
nbList = NeighborList(box, 4, potentials.meta["cov_map"])
nbList.allocate(positions)
pairs = nbList.pairs
nbfunc = potentials.dmff_potentials['HarmonicBondForce']
energy = nbfunc(positions, box, pairs, params)
print(energy)
print(pairs)

-425.40482
[[ 0  1  1]
 [ 0  2  2]
 [ 1  2  1]
 ...
 [62 65  0]
 [63 65  0]
 [64 65  0]]


You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.

In [11]:
efunc = potentials.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
totene

Array(-52.358917, dtype=float32)

## Compute forces and parametric gradients

Use `jax.grad` to compute forces and parametric gradients automatically

In [12]:
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
force.shape

(66, 3)

In [13]:
param_grad_func = jax.grad(nbfunc, argnums=-1, allow_int=True)
pgrad = param_grad_func(positions, box, pairs, params)
pgrad["HarmonicBondForce"]["length"]

Array([ 652.7753   ,   55.108738 ,  729.36115  , -171.4929   ,
        502.70837  ,  -44.917206 ,  129.63994  , -142.31796  ,
       -149.62088  ,  453.21503  ,   46.372574 ,  140.15303  ,
        575.488    ,  461.46902  ,  294.4358   ,  335.25153  ,
         27.828705 ,  671.3637   ,  390.8903   ,  519.6835   ,
        220.51129  ,  238.7695   ,  229.97302  ,  210.58838  ,
        237.08563  ,  196.40994  ,  231.8734   ,   35.663574 ,
        457.76416  ,   77.4798   ,  256.54382  ,  402.2121   ,
        592.46265  ,  421.86688  ,  -52.09662  ,  440.8465   ,
        611.9573   ,  237.98883  ,  110.286194 ,  150.65375  ,
        218.61087  ,  240.20477  , -211.85376  ,  150.7331   ,
        310.89404  ,  208.65228  , -139.23026  , -168.8883   ,
        114.3645   ,    3.7261353,  399.6282   ,  298.28455  ,
        422.06445  ,  485.1427   ,  512.1267   ,  549.84033  ,
        556.4724   ,  394.40845  ,  575.85767  ,  606.74744  ,
        526.18463  ,  521.27563  ,  558.55896  ,  560.4