In [1]:
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

from scripts.data_formatting import SmilesDataset
from scripts.downstream import get_prediction_smiles, split_batch_by_molecule
from scripts.nn_models import EGAT

In [2]:
# raw dataset
df = pd.read_csv('../datasets/13k_All_Manual.csv')
smiles_list = df['SMILES Labelled'].tolist()

In [3]:
# 70/15/15 train/test/val split
train_smiles, test_smiles = train_test_split(smiles_list, test_size=0.15, random_state=42)
train_smiles, val_smiles = train_test_split(train_smiles, test_size=0.1765, random_state=42)

# dataset objects
train_dataset = SmilesDataset(train_smiles)
test_dataset  = SmilesDataset(test_smiles)
val_dataset   = SmilesDataset(val_smiles)

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [4]:
# model instance
model = EGAT(input_dim=9, hidden_dim=128, output_dim=5, edge_dim=8, heads=4, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
counter = 0
patience = 20
best_val_loss = float('inf')
num_epochs = 1000

for epoch in range(num_epochs):
    # training
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = criterion(out, batch.y.argmax(dim=1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    # validation
    model.eval()
    val_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch.x, batch.edge_index, batch.edge_attr)
            loss = criterion(out, batch.y.argmax(dim=1))
            val_loss += loss.item()

            preds, true = split_batch_by_molecule(out, batch)
            for p, t in zip(preds, true):
                if torch.equal(p, t):
                    correct += 1
                total += 1
    avg_val_loss = val_loss / len(val_loader)

    # validation accuracy
    val_acc = correct / total
    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

    # early stopping
    if avg_val_loss < best_val_loss:
        counter = 0
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), f"../models/egat/egat_{epoch+1}.pt")
    else:
        counter += 1
        if counter >= patience:
            print(f"Training stopped at epoch {epoch+1}.")
            break

Epoch 1 | Train Loss: 0.5418 | Val Loss: 0.4602 | Val Acc: 16.30%
Epoch 2 | Train Loss: 0.4449 | Val Loss: 0.4275 | Val Acc: 20.95%
Epoch 3 | Train Loss: 0.4202 | Val Loss: 0.4042 | Val Acc: 22.93%
Epoch 4 | Train Loss: 0.3980 | Val Loss: 0.3827 | Val Acc: 23.89%
Epoch 5 | Train Loss: 0.3766 | Val Loss: 0.3643 | Val Acc: 26.72%
Epoch 6 | Train Loss: 0.3600 | Val Loss: 0.3522 | Val Acc: 27.95%
Epoch 7 | Train Loss: 0.3491 | Val Loss: 0.3415 | Val Acc: 28.22%
Epoch 8 | Train Loss: 0.3398 | Val Loss: 0.3313 | Val Acc: 28.86%
Epoch 9 | Train Loss: 0.3321 | Val Loss: 0.3293 | Val Acc: 30.63%
Epoch 10 | Train Loss: 0.3261 | Val Loss: 0.3187 | Val Acc: 31.00%
Epoch 11 | Train Loss: 0.3201 | Val Loss: 0.3139 | Val Acc: 31.37%
Epoch 12 | Train Loss: 0.3150 | Val Loss: 0.3075 | Val Acc: 30.57%
Epoch 13 | Train Loss: 0.3102 | Val Loss: 0.3065 | Val Acc: 31.37%
Epoch 14 | Train Loss: 0.3046 | Val Loss: 0.3015 | Val Acc: 32.28%
Epoch 15 | Train Loss: 0.3038 | Val Loss: 0.3009 | Val Acc: 31.59%
Epoc

In [14]:
MODEL_PATH = '../models/egat/egat_156.pt'
OUTPUT_CSV = '../results/egat_test_eval.csv'

# load best model
model.load_state_dict(torch.load(MODEL_PATH))

smiles_original = []
smiles_predicted = []
smiles_matched = []

# test evaluation
model.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
    for batch in test_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = criterion(out, batch.y.argmax(dim=1))
        test_loss += loss.item()

        preds, true = split_batch_by_molecule(out, batch)
        for p, t in zip(preds, true):
            if torch.equal(p, t):
                smiles_matched.append(True)
                correct += 1
            else:
                smiles_matched.append(False)
            total += 1

        smiles_original.extend(batch.smiles)
        smiles_predicted.extend(get_prediction_smiles(preds, batch.smiles))

avg_test_loss = test_loss / len(test_loader)
test_acc = correct / total

print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")

# saves test predictions to .csv file
test_df = pd.DataFrame({
    'SMILES Original': smiles_original,
    'SMILES Predicted': smiles_predicted,
    'SMILES Matched': smiles_matched
})
test_df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved predictions to {OUTPUT_CSV}.")

Test Loss: 0.1121
Test Accuracy: 72.55%
Saved predictions to ../results/gine_test_eval.csv.


In [15]:
# 10 correct: 84.48%
# 11 correct: 96.25%
# 20 correct: 91.97%
# 21 correct: 88.92%