# Classical Force Field in DMFF

DMFF implements classcial molecular mechanics force fields with the following forms:

$$\begin{align*}
    V(\mathbf{R}) &= V_{\mathrm{bond}} + V_{\mathrm{angle}} + V_\mathrm{torsion} + V_\mathrm{vdW} + V_\mathrm{Coulomb} \\
    &=  \sum_{\mathrm{bonds}}\frac{1}{2}k_b(r - r_0)^2 + \sum_{\mathrm{angles}}\frac{1}{2}k_\theta (\theta -\theta_0)^2 + \sum_{\mathrm{torsion}}\sum_{n=1}^4 V_n[1+\cos(n\phi - \phi_s)] \\
    &\quad+ \sum_{ij}4\varepsilon_{ij}\left[\left(\frac{\sigma_{ij}}{r_{ij}}\right)^{12} - \left(\frac{\sigma_{ij}}{r_{ij}}\right)^6\right] + \sum_{ij}\frac{q_iq_j}{4\pi\varepsilon_0r_{ij}}
\end{align*}$$

## Import necessary packages

In [2]:
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList

## Compute energy

DMFF uses **OpenMM** to parse input files, including coordinates files, topology specification files. Class `Hamiltonian` inherited from `openmm.ForceField` will be initialized and used to parse force field parameters in XML format. Take parametrzing an organic moleclue with GAFF2 force field as an example.

- `lig_top.xml`: Define bond connections (topology). Not necessary if such information is specified in pdb with `CONNECT` keyword.
- `gaff-2.11.xml`: GAFF2 force field parameters: bonds, angles, torsions and vdW params
- `lig-prm.xml`: Atomic charges

In [4]:
app.Topology.loadBondDefinitions("lig-top.xml")
pdb = app.PDBFile("lig.pdb")
ff = Hamiltonian("gaff-2.11.xml", "lig-prm.xml")

The method `Hamiltonian.createPotential` will be called to create differentiable potential energy functions for different energy terms. 

In [5]:
potentials = ff.createPotential(
    pdb.topology,
    nonbondedMethod=app.NoCutoff
)
for k in potentials.dmff_potentials.keys():
    pot = potentials.dmff_potentials[k]
    print(pot)

<function HarmonicBondJaxGenerator.createForce.<locals>.potential_fn at 0x7fcf60e4ce00>
<function HarmonicAngleJaxGenerator.createForce.<locals>.potential_fn at 0x7fcf163649a0>
<function PeriodicTorsionJaxGenerator.createForce.<locals>.potential_fn at 0x7fcf16364b80>
<function NonbondedJaxGenerator.createForce.<locals>.potential_fn at 0x7fcf15a845e0>


The force field parameters are stored as a Python dict in the `param` attribute of force generators.

In [13]:
params = ff.getParameters()
nbparam = params['NonbondedForce']
nbparam["epsilon"] # also "epsilon", "sigma" etc. keys

