In [None]:
from data import MultiMolGraphDataset
from torch_geometric.loader import DataLoader
import torch
import torchmetrics
from torchmetrics import MeanAbsoluteError
import random
import pandas as pd
import numpy as np
from siamesepairwise import SiameseDimeNet
from torch.optim.lr_scheduler import CosineAnnealingLR
from schedulers import CosineRestartsDecay
import os
import argparse
import torch.nn.functional as F
from torch_geometric.data import Data
from tqdm import tqdm  # for progress bars
from loss_utils import cosine_angle_loss, AngularErrorMetric


device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [2]:
sdf_path = '/home/calvin/code/chemprop_phd_customised/habnet/data/processed/sdf_data'
target_data = '/home/calvin/code/chemprop_phd_customised/habnet/data/processed/target_data/target_data.csv'

In [None]:
# Read in target_data with pandas and then select 'psi_1_dihedral' column and convert the dihedral angles to sin and cos 

# 1 & 2. Load your data
target_df = pd.read_csv(target_data)

# Drop -10 rows in the dihedral angle column
target_df = target_df[target_df['psi_1_dihedral'] != -10]

# 3. Extract the dihedral angles
angles = target_df['psi_1_dihedral']

# 3.1 Remove any -10 rows in the dihedral angle column
angles = angles[angles != -10]
# 3.2 Remove any NaN values
angles = angles.dropna()

# 4. If angles are in degrees, convert to radians
angles_rad = np.deg2rad(angles)

# 5. Compute sin & cos and assign
target_df['psi_1_dihedral_sin'] = np.sin(angles_rad)
target_df['psi_1_dihedral_cos'] = np.cos(angles_rad)

# Optional: inspect
print(target_df[['psi_1_dihedral', 'psi_1_dihedral_sin', 'psi_1_dihedral_cos']].head())



# 6. Save the modified DataFrame back to CSV
target_data = '/home/calvin/code/chemprop_phd_customised/habnet/data/processed/target_data/target_data_sin_cos.csv'
target_df.to_csv(target_data, index=False)

   psi_1_dihedral  psi_1_dihedral_sin  psi_1_dihedral_cos
0      313.550300           -0.724770            0.688991
2      292.580884           -0.923338            0.383987
3      233.012473           -0.798766           -0.601641
4       34.468778            0.565957            0.824435
5      202.755671           -0.386802           -0.922163


In [4]:
from data import EquiMultiMolDataset
import os
# 1) Prepare a root folder for PyG to use
root = 'data/equi_dataset'
os.makedirs(root, exist_ok=True)
# 2) Point to your SDF files folder and the CSV of targets
sdf_folder     = '/home/calvin/code/chemprop_phd_customised/habnet/data/processed/sdf_data'      # ← replace with your actual path
target_csv     = '/home/calvin/code/chemprop_phd_customised/habnet/data/processed/target_data/target_data_sin_cos.csv'
input_types    = ['r1h', 'r2h']       # ← the values of the SDF “type” property you want
target_columns = ['psi_1_dihedral_sin', 'psi_1_dihedral_cos']    # ← your CSV columns for sin & cos

# 3) Instantiate the dataset (this will auto-run process() if needed)
equi_dataset = EquiMultiMolDataset(
    root=root,
    sdf_folder=sdf_folder,
    target_csv=target_csv,
    input_type=input_types,
    target_columns=target_columns,
    keep_hs=True,
    sanitize=False,
    force_reload=True
)
                                   

Processing 1810 SDF files from ['/home/calvin/code/chemprop_phd_customised/habnet/data/processed/sdf_data']
RXN IDs available in CSV: ['rmg_rxn_86', 'rxn_618', 'rxn_1833', 'rmg_rxn_182', 'rxn_1240', 'rmg_rxn_157', 'rxn_563', 'rmg_rxn_521', 'kfir_rxn_13503', 'rxn_1423']...
  • kfir_rxn_10218.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_10391.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_11197.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_11663.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_11797.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_11915.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_12959.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_13108.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_1315.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_13503.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_13651.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_13712.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_13897.sdf → found types ['r1h', 'r2h']
  • kfir_rxn_2.sdf → found types ['r1h', 'r2h']
  • 

