In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".95"
import openmm.app as app
import openmm as mm
import openmm.unit as unit
import numpy as np
import jax
import jax.numpy as jnp
import dmff
from dmff.api.xmlio import XMLIO
from dmff.api.paramset import ParamSet
from dmff.generators.classical import CoulombGenerator, LennardJonesGenerator
from dmff.api.hamiltonian import Hamiltonian
from dmff.operators import ParmedLennardJonesOperator
from dmff import NeighborListFreud
from dmff.mbar import ReweightEstimator
import mdtraj as md
from tqdm import tqdm, trange
import parmed
import sys




In [47]:


particle_number =  500           #液体pdb文件中的分子数
target_han= 50.52                  #拟合的目标蒸发焓
SET_temperature=  293.15           #温度设定
time_gap=   2                      #分子动力学模拟过程中每一个frame的时间间隔，单位是皮秒picosecond   推荐2-4ps
loop_time =   3                   #迭代循环次数    推荐50-100
save_step =   3                  #保存的frame数  推荐150-200
skip_step =    2                   #计算各种物理常数之前，达到平衡所需要的步数
length_step = (skip_step+save_step) #分子动力学模拟生成的dcd文件存储的frame个数，与time_gap共同决定总模拟的时长，总时长=length_step*time_gap



In [48]:

prm_top = parmed.load_file("GMX.top")
prm_top_particle_number = prm_top * particle_number
dmfftop = dmff.DMFFTopology(from_top=prm_top_particle_number.topology)
prmop = ParmedLennardJonesOperator()
dmfftop = prmop(dmfftop, gmx_top = prm_top_particle_number)

prmop.renderLennardJonesXML("init.xml")
xmlio = XMLIO()
xmlio.loadXML("init.xml")
ffinfo = xmlio.parseXML()
cov_mat = dmfftop.buildCovMat()

In [49]:
paramset = ParamSet()
lj_gen = LennardJonesGenerator(ffinfo, paramset)
lj_force = lj_gen.createPotential(
    dmfftop, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=1.0, args={})



