In [1]:
# import de nodige packages
import os
import sys
import re
import math
from collections import defaultdict

import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split

import networkx as nx
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.rdmolfiles import MolFromXYZFile
from functions.data_loader import data_loader
from classes.smiles_to_graph import MolecularGraphFromSMILES
from classes.MPNN import MPNN
from functions.compute_loss import compute_loss
from functions.evaluations import evaluate_yield
from functions.evaluations import evaluate_borylation_site
from functions.evaluations import evaluate_reactivity
from functions.evaluations import evaluate_model
from functions.train import train_MPNN_model

# Load the data and couple the SMILES to the yields and remove nan's

In [2]:
yields_path = "data/compounds_yield.csv"
smiles_path = "data/compounds_smiles.csv"

df_merged = data_loader(yields_path, smiles_path)


#print("Merged DataFrame:")
#print(df_merged)


Convert the SMILES to Graphs

## Zet de SMILES om naar graphs

In [3]:
from rdkit import Chem

graphs = []
for _, row in tqdm(df_merged.iterrows(), total=len(df_merged), desc="Converting SMILES to graphs"):
    try:
        mol_graph = MolecularGraphFromSMILES(row['smiles_raw'])
        mol = Chem.MolFromSmiles(row['smiles_raw'])  # extra RDKit mol object
        num_atoms = mol.GetNumAtoms() if mol is not None else -1

        borylation_index = row['borylation_site']

        # Debug print vóór de fout
        if not (0 <= borylation_index < num_atoms):
            raise IndexError(f"index {borylation_index} is out of bounds for molecule with {num_atoms} atoms")

        graph = mol_graph.to_pyg_data(
            borylation_index=borylation_index,
            yield_value=row['yield']
        )
        graphs.append(graph)

    except Exception as e:
        print(f"\n🚨 Fout bij SMILES: {row['smiles_raw']}")
        print(f"  - borylation_site: {row['borylation_site']}")
        mol = Chem.MolFromSmiles(row['smiles_raw'])
        if mol:
            print(f"  - aantal atomen in RDKit mol: {mol.GetNumAtoms()}")
        else:
            print("  - RDKit kon mol niet parsen!")
        print(f"  - foutmelding: {e}")

# Verdeel de data in train, validatie en test sets
from sklearn.model_selection import train_test_split

# Eerste splitsing: 85% train+val, 15% test
train_val_graphs, test_graphs = train_test_split(
    graphs, test_size=0.15, random_state=42
)

# Tweede splitsing: 70/15 = 70/85 ≈ 0.8235 voor train
train_graphs, val_graphs = train_test_split(
    train_val_graphs, test_size=0.1765, random_state=42
)


Converting SMILES to graphs: 100%|██████████| 83/83 [00:00<00:00, 91.36it/s]


## Zet de graphs in een dataloader zodat het de GNN in kan

In [4]:
import torch
from torch_geometric.loader import DataLoader

# Instellingen
batch_size = 32
num_epochs = 20
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Aantal kenmerken per node en edge
node_in_feats = train_graphs[0].x.shape[1]
edge_in_feats = train_graphs[0].edge_attr.shape[1]

# Grootte van de verborgen laag
hidden_feats = 64 

# DataLoaders
train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

# Initialiseer model
model = MPNN(node_in_feats=node_in_feats, edge_in_feats=edge_in_feats, hidden_feats=hidden_feats)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop met validatie
for epoch in range(num_epochs):
    train_losses = train_MPNN_model(model, train_loader, optimizer, device)
    val_metrics = evaluate_model(model, val_loader, device)

    print(f"[Epoch {epoch+1}] Train loss: {train_losses['total']:.4f} | "
        f"Site: {train_losses['site']:.4f}, "
        f"Reactivity: {train_losses['react']:.4f}, "
        f"Yield: {train_losses['yield']:.4f}")


# Evaluatie op testset na training
print("\n✅ Evaluatie op testset na training:")
test_metrics = evaluate_model(model, test_loader, device)

print("📊 Testresultaten:")
print(f"🔹 Borylation site prediction:")
print(f"   - Accuracy      : {test_metrics['site_Accuracy']:.3f}")
print(f"   - Precision     : {test_metrics['site_Precision']:.3f}")
print(f"   - Recall        : {test_metrics['site_Recall']:.3f}")
print(f"   - F1-score      : {test_metrics['site_F1']:.3f}")
print(f"   - ROC AUC       : {test_metrics['site_AUC']:.3f}")

print(f"\n🔹 Reactivity prediction:")
print(f"   - MSE           : {test_metrics['react_MSE']:.3f}")
print(f"   - Pearson R     : {test_metrics['react_Pearson']:.3f}")
print(f"   - Spearman Rho  : {test_metrics['react_Spearman']:.3f}")

print(f"\n🔹 Yield prediction:")
print(f"   - MSE           : {test_metrics['yield_MSE']:.3f}")
print(f"   - MAE           : {test_metrics['yield_MAE']:.3f}")
print(f"   - R²            : {test_metrics['yield_R2']:.3f}")



[Epoch 1] Train loss: 415.2318 | Site: 1.6306, Reactivity: 0.0317, Yield: 4135.6941
[Epoch 2] Train loss: 411.3606 | Site: 1.5094, Reactivity: 0.0158, Yield: 4098.3536
[Epoch 3] Train loss: 398.7181 | Site: 1.4996, Reactivity: 0.0146, Yield: 3972.0389
[Epoch 4] Train loss: 361.4044 | Site: 1.5081, Reactivity: 0.0145, Yield: 3598.8191
[Epoch 5] Train loss: 319.0465 | Site: 1.5226, Reactivity: 0.0068, Yield: 3175.1708
[Epoch 6] Train loss: 248.2540 | Site: 1.5473, Reactivity: 0.0054, Yield: 2467.0134
[Epoch 7] Train loss: 170.2375 | Site: 1.5546, Reactivity: 0.0052, Yield: 1686.7773
[Epoch 8] Train loss: 99.1494 | Site: 1.5548, Reactivity: 0.0042, Yield: 975.9045
[Epoch 9] Train loss: 75.3989 | Site: 1.5495, Reactivity: 0.0030, Yield: 738.4641
[Epoch 10] Train loss: 117.8478 | Site: 1.5415, Reactivity: 0.0022, Yield: 1163.0403
[Epoch 11] Train loss: 117.1911 | Site: 1.5331, Reactivity: 0.0034, Yield: 1156.5463
[Epoch 12] Train loss: 89.0004 | Site: 1.5485, Reactivity: 0.0107, Yield: 874.

New version of code used above