Processing...


  • rmg_rxn_15515.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15516.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15517.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15521.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15524.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15526.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15531.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15538.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15539.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15540.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15541.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15545.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15565.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15566.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15567.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15575.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15576.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15577.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15579.sdf → found types ['r1h', 'r2h']
  • rmg_rxn_15580.sdf → found t

Done!


In [5]:
equi_dataset

EquiMultiMolDataset(1696)

In [6]:
from dimenet import DimeNetPPEncoder

encoder = DimeNetPPEncoder(hidden_channels=512,
                           dropout=0.2, num_blocks=6,
                           num_spherical=8, num_radial=7,
                           cutoff=6.0)
encoder = encoder.to(device)

In [7]:
# import dataloader from pytorch - not torch_geometric
from torch.utils.data import DataLoader
from data import siamese_collate
train_loader = DataLoader(equi_dataset, batch_size=1, shuffle=True, collate_fn=siamese_collate)

In [8]:
from torch_geometric.loader import DataLoader

# We need to split randomly into train/valid/test sets
from torch.utils.data import random_split

train_size = int(0.8 * len(equi_dataset))
valid_size = int(0.1 * len(equi_dataset))
test_size = len(equi_dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(equi_dataset, [train_size, valid_size, test_size])
# Create DataLoaders for each dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=32,
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)


In [None]:

def train_epoch(model, loader, optimizer, loss_fn, metric_fn):
    model.train()
    total_loss = 0.0
    total_err  = 0.0
    n_samples  = 0

    for batch in tqdm(loader, desc="Training", leave=False):
        batch = batch.to(device)
        optimizer.zero_grad()

        # forward + loss
        out   = model(batch)              # [B,2]
        loss  = loss_fn(out, batch.y)     # scalar
        loss.backward()
        optimizer.step()

        # metric: mean abs angle error (degrees)
        err = metric_fn(out, batch.y)     # scalar tensor

        bsize = batch.y.size(0)
        total_loss += loss.item() * bsize
        total_err  += err.item()  * bsize
        n_samples  += bsize

    return total_loss / n_samples, total_err / n_samples

def eval_epoch(model, loader, loss_fn, metric_fn):
    model.eval()
    total_loss = 0.0
    total_err  = 0.0
    n_samples  = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation", leave=False):
            batch = batch.to(device)
            out   = model(batch)
            loss  = loss_fn(out, batch.y)
            err   = metric_fn(out, batch.y)

            bsize = batch.y.size(0)
            total_loss += loss.item() * bsize
            total_err  += err.item()  * bsize
            n_samples  += bsize

    return total_loss / n_samples, total_err / n_samples


In [None]:
import argparse
import json
import os
from datetime import datetime
import yaml

def save_experiment(config: dict,
                    metrics: dict,
                    out_dir: str = "experiments"):
    """
    Dump one JSON per run, containing both hyper-parameters and
    your final validation metrics.
    """
    os.makedirs(out_dir, exist_ok=True)
    # timestamped filename
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    fname = os.path.join(out_dir, f"run_{ts}.json")
    payload = {
        "config":  config,
        "metrics": metrics,
    }
    with open(fname, "w") as f:
        json.dump(payload, f, indent=2)
    print(f"👉 Saved experiment to {fname}")

# --- when you start training, assemble your hyper‐params: ---
config = {
    "encoder_hidden":    256,
    "num_blocks":        4,
    "num_spherical":     7,
    "num_radial":        6,
    "cutoff":            5.0,
    "fusion":            "diff-prod",   # or "cat"
    "dropout":           0.1,
    "batch_size":        32,
    "lr":                1e-4,
    "weight_decay":      1e-5,
    "scheduler":         "CosineRestartsDecay(T0=20,decay=0.6)",
    "loss":              "cosine_angle_loss",
    "epochs":            num_epochs
}

