# Optimize CG benzene using MBAR to fit radial distribution function


In this demo, we would try to optimize a coarse-grained benzene model with three beads to fit experimental center-of-mass radial distribution function. The potential function only has harmonic bond term and Lennard-Jones term as:

$$\begin{align*}
    V(\mathbf{R}) &= V_{\mathrm{bond}} + V_\mathrm{vdW} \\
    &=  \sum_{\mathrm{bonds}}\frac{1}{2}k_b(r - r_0)^2 \\
    &\quad+ \sum_{ij}4\varepsilon_{ij}\left[\left(\frac{\sigma_{ij}}{r_{ij}}\right)^{12} - \left(\frac{\sigma_{ij}}{r_{ij}}\right)^6\right]
\end{align*}$$

## Import necessary packages & functions 

In [1]:
import openmm as mm
import openmm.app as app
import openmm.unit as unit
import numpy as np
import sys
import mdtraj as md
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from dmff.mbar import MBAREstimator, SampleState, TargetState, Sample, OpenMMSampleState, buildTrajEnergyFunction
from dmff.optimize import MultiTransform, genOptimizer
from dmff import Hamiltonian, NeighborListFreud
import optax
import jax
import jax.numpy as jnp

app.Topology.loadBondDefinitions("ben-top.xml")
kbT = 8.314 * 303 / 1000.0


def readRDF(fname):
    with open(fname, "r") as f:
        data = np.array([[float(j) for j in i.strip().split()] for i in f])
    xaxis = np.linspace(2.0, 14.0, 121)
    yinterp = np.interp(xaxis, data[:,0], data[:,1])
    return xaxis, yinterp

# read experimental benzene RDF
x_ref, y_ref = readRDF("benz.txt")


def sample_with_prm(parameter, trajectory, init_struct="box_relaxed.pdb"):
    pdb = app.PDBFile(init_struct)
    ff = app.ForceField(parameter)
    system = ff.createSystem(pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=1.1*unit.nanometer, constraints=None)
    system.addForce(mm.MonteCarloBarostat(1.0*unit.bar, 303.0*unit.kelvin, 20))
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
            force.setUseSwitchingFunction(False)
    integ = mm.LangevinIntegrator(303*unit.kelvin, 5/unit.picosecond, 1*unit.femtosecond)

    simulation = app.Simulation(pdb.topology, system, integ)
    simulation.context.setPositions(pdb.getPositions())
    simulation.reporters.append(app.DCDReporter(trajectory, 4000))
    simulation.reporters.append(app.StateDataReporter(sys.stdout, 20000, density=True, step=True, remainingTime=True, speed=True, totalSteps=500*1000))
    simulation.minimizeEnergy()
    simulation.step(500*1000)

  PyTreeDef = type(jax.tree_structure(None))


## sample with initial parameter set

In [2]:
sample_with_prm("ben-prm.xml", "init.dcd")
traj = md.load("init.dcd", top="box_relaxed.pdb")[50:]

#"Step","Density (g/mL)","Speed (ns/day)","Time Remaining"
20000,0.5609269274230193,0,--
40000,0.5459109958008428,154,4:18
60000,0.5327412068778505,152,4:09
80000,0.5458843337093509,153,3:57
100000,0.5499541448891778,153,3:45
120000,0.5559041043396893,152,3:35
140000,0.552264674142506,152,3:24
160000,0.5497341640178879,153,3:12
180000,0.5429753752031576,153,3:01
200000,0.5483391135340221,153,2:49
220000,0.5505668336047832,153,2:38
240000,0.5335132149170563,153,2:26
260000,0.5612835029864318,153,2:15
280000,0.5539964740924908,153,2:04
300000,0.5562792831995799,153,1:52
320000,0.5810152336723056,153,1:41
340000,0.5578735559454526,153,1:30
360000,0.5646608102872624,153,1:19
380000,0.5715438669104802,153,1:07
400000,0.5615453644018904,153,0:56
420000,0.5729058327239069,153,0:45
440000,0.5551077875519419,153,0:33
460000,0.575714406552948,153,0:22
480000,0.5542066355298249,153,0:11
500000,0.5602510266387776,153,0:00


