In [None]:
import sys, os, math
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw

import torch
import espaloma as esp

from openff.toolkit.topology import Molecule

In [None]:
dataset_name = "gen2"
# dataset_name = "pepconf"
# dataset_name = "vehicle"
# dataset_name = "phalkethoh"

In [None]:
%%capture
!wget "data.wangyq.net/esp_dataset/"$dataset_name".zip"
!unzip $dataset_name".zip"

In [None]:
ds = esp.data.dataset.GraphDataset.load(dataset_name)
ds.shuffle(seed=2666)
ds_tr, ds_vl, ds_te = ds.split([8, 1, 1])

## espaloma training

#### stage 1: graph -> atom latent representation

In [None]:
representation = esp.nn.Sequential(
    layer=esp.nn.layers.dgl_legacy.gn("SAGEConv"), # use SAGEConv implementation in DGL
    config=[128, "relu", 128, "relu", 128, "relu"], # 3 layers, 128 units, ReLU activation
)

#### stage 2 and 3: atom latent representation bond, angle, and torsion representation and parameters

In [None]:
readout = esp.nn.readout.janossy.JanossyPooling(
    in_features=128, config=[128, "relu", 128, "relu", 128, "relu"],
    out_features={              # define modular MM parameters Espaloma will assign
        1: {"e": 1, "s": 1}, # atom hardness and electronegativity
        2: {"log_coefficients": 2}, # bond linear combination, enforce positive
        3: {"log_coefficients": 2}, # angle linear combination, enforce positive
        4: {"k": 6}, # torsion barrier heights (can be positive or negative)
    },
)

#### compose all three stages Espaloma stages into an end-to-end model

In [None]:
espaloma_model = torch.nn.Sequential(
    representation, readout, esp.nn.readout.janossy.ExpCoefficients(),
    esp.mm.geometry.GeometryInGraph(),
    esp.mm.energy.EnergyInGraph(),
    #esp.mm.energy.EnergyInGraph(suffix="_ref"),
    #esp.nn.readout.charge_equilibrium.ChargeEquilibrium(),
)

In [None]:
if torch.cuda.is_available():
    espaloma_model = espaloma_model.cuda()
else:
    print("cuda not available")

In [None]:
# define loss function (MSE between predicted and reference energy)
loss_fn = esp.metrics.GraphMetric(
    base_metric=torch.nn.MSELoss(), # use mean-squared error loss
    between=['u', "u_ref"],         # between predicted and QM energies
    level="g", # compare on graph level
)

In [None]:
# define optimizer
optimizer = torch.optim.Adam(espaloma_model.parameters(), 1e-4)

## train model

In [None]:
#n_epochs = 10000
n_epochs = 10

In [None]:
for idx_epoch in range(n_epochs):
    for g in ds_tr:
        optimizer.zero_grad()
        if torch.cuda.is_available():
            g.heterograph = g.heterograph.to("cuda:0")
        g = espaloma_model(g.heterograph)
        loss = loss_fn(g)
        loss.backward()
        optimizer.step()
    torch.save(espaloma_model.state_dict(), "%s.th" % idx_epoch)

## inspect loss 

In [None]:
inspect_metric = esp.metrics.center(torch.nn.L1Loss()) # use mean-squared error loss

In [None]:
loss_tr = []
loss_vl = []

In [None]:
with torch.no_grad():
    for idx_epoch in range(n_epochs):
        espaloma_model.load_state_dict(
            torch.load("%s.th" % idx_epoch)
        )

        # training set performance
        u = []
        u_ref = []
        for g in ds_tr:
            if torch.cuda.is_available():
                g.heterograph = g.heterograph.to("cuda:0")
            espaloma_model(g.heterograph)
            u.append(g.nodes['g'].data['u'])
            u_ref.append(g.nodes['g'])
        u = torch.cat(u, dim=0)
        u_ref = torch.cat(u_ref, dim=0)
        loss_tr.append(inspect_metric(u, u_ref))


        # validation set performance
        u = []
        u_ref = []
        for g in ds_vl:
            if torch.cuda.is_available():
                g.heterograph = g.heterograph.to("cuda:0")
            espaloma_model(g.heterograph)
            u.append(g.nodes['g'].data['u'])
            u_ref.append(g.nodes['g'])
        u = torch.cat(u, dim=0)
        u_ref = torch.cat(u_ref, dim=0)
        loss_vl.append(inspect_metric(u, u_ref))

In [None]:
loss_tr = np.array(loss_tr) * 627.5
loss_vl = np.array(loss_vl) * 627.5

In [None]:
from matplotlib import pyplot as plt

plt.plot(loss_tr, label="train")
plt.plot(loss_vl, label="valid")
plt.yscale("log")
plt.legend()
plt.show()