# ... inside your training loop you track best_val_error (degrees) etc ...
metrics = {
    "best_val_loss":   best_val_loss,
    "best_val_error":  best_val_err,   # e.g. AngularErrorMetric in ° 
    "final_val_loss":  va_loss,
    "final_val_error": va_err,
}

# at the end of training:
save_experiment(config, metrics)

# ─── parse config ────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True,
                    help="Path to YAML config file")
args = parser.parse_args()

with open(args.config, "r") as f:
    cfg = yaml.safe_load(f)

# ─── make experiment directory ────────────────────────────────────────
ts = datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"dimenetpp_{ts}"
out_dir = os.path.join(cfg["logging"]["output_root"], exp_name)
os.makedirs(out_dir, exist_ok=True)


with open(os.path.join(out_dir, "used_config.yaml"), "w") as f:
    yaml.safe_dump(cfg, f)

In [None]:
# Set Seed
seed = cfg["training"]["seed"]
seed_everything(seed)

# Build the model
enc_cfg = cfg["model"]
encoder = DimeNetPPEncoder(
    hidden_channels=enc_cfg["hidden_channels"],
    num_blocks=enc_cfg["num_blocks"],
    num_spherical=enc_cfg["num_spherical"],
    num_radial=enc_cfg["num_radial"],
    cutoff=enc_cfg["cutoff"],
    dropout=enc_cfg["dropout"],
).to(device)

model = SiameseDimeNet(
    encoder=encoder,
    fusion=cfg["model"]["fusion"],
    dropout=cfg["model"]["dropout"],
).to(device)

# Build the optimizer & Scheduler
tr_cfg = cfg["training"]
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=tr_cfg["lr"],
    weight_decay=tr_cfg["weight_decay"],
)
sch_cfg = tr_cfg["scheduler"]
if sch_cfg["name"] == "CosineRestartsDecay":
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=sch_cfg["T_0"],
        T_mult=sch_cfg["T_mult"],
        eta_min=sch_cfg["eta_min"],
    )
else:
    raise ValueError(f"Unknown scheduler: {sch_cfg['name']}")
# ─── build the loss function ──────────────────────────────────────────
loss_fn = cosine_angle_loss
metric_fn = AngularErrorMetric(
    in_degrees=True,)
# Build data loaders

## Randomly split the dataset into train/valid/test sets
train_size = int(0.8 * len(equi_dataset))
valid_size = int(0.1 * len(equi_dataset))
test_size = len(equi_dataset) - train_size - valid_size
train_dataset, valid_dataset, test_dataset = random_split(equi_dataset, [train_size, valid_size, test_size])
# Create DataLoaders for each dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg["training"]["batch_size"],
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=cfg["training"]["batch_size"],
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)
test_loader = DataLoader(
    test_dataset,
    batch_size=cfg["training"]["batch_size"],
    shuffle=True,
    follow_batch=['z_s', 'z_t']
)





EquiDataBatch(y=[16, 2], z_s=[128], z_s_batch=[128], z_s_ptr=[17], pos_s=[128, 3], z_t=[165], z_t_batch=[165], z_t_ptr=[17], pos_t=[165, 3], id=[16])


In [10]:

model = SiameseDimeNet(
    encoder=encoder)
model = model.to(device)
num_epochs  = 200
learning_rate = 1e-4
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer, mode='min', factor=0.5, patience=5, verbose=True
# )
#scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
scheduler = CosineRestartsDecay(
    optimizer,
    T_0     = 20,
    T_mult  = 2,
    eta_min = 1e-4,
    decay   = 0.3) 


import torch.nn.functional as F

best_val_loss = float('inf')

loss_fn = cosine_angle_loss
metric_fn   = AngularErrorMetric(in_degrees=True)

# 3) Training loop

