In [1]:
from data import MultiMolGraphDataset, EquiDataset
from torch_geometric.loader import DataLoader
import torch
import torchmetrics
from torchmetrics import MeanAbsoluteError
import random
import pandas as pd
import numpy as np
from siamesepairwise import SiameseDimeNet
from torch.optim.lr_scheduler import CosineAnnealingLR
from schedulers import CosineRestartsDecay
import os
import argparse
import torch.nn.functional as F
from torch_geometric.data import Data
from tqdm import tqdm  # for progress bars
from loss_utils import cosine_angle_loss, AngularErrorMetric


device = 'cpu'

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [3]:
equi_data = EquiDataset(
    root='data/equi',
    geoms_csv='/home/calvin/code/GINE/rxn_geometries.csv',
    sdf_folder='/home/calvin/code/chemprop_phd_customised/habnet/data/processed/sdf_data' ,
    target_csv='/home/calvin/code/chemprop_phd_customised/habnet/data/processed/target_data/target_data_sin_cos.csv',
    target_columns=['psi_1_dihedral_sin', 'psi_1_dihedral_cos'],
    transform=None,
    pre_transform=None,
    pre_filter=None,
    force_reload=True
)

Processing...


  → Built 1696 EquiData examples
Saved processed data to data/equi/processed/equi_single_3f0c9001f7b2f62834a490d3aa840439d43bb5f5.pt


Done!


  → Built 1696 EquiData examples
Saved processed data to data/equi/processed/equi_single_3f0c9001f7b2f62834a490d3aa840439d43bb5f5.pt


In [4]:
equi_data[0].z

tensor([66, 68,  6,  6,  7,  1,  1,  7, 56, 56,  6,  7, 71,  1])

In [5]:
from dimenet import DimeNetPPEncoder, FlaggedDimeNetPPEncoder

# encoder = FlaggedDimeNetPPEncoder(
#     hidden_channels = 256,
#     out_channels    = 256,
#     num_blocks      = 8,
#     dropout         = 0.1,
# ).to(device)
# encoder = encoder.to(device)
encoder = DimeNetPPEncoder(
    hidden_channels = 512,
    out_channels    = 512,
    num_blocks      = 8,
    dropout         = 0.1,     # your original DimeNet wrapper handles dropout
).to(device)


In [6]:
from torch_geometric.loader import DataLoader
# Random split using torch
train_size = int(0.8 * len(equi_data))
valid_size = int(0.1 * len(equi_data))
test_size = len(equi_data) - train_size - valid_size
# Shuffle the dataset
train_data, valid_data, test_data = torch.utils.data.random_split(equi_data, [train_size, valid_size, test_size])
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=8, shuffle=False)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)


In [7]:
def train_epoch(model, loader, optimizer, loss_fn, metric_fn, scaler):
    model.train()
    total_loss = 0.0
    total_err  = 0.0
    n_samples  = 0

    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()

        # forward + loss
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            out   = model(batch)              # [B,2]

            loss  = loss_fn(out, batch.y)     # scalar
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # loss.backward()
        # optimizer.step()

        # metric: mean abs angle error (degrees)
        err = metric_fn(out, batch.y)     # scalar tensor

        bsize = batch.y.size(0)
        total_loss += loss.item() * bsize
        total_err  += err.item()  * bsize
        n_samples  += bsize

    return total_loss / n_samples, total_err / n_samples

def eval_epoch(model, loader, loss_fn, metric_fn):
    model.eval()
    total_loss = 0.0
    total_err  = 0.0
    n_samples  = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation", leave=False):
            batch = batch.to(device)
            out   = model(batch)
            loss  = loss_fn(out, batch.y)
            err   = metric_fn(out, batch.y)

            bsize = batch.y.size(0)
            total_loss += loss.item() * bsize
            total_err  += err.item()  * bsize
            n_samples  += bsize

    return total_loss / n_samples, total_err / n_samples


In [8]:
from siamesepairwise import DimeNet

model = DimeNet(
    encoder=encoder,
    dropout=0.2,
    head_hidden_dims=[128, 128],
)
model = model.to(device)
num_epochs  = 200
learning_rate = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer, mode='min', factor=0.5, patience=5, verbose=True
# )
#scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
scheduler = CosineRestartsDecay(
    optimizer,
    T_0     = 20,
    T_mult  = 2,
    eta_min = 1e-4,
    decay   = 0.3) 


import torch.nn.functional as F

best_val_loss = float('inf')

