# Mutipolar polarizable force field with fluctuating charges

In this demo, we show how to implement a **multipolar polarizable potential with fluctuating charges** with DMFF API.

In conventional models, atomic charges are pre-defined and remain unchanged during the simulation. Here, we want to implement a model that considers atomic charges as *conformer-dependent*, so that the charges can vary during a molecular dynamics simulation. This will give better description of the system's behavior.

## System preparation
Load the coordinates, box of a water dimer system. The unit of the frontend API is **nanometer**.

In [1]:
import sys
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
from dmff.api import Hamiltonian
from jax_md import space, partition
from jax import value_and_grad, jit
import pickle
from dmff.admp.pme import trim_val_0
from dmff.admp.spatial import v_pbc_shift
from dmff.common import nblist
from dmff.utils import jit_condition
from dmff.admp.pairwise import (
    TT_damping_qq_c6_kernel,
    generate_pairwise_interaction,
    slater_disp_damping_kernel,
    slater_sr_kernel,
    TT_damping_qq_kernel
)

rc = 0.4
pdb = app.PDBFile("water_dimer.pdb")
# construct inputs
positions = jnp.array(pdb.positions._value)
a, b, c = pdb.topology.getPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value])


  PyTreeDef = type(jax.tree_structure(None))


## Genearate auto-differentiable multipolar polarizable (ADMP) forces

First, we will use the `dmff` to create a multipolar polarizable potential with **fixed** atomic charges.

Here, we have two types of force: 

- Dispersion force
- Multipolar polarizable PME force.

We will focus on the PME force.

In [2]:
H = Hamiltonian('forcefield.xml')
# generator stores all force field parameters     
pots = H.createPotential(pdb.topology, nonbondedCutoff=rc*unit.nanometer, step_pol=5)
pme_pot = pots.dmff_potentials['ADMPPmeForce']
disp_generator, pme_generator = H.getGenerators()

The function `pme_pot` takes the following actions:

- Expand **force field parameters** (oxygen and hydrogen charges, polarizabilites, etc.) pre-defined in `forcefield.xml` to each atom, which we called **atomic parameters**
- Calls the real PME kernel function to evaluate energy.

The force field parameters are stored in Hamiltonian `H`. And the expansion is implemented with the *broadcast* feature of `jax.numpy.ndarray`. The expansion can be done using the variable `map_atomtype`, which maps each atom to the corrsponding atomtype.

In [5]:
params = H.getParameters()['ADMPPmeForce']
map_atomtype = pots.meta["ADMPPmeForce_map_atomtype"]
params['Q_local'][map_atomtype]

Array([[-0.803721  , -0.0784325 ,  0.        ,  0.        ,  0.00459693,
         0.        ,  0.        ,  0.12960503,  0.        ],
       [ 0.401876  , -0.0095895 , -0.0121713 ,  0.        ,  0.00812139,
         0.00436148,  0.        ,  0.00701541,  0.        ],
       [ 0.401876  , -0.0095895 , -0.0121713 ,  0.        ,  0.00812139,
         0.00436148,  0.        ,  0.00701541,  0.        ],
       [-0.803721  , -0.0784325 ,  0.        ,  0.        ,  0.00459693,
         0.        ,  0.        ,  0.12960503,  0.        ],
       [ 0.401876  , -0.0095895 , -0.0121713 ,  0.        ,  0.00812139,
         0.00436148,  0.        ,  0.00701541,  0.        ],
       [ 0.401876  , -0.0095895 , -0.0121713 ,  0.        ,  0.00812139,
         0.00436148,  0.        ,  0.00701541,  0.        ]],      dtype=float64)

## Implement fluctuating charges

Since this expansion process is done internally within `pme_pot`, it is **not flexible** enough for us to specify atom-specific charges, i.e. **fluctuating charges**. 

As a result, we must re-write `pme_pot` to enable modifying the atomic charges after force field parameter expansion. 

Benifiting from the flexible APIs in DMFF, we will reuse most of the functions and variables in the `pme_generator`, only modify charges in the input parameters, i.e. the `Q_local` argument in `pme_generator.pme_force.get_energy` function. One particular thing to be careful is that all ADMP backend functions assumes the inputs (`positions` and `box`) are in Angstrom, instead of nm!

In [6]:
from dmff.utils import jit_condition
from dmff.admp.pme import trim_val_0
from dmff.admp.spatial import v_pbc_shift