for epoch in range(1, num_epochs + 1):
    tr_loss, tr_err = train_epoch(model, train_loader, optimizer, loss_fn, metric_fn)
    va_loss, va_err = eval_epoch(model, valid_loader, loss_fn, metric_fn)

    # Step the scheduler once per epoch
    # If your scheduler wants the epoch number, do: scheduler.step(epoch)
    # Otherwise just:
    scheduler.step()

    lr = scheduler.get_last_lr()[0]
    print(f"Epoch {epoch:02d} | "
          f"Train L={tr_loss:.4f}, Err={tr_err:.1f}° | "
          f"Val   L={va_loss:.4f}, Err={va_err:.1f}° | "
          f"LR={lr:.2e}")

    # (Optional) save best‐model checkpoint
    if va_loss < best_val_loss:
        best_val_loss = va_loss
        torch.save(model.state_dict(), 'best_dimenet_model.pt')
        print(" ↳ New best model saved!")

  return F.linear(input, self.weight, self.bias)
                                                         

Epoch 01 | Train L=0.7585, Err=71.9° | Val   L=0.7352, Err=70.6° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 02 | Train L=0.7192, Err=69.3° | Val   L=0.7241, Err=69.7° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 03 | Train L=0.7097, Err=68.6° | Val   L=0.7324, Err=69.9° | LR=1.00e-04


                                                         

Epoch 04 | Train L=0.7068, Err=68.4° | Val   L=0.7066, Err=67.8° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 05 | Train L=0.7009, Err=67.9° | Val   L=0.7268, Err=69.5° | LR=1.00e-04


                                                         

Epoch 06 | Train L=0.6901, Err=67.0° | Val   L=0.6876, Err=66.5° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 07 | Train L=0.6752, Err=65.8° | Val   L=0.7240, Err=69.1° | LR=1.00e-04


                                                         

Epoch 08 | Train L=0.6726, Err=65.9° | Val   L=0.7402, Err=70.3° | LR=1.00e-04


                                                         

Epoch 09 | Train L=0.6641, Err=65.3° | Val   L=0.6869, Err=65.8° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 10 | Train L=0.6536, Err=64.4° | Val   L=0.6819, Err=65.8° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 11 | Train L=0.6523, Err=64.2° | Val   L=0.7021, Err=66.9° | LR=1.00e-04


                                                         

Epoch 12 | Train L=0.6404, Err=63.3° | Val   L=0.6821, Err=65.5° | LR=1.00e-04


                                                         

Epoch 13 | Train L=0.6594, Err=64.6° | Val   L=0.6966, Err=67.2° | LR=1.00e-04


                                                         

Epoch 14 | Train L=0.6452, Err=64.1° | Val   L=0.6809, Err=65.7° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 15 | Train L=0.6280, Err=62.4° | Val   L=0.6729, Err=65.1° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 16 | Train L=0.6233, Err=62.1° | Val   L=0.6536, Err=63.9° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 17 | Train L=0.6078, Err=61.3° | Val   L=0.6046, Err=60.7° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 18 | Train L=0.5964, Err=60.2° | Val   L=0.6320, Err=62.7° | LR=1.00e-04


                                                         

Epoch 19 | Train L=0.5988, Err=60.4° | Val   L=0.5889, Err=59.1° | LR=1.00e-04
 ↳ New best model saved!


                                                         

Epoch 20 | Train L=0.5899, Err=59.9° | Val   L=0.6219, Err=61.8° | LR=1.00e-04


                                                         

Epoch 21 | Train L=0.5852, Err=59.5° | Val   L=0.6164, Err=60.5° | LR=3.01e-05


                                                         

Epoch 22 | Train L=0.5610, Err=57.9° | Val   L=0.5401, Err=55.6° | LR=3.04e-05
 ↳ New best model saved!


                                                         

Epoch 23 | Train L=0.5504, Err=57.1° | Val   L=0.5259, Err=54.4° | LR=3.10e-05
 ↳ New best model saved!


                                                         

Epoch 24 | Train L=0.5482, Err=56.9° | Val   L=0.5255, Err=54.5° | LR=3.17e-05
 ↳ New best model saved!


                                                         