loss_fn = cosine_angle_loss
metric_fn   = AngularErrorMetric(in_degrees=True)

In [None]:
for epoch in range(1, num_epochs + 1):
    scaler = torch.amp.GradScaler()
    tr_loss, tr_err = train_epoch(model, train_loader, optimizer, loss_fn, metric_fn, scaler)
    va_loss, va_err = eval_epoch(model, valid_loader, loss_fn, metric_fn)

    # Step the scheduler once per epoch
    # If your scheduler wants the epoch number, do: scheduler.step(epoch)
    # Otherwise just:
    scheduler.step()

    lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch:02d} | "
          f"Train L={tr_loss:.4f}, Err={tr_err:.1f}° | "
          f"Val   L={va_loss:.4f}, Err={va_err:.1f}° | "
          f"LR={lr:.2e}")

    # (Optional) save best‐model checkpoint
    if va_loss < best_val_loss:
        best_val_loss = va_loss
        torch.save(model.state_dict(), 'best_dimenet_model.pt')
        print(" ↳ New best model saved!")

                                                           

Epoch 01 | Train L=0.7533, Err=71.7° | Val   L=0.7974, Err=75.3° | LR=1.00e-04
 ↳ New best model saved!


                                                           

Epoch 02 | Train L=0.7261, Err=69.5° | Val   L=0.7952, Err=75.0° | LR=1.00e-04
 ↳ New best model saved!


                                                           

Epoch 03 | Train L=0.7284, Err=69.8° | Val   L=0.7962, Err=75.1° | LR=1.00e-04


                                                           

Epoch 04 | Train L=0.7263, Err=69.5° | Val   L=0.7972, Err=75.2° | LR=1.00e-04


                                                           

Epoch 05 | Train L=0.7307, Err=69.8° | Val   L=0.8000, Err=75.7° | LR=1.00e-04


                                                           

Epoch 06 | Train L=0.7289, Err=69.7° | Val   L=0.7978, Err=75.3° | LR=1.00e-04


                                                           

Epoch 07 | Train L=0.7296, Err=69.7° | Val   L=0.7966, Err=75.1° | LR=1.00e-04


                                                           

Epoch 08 | Train L=0.7292, Err=69.6° | Val   L=0.7976, Err=75.3° | LR=1.00e-04


                                                           

Epoch 09 | Train L=0.7267, Err=69.5° | Val   L=0.7985, Err=75.4° | LR=1.00e-04


Training:  32%|███▏      | 55/170 [03:15<07:19,  3.82s/it]

In [None]:
# make sure predictions really lie on the circle
with torch.no_grad():
    out = model(next(iter(train_loader)).to(device))
    print(out.norm(dim=1).mean())  # should be ≈1.0


tensor(1.)


In [None]:
for batch in train_loader:
    batch = batch.to(device)
    print(batch.y)
    break

tensor([[-2.0573e-01,  9.7861e-01],
        [ 2.9312e-01,  9.5607e-01],
        [ 8.4652e-01, -5.3236e-01],
        [-3.8566e-02,  9.9926e-01],
        [ 8.7958e-01,  4.7576e-01],
        [ 9.9661e-02,  9.9502e-01],
        [-5.3904e-04,  1.0000e+00],
        [-8.3266e-01,  5.5379e-01],
        [ 9.6501e-01,  2.6222e-01],
        [ 8.8336e-01,  4.6869e-01],
        [ 4.8802e-01,  8.7283e-01],
        [-5.8331e-04,  1.0000e+00],
        [-6.3842e-04,  1.0000e+00],
        [ 8.0287e-01,  5.9615e-01],
        [-5.3548e-01, -8.4455e-01],
        [ 9.8448e-01, -1.7550e-01],
        [-9.4104e-01,  3.3829e-01],
        [ 3.3494e-01, -9.4224e-01],
        [-4.1226e-01,  9.1107e-01],
        [ 9.9049e-01,  1.3759e-01],
        [-1.9703e-01,  9.8040e-01],
        [-6.0893e-01,  7.9322e-01],
        [-4.1957e-01,  9.0772e-01],
        [-5.3910e-01,  8.4224e-01],
        [ 7.4858e-01,  6.6305e-01],
        [-8.5713e-01, -5.1510e-01],
        [-1.5980e-01,  9.8715e-01],
        [ 9.9824e-01, -5.926

In [None]:
torch.cuda.empty_cache()
print("CUDA memory cleared")
# Print CUDA memory summary
