In [1]:
import openmm.app as app
import openmm as mm
import openmm.unit as unit
import numpy as np
import jax
import jax.numpy as jnp
import dmff
from dmff.api.xmlio import XMLIO
from dmff.api.paramset import ParamSet
from dmff.generators.classical import CoulombGenerator, LennardJonesGenerator
from dmff.api.hamiltonian import Hamiltonian
from dmff.operators import ParmedLennardJonesOperator
from dmff import NeighborListFreud
from dmff.mbar import MBARSimpleEstimator
import mdtraj as md
from tqdm import tqdm, trange
import parmed
import sys
import os

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
prm_top = parmed.load_file("Lig.top")
prm_top_500 = prm_top * 500
dmfftop = dmff.DMFFTopology(from_top=prm_top_500.topology)
prmop = ParmedLennardJonesOperator()
dmfftop = prmop(dmfftop, gmx_top = prm_top_500)

prmop.renderLennardJonesXML("init.xml")
xmlio = XMLIO()
xmlio.loadXML("init.xml")
ffinfo = xmlio.parseXML()
cov_mat = dmfftop.buildCovMat()

paramset = ParamSet()
lj_gen = LennardJonesGenerator(ffinfo, paramset)

In [3]:
lj_force = lj_gen.createPotential(
    dmfftop, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=1.0, args={})

### 基于给定拓扑进行采样的函数