Epoch 25 | Train L=0.5328, Err=55.5° | Val   L=0.5354, Err=55.5° | LR=3.27e-05


                                                         

Epoch 26 | Train L=0.5366, Err=55.9° | Val   L=0.5288, Err=54.7° | LR=3.38e-05


                                                         

Epoch 27 | Train L=0.5277, Err=55.1° | Val   L=0.5464, Err=55.7° | LR=3.52e-05


                                                         

Epoch 28 | Train L=0.5400, Err=56.4° | Val   L=0.5291, Err=54.5° | LR=3.67e-05


                                                         

Epoch 29 | Train L=0.5175, Err=54.6° | Val   L=0.5373, Err=55.3° | LR=3.84e-05


                                                         

Epoch 30 | Train L=0.5169, Err=54.6° | Val   L=0.5393, Err=55.8° | LR=4.03e-05


                                                         

Epoch 31 | Train L=0.5181, Err=54.4° | Val   L=0.5384, Err=55.7° | LR=4.23e-05


                                                         

Epoch 32 | Train L=0.5129, Err=54.2° | Val   L=0.5445, Err=56.2° | LR=4.44e-05


                                                         

Epoch 33 | Train L=0.5070, Err=53.7° | Val   L=0.5341, Err=55.1° | LR=4.67e-05


                                                         

Epoch 34 | Train L=0.5042, Err=53.2° | Val   L=0.5134, Err=53.8° | LR=4.91e-05
 ↳ New best model saved!


                                                         

Epoch 35 | Train L=0.5020, Err=53.4° | Val   L=0.5329, Err=55.2° | LR=5.16e-05


                                                         

Epoch 36 | Train L=0.5051, Err=53.5° | Val   L=0.5468, Err=56.0° | LR=5.42e-05


                                                         

Epoch 37 | Train L=0.4901, Err=52.2° | Val   L=0.5397, Err=56.0° | LR=5.68e-05


                                                         

Epoch 38 | Train L=0.4924, Err=52.5° | Val   L=0.5257, Err=55.1° | LR=5.95e-05


                                                         

Epoch 39 | Train L=0.4978, Err=52.8° | Val   L=0.5264, Err=54.6° | LR=6.23e-05


                                                         

Epoch 40 | Train L=0.4875, Err=52.2° | Val   L=0.4981, Err=52.4° | LR=6.50e-05
 ↳ New best model saved!


                                                         

Epoch 41 | Train L=0.4969, Err=52.8° | Val   L=0.5268, Err=55.5° | LR=6.77e-05


                                                         

Epoch 42 | Train L=0.4802, Err=51.5° | Val   L=0.4958, Err=51.7° | LR=7.05e-05
 ↳ New best model saved!


                                                         

Epoch 43 | Train L=0.4848, Err=51.9° | Val   L=0.5233, Err=54.9° | LR=7.32e-05


                                                         

Epoch 44 | Train L=0.4944, Err=52.6° | Val   L=0.5155, Err=54.8° | LR=7.58e-05


                                                         

Epoch 45 | Train L=0.4699, Err=50.8° | Val   L=0.5202, Err=54.4° | LR=7.84e-05


                                                         

Epoch 46 | Train L=0.4781, Err=51.4° | Val   L=0.5242, Err=55.6° | LR=8.09e-05


                                                         

Epoch 47 | Train L=0.4847, Err=51.7° | Val   L=0.4915, Err=52.5° | LR=8.33e-05
 ↳ New best model saved!


                                                         

Epoch 48 | Train L=0.5387, Err=56.0° | Val   L=0.5322, Err=54.9° | LR=8.56e-05


                                                         

Epoch 49 | Train L=0.5263, Err=55.2° | Val   L=0.5298, Err=55.1° | LR=8.77e-05


                                                         

Epoch 50 | Train L=0.4941, Err=52.4° | Val   L=0.5153, Err=54.5° | LR=8.97e-05


                                                         