@jit_condition(static_argnums=())
def compute_leading_terms(positions, box):
    n_atoms = len(positions)
    c0 = jnp.zeros(n_atoms)
    c6_list = jnp.zeros(n_atoms)
    box_inv = jnp.linalg.inv(box)
    O = positions[::3]
    H1 = positions[1::3]
    H2 = positions[2::3]
    ROH1 = H1 - O
    ROH2 = H2 - O
    ROH1 = v_pbc_shift(ROH1, box, box_inv)
    ROH2 = v_pbc_shift(ROH2, box, box_inv)
    dROH1 = jnp.linalg.norm(ROH1, axis=1)
    dROH2 = jnp.linalg.norm(ROH2, axis=1)
    costh = jnp.sum(ROH1 * ROH2, axis=1) / (dROH1 * dROH2)
    angle = jnp.arccos(costh) * 180 / jnp.pi
    dipole = -0.016858755 + 0.002287251 * angle + 0.239667591 * dROH1 + (-0.070483437) * dROH2
    charge_H = dipole / dROH1
    charge_O = charge_H * (-2)
    C6_H = (-2.36066199 + (-0.007049238) * angle + 1.949429648 * dROH1+ 2.097120784 * dROH2) * 0.529**6 * 2625.5
    C6_O = (-8.641301261 + 0.093247893 * angle + 11.90395358 * (dROH1+ dROH2)) * 0.529**6 * 2625.5
    C6_H = trim_val_0(C6_H)
    c0 = c0.at[::3].set(charge_O)
    c0 = c0.at[1::3].set(charge_H)
    c0 = c0.at[2::3].set(charge_H)
    c6_list = c6_list.at[::3].set(jnp.sqrt(C6_O))
    c6_list = c6_list.at[1::3].set(jnp.sqrt(C6_H))
    c6_list = c6_list.at[2::3].set(jnp.sqrt(C6_H))
    return c0, c6_list


def generate_calculator(pots, pme_generator, params):
    map_atomtype = pots.meta["ADMPPmeForce_map_atomtype"]
    map_poltype = pots.meta["ADMPPmeForce_map_poltype"]
    def admp_calculator(positions, box, pairs):
        positions = positions * 10 # convert from nm to angstrom
        box = box * 10
        c0, c6_list = compute_leading_terms(positions, box) # compute fluctuated charges
        Q_local = params["Q_local"][map_atomtype]
        Q_local = Q_local.at[:,0].set(c0)  # change fixed charge into fluctuated one
        pol = params["pol"][map_poltype]
        tholes = params["thole"][map_poltype]
        mScales = pme_generator.mScales
        pScales = pme_generator.pScales
        dScales = pme_generator.dScales
        E_pme = pme_generator.pme_force.get_energy(
            positions, 
            box, 
            pairs, 
            Q_local, 
            pol, 
            tholes, 
            mScales, 
            pScales, 
            dScales
        )
        return E_pme 
    return jax.jit(admp_calculator)


**Finally, compute the energy and force!**

In [8]:
# neighbor list
nbl = nblist.NeighborList(box, rc, pots.meta["cov_map"])
nbl.allocate(positions)
pairs = nbl.pairs

potential_fn = generate_calculator(pots, pme_generator, params)
ene = potential_fn(positions, box, pairs)
print(ene)

-41.261709056188494


In [9]:
force_fn = jax.grad(potential_fn, argnums=(0))
force = -force_fn(positions, box, pairs)
print(force)

[[ -76.31268719  117.49783627  -79.89266772]
 [ 751.2499921  -582.24588471 -251.82070224]
 [ -18.97483886  -49.68783375  146.28345763]
 [-675.35013452  382.30839617  204.50616711]
 [ -25.65479533  -52.55337869   41.92507785]
 [  45.04246381  184.68086471  -61.00133263]]


The running speed of the first pass is slow because JAX is trying to track the data flow and compile the code. Once the code is compiled, it runs much faster, until the shapes of the input parameters change, trigerring a recompilation.  

In [10]:
print(-force_fn(positions, box, pairs))

[[ -76.31268719  117.49783627  -79.89266772]
 [ 751.2499921  -582.24588471 -251.82070224]
 [ -18.97483886  -49.68783375  146.28345763]
 [-675.35013452  382.30839617  204.50616711]
 [ -25.65479533  -52.55337869   41.92507785]
 [  45.04246381  184.68086471  -61.00133263]]
