In [44]:
%load_ext autoreload
%autoreload 2
import espaloma as esp
import torch
import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
from simtk import unit
GAS_CONSTANT = 8.31446261815324 * unit.joule / (unit.kelvin * unit.mole)
GAS_CONSTANT = GAS_CONSTANT.value_in_unit(
    esp.units.ENERGY_UNIT / (unit.kelvin)
)
kT = GAS_CONSTANT * 300

In [46]:
WINDOWS = 50

In [47]:
def leapfrog(xs, vs, closure, dt=1.0):
    x = xs[-1]
    v = vs[-1]

    x = x + v * dt

    energy_old = closure(x)

    a = -torch.autograd.grad(
        energy_old.sum(),
        [x],
        create_graph=True,
        retain_graph=True,
    )[0]

    v = v + a * dt

    x = x + v * dt

    vs.append(v)
    xs.append(x)

    return xs, vs

In [48]:
gs = esp.data.dataset.GraphDataset([esp.Graph('C' * idx) for idx in range(1, 3)])
gs.apply(
    esp.graphs.LegacyForceField('smirnoff99Frosst').parametrize,
    in_place=True,
)
ds = gs.view(batch_size=len(gs))

In [49]:
layer = esp.nn.dgl_legacy.gn()

representation = esp.nn.Sequential(
    layer,
    [32, 'leaky_relu', 128, 'leaky_relu', 128, 'leaky_relu'],
)

readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128,
    config=[128, 'leaky_relu', 128, 'leaky_relu'],
    out_features={
        1: {'epsilons': WINDOWS, 'sigma': 1},
        2: {'ks': WINDOWS, 'eqs': WINDOWS},
        3: {'ks': WINDOWS, 'eqs': WINDOWS},
    }
)


net = torch.nn.Sequential(
    representation,
    readout,
)

realize = torch.nn.Sequential(
    esp.mm.geometry.GeometryInGraph(),
    esp.mm.energy.EnergyInGraph(suffix='_ref', terms=['n2', 'n3']),
)


In [50]:
def closure(x, idx, g):
    with g.local_scope():
        g.nodes['n1'].data['xyz'] = x
        
        if idx != -1:

            g.nodes['n2'].data['eq_ref'] = g.nodes['n2'].data['eqs'][:, idx][:, None]
            g.nodes['n2'].data['k_ref'] = g.nodes['n2'].data['ks'][:, idx][:, None]

            g.nodes['n3'].data['eq_ref'] = g.nodes['n3'].data['eqs'][:, idx][:, None]
            g.nodes['n3'].data['k_ref'] = g.nodes['n3'].data['ks'][:, idx][:, None]
            
        realize(g)
        return g.nodes['g'].data['u_ref']



In [51]:
def simulation(net, g):
    with g.local_scope():
        net(g)
        
        particle_distribution = torch.distributions.normal.Normal(
            loc=torch.zeros(g.number_of_nodes('n1'), 128, 3),
            scale=g.nodes['n1'].data['sigma'][:, :, None].repeat(1, 128, 3).exp()
        )

        #normal_distribution = torch.distributions.normal.Normal(0, 1.0)
        
        x = torch.nn.Parameter(
            particle_distribution.rsample()
        )
        
        v = torch.zeros_like(x)

        xs = [x]
        vs = [v]
        

        for idx in range(1, WINDOWS):

            xs, vs = leapfrog(xs, vs, lambda x: closure(x, idx, g=g), 1e-2)
        
        return xs, vs, particle_distribution

In [53]:
optimizer = torch.optim.Adam(net.parameters(), 1e-3)
normal_distribution = torch.distributions.normal.Normal(0, 1.0)

for _ in range(10000):
    for g in ds:
        optimizer.zero_grad()

        xs, vs, particle_distribution = simulation(net, g)

        energy = closure(xs[-1], idx=-1, g=g).sum()

        log_p = -energy/kT + normal_distribution.log_prob(vs[-1]).sum()

        log_q = normal_distribution.log_prob(vs[0]).sum() + particle_distribution.log_prob(xs[0]).sum()

        loss = -log_p + log_q

        loss.backward()

        print(loss, energy)

        optimizer.step()

