In [1]:
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

from scripts.data_formatting import SmilesDataset
from scripts.downstream import get_prediction_smiles, split_batch_by_molecule
from scripts.nn_models import GCN

In [2]:
# raw dataset
df = pd.read_csv('../datasets/13k_All_Manual.csv')
smiles_list = df['SMILES Labelled'].tolist()

In [3]:
# 70/15/15 train/test/val split
train_smiles, test_smiles = train_test_split(smiles_list, test_size=0.15, random_state=42)
train_smiles, val_smiles = train_test_split(train_smiles, test_size=0.1765, random_state=42)

# dataset objects
train_dataset = SmilesDataset(train_smiles)
test_dataset  = SmilesDataset(test_smiles)
val_dataset   = SmilesDataset(val_smiles)

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [4]:
# class weights based on train split
labels = []
for batch in train_loader:
    labels.append(batch.y.argmax(dim=1))
labels = torch.cat(labels)
class_counts = torch.bincount(labels)
class_weights = labels.size(0) / (len(class_counts) * class_counts.float())

# model instance
model = GCN(input_dim=9, hidden_dim=128, output_dim=5, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion = torch.nn.CrossEntropyLoss()


In [5]:
counter = 0
patience = 15
best_val_loss = float('inf')
num_epochs = 1000

for epoch in range(num_epochs):
    # training
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = criterion(out, batch.y.argmax(dim=1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_loader)

    # validation
    model.eval()
    val_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch.x, batch.edge_index)
            loss = criterion(out, batch.y.argmax(dim=1))
            val_loss += loss.item()

            preds, true = split_batch_by_molecule(out, batch)
            for p, t in zip(preds, true):
                if torch.equal(p, t):
                    correct += 1
                total += 1
    avg_val_loss = val_loss / len(val_loader)

    # validation accuracy
    val_acc = correct / total
    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

    # early stopping
    if avg_val_loss < best_val_loss:
        counter = 0
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), f"../models/gcn/gcn_{epoch+1}.pt")
    else:
        counter += 1
        if counter >= patience:
            print(f"Training stopped at epoch {epoch+1}.")
            break

Epoch 1 | Train Loss: 0.6085 | Val Loss: 0.5308 | Val Acc: 15.55%
Epoch 2 | Train Loss: 0.5183 | Val Loss: 0.4976 | Val Acc: 19.51%
Epoch 3 | Train Loss: 0.5001 | Val Loss: 0.4905 | Val Acc: 19.19%
Epoch 4 | Train Loss: 0.4954 | Val Loss: 0.4820 | Val Acc: 19.40%
Epoch 98 | Train Loss: 0.3165 | Val Loss: 0.3128 | Val Acc: 35.28%
Epoch 99 | Train Loss: 0.3146 | Val Loss: 0.3152 | Val Acc: 35.76%
Epoch 100 | Train Loss: 0.3135 | Val Loss: 0.3058 | Val Acc: 35.76%
Epoch 101 | Train Loss: 0.3106 | Val Loss: 0.3078 | Val Acc: 35.17%
Epoch 102 | Train Loss: 0.3136 | Val Loss: 0.3105 | Val Acc: 37.31%
Epoch 103 | Train Loss: 0.3109 | Val Loss: 0.3094 | Val Acc: 34.63%
Epoch 104 | Train Loss: 0.3106 | Val Loss: 0.3131 | Val Acc: 35.86%
Epoch 105 | Train Loss: 0.3105 | Val Loss: 0.3085 | Val Acc: 35.38%
Epoch 106 | Train Loss: 0.3087 | Val Loss: 0.3255 | Val Acc: 34.85%
Epoch 107 | Train Loss: 0.3095 | Val Loss: 0.2995 | Val Acc: 35.86%
Epoch 108 | Train Loss: 0.3095 | Val Loss: 0.3001 | Val Ac

In [6]:
MODEL_PATH = '../models/gcn/gcn_217.pt'
OUTPUT_CSV = '../results/gcn_test_eval.csv'

# load best model
model.load_state_dict(torch.load(MODEL_PATH))

smiles_original = []
smiles_predicted = []
smiles_matched = []

# test evaluation
model.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
    for batch in test_loader:
        out = model(batch.x, batch.edge_index)
        loss = criterion(out, batch.y.argmax(dim=1))
        test_loss += loss.item()

        preds, true = split_batch_by_molecule(out, batch)
        for p, t in zip(preds, true):
            if torch.equal(p, t):
                smiles_matched.append(True)
                correct += 1
            else:
                smiles_matched.append(False)
            total += 1

        smiles_original.extend(batch.smiles)
        smiles_predicted.extend(get_prediction_smiles(preds, batch.smiles))

avg_test_loss = test_loss / len(test_loader)
test_acc = correct / total

print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")

# saves test predictions to .csv file
test_df = pd.DataFrame({
    'SMILES Original': smiles_original,
    'SMILES Predicted': smiles_predicted,
    'SMILES Matched': smiles_matched
})
test_df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved predictions to {OUTPUT_CSV}.")

Test Loss: 0.2733
Test Accuracy: 43.45%
Saved predictions to ../results/gcn_test_eval.csv.
