In [1]:
# import de nodige packages
import os
import sys
import re
import math
from collections import defaultdict

import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split

import networkx as nx
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.rdmolfiles import MolFromXYZFile
from functions.data_loader import data_loader
from classes.smiles_to_graph import MolecularGraphFromSMILES
from classes.MPNN import MPNN
from functions.compute_loss import compute_loss
from functions.evaluations import evaluate_yield
from functions.evaluations import evaluate_borylation_site
from functions.evaluations import evaluate_model
from functions.train import train_MPNN_model

# Load the data and couple the SMILES to the yields and remove nan's

In [2]:
yields_path = "data/compounds_yield.csv"
smiles_path = "data/compounds_smiles.csv"

df_merged = data_loader(yields_path, smiles_path)


#print("Merged DataFrame:")
#print(df_merged)


Convert the SMILES to Graphs

## Zet de SMILES om naar graphs

In [3]:
from rdkit import Chem

graphs = []
for _, row in tqdm(df_merged.iterrows(), total=len(df_merged), desc="Converting SMILES to graphs"):
    try:
        mol_graph = MolecularGraphFromSMILES(row['smiles_raw'])
        mol = Chem.MolFromSmiles(row['smiles_raw'])  # extra RDKit mol object
        num_atoms = mol.GetNumAtoms() if mol is not None else -1

        borylation_index = row['borylation_site']

        # Debug print vóór de fout
        if not (0 <= borylation_index < num_atoms):
            raise IndexError(f"index {borylation_index} is out of bounds for molecule with {num_atoms} atoms")

        graph = mol_graph.to_pyg_data(
            borylation_index=borylation_index,
            yield_value=row['yield']
        )
        graphs.append(graph)

    except Exception as e:
        print(f"\n🚨 Fout bij SMILES: {row['smiles_raw']}")
        print(f"  - borylation_site: {row['borylation_site']}")
        mol = Chem.MolFromSmiles(row['smiles_raw'])
        if mol:
            print(f"  - aantal atomen in RDKit mol: {mol.GetNumAtoms()}")
        else:
            print("  - RDKit kon mol niet parsen!")
        print(f"  - foutmelding: {e}")

# Verdeel de data in train, validatie en test sets
from sklearn.model_selection import train_test_split

# Eerste splitsing: 85% train+val, 15% test
train_val_graphs, test_graphs = train_test_split(
    graphs, test_size=0.15, random_state=42
)

# Tweede splitsing: 70/15 = 70/85 ≈ 0.8235 voor train
train_graphs, val_graphs = train_test_split(
    train_val_graphs, test_size=0.1765, random_state=42
)


Converting SMILES to graphs:   0%|          | 0/83 [00:00<?, ?it/s]

Converting SMILES to graphs: 100%|██████████| 83/83 [00:00<00:00, 526.73it/s]


## Zet de graphs in een dataloader zodat het de GNN in kan

In [6]:
import torch
from torch_geometric.loader import DataLoader

#dus reactivity fixen
# hebben we dingen genormaliseerd?->zeker yield enzo wel handig denk ik

# Instellingen
batch_size = 16
num_epochs = 500
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Architecture MPNN
node_in_feats=train_graphs[0].x.shape[1]                # Aantal input features per node
edge_in_feats=train_graphs[0].edge_attr.shape[1]        # Aantal input features per edge
hidden_feats=128                                         # Aantal verborgen features
num_step_message_passing=3                              # Aantal stappen voor message passing
num_step_set2set=3                                      # Aantal stappen voor Set2Set
num_layer_set2set=1                                     # Aantal lagen voor Set2Set
readout_feats=1024                                      # Aantal features voor readout
activation='leaky_relu'                                 # Activatiefunctie
dropout=0.4                                             # Dropout percentage

# DataLoaders
train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

# Initialiseer model
model = MPNN(
    node_in_feats=node_in_feats,
    edge_in_feats=edge_in_feats,
    hidden_feats=hidden_feats,
    num_step_message_passing=num_step_message_passing,
    num_step_set2set=num_step_set2set,
    num_layer_set2set=num_layer_set2set,
    readout_feats=readout_feats,
    activation=activation,
    dropout=dropout
)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop met validatie
for epoch in range(num_epochs):
    train_losses = train_MPNN_model(model, train_loader, optimizer, device)
    val_metrics = evaluate_model(model, val_loader, device)

    print(f"[Epoch {epoch+1}] Train loss: {train_losses['total']:.4f} | "
        f"Site: {train_losses['site']:.4f}, "
        f"Yield: {train_losses['yield']:.4f}")


# Evaluatie op testset na training
print("\n✅ Evaluatie op testset na training:")
test_metrics = evaluate_model(model, test_loader, device)

print("📊 Testresultaten:")
print(f"🔹 Borylation site prediction:")
print(f"   - Accuracy      : {test_metrics['site_Accuracy']:.3f}")
print(f"   - Precision     : {test_metrics['site_Precision']:.3f}")
print(f"   - Recall        : {test_metrics['site_Recall']:.3f}")
print(f"   - F1-score      : {test_metrics['site_F1']:.3f}")
print(f"   - ROC AUC       : {test_metrics['site_AUC']:.3f}")

print(f"\n🔹 Yield prediction:")
print(f"   - MSE           : {test_metrics['yield_MSE']:.3f}")
print(f"   - MAE           : {test_metrics['yield_MAE']:.3f}")
print(f"   - R²            : {test_metrics['yield_R2']:.3f}")



[Epoch 1] Train loss: 414.8612 | Site: 1.5812, Yield: 4122.7995
[Epoch 2] Train loss: 296.3749 | Site: 1.4915, Yield: 2938.8343
[Epoch 3] Train loss: 89.3294 | Site: 1.5318, Yield: 867.9763
[Epoch 4] Train loss: 130.7162 | Site: 1.5211, Yield: 1281.9517
[Epoch 5] Train loss: 81.1804 | Site: 1.5710, Yield: 786.0938
[Epoch 6] Train loss: 86.8865 | Site: 1.5289, Yield: 843.5758
[Epoch 7] Train loss: 72.0565 | Site: 1.5160, Yield: 695.4049
[Epoch 8] Train loss: 83.4164 | Site: 1.4947, Yield: 809.2170
[Epoch 9] Train loss: 72.6422 | Site: 1.5032, Yield: 701.3899
[Epoch 10] Train loss: 74.3494 | Site: 1.4538, Yield: 718.9560
[Epoch 11] Train loss: 72.7043 | Site: 1.4838, Yield: 702.2052
[Epoch 12] Train loss: 72.2409 | Site: 1.4658, Yield: 697.7508
[Epoch 13] Train loss: 69.2545 | Site: 1.4606, Yield: 667.9394
[Epoch 14] Train loss: 73.0251 | Site: 1.5283, Yield: 704.9679
[Epoch 15] Train loss: 70.4524 | Site: 1.4695, Yield: 679.8294
[Epoch 16] Train loss: 75.1935 | Site: 1.4743, Yield: 727.

New version of code used above