In [None]:
import openmm.app as app
import openmm as mm
import openmm.unit as unit
import numpy as np
import jax
import jax.numpy as jnp
import dmff
from dmff.api.xmlio import XMLIO
from dmff.api.paramset import ParamSet
from dmff.generators.classical import CoulombGenerator, LennardJonesGenerator
from dmff.api.hamiltonian import Hamiltonian
from dmff.operators import ParmedLennardJonesOperator
from dmff import NeighborListFreud
from dmff.mbar import ReweightEstimator
import mdtraj as md
from tqdm import tqdm, trange
import parmed
import sys
import os

In [None]:
prm_top = parmed.load_file("Lig.top")
prm_top_500 = prm_top * 500
dmfftop = dmff.DMFFTopology(from_top=prm_top_500.topology)
prmop = ParmedLennardJonesOperator()
dmfftop = prmop(dmfftop, gmx_top = prm_top_500)

prmop.renderLennardJonesXML("init.xml")
xmlio = XMLIO()
xmlio.loadXML("init.xml")
ffinfo = xmlio.parseXML()
cov_mat = dmfftop.buildCovMat()

paramset = ParamSet()
lj_gen = LennardJonesGenerator(ffinfo, paramset)

In [None]:
lj_force = lj_gen.createPotential(
    dmfftop, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=1.0, args={})

### 基于给定拓扑进行采样的函数

In [None]:
def runMD(topfile, pdbfile, trajfile, length):
    try:
        os.remove("Lig_500.top")
    except:
        pass
    top_prm = parmed.load_file(topfile)
    top_500 = top_prm * 500
    top_500.save("Lig_500.top")
    pdb = app.PDBFile(pdbfile)
    top = app.GromacsTopFile("Lig_500.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
    system.addForce(mm.MonteCarloBarostat(1.0*unit.bar, 300*unit.kelvin, 25))
    integ = mm.LangevinIntegrator(300*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    simulation = app.Simulation(top.topology, system, integ)
    simulation.reporters.append(app.StateDataReporter(sys.stdout, 400, time=True, potentialEnergy=True, temperature=True, density=True, speed=True, remainingTime=True, totalSteps=int(length) * 400))
    simulation.reporters.append(app.DCDReporter(trajfile, 400))
    simulation.context.setPositions(pdb.getPositions())
    simulation.minimizeEnergy(maxIterations=200)
    simulation.step(int(length) * 400)
    os.remove("Lig_500.top")

# runMD("Lig.top", "init.pdb", "traj.dcd", 100)

In [None]:
def rerun_energy(pdb, traj, top, skip=20, removeLJ=True, skpi=0):
    samples = md.load(traj, top=pdb)[skip:]
    try:
        os.remove("Lig_500.top")
    except:
        pass
    top_prm = parmed.load_file(top)
    top_500 = top_prm * 500
    top_500.save("Lig_500.top")
    pdb = app.PDBFile(pdb)
    top = app.GromacsTopFile("Lig_500.top")
    os.remove("Lig_500.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
            if removeLJ:
                for npart in range(force.getNumParticles()):
                    chrg, sig, eps = force.getParticleParameters(npart)
                    force.setParticleParameters(npart, chrg, 1.0, 0.0)
                for nex in range(force.getNumExceptions()):
                    p1, p2, chrg, sig, eps = force.getExceptionParameters(nex)
                    force.setExceptionParameters(nex, p1, p2, chrg, 1.0, 0.0)
    integ = mm.LangevinIntegrator(300*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    ctx = mm.Context(system, integ)
    energies = []
    for frame in tqdm(samples):
        ctx.setPositions(frame.xyz[0] * unit.nanometer)
        ctx.setPeriodicBoxVectors(*frame.unitcell_vectors[0])
        ctx.applyConstraints(1e-10)
        state = ctx.getState(getEnergy=True)
        energy = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
        energies.append(energy)
    return np.array(energies)

# rerun_energy("init.pdb", "traj.dcd", "Lig.top")

In [None]:
def rerun_dmff_lennard_jones(params, pdb, traj, efunc, skip=0):
    samples = md.load(traj, top=pdb)[skip:]
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    for nframe in trange(len(samples)):
        frame = samples[nframe]
        # calc pair
        pairs = jnp.array(nblist.allocate(frame.xyz[0], frame.unitcell_vectors[0]))
        ener = efunc(xyzs_jnp[nframe,:,:], cell_jnp[nframe,:,:], pairs, params)
        energies.append(ener.reshape((1,)))
    energies = jnp.concatenate(energies)
    return energies

In [None]:
import optax
optimizer = optax.adam(0.001)
opt_state = optimizer.init(paramset)

In [None]:
lbfgs = None

os.system("cp Lig.top loop-0.top")
for nloop in range(1, 51):
    # sample
    print("SAMPLE")
    runMD(f"loop-{nloop-1}.top", "init.pdb", f"loop-{nloop}.dcd", length=70)
    print("RERUN")
    ener = rerun_energy("init.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=20)
    ener_no_lj = rerun_energy("init.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", skip=20)
    print("ESTIMATOR")
    traj = md.load(f"loop-{nloop}.dcd", top="init.pdb")[20:]
    estimator = ReweightEstimator(ener, base_energies=ener_no_lj, volume=traj.unitcell_volumes)

    print("CALC DENSE")
    density = md.density(traj) * 0.001

    # get loss & grad
    def loss(paramset):
        lj_jax = rerun_dmff_lennard_jones(paramset, "init.pdb", f"loop-{nloop}.dcd", lj_force, skip=20)
        weight = estimator.estimate_weight(lj_jax)
        dens = weight * density
        dens = dens.mean()
        return jnp.power(dens - 0.85, 2)
    
    v_and_g = jax.value_and_grad(loss, 0)
    v, g = v_and_g(paramset)
    print("Loss:", v)
    # update parameters
    updates, opt_state = optimizer.update(g, opt_state)
    paramset = optax.apply_updates(paramset, updates)
    paramset = jax.tree_map(lambda x: jnp.clip(x, 0.0, 1e8), paramset)
    
    # upate ffinfo
    lj_gen.overwrite(paramset)
    prmop.overwriteLennardJones(prm_top, lj_gen.ffinfo)
    prm_top.save(f"loop-{nloop}.top")
    