## compute radial distribution function per frame

In [4]:
def compute_rdf_frame(traj, xaxis):
    rdf_list = []
    delta = xaxis[1] - xaxis[0]

    tidx = []
    for ii in range(200):
        tidx.append(3*ii)
    tidx = np.array(tidx)
    tsub = traj.atom_slice(tidx)
    xyzs = traj.xyz
    com = np.zeros((traj.n_frames, 200, 3))

    for na in range(3):
        com += xyzs[:,tidx+na,:]
    com = com / 3

    pairs = []
    for na in range(200):
        for nb in range(na+1, 200):
            pairs.append([na, nb])
    tsub.xyz = com

    for frame in tsub:
        _, g_r = md.compute_rdf(frame, pairs, r_range=(xaxis[0]-0.5*delta, xaxis[-1]+0.5*delta+1e-10), bin_width=delta)
        rdf_list.append(g_r.reshape((1, -1)))
    return np.concatenate(rdf_list, axis=0)


## initialize MBAR estimator

In [5]:
state_name = "ben-prm"
state = OpenMMSampleState(state_name, "ben-prm.xml", "box_relaxed.pdb", temperature=303.0, pressure=1.0)
sample = Sample(traj, state_name)


estimator = MBAREstimator()
estimator.add_state(state)
estimator.add_sample(sample)
estimator.optimize_mbar()
rdf_frames = compute_rdf_frame(estimator._full_samples, x_ref*0.1)

100%|██████████████████████████████████████████| 75/75 [00:00<00:00, 184.40it/s]


## define a function to calculate DMFF energy using mdtraj.Trajectory as input

Here we use "buildEnergyFunction" function generator to build a function which can calculate energies of a trajectory with the MDTraj trajectory itself and a parameter set.

In [6]:
hamilt = Hamiltonian("ben-prm.xml")
top_pdb = app.PDBFile("box_relaxed.pdb")
pot = hamilt.createPotential(top_pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=1.1*unit.nanometer, ethresh=1e-4)
efunc = pot.getPotentialFunc()

target_energy_function = buildTrajEnergyFunction(efunc,
                                                pot.meta["cov_map"],
                                                1.1,
                                                ensemble="npt")

## Create optax transforms 

We also need to create transform for each force field parameter. The parameter not setted later will not be optimized.

In [7]:
multiTrans = MultiTransform(hamilt.paramtree)
multiTrans["LennardJonesForce/sigma"] = genOptimizer(learning_rate=0.005, clip=0.05)
multiTrans["LennardJonesForce/epsilon"] = genOptimizer(learning_rate=0.005, clip=0.05)
multiTrans["HarmonicBondForce/k"] = genOptimizer(learning_rate=10.0, clip=10.0)
multiTrans.finalize()

## Initialize optimizer

In [8]:
grad_transform = optax.multi_transform(multiTrans.transforms, multiTrans.labels)
opt_state = grad_transform.init(hamilt.paramtree)

## Run optimization loop