In [4]:
def runMD(topfile, pdbfile, trajfile, length):
    try:
        os.remove("Lig_500.top")
    except:
        pass
    top_prm = parmed.load_file(topfile)
    top_500 = top_prm * 500
    top_500.save("Lig_500.top")
    pdb = app.PDBFile(pdbfile)
    top = app.GromacsTopFile("Lig_500.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
    system.addForce(mm.MonteCarloBarostat(1.0*unit.bar, 300*unit.kelvin, 25))
    integ = mm.LangevinIntegrator(300*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    simulation = app.Simulation(top.topology, system, integ)
    simulation.reporters.append(app.StateDataReporter(sys.stdout, 400, time=True, potentialEnergy=True, temperature=True, density=True, speed=True, remainingTime=True, totalSteps=int(length) * 400))
    simulation.reporters.append(app.DCDReporter(trajfile, 400))
    simulation.context.setPositions(pdb.getPositions())
    simulation.minimizeEnergy(maxIterations=200)
    simulation.step(int(length) * 400)
    os.remove("Lig_500.top")

# runMD("Lig.top", "init.pdb", "traj.dcd", 100)

In [16]:
def rerun_energy(pdb, traj, top, skip=20, removeLJ=True):
    samples = md.load(traj, top=pdb)
    try:
        os.remove("Lig_500.top")
    except:
        pass
    top_prm = parmed.load_file(top)
    top_500 = top_prm * 500
    top_500.save("Lig_500.top")
    pdb = app.PDBFile(pdb)
    top = app.GromacsTopFile("Lig_500.top")
    os.remove("Lig_500.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
            if removeLJ:
                for npart in range(force.getNumParticles()):
                    chrg, sig, eps = force.getParticleParameters(npart)
                    force.setParticleParameters(npart, chrg, 1.0, 0.0)
                for nex in range(force.getNumExceptions()):
                    p1, p2, chrg, sig, eps = force.getExceptionParameters(nex)
                    force.setExceptionParameters(nex, p1, p2, chrg, 1.0, 0.0)
    integ = mm.LangevinIntegrator(300*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    ctx = mm.Context(system, integ)
    energies = []
    for frame in tqdm(samples):
        ctx.setPositions(frame.xyz[0] * unit.nanometer)
        ctx.setPeriodicBoxVectors(*frame.unitcell_vectors[0])
        ctx.applyConstraints(1e-10)
        state = ctx.getState(getEnergy=True)
        energy = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
        energies.append(energy)
    return np.array(energies)

# rerun_energy("init.pdb", "traj.dcd", "Lig.top")

In [17]:
def rerun_dmff_lennard_jones(params, pdb, traj, efunc, skip=20):
    samples = md.load(traj, top=pdb)
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    pairs_prep = []
    for frame in tqdm(samples):
        # calc pair
        pairs = nblist.allocate(frame.xyz[0], frame.unitcell_vectors[0])
        pairs_prep.append(pairs)
    pmax = max([i.shape[0] for i in pairs_prep])
    pairs_jnp = np.zeros((len(pairs_prep), pmax, 3), dtype=int) + frame.xyz.shape[1]
    for n, pairs in enumerate(pairs_prep):
        pairs_jnp[n,:pairs.shape[0],:] = pairs[:,:]
    pairs_jnp = jnp.array(pairs_jnp)
    energies = []
    for nframe in trange(xyzs_jnp.shape[0]):
        ener = efunc(xyzs_jnp[nframe,:,:], cell_jnp[nframe,:,:], pairs_jnp[nframe,:,:], params)
        energies.append(ener.reshape((1,)))
    energies = jnp.concatenate(energies)
    return energies

In [7]:
ener = rerun_energy("init.pdb", "traj.dcd", "Lig.top", removeLJ=False)
ener_no_lj = rerun_energy("init.pdb", "traj.dcd", "Lig.top")

100%|██████████| 80/80 [00:02<00:00, 28.26it/s]
100%|██████████| 80/80 [00:03<00:00, 26.37it/s]


In [8]:
estimator = MBARSimpleEstimator(md.load("traj.dcd", top="init.pdb")[20:], ener, base_energies=ener_no_lj)

In [9]:
pj = rerun_dmff_lennard_jones(paramset, "init.pdb", "traj.dcd", lj_force)

100%|██████████| 80/80 [00:11<00:00,  6.69it/s]
100%|██████████| 80/80 [00:08<00:00,  9.88it/s]


In [10]:
estimator.optimize_mbar()
estimator.estimate_weight(pj)

Array([0.01250913, 0.01250913, 0.01253358, 0.01249692, 0.01250913,
       0.01247253, 0.01248472, 0.01248472, 0.01255809, 0.01249692,
       0.01241178, 0.01243605, 0.01248472, 0.01253358, 0.01248472,
       0.01248472, 0.01246036, 0.01249692, 0.01248472, 0.01250913,
       0.01246036, 0.01253358, 0.01246036, 0.01250913, 0.01246036,
       0.01253358, 0.01247253, 0.01249692, 0.01246036, 0.01246036,
       0.01254583, 0.01252135, 0.01252135, 0.01247253, 0.01243605,
       0.01250913, 0.01247253, 0.01250913, 0.01254583, 0.01253358,
       0.01253358, 0.01250913, 0.01248472, 0.01250913, 0.01252135,
       0.01247253, 0.01252135, 0.01249692, 0.01255809, 0.01252135,
       0.01248472, 0.01252135, 0.01253358, 0.01246036, 0.01253358,
       0.01252135, 0.01249692, 0.01248472, 0.01248472, 0.01249692,
       0.01248472, 0.01247253, 0.01250913, 0.01250913, 0.01250913,
       0.01249692, 0.01254583, 0.01250913, 0.01252135, 0.01248472,
       0.01248472, 0.01252135, 0.01253358, 0.01249692, 0.01250

In [19]:
os.system("cp Lig.top loop-0.top")
for nloop in range(1, 51):
    # sample
    print("SAMPLE")
    runMD(f"loop-{nloop-1}.top", "init.pdb", f"loop-{nloop}.dcd", length=250)
    print("RERUN")
    ener = rerun_energy("init.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=False)
    ener_no_lj = rerun_energy("init.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top")
    print("ESTIMATOR")
    traj = md.load(f"loop-{nloop}.dcd", top="init.pdb")
    estimator = MBARSimpleEstimator(traj[50:], ener[50:], base_energies=ener_no_lj[50:])
    estimator.optimize_mbar()

    print("CALC DENSE")
    density = md.density(traj[50:])

    # get loss & grad
    def loss(paramset):
        lj_jax = rerun_dmff_lennard_jones(paramset, "init.pdb", f"loop-{nloop}.dcd", lj_force)
        weight = estimator.estimate_weight(lj_jax)
        dens = weight * density
        dens = dens.sum()
        return jnp.power(dens - 0.85, 2)
    
    v, g = jax.value_and_grad(loss, 0)(paramset)
    # update parameters
    break

SAMPLE
#"Time (ps)","Potential Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)","Time Remaining"
0.9999999999999897,9906.195431707602,189.56557559842452,1.0589467160860682,0,--
1.9999999999999685,15665.554806707602,263.83358448725124,1.0536357158694805,504,0:42
2.999999999999947,18350.3204317076,288.0161941328282,1.0392838755533151,518,0:41
3.999999999999926,19489.2969942076,293.4450387466756,1.0348227925294033,522,0:40
5.000000000000082,19767.3829317076,296.74764647104627,1.0234088379128177,524,0:40
6.000000000000238,19926.4923067076,297.4872797015903,1.0269794644060593,525,0:40
7.000000000000394,20177.0938692076,304.59497914631845,1.0223392051765439,525,0:40
8.00000000000055,20548.8751192076,297.69283068792157,1.0268097919043793,526,0:39
9.000000000000352,20522.4610567076,299.3977485317577,1.025601564648714,526,0:39
10.000000000000153,20879.9766817076,295.1012849058824,1.0265477306771427,527,0:39
10.999999999999954,20399.9688692076,302.54388628814496,1.0260564413

100%|██████████| 250/250 [00:08<00:00, 30.82it/s]
100%|██████████| 250/250 [00:08<00:00, 30.05it/s]


ESTIMATOR
CALC DENSE


100%|██████████| 250/250 [00:35<00:00,  7.07it/s]


: 

: 

In [12]:
traj = md.load("traj.dcd", top="init.pdb")[20:]

In [14]:
md.density(traj)

array([1020.7423 , 1020.54297, 1023.51685, 1023.11426, 1017.919  ,
       1018.255  , 1019.1031 , 1025.371  , 1027.7782 , 1029.2797 ,
       1022.16235, 1019.91113, 1032.4197 , 1013.8764 , 1025.8369 ,
       1028.1682 , 1031.771  , 1024.9465 , 1023.2987 , 1012.95087,
       1024.6608 , 1028.2667 , 1027.732  , 1011.63086, 1020.11383,
       1018.0711 , 1022.37286, 1022.6274 , 1022.18243, 1025.5342 ,
       1023.39355, 1025.1353 , 1015.5803 , 1013.7192 , 1019.0011 ,
       1025.2487 , 1032.2894 , 1025.2408 , 1025.2838 , 1028.072  ,
       1031.7653 , 1029.5503 , 1030.5808 , 1032.1357 , 1027.9296 ,
       1030.8942 , 1023.5372 , 1018.97437, 1025.3467 , 1023.34485,
       1030.7388 , 1023.0421 , 1017.71844, 1021.6171 , 1005.0613 ,
       1006.9328 , 1016.72064, 1021.7599 , 1022.00946, 1024.5492 ,
       1025.2864 , 1030.8772 , 1022.34924, 1024.1471 , 1019.6108 ,
       1017.5288 , 1021.27637, 1008.66846, 1021.22546, 1026.0709 ,
       1024.3787 , 1012.9853 , 1019.5186 , 1021.5435 , 1025.31