In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import random
import numpy as np
import pandas as pd
from dgl.data import ZINCDataset

from utils.collate_regression import collate
from utils.calculate_eigenvectors import k_eigenvectors

from components.graph_transformer_model import GraphTransformerModel


In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
train_dataset = ZINCDataset(mode='train')
val_dataset   = ZINCDataset(mode='valid')
test_dataset  = ZINCDataset(mode='test')

In [None]:
pos_dim = 4
k_eigenvectors(train_dataset, k=pos_dim)
k_eigenvectors(val_dataset, k=pos_dim)
k_eigenvectors(test_dataset, k=pos_dim)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def train_epoch(model, loss_fn, optimizer, loader):
    model.train()
    total_loss = 0
    for g, labels in loader:
        g = g.to(device)
        h = g.ndata['feat'].to(device)
        pos_enc= g.ndata['PE'].to(device)

        labels = labels.to(device).float()

        preds = model(g, h, pos_enc)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
    return total_loss / len(loader.dataset)

In [None]:
@torch.no_grad()
def eval_epoch(model, loader):
    model.eval()
    total_mae = 0
    for g, labels in loader:
        g = g.to(device)
        h = g.ndata['feat'].to(device)
        pos_enc= g.ndata['PE'].to(device)

        labels = labels.to(device).float()
        preds = model(g, h, pos_enc)
        total_mae += torch.abs(preds - labels).sum().item()
    return total_mae / len(loader.dataset)


In [None]:
for seed in [42, 7, 5, 9]:
    print(f"seed: {seed}")
    model = GraphTransformerModel(
        num_atom_type = train_dataset.num_atom_types,
        pos_enc_dim = pos_dim,
        out_size=1, #regression
        hidden_size=16,
        num_heads=8, 
        num_layers=10
    ).to(device)
    
    loss_fn = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    set_seed(seed)
    metrics = {
        "epoch": [],
        "train_loss": [],
        "val_mae": []
    }


    num_epochs = 50

    for epoch in range(num_epochs):
        train_loss = train_epoch(model, loss_fn, optimizer, train_loader)
        val_mae = eval_epoch(model, val_loader)

        metrics["epoch"].append(epoch + 1)
        metrics["train_loss"].append(train_loss)
        metrics["val_mae"].append(val_mae)

        print(
            f"Epoch {epoch+1}: "
            f"Train loss={train_loss:.4f}, "
            f"Val MAE={val_mae:.4f}"
        )

    # Save metrics
    test_mae = eval_epoch(model, test_loader)
    print(f"seed: {seed}")
    print(f"Test MAE: {test_mae:.4f}")
    metrics["test_mae"] = [np.nan] * (num_epochs - 1) + [test_mae]
    df = pd.DataFrame(metrics)
    df.to_csv(f"results//training_metrics_{seed}.csv", index=False)