In [None]:
for nloop in range(1, 51):
    print("LOOP", nloop)
    target_state = TargetState(303.0, target_energy_function)

    def lossfunc(param):
        weight, utarget = estimator.estimate_weight(target_state, parameters=param)
        rdf_pert = (rdf_frames * weight.reshape((-1, 1))).sum(axis=0)
        loss_ref = jax.numpy.log(jax.numpy.power(rdf_pert - y_ref, 2).mean())
        return loss_ref, utarget

    (loss, utarget), g = jax.value_and_grad(lossfunc, 0, has_aux=True)(hamilt.paramtree)
    print("Loss:", loss)
    ieff = estimator.estimate_effective_sample(utarget, decompose=True)

    updates, opt_state = grad_transform.update(g, opt_state, params=hamilt.paramtree)
    newprm = optax.apply_updates(hamilt.paramtree, updates)
    hamilt.updateParameters(newprm)
    # render optimized parameters in xml force field
    hamilt.render(f"loop-{nloop}.xml")

    print("Neff:", ieff["Total"])

    # if the effective samples of a state is 0, remove the state
    print("Total effective samples:")
    for k, v in ieff.items():
        print(f"{k}: {v}")

    # if the effective samples of a state is less than 5, then remove this sample
    for k, v in ieff.items():
        if v < 5 and k != "Total":
            estimator.remove_state(k)

    # if all the states are removed, add a new state.
    if len(estimator.states) < 1:
        print("Add", f"loop-{nloop}")
        sample_with_prm(f"loop-{nloop}.xml", f"loop-{nloop}.dcd")
        traj = md.load(f"loop-{nloop}.dcd", top="box_relaxed.pdb")[50:]
        state = OpenMMSampleState(f"loop-{nloop}", f"loop-{nloop}.xml", "box_relaxed.pdb", temperature=303.0, pressure=1.0)
        sample = Sample(traj, f"loop-{nloop}")
        estimator.add_state(state)
        estimator.add_sample(sample)

        draw_frames = compute_rdf_frame(traj, x_ref*0.1)
        plt.figure()
        plt.plot(x_ref, draw_frames.mean(axis=0))
        plt.plot(x_ref, y_ref)
        plt.savefig(f"com-{nloop}.png")

    estimator.optimize_mbar()
    rdf_frames = compute_rdf_frame(estimator._full_samples, x_ref*0.1)

LOOP 1


100%|███████████████████████████████████████████| 75/75 [00:04<00:00, 16.20it/s]
100%|███████████████████████████████████████████| 75/75 [00:46<00:00,  1.61it/s]


Loss: -1.8043019
Neff: 74.99999916040227
Total effective samples:
ben-prm: 73
Total: 74.99999916040227
LOOP 2


100%|██████████████████████████████████████████| 75/75 [00:00<00:00, 219.57it/s]
100%|███████████████████████████████████████████| 75/75 [00:45<00:00,  1.64it/s]


Loss: -1.9461998
Neff: 1.0820234332264933
Total effective samples:
ben-prm: 1
Total: 1.0820234332264933
Add loop-2
#"Step","Density (g/mL)","Speed (ns/day)","Time Remaining"
20000,0.5760809744941384,0,--
40000,0.5833184680811015,156,4:14
60000,0.5781627044833334,156,4:04
80000,0.5940449427033085,155,3:53
100000,0.5572647706997856,155,3:43
120000,0.5734518088240227,155,3:32
140000,0.5678675808280508,155,3:21
160000,0.5802075827897538,155,3:09
180000,0.5636034778610327,155,2:58
200000,0.5828136632596796,155,2:47
220000,0.607298999952365,155,2:36
240000,0.5981345756912775,154,2:25
260000,0.5811187528304593,154,2:14
280000,0.5692368156592845,153,2:03
300000,0.5830765331944094,153,1:52
320000,0.5859934675864289,153,1:41
340000,0.5906165150125336,153,1:30
360000,0.593990819511393,152,1:19
380000,0.5870797573618849,152,1:08
400000,0.5827299014964887,152,0:56
420000,0.591129033152864,152,0:45
440000,0.5828923868029461,152,0:34
460000,0.5858575619688796,152,0:22
480000,0.5841300183857974,152,0:

100%|██████████████████████████████████████████| 75/75 [00:00<00:00, 122.79it/s]


LOOP 3


100%|███████████████████████████████████████████| 75/75 [00:04<00:00, 15.09it/s]
100%|███████████████████████████████████████████| 75/75 [00:48<00:00,  1.56it/s]


Loss: -2.02044
Neff: 74.99999890053131
Total effective samples:
loop-2: 73
Total: 74.99999890053131
LOOP 4


100%|██████████████████████████████████████████| 75/75 [00:00<00:00, 108.39it/s]
100%|███████████████████████████████████████████| 75/75 [00:46<00:00,  1.60it/s]


Loss: -2.0885944
Neff: 7.05485934835645
Total effective samples:
loop-2: 7
Total: 7.05485934835645
LOOP 5


100%|██████████████████████████████████████████| 75/75 [00:00<00:00, 193.46it/s]
 13%|█████▋                                     | 10/75 [00:06<00:41,  1.58it/s]