In [1]:
import pandas as pd
import torch
from torch_geometric.loader import DataLoader

from scripts.data_formatting import SmilesDataset
from scripts.downstream import get_prediction_smiles, split_batch_by_molecule
from scripts.nn_models import GINE

In [2]:
# raw dataset
df = pd.read_csv('../datasets/289-test-rxns.csv')
smiles_list = df['SMILES Labelled'].tolist()

In [3]:
# dataset objects
test_dataset  = SmilesDataset(smiles_list)

# dataloaders
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [4]:
# model instance
model = GINE(input_dim=9, hidden_dim=128, output_dim=5, edge_dim=8, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

In [12]:
MODEL_PATH = '../models/gine/gine_156.pt'
OUTPUT_CSV = '../results/289_test_eval.csv'

# load best model
model.load_state_dict(torch.load(MODEL_PATH))

smiles_original = []
smiles_predicted = []
smiles_matched = []

# test evaluation
model.eval()
test_loss, correct, total = 0, 0, 0
with torch.no_grad():
    for batch in test_loader:
        out = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = criterion(out, batch.y.argmax(dim=1))
        test_loss += loss.item()

        preds, true = split_batch_by_molecule(out, batch)
        for p, t in zip(preds, true):
            mask = (p != 2) & (t != 2) & (p != 4) & (t != 4)
            if torch.equal(p[mask], t[mask]):
                smiles_matched.append(True)
                correct += 1
            else:
                smiles_matched.append(False)
            total += 1

        smiles_original.extend(batch.smiles)
        smiles_predicted.extend(get_prediction_smiles(preds, batch.smiles))

avg_test_loss = test_loss / len(test_loader)
test_acc = correct / total

print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")

# saves test predictions to .csv file
test_df = pd.DataFrame({
    'SMILES Original': smiles_original,
    'SMILES Predicted': smiles_predicted,
    'SMILES Matched': smiles_matched
})
test_df.to_csv(OUTPUT_CSV, index=False)
print(f"Saved predictions to {OUTPUT_CSV}.")

Test Loss: 0.2844
Test Accuracy: 68.51%
Saved predictions to ../results/289_test_eval.csv.


In [None]:
# OVERALL CORRECT: 64.71%
# SOURCE-SINK CORRECT: 68.51%
# 10 correct: 74.74%
# 11 correct: 87.89%
# 20 correct: 83.04%
# 21 correct: 80.97%