Epoch 51 | Train L=0.4868, Err=52.0° | Val   L=0.5115, Err=54.1° | LR=9.16e-05


                                                         

Epoch 52 | Train L=0.4733, Err=51.2° | Val   L=0.5131, Err=54.4° | LR=9.33e-05


                                                         

Epoch 53 | Train L=0.4795, Err=51.5° | Val   L=0.5395, Err=56.4° | LR=9.48e-05


                                                         

Epoch 54 | Train L=0.4849, Err=51.9° | Val   L=0.4952, Err=53.4° | LR=9.62e-05


                                                         

Epoch 55 | Train L=0.4699, Err=50.8° | Val   L=0.5032, Err=53.2° | LR=9.73e-05


                                                         

Epoch 56 | Train L=0.4866, Err=52.0° | Val   L=0.5038, Err=53.7° | LR=9.83e-05


                                                         

Epoch 57 | Train L=0.4652, Err=50.2° | Val   L=0.4976, Err=52.6° | LR=9.90e-05


                                                         

Epoch 58 | Train L=0.4816, Err=51.7° | Val   L=0.5315, Err=55.2° | LR=9.96e-05


                                                         

Epoch 59 | Train L=0.4651, Err=50.0° | Val   L=0.4950, Err=52.4° | LR=9.99e-05


                                                         

Epoch 60 | Train L=0.4357, Err=47.7° | Val   L=0.4708, Err=51.4° | LR=3.00e-05
 ↳ New best model saved!


                                                         

Epoch 61 | Train L=0.4250, Err=47.1° | Val   L=0.4884, Err=52.5° | LR=9.04e-06


                                                         

Epoch 62 | Train L=0.4131, Err=46.2° | Val   L=0.4806, Err=51.9° | LR=9.14e-06


                                                         

Epoch 63 | Train L=0.4070, Err=45.7° | Val   L=0.4779, Err=51.5° | LR=9.32e-06


                                                         

Epoch 64 | Train L=0.3983, Err=45.0° | Val   L=0.4723, Err=51.2° | LR=9.56e-06


                                                         

Epoch 65 | Train L=0.4019, Err=45.2° | Val   L=0.4682, Err=50.9° | LR=9.87e-06
 ↳ New best model saved!


                                                         

Epoch 66 | Train L=0.3946, Err=44.5° | Val   L=0.4747, Err=51.4° | LR=1.03e-05


                                                         

Epoch 67 | Train L=0.4007, Err=45.2° | Val   L=0.4721, Err=51.4° | LR=1.07e-05


                                                         

Epoch 68 | Train L=0.3966, Err=44.9° | Val   L=0.4788, Err=51.8° | LR=1.12e-05


                                                         

Epoch 69 | Train L=0.3899, Err=44.3° | Val   L=0.4977, Err=53.2° | LR=1.18e-05


                                                         

Epoch 70 | Train L=0.3908, Err=44.7° | Val   L=0.4989, Err=53.5° | LR=1.25e-05


                                                         

Epoch 71 | Train L=0.3840, Err=43.9° | Val   L=0.4892, Err=52.3° | LR=1.32e-05


                                                         

Epoch 72 | Train L=0.3851, Err=44.0° | Val   L=0.4873, Err=52.4° | LR=1.40e-05


                                                         

Epoch 73 | Train L=0.3846, Err=43.7° | Val   L=0.4950, Err=52.8° | LR=1.48e-05


                                                         

Epoch 74 | Train L=0.3852, Err=43.9° | Val   L=0.5062, Err=53.7° | LR=1.57e-05


                                                         

Epoch 75 | Train L=0.3863, Err=44.1° | Val   L=0.5020, Err=53.4° | LR=1.67e-05


                                                         

Epoch 76 | Train L=0.3830, Err=43.7° | Val   L=0.4922, Err=52.7° | LR=1.77e-05


                                                         

Epoch 77 | Train L=0.3773, Err=43.3° | Val   L=0.4877, Err=52.4° | LR=1.88e-05


                                                         

