# Example training script

This tutorial shows a brief example of how the networks were trained. The specific example below is for $\Delta$-learning of formation energy in a single-task setting. We'll start with the imports:

In [6]:
import os
import h5py
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from delfta.net import EGNN
from delfta.net_utils import MODEL_HPARAMS
from delfta.utils import DATA_PATH, ROOT_PATH
from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import add_self_loops
from torch_geometric.utils.undirected import to_undirected
from tqdm import tqdm

Next, we'll download the training data if that hasn't happened yet (not done by default during the setup of `delfta`, since the original training files aren't needed to run the trained models). 

In [7]:
if not os.path.exists(os.path.join(DATA_PATH, "qmugs", "qmugs_conf00.h5")): 
    from delfta.download import DATASET_REMOTE, download
    import tarfile 
    
    download(DATASET_REMOTE, os.path.join(DATA_PATH, "qmugs.tar.gz"))
    with tarfile.open(os.path.join(DATA_PATH, "qmugs.tar.gz")) as handle:
        handle.extractall(DATA_PATH)

We'll define a new dataset class which is similar to the one in `delfta.net_utils`, but it doesn't load all the molecules in memory - which is better for training with a large number of molecules.

In [8]:
class DatasetSingletaskh5(Dataset):
    def __init__(
        self, txtfile, prop,
    ):

        # read txt
        with open(txtfile, "r") as f:
            chembls = [elem.rstrip("\n") for elem in f.readlines()]

        # create dict on the fly: idx -> chembl
        nums = list(range(0, len(chembls)))
        self.idx2chembl = {}
        for x in range(len(chembls)):
            dict = {nums[x]: chembls[x]}
            self.idx2chembl.update(dict)

        # read h5
        self.h5f0 = h5py.File(os.path.join(DATA_PATH, "qmugs", "qmugs_conf00.h5"), "r")
        self.h5f1 = h5py.File(os.path.join(DATA_PATH, "qmugs", "qmugs_conf01.h5"), "r")
        self.h5f2 = h5py.File(os.path.join(DATA_PATH, "qmugs", "qmugs_conf02.h5"), "r")

        # define property of interest
        self.prop = prop

    def __getitem__(self, idx):

        chembl_id = self.idx2chembl[idx]

        #### nodes coordinates and target
        if "conf_00" in chembl_id:
            atomids = torch.LongTensor(self.h5f0[str(chembl_id)]["atomids"])
            coords = torch.FloatTensor(self.h5f0[str(chembl_id)]["coords"])
            target = torch.FloatTensor(self.h5f0[str(chembl_id)][self.prop])
        elif "conf_01" in chembl_id:
            atomids = torch.LongTensor(self.h5f1[str(chembl_id)]["atomids"])
            coords = torch.FloatTensor(self.h5f1[str(chembl_id)]["coords"])
            target = torch.FloatTensor(self.h5f1[str(chembl_id)][self.prop])
        elif "conf_02" in chembl_id:
            atomids = torch.LongTensor(self.h5f2[str(chembl_id)]["atomids"])
            coords = torch.FloatTensor(self.h5f2[str(chembl_id)]["coords"])
            target = torch.FloatTensor(self.h5f2[str(chembl_id)][self.prop])

        #### edges
        edge_index = np.array(nx.complete_graph(atomids.size(0)).edges())
        edge_index = to_undirected(torch.from_numpy(edge_index).t().contiguous())
        edge_index, _ = add_self_loops(edge_index, num_nodes=coords.shape[0])

        #### graph object
        graph_data = Data(
            atomids=atomids,
            coords=coords,
            edge_index=edge_index,
            target=target,
            num_nodes=atomids.size(0),
        )

        return graph_data

    def __len__(self):
        return len(self.idx2chembl)



Now we'll define the training and evaluation loops:

In [9]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mae_loss = lambda x, y: F.l1_loss(x, y).item()


def train_loop(model, loader, optimizer, criterion):
    model.train()
    training_loss = []

    for g_batch in tqdm(loader, total=len(loader)):

        optimizer.zero_grad()
        g_batch = g_batch.to(DEVICE)
        target = g_batch.target

        prediction = model(g_batch).squeeze(1)

        loss = criterion(prediction, target)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            mae = mae_loss(prediction, target)
            training_loss.append(mae)

    return np.mean(training_loss), np.std(training_loss)


def eval_loop(model, loader):
    model.eval()

    maes = []

    with torch.no_grad():
        for g_batch in tqdm(loader, total=len(loader)):
            g_batch = g_batch.to(DEVICE)
            prediction = model(g_batch).squeeze(1)
            maes.append(mae_loss(prediction, g_batch.target))

    eval_mae = sum(maes) / len(maes)
    return eval_mae

And finally we can start training the model:  

In [10]:
save_path = os.path.join(ROOT_PATH, "tutorials", "training_evaluation")
os.makedirs(save_path, exist_ok=True)
train_path = os.path.join(DATA_PATH, "qmugs", "train_example.txt") # this includes only 50 molecules so that in this tutorial, we quickly see results
eval_path = os.path.join(DATA_PATH, "qmugs", "eval_example.txt")  
prop = "DELTA_ENERGY"
model_param = MODEL_HPARAMS["single_energy_delta"]

train_data = DatasetSingletaskh5(txtfile=train_path, prop=prop)
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=0)

validation_data = DatasetSingletaskh5(txtfile=eval_path, prop=prop)
validation_loader = DataLoader(
    validation_data, batch_size=2, shuffle=False, num_workers=0
)

# Load model
model = EGNN(
    n_outputs=model_param.n_outputs,
    global_prop=model_param.global_prop,
    n_kernels=model_param.n_kernels,
    mlp_dim=model_param.mlp_dim,
)
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-10)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.7, patience=20, verbose=True
)

train_m_losses, train_std_losses = [], []
val_losses = []
val_maes = []
epoch_maes = []
min_mae = 1e10
EPOCHS = 100

for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}...", flush=True)

    m_loss, std_loss = train_loop(model, train_loader, optimizer, nn.MSELoss())
    train_m_losses.append(m_loss)
    train_std_losses.append(std_loss)

    eval_mae = eval_loop(model, validation_loader)
    val_losses.append(eval_mae)
    scheduler.step(eval_mae)

    if eval_mae < min_mae:

        min_mae = eval_mae
        val_maes.append(eval_mae)
        epoch_maes.append(epoch)
        print(f"New min eval_mae in epoch {epoch}: {eval_mae:.6f}", flush=True)
        torch.save(
            model.state_dict(), os.path.join(save_path, "model.pt"),
        )
        torch.save(
            [train_m_losses, train_std_losses, val_losses],
            os.path.join(save_path, "loss_train_eval.pt"),
        )


Epoch 1/100...


100%|██████████| 25/25 [00:46<00:00,  1.88s/it]
100%|██████████| 25/25 [00:13<00:00,  1.84it/s]

New min eval_mae in epoch 0: 0.931423
Epoch 2/100...



 56%|█████▌    | 14/25 [00:20<00:12,  1.10s/it]