In [51]:
gap_step=int(time_gap*4)
def runMD(topfile, pdbfile, trajfile, length):
    try:
        os.remove("Lig_particle_number.top")
    except:
        pass
    top_prm = parmed.load_file(topfile)
    top_500 = top_prm * particle_number
    top_500.save("Lig_particle_number.top")
    pdb = app.PDBFile(pdbfile)
    top = app.GromacsTopFile("Lig_particle_number.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
    system.addForce(mm.MonteCarloBarostat(1.0*unit.bar, SET_temperature*unit.kelvin, 25))
    integ = mm.LangevinIntegrator(SET_temperature*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    simulation = app.Simulation(top.topology, system, integ)
    simulation.reporters.append(app.StateDataReporter(sys.stdout, gap_step, time=True, potentialEnergy=True, temperature=True, density=True, speed=True, remainingTime=True, totalSteps=int(length) * gap_step))
    simulation.reporters.append(app.DCDReporter(trajfile, gap_step))
    simulation.context.setPositions(pdb.getPositions())
    simulation.minimizeEnergy(maxIterations=200)
    simulation.step(int(length) * gap_step)
    os.remove("Lig_particle_number.top")


In [52]:
def runMD_gas_NVT(topfile, pdbfile, trajfile, length):
    try:
        os.remove("GAS.top")
    except:
        pass
    top_prm = parmed.load_file(topfile)
    top_GAS = top_prm 
    top_GAS.save("GAS.top")
    pdb = app.PDBFile(pdbfile)
    top = app.GromacsTopFile("GAS.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
    integ = mm.LangevinIntegrator(SET_temperature*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    simulation = app.Simulation(top.topology, system, integ)
    simulation.reporters.append(app.StateDataReporter(sys.stdout, gap_step, time=True, potentialEnergy=True, temperature=True, density=True, speed=True, remainingTime=True, totalSteps=int(length) * gap_step))
    simulation.reporters.append(app.DCDReporter(trajfile, gap_step))
    simulation.context.setPositions(pdb.getPositions())
    simulation.minimizeEnergy(maxIterations=200)
    simulation.step(int(length) * gap_step)
    os.remove("GAS.top")

In [53]:
def rerun_energy(pdb, traj, top, skip=20, removeLJ=True, skpi=0):
    samples = md.load(traj, top=pdb)[skip:]
    try:
        os.remove("Lig_particle_number.top")
    except:
        pass
    top_prm = parmed.load_file(top)
    top_500 = top_prm * particle_number
    top_500.save("Lig_particle_number.top")
    pdb = app.PDBFile(pdb)
    top = app.GromacsTopFile("Lig_particle_number.top")
    os.remove("Lig_particle_number.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
            if removeLJ:
                for npart in range(force.getNumParticles()):
                    chrg, sig, eps = force.getParticleParameters(npart)
                    force.setParticleParameters(npart, chrg, 1.0, 0.0)
                for nex in range(force.getNumExceptions()):
                    p1, p2, chrg, sig, eps = force.getExceptionParameters(nex)
                    force.setExceptionParameters(nex, p1, p2, chrg, 1.0, 0.0)
    integ = mm.LangevinIntegrator(SET_temperature*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    ctx = mm.Context(system, integ)
    energies = []
    for frame in tqdm(samples):
        ctx.setPositions(frame.xyz[0] * unit.nanometer)
        ctx.setPeriodicBoxVectors(*frame.unitcell_vectors[0])
        ctx.applyConstraints(1e-10)
        state = ctx.getState(getEnergy=True)
        energy = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
        energies.append(energy)
    return np.array(energies)

In [54]:
def rerun_energy_gas(pdb, traj, top, skip=50, removeLJ=True, skpi=0):
    samples = md.load(traj, top=pdb)[skip:]
    try:
        os.remove("GAS.top")
    except:
        pass
    top_prm = parmed.load_file(top)
    top_gas = top_prm
    top_gas.save("GAS.top")
    pdb = app.PDBFile(pdb)
    top = app.GromacsTopFile("GAS.top")
    os.remove("GAS.top")
    top.topology.setPeriodicBoxVectors(pdb.topology.getPeriodicBoxVectors())
    system = top.createSystem(nonbondedMethod=app.PME, nonbondedCutoff=1.0*unit.nanometer, constraints=app.HBonds, hydrogenMass=3*unit.dalton)
    
    for force in system.getForces():
        if isinstance(force, mm.NonbondedForce):
            force.setUseDispersionCorrection(False)
            if removeLJ:
                for npart in range(force.getNumParticles()):
                    chrg, sig, eps = force.getParticleParameters(npart)
                    force.setParticleParameters(npart, chrg, 1.0, 0.0)
                for nex in range(force.getNumExceptions()):
                    p1, p2, chrg, sig, eps = force.getExceptionParameters(nex)
                    force.setExceptionParameters(nex, p1, p2, chrg, 1.0, 0.0)
    integ = mm.LangevinIntegrator(SET_temperature*unit.kelvin, 1/unit.picosecond, 2.5*unit.femtosecond)
    ctx = mm.Context(system, integ)
    energies = []
    for frame in tqdm(samples):
        ctx.setPositions(frame.xyz[0] * unit.nanometer)
        ctx.setPeriodicBoxVectors(*frame.unitcell_vectors[0])
        ctx.applyConstraints(1e-10)
        state = ctx.getState(getEnergy=True)
        energy = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
        energies.append(energy)
    return np.array(energies)

In [55]:
def rerun_dmff_lennard_jones(params, pdb, traj, efunc, skip=0):
    samples = md.load(traj, top=pdb)[skip:]
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    for nframe in trange(len(samples)):
        frame = samples[nframe]
        # calc pair
        pairs = jnp.array(nblist.allocate(frame.xyz[0], frame.unitcell_vectors[0]))
        ener = efunc(xyzs_jnp[nframe,:,:], cell_jnp[nframe,:,:], pairs, params)
        energies.append(ener.reshape((1,)))
    energies = jnp.concatenate(energies)
    return energies

In [21]:
import optax
optimizer = optax.adam(0.001)
opt_state = optimizer.init(paramset)

In [22]:
print(paramset.to_jax())

None


In [23]:
lbfgs = None


os.system("cp GMX.top loop-0.top")
Losslist=[]
enthalpy_of_evaporation_list=[]
energy_gas_list=[]
energy_liquid_list=[]
for nloop in range(1, loop_time+1):
    # sample liquid
    print("SAMPLE_liquid")
    try:
        runMD(f"loop-{nloop-1}.top", "liquid.pdb", f"loop-{nloop}.dcd", length=length_step)
    except:
        runMD(f"loop-{nloop-1}.top", "liquid.pdb", f"loop-{nloop}.dcd", length=length_step)
        
    # sample GAs
    print("SAMPLE_GAS")

    try:
        runMD_gas_NVT(f"loop-{nloop-1}.top", "GAS.pdb", f"loop-{nloop}-gas.dcd", length=skip_step+1)
    except:
        runMD_gas_NVT(f"loop-{nloop-1}.top", "GAS.pdb", f"loop-{nloop}-gas.dcd", length=skip_step+1)
        
    #重新计算能量   不可微的计算液态能量 、液态不包含LJ势能的能量 、气态能量
    print("RERUN")
    ener = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=skip_step)
    energy_liquid_list.append(ener.mean())
    ener_gas = rerun_energy_gas("GAS.pdb", f"loop-{nloop}-gas.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=skip_step).mean()
    energy_gas_list.append(ener_gas)
    ener_no_lj = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", skip=skip_step)
    
    print("ESTIMATOR")
    traj = md.load(f"loop-{nloop}.dcd", top="liquid.pdb")[skip_step:]
    estimator = ReweightEstimator(ener, base_energies=ener_no_lj, volume=traj.unitcell_volumes)
    
    #计算当前蒸发焓
    print("CALC han")
    ener_lquid_now = (ener / particle_number).mean()
    enthalpy_of_evaporation_now = ener_gas - ener_lquid_now + 8.314 * SET_temperature * 0.001
    enthalpy_of_evaporation_list.append(enthalpy_of_evaporation_now)
    with open('enthalpy_of_evaporation_list.txt', 'a') as f:
        f.write("%s\n" % str(enthalpy_of_evaporation_now))
        
    # get loss & grad
    def loss(paramset):
        lj_jax = rerun_dmff_lennard_jones(paramset, "liquid.pdb", f"loop-{nloop}.dcd", lj_force, skip=skip_step)
        weight = estimator.estimate_weight(lj_jax)
        ener_lquid = (weight * ener) / particle_number 
        ener_lquid = ener_lquid.mean()
        enthalpy_of_evaporation = ener_gas - ener_lquid + 8.314*SET_temperature*0.001
        return  jnp.power(enthalpy_of_evaporation - target_han, 2)
    
    v_and_g = jax.value_and_grad(loss, 0)
    v, g = v_and_g(paramset)
    print("Loss:", v)
    Losslist.append(v)
    Losslist_np =jax.device_get(v)
    
    with open('Losslist.txt', 'a') as f:
        f.write("%s\n" % str(Losslist_np))
        
    # update parameters
    updates, opt_state = optimizer.update(g, opt_state)
    paramset = optax.apply_updates(paramset, updates)
    paramset = jax.tree_map(lambda x: jnp.clip(x, 0.0, 1e8), paramset)
    
    # upate ffinfo
    lj_gen.overwrite(paramset)
    prmop.overwriteLennardJones(prm_top, lj_gen.ffinfo)
    prm_top.save(f"loop-{nloop}.top")
    break



SAMPLE_liquid
#"Time (ps)","Potential Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)","Time Remaining"
1.9999999999999685,25559.93282267052,261.6260643891745,0.587855285704582,0,--
3.999999999999926,27195.271742220953,291.08799390638137,0.636128354599301,24.6,0:21
6.000000000000238,26198.66408430355,297.4209345472045,0.7114427550023118,24.5,0:14
8.00000000000055,24379.765959753087,297.92881545950445,0.7950640339725844,24.2,0:07
10.000000000000153,23553.00651716363,299.56995957207,0.8611451269801752,24,0:00
SAMPLE_GAS
#"Time (ps)","Potential Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)","Time Remaining"
1.9999999999999685,89.27207064711897,248.68547287686042,0.0001793382192547754,0,--
3.999999999999926,84.5346151109794,397.8411736285588,0.0001793382192547754,38.5,0:04
6.000000000000238,104.47298539977615,441.51972991557955,0.0001793382192547754,38.5,0:00
RERUN


100%|██████████| 3/3 [00:00<00:00, 10.28it/s]
100%|██████████| 1/1 [00:00<00:00, 87.70it/s]
100%|██████████| 3/3 [00:00<00:00, 32.23it/s]


ESTIMATOR
CALC han


100%|██████████| 3/3 [00:03<00:00,  1.15s/it]


Loss: 48.56942415389927


OSError: loop-1.top exists; not overwriting

In [56]:
os.system("cp GMX.top loop-0.top")
Losslist=[]
enthalpy_of_evaporation_list=[]
energy_gas_list=[]
energy_liquid_list=[]
nloop=1
# sample liquid
print("SAMPLE_liquid")
try:
    runMD(f"loop-{nloop-1}.top", "liquid.pdb", f"loop-{nloop}.dcd", length=length_step)
except:
    runMD(f"loop-{nloop-1}.top", "liquid.pdb", f"loop-{nloop}.dcd", length=length_step)
    
# sample GAs
print("SAMPLE_GAS")

try:
    runMD_gas_NVT(f"loop-{nloop-1}.top", "GAS.pdb", f"loop-{nloop}-gas.dcd", length=skip_step+1)
except:
    runMD_gas_NVT(f"loop-{nloop-1}.top", "GAS.pdb", f"loop-{nloop}-gas.dcd", length=skip_step+1)
    
#重新计算能量   不可微的计算液态能量 、液态不包含LJ势能的能量 、气态能量
print("RERUN")
ener = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=skip_step)
energy_liquid_list.append(ener.mean())
ener_gas = rerun_energy_gas("GAS.pdb", f"loop-{nloop}-gas.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=skip_step).mean()
energy_gas_list.append(ener_gas)
ener_no_lj = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", skip=skip_step)

print("ESTIMATOR")
traj = md.load(f"loop-{nloop}.dcd", top="liquid.pdb")[skip_step:]
estimator = ReweightEstimator(ener, base_energies=ener_no_lj, volume=traj.unitcell_volumes)

#计算当前蒸发焓
print("CALC han")
ener_lquid_now = (ener / particle_number).mean()
enthalpy_of_evaporation_now = ener_gas - ener_lquid_now + 8.314 * SET_temperature * 0.001
enthalpy_of_evaporation_list.append(enthalpy_of_evaporation_now)
with open('enthalpy_of_evaporation_list.txt', 'a') as f:
    f.write("%s\n" % str(enthalpy_of_evaporation_now))


SAMPLE_liquid
#"Time (ps)","Potential Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)","Time Remaining"
0.02,8173.932404217377,6.703857642670677,0.5219787647741313,0,--
0.04,8621.532058682706,13.061114815352228,0.5219787647741313,15.6,0:00
0.06000000000000002,9020.91340330848,19.508548603822213,0.5219787647741313,16,0:00
0.08000000000000003,9404.903082206589,25.197891467701712,0.5236844941653742,15.4,0:00
0.10000000000000005,9767.991039853769,31.02650253047768,0.5236844941653742,15.6,0:00
SAMPLE_GAS
#"Time (ps)","Potential Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)","Time Remaining"
0.02,51.96034450469789,4.588350418875178,0.0001793382192547754,0,--
0.04,52.61792733776507,13.034562597753565,0.0001793382192547754,30.2,0:00
0.06000000000000002,53.57526517610673,17.251957124479418,0.0001793382192547754,28.9,0:00
RERUN


100%|██████████| 3/3 [00:00<00:00, 45.54it/s]
100%|██████████| 1/1 [00:00<00:00, 96.25it/s]
100%|██████████| 3/3 [00:00<00:00, 29.49it/s]


ESTIMATOR
CALC han


In [61]:
a = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=False, skip=skip_step)

100%|██████████| 3/3 [00:00<00:00, 42.52it/s]


In [62]:
b = rerun_energy("liquid.pdb", f"loop-{nloop}.dcd", f"loop-{nloop-1}.top", removeLJ=True, skip=skip_step)

100%|██████████| 3/3 [00:00<00:00,  7.12it/s]


In [57]:
print(enthalpy_of_evaporation_now)

37.21659175121799


In [68]:
def rerun_dmff_lennard_jones(params, pdb, traj, efunc, skip=0):
    samples = md.load(traj, top=pdb)[skip:]
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    nblist = NeighborListFreud(samples.unitcell_vectors[0], 1.0, cov_mat)
    xyzs_jnp = jnp.array(samples.xyz)
    cell_jnp = jnp.array(samples.unitcell_vectors)
    energies = []
    for nframe in trange(len(samples)):
        frame = samples[nframe]
        # calc pair
        pairs = jnp.array(nblist.allocate(frame.xyz[0], frame.unitcell_vectors[0]))
        ener = efunc(xyzs_jnp[nframe,:,:], cell_jnp[nframe,:,:], pairs, params)
        energies.append(ener.reshape((1,)))
    energies = jnp.concatenate(energies)
    print("this is nframe", nframe)
    print(frame.xyz[0])
    print()
    return energies
c = rerun_dmff_lennard_jones(paramset, "liquid.pdb", f"loop-{nloop}.dcd", lj_force, skip=skip_step)

100%|██████████| 3/3 [00:01<00:00,  2.17it/s]

this is nframe 2
[[3.4191964 3.3035588 4.1944   ]
 [3.4458673 3.3379915 4.0948067]
 [3.3057523 3.2243938 4.2152944]
 ...
 [5.4177365 3.5355847 2.864259 ]
 [5.326687  3.6647923 2.7327852]
 [5.2525873 3.6593964 2.6698024]]






In [65]:
a-b-c

Array([ 0.00010779, -0.00236301, -0.0025776 ], dtype=float64)

In [None]:
    
# get loss & grad
def loss(paramset):
    lj_jax = rerun_dmff_lennard_jones(paramset, "liquid.pdb", f"loop-{nloop}.dcd", lj_force, skip=skip_step)
    weight = estimator.estimate_weight(lj_jax)
    ener_lquid = (weight * ener) / particle_number 
    ener_lquid = ener_lquid.mean()
    enthalpy_of_evaporation = ener_gas - ener_lquid + 8.314*SET_temperature*0.001
    return  jnp.power(enthalpy_of_evaporation - target_han, 2)

v_and_g = jax.value_and_grad(loss, 0)
v, g = v_and_g(paramset)
print("Loss:", v)
Losslist.append(v)
Losslist_np =jax.device_get(v)

with open('Losslist.txt', 'a') as f:
    f.write("%s\n" % str(Losslist_np))
    
# update parameters
updates, opt_state = optimizer.update(g, opt_state)
paramset = optax.apply_updates(paramset, updates)
paramset = jax.tree_map(lambda x: jnp.clip(x, 0.0, 1e8), paramset)

# upate ffinfo
lj_gen.overwrite(paramset)
prmop.overwriteLennardJones(prm_top, lj_gen.ffinfo)
prm_top.save(f"loop-{nloop}.top")
break


In [31]:
key = lj_gen.name
paramset[key]

{'epsilon': Array([0.358824 , 0.06176  , 0.45673  , 0.0646888, 0.879314 , 0.001    ],      dtype=float64),
 'epsilon_nbfix': Array([], dtype=float64),
 'sigma': Array([0.338967, 0.260964, 0.338967, 0.246135, 0.305647, 0.      ],      dtype=float64),
 'sigma_nbfix': Array([], dtype=float64)}

In [26]:
print(paramset.to_jax())

None


In [27]:
v, g = v_and_g(paramset)

100%|██████████| 3/3 [00:02<00:00,  1.44it/s]


In [29]:
g

<dmff.api.paramset.ParamSet at 0x7fe8c6cac850>

In [34]:
samples = md.load("loop-11.dcd",top="init.pdb")

In [35]:
samples[0]

<mdtraj.Trajectory with 1 frames, 8000 atoms, 500 residues, and unitcells at 0x7fe8c78133d0>

In [37]:
samples[0].unitcell_vectors

array([[[4.415939, 0.      , 0.      ],
        [0.      , 4.415939, 0.      ],
        [0.      , 0.      , 4.415939]]], dtype=float32)

In [38]:
samples[0].xyz

array([[[3.7995355, 2.9408348, 1.4133935],
        [3.758309 , 2.847964 , 1.4520082],
        [3.8661735, 2.9332044, 1.2929636],
        ...,
        [4.6403117, 2.387629 , 2.1475122],
        [4.5987806, 2.212958 , 2.2514014],
        [4.5509157, 2.174051 , 2.1760225]]], dtype=float32)