In [314]:
%load_ext autoreload
%autoreload 2
import espaloma as esp
import torch
import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [315]:
from simtk import unit
GAS_CONSTANT = 8.31446261815324 * unit.joule / (unit.kelvin * unit.mole)
GAS_CONSTANT = GAS_CONSTANT.value_in_unit(
    esp.units.ENERGY_UNIT / (unit.kelvin)
)
kT = GAS_CONSTANT * 300

In [316]:
WINDOWS = 50

In [317]:
def leapfrog(xs, vs, closure, dt=1.0):
    x = xs[-1]
    v = vs[-1]

    x = x + v * dt

    energy_old = closure(x)

    a = -torch.autograd.grad(
        energy_old.sum(),
        [x],
        create_graph=True,
        retain_graph=True,
    )[0]

    v = v + a * dt

    x = x + v * dt

    vs.append(v)
    xs.append(x)

    return xs, vs

In [318]:
g = esp.Graph('CC')
g = esp.graphs.LegacyForceField('smirnoff99Frosst').parametrize(g)

In [319]:
layer = esp.nn.dgl_legacy.gn()

representation = esp.nn.Sequential(
    layer,
    [32, 'leaky_relu', 128, 'leaky_relu', 128, 'leaky_relu'],
)

readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128,
    config=[128, 'leaky_relu', 128, 'leaky_relu'],
    out_features={
        1: {'epsilons': WINDOWS, 'sigma': 1},
        2: {'ks': WINDOWS, 'eqs': WINDOWS},
        3: {'ks': WINDOWS, 'eqs': WINDOWS},
    }
)


net = torch.nn.Sequential(
    representation,
    readout,
)

realize = torch.nn.Sequential(
    esp.mm.geometry.GeometryInGraph(),
    esp.mm.energy.EnergyInGraph(suffix='_ref', terms=['n2', 'n3']),
)


In [320]:
def closure(x, idx, g=g):
    with g.heterograph.local_scope():
        g.nodes['n1'].data['xyz'] = x
        
        if idx != -1:

            g.nodes['n2'].data['eq_ref'] = g.nodes['n2'].data['eqs'][:, idx][:, None].exp()
            g.nodes['n2'].data['k_ref'] = g.nodes['n2'].data['ks'][:, idx][:, None].exp()

            g.nodes['n3'].data['eq_ref'] = g.nodes['n3'].data['eqs'][:, idx][:, None].exp()
            g.nodes['n3'].data['k_ref'] = g.nodes['n3'].data['ks'][:, idx][:, None].exp()
            
        realize(g.heterograph)
        return g.nodes['g'].data['u_ref']



In [321]:
def simulation(net, g=g):
    with g.heterograph.local_scope():
        net(g.heterograph)
        
        particle_distribution = torch.distributions.normal.Normal(
            loc=torch.zeros(g.heterograph.number_of_nodes('n1'), 128, 3),
            scale=g.nodes['n1'].data['sigma'][:, :, None].repeat(1, 128, 3).exp()
        )

        #normal_distribution = torch.distributions.normal.Normal(0, 1.0)
        
        x = torch.nn.Parameter(
            particle_distribution.rsample()
        )
        
        v = torch.zeros_like(x)

        xs = [x]
        vs = [v]
        

        for idx in range(1, WINDOWS):

            xs, vs = leapfrog(xs, vs, lambda x: closure(x, idx, g=g), 1e-2)
        
        return xs, vs, particle_distribution

In [322]:
optimizer = torch.optim.Adam(net.parameters(), 1e-3)
normal_distribution = torch.distributions.normal.Normal(0, 1.0)

for _ in range(1000):
    optimizer.zero_grad()
    
    xs, vs, particle_distribution = simulation(net)
    
    energy = closure(xs[-1], idx=-1).sum()
    
    log_p = -energy/kT + normal_distribution.log_prob(vs[-1]).sum()
    
    log_q = normal_distribution.log_prob(vs[0]).sum() + particle_distribution.log_prob(xs[0]).sum()
    
    loss = -log_p + log_q
    
    loss.backward()
    
    print(loss, energy)
    
    optimizer.step()