tensor(329006.0312, grad_fn=<AddBackward0>) tensor(318.9930, grad_fn=<SumBackward0>)
tensor(332908.4688, grad_fn=<AddBackward0>) tensor(322.5504, grad_fn=<SumBackward0>)
tensor(324993.3125, grad_fn=<AddBackward0>) tensor(314.7919, grad_fn=<SumBackward0>)
tensor(324919.4062, grad_fn=<AddBackward0>) tensor(314.3269, grad_fn=<SumBackward0>)
tensor(319555.5938, grad_fn=<AddBackward0>) tensor(309.1364, grad_fn=<SumBackward0>)
tensor(312144.4375, grad_fn=<AddBackward0>) tensor(302.0568, grad_fn=<SumBackward0>)
tensor(303424.1562, grad_fn=<AddBackward0>) tensor(293.7985, grad_fn=<SumBackward0>)
tensor(279711.7812, grad_fn=<AddBackward0>) tensor(271.2476, grad_fn=<SumBackward0>)
tensor(261853.5469, grad_fn=<AddBackward0>) tensor(254.0311, grad_fn=<SumBackward0>)
tensor(242486.4844, grad_fn=<AddBackward0>) tensor(235.2502, grad_fn=<SumBackward0>)
tensor(258003.1250, grad_fn=<AddBackward0>) tensor(249.7893, grad_fn=<SumBackward0>)
tensor(255138.7969, grad_fn=<AddBackward0>) tensor(247.1566, grad

tensor(72434.0938, grad_fn=<AddBackward0>) tensor(70.9084, grad_fn=<SumBackward0>)
tensor(73990.2344, grad_fn=<AddBackward0>) tensor(71.9548, grad_fn=<SumBackward0>)
tensor(69212.5312, grad_fn=<AddBackward0>) tensor(68.3079, grad_fn=<SumBackward0>)
tensor(69807.4844, grad_fn=<AddBackward0>) tensor(69.4598, grad_fn=<SumBackward0>)
tensor(74262.8516, grad_fn=<AddBackward0>) tensor(74.1278, grad_fn=<SumBackward0>)
tensor(70397.7266, grad_fn=<AddBackward0>) tensor(70.1263, grad_fn=<SumBackward0>)
tensor(75042.4531, grad_fn=<AddBackward0>) tensor(73.7766, grad_fn=<SumBackward0>)
tensor(75358.0078, grad_fn=<AddBackward0>) tensor(73.7975, grad_fn=<SumBackward0>)
tensor(75858.1719, grad_fn=<AddBackward0>) tensor(74.4207, grad_fn=<SumBackward0>)
tensor(73377.2422, grad_fn=<AddBackward0>) tensor(71.7046, grad_fn=<SumBackward0>)
tensor(72292.1719, grad_fn=<AddBackward0>) tensor(70.3769, grad_fn=<SumBackward0>)
tensor(67200.1875, grad_fn=<AddBackward0>) tensor(66.4130, grad_fn=<SumBackward0>)
tens

tensor(60813.8086, grad_fn=<AddBackward0>) tensor(57.4113, grad_fn=<SumBackward0>)
tensor(55540.4648, grad_fn=<AddBackward0>) tensor(52.4310, grad_fn=<SumBackward0>)
tensor(57009.7188, grad_fn=<AddBackward0>) tensor(53.6724, grad_fn=<SumBackward0>)
tensor(59862.6016, grad_fn=<AddBackward0>) tensor(56.8509, grad_fn=<SumBackward0>)
tensor(60770.0234, grad_fn=<AddBackward0>) tensor(57.5375, grad_fn=<SumBackward0>)
tensor(57909.4844, grad_fn=<AddBackward0>) tensor(54.8872, grad_fn=<SumBackward0>)
tensor(55824.7734, grad_fn=<AddBackward0>) tensor(52.3933, grad_fn=<SumBackward0>)
tensor(56051.3945, grad_fn=<AddBackward0>) tensor(54.8254, grad_fn=<SumBackward0>)
tensor(61011.5156, grad_fn=<AddBackward0>) tensor(59.3455, grad_fn=<SumBackward0>)
tensor(54907.0742, grad_fn=<AddBackward0>) tensor(52.9731, grad_fn=<SumBackward0>)
tensor(53961.7031, grad_fn=<AddBackward0>) tensor(51.1415, grad_fn=<SumBackward0>)
tensor(61204.8750, grad_fn=<AddBackward0>) tensor(56.3073, grad_fn=<SumBackward0>)
tens

tensor(52488.6172, grad_fn=<AddBackward0>) tensor(48.3294, grad_fn=<SumBackward0>)
tensor(51048.6367, grad_fn=<AddBackward0>) tensor(46.7572, grad_fn=<SumBackward0>)
tensor(52779.4961, grad_fn=<AddBackward0>) tensor(48.5236, grad_fn=<SumBackward0>)
tensor(51054.3516, grad_fn=<AddBackward0>) tensor(47.1998, grad_fn=<SumBackward0>)
tensor(50660.2539, grad_fn=<AddBackward0>) tensor(46.7002, grad_fn=<SumBackward0>)
tensor(52065.1094, grad_fn=<AddBackward0>) tensor(47.7271, grad_fn=<SumBackward0>)
tensor(51346.1797, grad_fn=<AddBackward0>) tensor(47.1185, grad_fn=<SumBackward0>)
tensor(53244.4844, grad_fn=<AddBackward0>) tensor(47.9335, grad_fn=<SumBackward0>)
tensor(48985.2031, grad_fn=<AddBackward0>) tensor(45.6042, grad_fn=<SumBackward0>)
tensor(53135.3828, grad_fn=<AddBackward0>) tensor(50.0540, grad_fn=<SumBackward0>)
tensor(52543.6328, grad_fn=<AddBackward0>) tensor(49.7406, grad_fn=<SumBackward0>)
tensor(56995.9766, grad_fn=<AddBackward0>) tensor(54.5832, grad_fn=<SumBackward0>)
tens

tensor(247162.7500, grad_fn=<AddBackward0>) tensor(229.4249, grad_fn=<SumBackward0>)
tensor(255819.7656, grad_fn=<AddBackward0>) tensor(239.3114, grad_fn=<SumBackward0>)
tensor(256227.7344, grad_fn=<AddBackward0>) tensor(239.6719, grad_fn=<SumBackward0>)
tensor(235466.0938, grad_fn=<AddBackward0>) tensor(219.9346, grad_fn=<SumBackward0>)
tensor(234038.7656, grad_fn=<AddBackward0>) tensor(218.0672, grad_fn=<SumBackward0>)
tensor(227504.1094, grad_fn=<AddBackward0>) tensor(211.5413, grad_fn=<SumBackward0>)
tensor(225547.2812, grad_fn=<AddBackward0>) tensor(209.5779, grad_fn=<SumBackward0>)
tensor(231001.9375, grad_fn=<AddBackward0>) tensor(214.1197, grad_fn=<SumBackward0>)
tensor(223046.0781, grad_fn=<AddBackward0>) tensor(205.8928, grad_fn=<SumBackward0>)
tensor(218681.3281, grad_fn=<AddBackward0>) tensor(201.2074, grad_fn=<SumBackward0>)
tensor(257035.7656, grad_fn=<AddBackward0>) tensor(236.0344, grad_fn=<SumBackward0>)
tensor(344580.1562, grad_fn=<AddBackward0>) tensor(315.2585, grad

tensor(192047.6562, grad_fn=<AddBackward0>) tensor(175.3944, grad_fn=<SumBackward0>)
tensor(238383.0938, grad_fn=<AddBackward0>) tensor(217.4411, grad_fn=<SumBackward0>)
tensor(261698.9062, grad_fn=<AddBackward0>) tensor(236.2791, grad_fn=<SumBackward0>)
tensor(256758.8750, grad_fn=<AddBackward0>) tensor(232.4551, grad_fn=<SumBackward0>)
tensor(232668.1094, grad_fn=<AddBackward0>) tensor(212.2494, grad_fn=<SumBackward0>)
tensor(232842.5156, grad_fn=<AddBackward0>) tensor(215.9680, grad_fn=<SumBackward0>)
tensor(229502.2344, grad_fn=<AddBackward0>) tensor(214.9768, grad_fn=<SumBackward0>)
tensor(229862.4219, grad_fn=<AddBackward0>) tensor(215.9447, grad_fn=<SumBackward0>)
tensor(229709.7812, grad_fn=<AddBackward0>) tensor(216.7875, grad_fn=<SumBackward0>)
tensor(239430.2656, grad_fn=<AddBackward0>) tensor(226.2998, grad_fn=<SumBackward0>)
tensor(243717.7188, grad_fn=<AddBackward0>) tensor(230.7130, grad_fn=<SumBackward0>)
tensor(246006.4688, grad_fn=<AddBackward0>) tensor(232.4584, grad

tensor(244433.5469, grad_fn=<AddBackward0>) tensor(217.4123, grad_fn=<SumBackward0>)
tensor(295903., grad_fn=<AddBackward0>) tensor(260.9845, grad_fn=<SumBackward0>)
tensor(322557.3125, grad_fn=<AddBackward0>) tensor(279.4186, grad_fn=<SumBackward0>)
tensor(318768.3125, grad_fn=<AddBackward0>) tensor(273.9192, grad_fn=<SumBackward0>)
tensor(323342.0625, grad_fn=<AddBackward0>) tensor(279.8835, grad_fn=<SumBackward0>)
tensor(323792.4062, grad_fn=<AddBackward0>) tensor(282.9439, grad_fn=<SumBackward0>)
tensor(424326.0312, grad_fn=<AddBackward0>) tensor(370.4921, grad_fn=<SumBackward0>)
tensor(315894.1875, grad_fn=<AddBackward0>) tensor(270.8043, grad_fn=<SumBackward0>)
tensor(315625.5000, grad_fn=<AddBackward0>) tensor(267.1221, grad_fn=<SumBackward0>)
tensor(324811.6562, grad_fn=<AddBackward0>) tensor(273.4082, grad_fn=<SumBackward0>)
tensor(342678.3750, grad_fn=<AddBackward0>) tensor(286.2694, grad_fn=<SumBackward0>)
tensor(381676.8125, grad_fn=<AddBackward0>) tensor(320.6440, grad_fn=

tensor(227496.9219, grad_fn=<AddBackward0>) tensor(206.0938, grad_fn=<SumBackward0>)
tensor(219836., grad_fn=<AddBackward0>) tensor(200.1925, grad_fn=<SumBackward0>)
tensor(229487.1562, grad_fn=<AddBackward0>) tensor(207.9289, grad_fn=<SumBackward0>)
tensor(226085.4062, grad_fn=<AddBackward0>) tensor(205.8613, grad_fn=<SumBackward0>)
tensor(224153.8281, grad_fn=<AddBackward0>) tensor(204.9453, grad_fn=<SumBackward0>)
tensor(209722.5312, grad_fn=<AddBackward0>) tensor(192.5626, grad_fn=<SumBackward0>)
tensor(209920.8125, grad_fn=<AddBackward0>) tensor(193.3245, grad_fn=<SumBackward0>)
tensor(215137.8125, grad_fn=<AddBackward0>) tensor(199.7410, grad_fn=<SumBackward0>)
tensor(211009.9531, grad_fn=<AddBackward0>) tensor(195.5526, grad_fn=<SumBackward0>)
tensor(209081.6094, grad_fn=<AddBackward0>) tensor(193.2923, grad_fn=<SumBackward0>)
tensor(237530.3750, grad_fn=<AddBackward0>) tensor(217.7864, grad_fn=<SumBackward0>)
tensor(318662.6875, grad_fn=<AddBackward0>) tensor(291.2039, grad_fn=

tensor(219995.6875, grad_fn=<AddBackward0>) tensor(206.3724, grad_fn=<SumBackward0>)
tensor(217048.2812, grad_fn=<AddBackward0>) tensor(203.7095, grad_fn=<SumBackward0>)
tensor(225486.4531, grad_fn=<AddBackward0>) tensor(211.8668, grad_fn=<SumBackward0>)
tensor(214821.5938, grad_fn=<AddBackward0>) tensor(201.5731, grad_fn=<SumBackward0>)
tensor(221567.0938, grad_fn=<AddBackward0>) tensor(207.5082, grad_fn=<SumBackward0>)
tensor(219500.6250, grad_fn=<AddBackward0>) tensor(206.1804, grad_fn=<SumBackward0>)
tensor(213983.8438, grad_fn=<AddBackward0>) tensor(201.5202, grad_fn=<SumBackward0>)
tensor(215480.3281, grad_fn=<AddBackward0>) tensor(202.4161, grad_fn=<SumBackward0>)
tensor(215035.9375, grad_fn=<AddBackward0>) tensor(201.6011, grad_fn=<SumBackward0>)
tensor(213065.3750, grad_fn=<AddBackward0>) tensor(200.2352, grad_fn=<SumBackward0>)
tensor(201216.3281, grad_fn=<AddBackward0>) tensor(189.7112, grad_fn=<SumBackward0>)
tensor(207534.4375, grad_fn=<AddBackward0>) tensor(195.9459, grad

tensor(112259.4219, grad_fn=<AddBackward0>) tensor(107.3231, grad_fn=<SumBackward0>)
tensor(110568.4766, grad_fn=<AddBackward0>) tensor(105.9317, grad_fn=<SumBackward0>)
tensor(112480.2266, grad_fn=<AddBackward0>) tensor(108.1600, grad_fn=<SumBackward0>)
tensor(107977.0156, grad_fn=<AddBackward0>) tensor(104.3848, grad_fn=<SumBackward0>)
tensor(107866.4375, grad_fn=<AddBackward0>) tensor(104.5999, grad_fn=<SumBackward0>)
tensor(105436.8672, grad_fn=<AddBackward0>) tensor(102.7095, grad_fn=<SumBackward0>)
tensor(108053.8516, grad_fn=<AddBackward0>) tensor(105.2370, grad_fn=<SumBackward0>)
tensor(100658.3438, grad_fn=<AddBackward0>) tensor(98.5525, grad_fn=<SumBackward0>)
tensor(103070.4844, grad_fn=<AddBackward0>) tensor(100.7184, grad_fn=<SumBackward0>)
tensor(99547.6172, grad_fn=<AddBackward0>) tensor(97.4753, grad_fn=<SumBackward0>)
tensor(97095.2656, grad_fn=<AddBackward0>) tensor(95.0179, grad_fn=<SumBackward0>)
tensor(94791.7266, grad_fn=<AddBackward0>) tensor(93.0682, grad_fn=<Su

tensor(229034.0312, grad_fn=<AddBackward0>) tensor(219.9174, grad_fn=<SumBackward0>)
tensor(231667.1875, grad_fn=<AddBackward0>) tensor(222.6156, grad_fn=<SumBackward0>)
tensor(232180.5000, grad_fn=<AddBackward0>) tensor(223.3775, grad_fn=<SumBackward0>)
tensor(235530.2188, grad_fn=<AddBackward0>) tensor(226.7575, grad_fn=<SumBackward0>)
tensor(233421.1250, grad_fn=<AddBackward0>) tensor(224.9290, grad_fn=<SumBackward0>)
tensor(235849.8125, grad_fn=<AddBackward0>) tensor(227.1478, grad_fn=<SumBackward0>)
tensor(233277.3594, grad_fn=<AddBackward0>) tensor(224.7902, grad_fn=<SumBackward0>)
tensor(224327.6719, grad_fn=<AddBackward0>) tensor(216.4016, grad_fn=<SumBackward0>)
tensor(229130.6719, grad_fn=<AddBackward0>) tensor(220.8784, grad_fn=<SumBackward0>)
tensor(229858.3281, grad_fn=<AddBackward0>) tensor(221.3576, grad_fn=<SumBackward0>)
tensor(230447.1562, grad_fn=<AddBackward0>) tensor(221.8573, grad_fn=<SumBackward0>)
tensor(229504.2344, grad_fn=<AddBackward0>) tensor(220.8980, grad

tensor(103759.7656, grad_fn=<AddBackward0>) tensor(98.7631, grad_fn=<SumBackward0>)
tensor(100716.8750, grad_fn=<AddBackward0>) tensor(95.9743, grad_fn=<SumBackward0>)
tensor(102489.1406, grad_fn=<AddBackward0>) tensor(97.4806, grad_fn=<SumBackward0>)
tensor(106939.5000, grad_fn=<AddBackward0>) tensor(101.6107, grad_fn=<SumBackward0>)
tensor(95238.4141, grad_fn=<AddBackward0>) tensor(90.5424, grad_fn=<SumBackward0>)
tensor(96240.4062, grad_fn=<AddBackward0>) tensor(91.7634, grad_fn=<SumBackward0>)
tensor(95260.8203, grad_fn=<AddBackward0>) tensor(90.8790, grad_fn=<SumBackward0>)
tensor(92972.2891, grad_fn=<AddBackward0>) tensor(88.7986, grad_fn=<SumBackward0>)
tensor(94811.4062, grad_fn=<AddBackward0>) tensor(90.4925, grad_fn=<SumBackward0>)
tensor(95084.4375, grad_fn=<AddBackward0>) tensor(91.1123, grad_fn=<SumBackward0>)
tensor(93002.8828, grad_fn=<AddBackward0>) tensor(89.2902, grad_fn=<SumBackward0>)
tensor(89769.8828, grad_fn=<AddBackward0>) tensor(86.1475, grad_fn=<SumBackward0>)

tensor(195856.7812, grad_fn=<AddBackward0>) tensor(189.5904, grad_fn=<SumBackward0>)
tensor(194076.9844, grad_fn=<AddBackward0>) tensor(187.5139, grad_fn=<SumBackward0>)
tensor(198814.9844, grad_fn=<AddBackward0>) tensor(192.0396, grad_fn=<SumBackward0>)
tensor(204771.8750, grad_fn=<AddBackward0>) tensor(197.4704, grad_fn=<SumBackward0>)
tensor(194545.4062, grad_fn=<AddBackward0>) tensor(188.0250, grad_fn=<SumBackward0>)
tensor(193679.4688, grad_fn=<AddBackward0>) tensor(187.4990, grad_fn=<SumBackward0>)
tensor(192361.7188, grad_fn=<AddBackward0>) tensor(186.4603, grad_fn=<SumBackward0>)
tensor(190429.2656, grad_fn=<AddBackward0>) tensor(184.7566, grad_fn=<SumBackward0>)
tensor(195205.3438, grad_fn=<AddBackward0>) tensor(189.3195, grad_fn=<SumBackward0>)
tensor(186464.2812, grad_fn=<AddBackward0>) tensor(180.8230, grad_fn=<SumBackward0>)
tensor(188156.9375, grad_fn=<AddBackward0>) tensor(182.3782, grad_fn=<SumBackward0>)
tensor(184879.8281, grad_fn=<AddBackward0>) tensor(179.0846, grad

tensor(67507.2734, grad_fn=<AddBackward0>) tensor(66.6812, grad_fn=<SumBackward0>)
tensor(66592.0938, grad_fn=<AddBackward0>) tensor(65.9933, grad_fn=<SumBackward0>)
tensor(65900.2500, grad_fn=<AddBackward0>) tensor(65.1458, grad_fn=<SumBackward0>)
tensor(63563.4844, grad_fn=<AddBackward0>) tensor(62.6403, grad_fn=<SumBackward0>)
tensor(60770.5938, grad_fn=<AddBackward0>) tensor(59.3586, grad_fn=<SumBackward0>)
tensor(64675.3242, grad_fn=<AddBackward0>) tensor(61.9203, grad_fn=<SumBackward0>)
tensor(61868.7734, grad_fn=<AddBackward0>) tensor(59.1354, grad_fn=<SumBackward0>)
tensor(62317.8164, grad_fn=<AddBackward0>) tensor(59.7062, grad_fn=<SumBackward0>)
tensor(58930.3516, grad_fn=<AddBackward0>) tensor(56.9834, grad_fn=<SumBackward0>)
tensor(59400.9727, grad_fn=<AddBackward0>) tensor(57.5497, grad_fn=<SumBackward0>)
tensor(60759.3945, grad_fn=<AddBackward0>) tensor(58.9721, grad_fn=<SumBackward0>)
tensor(59656.6133, grad_fn=<AddBackward0>) tensor(57.4605, grad_fn=<SumBackward0>)
tens

KeyboardInterrupt: 

In [None]:
torch.save(
    net.state_dict(),
    'net.th'
)

In [41]:
g = esp.Graph('CC')

In [42]:
xs, vs, particle_distribution = simulation(net, g=g.heterograph)

In [43]:
import nglview as nv
from rdkit.Geometry import Point3D
from rdkit import Chem
from rdkit.Chem import AllChem

conf_idx = 1

mol = g.mol.to_rdkit()
AllChem.EmbedMolecule(mol)
conf = mol.GetConformer()

xs, vs, particle_distribution = simulation(net, g=g.heterograph)
x = xs[-1]


for idx_atom in range(mol.GetNumAtoms()):
    conf.SetAtomPosition(
        idx_atom,
        Point3D(
            float(x[idx_atom, conf_idx, 0]),
            float(x[idx_atom, conf_idx, 1]),
            float(x[idx_atom, conf_idx, 2]),
        ))
    
nv.show_rdkit(mol)

NGLWidget()