Epoch 78 | Train L=0.3774, Err=43.3° | Val   L=0.5002, Err=53.3° | LR=1.99e-05


                                                         

Epoch 79 | Train L=0.3707, Err=42.4° | Val   L=0.4919, Err=51.9° | LR=2.11e-05


                                                         

Epoch 80 | Train L=0.3717, Err=42.5° | Val   L=0.4956, Err=53.0° | LR=2.23e-05


                                                         

Epoch 81 | Train L=0.3739, Err=43.0° | Val   L=0.4870, Err=51.8° | LR=2.36e-05


                                                         

Epoch 82 | Train L=0.3650, Err=42.4° | Val   L=0.4900, Err=52.0° | LR=2.50e-05


                                                         

Epoch 83 | Train L=0.3626, Err=42.1° | Val   L=0.5101, Err=53.5° | LR=2.63e-05


                                                         

Epoch 84 | Train L=0.3937, Err=44.4° | Val   L=0.5032, Err=53.6° | LR=2.78e-05


                                                         

Epoch 85 | Train L=0.3702, Err=42.7° | Val   L=0.4891, Err=52.5° | LR=2.92e-05


                                                         

KeyboardInterrupt: 

In [None]:
batch = next(iter(train_loader))
print(batch)

EquiDataBatch(y=[32, 2], z_s=[334], z_s_batch=[334], z_s_ptr=[33], pos_s=[334, 3], z_t=[341], z_t_batch=[341], z_t_ptr=[33], pos_t=[341, 3], id=[32])


In [16]:
torch.cuda.empty_cache()
print("CUDA memory cleared")
# Print CUDA memory summary


CUDA memory cleared


In [15]:
print("CUDA memory summary:")
print(torch.cuda.memory_summary(device=None, abbreviated=False))

CUDA memory summary:
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 10        |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 419315 KiB |   6020 MiB |  63448 GiB |  63448 GiB |
|       from large pool |  79195 KiB |   5742 MiB |  60360 GiB |  60360 GiB |
|       from small pool | 340120 KiB |    415 MiB |   3088 GiB |   3087 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 419315 KiB |   6020 MiB |  63448 GiB |  63448 GiB |
|       from large pool |  79195 KiB |   5742 MiB |  60360 GiB |  60360 GiB |
|       from small pool | 340120 KiB |    415 MiB |   3088 GiB |   3087 GiB |
|------------------------------------------

In [None]:
batch.y

tensor([[ 9.9498e-01,  1.0010e-01],
        [-8.8480e-01,  4.6598e-01],
        [-7.0213e-01, -7.1205e-01],
        [-9.9791e-01,  6.4688e-02],
        [ 7.9861e-02, -9.9681e-01],
        [ 3.9891e-01, -9.1699e-01],
        [-4.3020e-02,  9.9907e-01],
        [-1.6591e-01,  9.8614e-01],
        [ 1.8761e-01, -9.8224e-01],
        [-9.6775e-01, -2.5191e-01],
        [-7.0285e-01,  7.1134e-01],
        [ 2.6074e-01,  9.6541e-01],
        [-5.3904e-04,  1.0000e+00],
        [-1.3275e-01, -9.9115e-01],
        [-4.9938e-04,  1.0000e+00],
        [-1.6487e-01, -9.8632e-01],
        [ 8.2383e-01,  5.6684e-01],
        [ 9.9917e-01,  4.0807e-02],
        [ 1.1098e-01,  9.9382e-01],
        [-9.3298e-01, -3.5993e-01],
        [ 2.6962e-01,  9.6297e-01],
        [-4.5483e-03, -9.9999e-01],
        [-3.3019e-01,  9.4391e-01],
        [ 3.2036e-01, -9.4730e-01],
        [ 9.9998e-01,  6.6410e-03],
        [ 1.0000e+00, -1.0429e-03],
        [ 3.4122e-01, -9.3999e-01],
        [-9.9093e-01, -1.344