Array([ 0.4133792,  0.4133792,  0.6677664,  0.4133792,  0.4510352,
        0.4133792,  0.4133792,  0.4133792,  0.4133792,  0.4133792,
        0.4133792,  0.4133792,  0.6677664,  0.6677664,  0.4510352,
        0.4510352,  0.4133792,  0.4133792,  0.4133792,  0.0870272,
        0.0870272,  0.0870272,  0.0673624,  0.0673624,  0.0673624,
        0.0870272,  0.04184  ,  0.0196648,  0.0602496,  0.0518816,
        0.       ,  0.0870272,  0.3481088,  1.1037391,  1.6451488,
        2.073172 ,  0.6845024,  0.4594032,  0.3937144,  0.3589872,
       16.212164 ,  0.8543728,  0.3937144,  0.3937144,  0.3937144,
        0.3937144,  0.3937144,  0.89956  ,  0.3589872,  0.4912016,
        0.3560584, 10.649535 ,  7.0956454,  4.79068  ,  3.2752352,
        0.646428 ,  0.468608 ,  0.2184048,  0.1351432,  0.039748 ,
        0.6845024,  0.6845024, 10.649535 , 10.649535 ,  0.89956  ,
        0.89956  ,  0.3589872,  0.3589872,  0.2184048,  0.2184048,
        0.6121192,  0.389112 ,  0.3037584,  0.3037584,  0.3037

Each generated function will read **coordinates, box, pairs** and force field parameters as inputs. `pairs` is a integer array in which each row specifying atoms condsidered as neighbors within rcut. This can be calculated with `dmff.NeighborList` class which is supported by `jax_md`.

The potential energy function will give energy (a scalar, in kJ/mol) as output:



In [47]:
positions = jnp.array(pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
box = jnp.array([
    [10.0, 0.0, 0.0], 
    [0.0, 10.0, 0.0],
    [0.0, 0.0, 10.0]
])
nbList = NeighborList(box, 4,potentials.meta["cov_map"]) # small problems in deepmd/dmff
nbList.allocate(positions)
pairs = nbList.pairs
nbfunc = potentials.dmff_potentials['NonbondedForce']
energy = nbfunc(positions, box, pairs, params)
print(energy)
print(pairs)

-425.40488
[[ 0  1  1]
 [ 0  2  2]
 [ 1  2  1]
 ...
 [62 65  0]
 [63 65  0]
 [64 65  0]]


In [53]:
nbfunc(positions, box, pairs, params)

Array(-425.40488, dtype=float32)

You can also obtain the whole potential energy function and force field parameter set, instead of seperated functions for different energy terms.

In [52]:
efunc = potentials.getPotentialFunc()
params = ff.getParameters()
totene = efunc(positions, box, pairs, params)
totene

Array(-52.35898, dtype=float32)

## Compute forces and parametric gradients

Use `jax.grad` to compute forces and parametric gradients automatically

In [27]:
pos_grad_func = jax.grad(efunc, argnums=0)
force = -pos_grad_func(positions, box, pairs, params)
force.shape

(66, 3)

In [29]:
print(positions, box, pairs, params)

[[-4.3959  1.4987  0.8022]
 [-4.417   1.3741  0.8211]
 [-4.3809  1.3116  0.9424]
 [-4.3349  1.5775  0.9031]
 [-4.2992  1.5185  1.0193]
 [-4.3249  1.3756  1.0359]
 [-4.1405  1.5095  1.2249]
 [-4.2331  1.5977  1.1325]
 [-4.1861  1.3645  1.2355]
 [-4.2724  1.3081  1.1541]
 [-4.2914  1.1763  1.1977]
 [-4.2137  1.1543  1.3076]
 [-4.1488  1.2622  1.3335]
 [-4.4702  1.2908  0.7219]
 [-4.051   1.2742  1.4454]
 [-3.9982  1.3815  1.4738]
 [-4.5103  1.1514  0.7284]
 [-4.4195  1.059   0.772 ]
 [-4.4561  0.9134  0.7798]
 [-4.5816  0.8749  0.7423]
 [-4.6793  0.9732  0.6956]
 [-4.6474  1.105   0.6883]
 [-4.739   1.2033  0.6438]
 [-4.7858  1.1964  0.5076]
 [-4.8865  1.1048  0.4992]
 [-4.8336  1.3205  0.4673]
 [-4.6864  1.155   0.42  ]
 [-4.0171  1.1604  1.5192]
 [-4.3759  1.0766  1.135 ]
 [-4.4422  0.9875  1.2402]
 [-4.365   0.8677  1.2505]
 [-4.3622  0.815   0.8249]
 [-4.4124  0.6921  0.892 ]
 [-4.3078  0.6214  0.9812]
 [-4.1716  0.6273  0.9227]
 [-4.1308  0.7688  0.9088]
 [-4.2185  0.837   0.801 ]
 

In [61]:
param_grad_func = jax.grad(nbfunc, argnums=-1,allow_int=True)
pgrad = param_grad_func(positions, box, pairs, params)
pgrad["NonbondedForce"]["charge"]

Array([ 652.7753   ,   55.108738 ,  729.36115  , -171.4929   ,
        502.70837  ,  -44.917206 ,  129.63994  , -142.31796  ,
       -149.62088  ,  453.21503  ,   46.372574 ,  140.15303  ,
        575.488    ,  461.46902  ,  294.4358   ,  335.25153  ,
         27.828705 ,  671.3637   ,  390.8903   ,  519.6835   ,
        220.51129  ,  238.7695   ,  229.97302  ,  210.58838  ,
        237.08563  ,  196.40994  ,  231.8734   ,   35.663574 ,
        457.76416  ,   77.4798   ,  256.54382  ,  402.2121   ,
        592.46265  ,  421.86688  ,  -52.09662  ,  440.8465   ,
        611.9573   ,  237.98883  ,  110.286194 ,  150.65375  ,
        218.61087  ,  240.20477  , -211.85376  ,  150.7331   ,
        310.89404  ,  208.65228  , -139.23026  , -168.8883   ,
        114.3645   ,    3.7261353,  399.6282   ,  298.28455  ,
        422.06445  ,  485.1427   ,  512.1267   ,  549.84033  ,
        556.4724   ,  394.40845  ,  575.85767  ,  606.74744  ,
        526.18463  ,  521.27563  ,  558.55896  ,  560.4