tensor(186593.4844, grad_fn=<AddBackward0>) tensor(180.9468, grad_fn=<SumBackward0>)
tensor(190234.2188, grad_fn=<AddBackward0>) tensor(184.3761, grad_fn=<SumBackward0>)
tensor(190979.5938, grad_fn=<AddBackward0>) tensor(185.0009, grad_fn=<SumBackward0>)
tensor(180450.9219, grad_fn=<AddBackward0>) tensor(175.0479, grad_fn=<SumBackward0>)
tensor(176517.3594, grad_fn=<AddBackward0>) tensor(171.3614, grad_fn=<SumBackward0>)
tensor(172670.6875, grad_fn=<AddBackward0>) tensor(167.6714, grad_fn=<SumBackward0>)
tensor(165825.2969, grad_fn=<AddBackward0>) tensor(161.1991, grad_fn=<SumBackward0>)
tensor(154841.3438, grad_fn=<AddBackward0>) tensor(150.7540, grad_fn=<SumBackward0>)
tensor(143151.3906, grad_fn=<AddBackward0>) tensor(139.5247, grad_fn=<SumBackward0>)
tensor(136697.2969, grad_fn=<AddBackward0>) tensor(132.9081, grad_fn=<SumBackward0>)
tensor(145672.5156, grad_fn=<AddBackward0>) tensor(141.0376, grad_fn=<SumBackward0>)
tensor(121934.6328, grad_fn=<AddBackward0>) tensor(118.9739, grad

tensor(37198.8672, grad_fn=<AddBackward0>) tensor(28.0597, grad_fn=<SumBackward0>)
tensor(37126.7188, grad_fn=<AddBackward0>) tensor(30.0662, grad_fn=<SumBackward0>)
tensor(37309.3516, grad_fn=<AddBackward0>) tensor(31.5883, grad_fn=<SumBackward0>)
tensor(38578.4414, grad_fn=<AddBackward0>) tensor(31.7920, grad_fn=<SumBackward0>)
tensor(32290.2930, grad_fn=<AddBackward0>) tensor(22.7956, grad_fn=<SumBackward0>)
tensor(35959.0625, grad_fn=<AddBackward0>) tensor(25.7351, grad_fn=<SumBackward0>)
tensor(33174.3516, grad_fn=<AddBackward0>) tensor(25.3682, grad_fn=<SumBackward0>)
tensor(34003.5117, grad_fn=<AddBackward0>) tensor(28.5207, grad_fn=<SumBackward0>)
tensor(32185.6250, grad_fn=<AddBackward0>) tensor(25.6084, grad_fn=<SumBackward0>)
tensor(34362.1250, grad_fn=<AddBackward0>) tensor(23.9991, grad_fn=<SumBackward0>)
tensor(33932.5312, grad_fn=<AddBackward0>) tensor(28.4069, grad_fn=<SumBackward0>)
tensor(37057.4883, grad_fn=<AddBackward0>) tensor(29.5942, grad_fn=<SumBackward0>)
tens

tensor(141667.8125, grad_fn=<AddBackward0>) tensor(137.6611, grad_fn=<SumBackward0>)
tensor(138499.4688, grad_fn=<AddBackward0>) tensor(134.6620, grad_fn=<SumBackward0>)
tensor(134494.0312, grad_fn=<AddBackward0>) tensor(130.7610, grad_fn=<SumBackward0>)
tensor(130440.9453, grad_fn=<AddBackward0>) tensor(126.7955, grad_fn=<SumBackward0>)
tensor(129470.6719, grad_fn=<AddBackward0>) tensor(125.7094, grad_fn=<SumBackward0>)
tensor(134293.9844, grad_fn=<AddBackward0>) tensor(129.9443, grad_fn=<SumBackward0>)
tensor(135104.7812, grad_fn=<AddBackward0>) tensor(130.5447, grad_fn=<SumBackward0>)
tensor(130716.0625, grad_fn=<AddBackward0>) tensor(126.4716, grad_fn=<SumBackward0>)
tensor(124129.9766, grad_fn=<AddBackward0>) tensor(120.5229, grad_fn=<SumBackward0>)
tensor(121960.1406, grad_fn=<AddBackward0>) tensor(118.6596, grad_fn=<SumBackward0>)
tensor(119371.6406, grad_fn=<AddBackward0>) tensor(116.2960, grad_fn=<SumBackward0>)
tensor(114497.6094, grad_fn=<AddBackward0>) tensor(111.7050, grad

tensor(160574.8125, grad_fn=<AddBackward0>) tensor(129.8873, grad_fn=<SumBackward0>)
tensor(136019.1094, grad_fn=<AddBackward0>) tensor(127.5361, grad_fn=<SumBackward0>)
tensor(152408.6719, grad_fn=<AddBackward0>) tensor(145.0010, grad_fn=<SumBackward0>)
tensor(167169.3125, grad_fn=<AddBackward0>) tensor(160.6127, grad_fn=<SumBackward0>)
tensor(174925.3750, grad_fn=<AddBackward0>) tensor(168.6722, grad_fn=<SumBackward0>)
tensor(171860.0625, grad_fn=<AddBackward0>) tensor(166.1376, grad_fn=<SumBackward0>)
tensor(181374.7344, grad_fn=<AddBackward0>) tensor(175.2779, grad_fn=<SumBackward0>)
tensor(178720.3750, grad_fn=<AddBackward0>) tensor(172.9112, grad_fn=<SumBackward0>)
tensor(177920.6875, grad_fn=<AddBackward0>) tensor(172.2164, grad_fn=<SumBackward0>)
tensor(180363.5625, grad_fn=<AddBackward0>) tensor(174.5706, grad_fn=<SumBackward0>)
tensor(175129.9375, grad_fn=<AddBackward0>) tensor(169.6760, grad_fn=<SumBackward0>)
tensor(176121.9844, grad_fn=<AddBackward0>) tensor(170.6753, grad

tensor(40002.7656, grad_fn=<AddBackward0>) tensor(33.5660, grad_fn=<SumBackward0>)
tensor(41938.0156, grad_fn=<AddBackward0>) tensor(34.9605, grad_fn=<SumBackward0>)
tensor(40191.6250, grad_fn=<AddBackward0>) tensor(33.6412, grad_fn=<SumBackward0>)
tensor(42200.5352, grad_fn=<AddBackward0>) tensor(34.2651, grad_fn=<SumBackward0>)
tensor(37802.3438, grad_fn=<AddBackward0>) tensor(29.5942, grad_fn=<SumBackward0>)
tensor(39244.1172, grad_fn=<AddBackward0>) tensor(31.1299, grad_fn=<SumBackward0>)
tensor(40416.7695, grad_fn=<AddBackward0>) tensor(33.0236, grad_fn=<SumBackward0>)
tensor(38784.3633, grad_fn=<AddBackward0>) tensor(31.4313, grad_fn=<SumBackward0>)
tensor(38119.2305, grad_fn=<AddBackward0>) tensor(31.1541, grad_fn=<SumBackward0>)
tensor(37730.0781, grad_fn=<AddBackward0>) tensor(30.0351, grad_fn=<SumBackward0>)
tensor(37683.0625, grad_fn=<AddBackward0>) tensor(29.8802, grad_fn=<SumBackward0>)
tensor(37562.0195, grad_fn=<AddBackward0>) tensor(30.4721, grad_fn=<SumBackward0>)
tens

tensor(26728.5977, grad_fn=<AddBackward0>) tensor(19.4724, grad_fn=<SumBackward0>)
tensor(25326.5098, grad_fn=<AddBackward0>) tensor(18.1674, grad_fn=<SumBackward0>)
tensor(27257.6406, grad_fn=<AddBackward0>) tensor(20.1024, grad_fn=<SumBackward0>)
tensor(26511.9414, grad_fn=<AddBackward0>) tensor(19.0778, grad_fn=<SumBackward0>)
tensor(25012.3594, grad_fn=<AddBackward0>) tensor(18.5737, grad_fn=<SumBackward0>)
tensor(26802.3242, grad_fn=<AddBackward0>) tensor(19.6609, grad_fn=<SumBackward0>)
tensor(25608.2734, grad_fn=<AddBackward0>) tensor(18.9204, grad_fn=<SumBackward0>)
tensor(25394.9297, grad_fn=<AddBackward0>) tensor(18.9342, grad_fn=<SumBackward0>)
tensor(28304.9434, grad_fn=<AddBackward0>) tensor(20.8273, grad_fn=<SumBackward0>)
tensor(26720.0996, grad_fn=<AddBackward0>) tensor(18.5188, grad_fn=<SumBackward0>)
tensor(24634.8301, grad_fn=<AddBackward0>) tensor(16.5656, grad_fn=<SumBackward0>)
tensor(27007.8398, grad_fn=<AddBackward0>) tensor(19.1467, grad_fn=<SumBackward0>)
tens

KeyboardInterrupt: 

In [323]:
xs, vs, particle_distribution = simulation(net)

In [326]:
import nglview as nv
from rdkit.Geometry import Point3D
from rdkit import Chem
from rdkit.Chem import AllChem

conf_idx = 1

mol = g.mol.to_rdkit()
AllChem.EmbedMolecule(mol)
conf = mol.GetConformer()

xs, vs, particle_distribution = simulation(net)
x = xs[-1]


for idx_atom in range(mol.GetNumAtoms()):
    conf.SetAtomPosition(
        idx_atom,
        Point3D(
            float(x[idx_atom, conf_idx, 0]),
            float(x[idx_atom, conf_idx, 1]),
            float(x[idx_atom, conf_idx, 2]),
        ))
    
nv.show_rdkit(mol)

NGLWidget()