In [None]:
# import de nodige packages
import os
import sys
import re
import math
from collections import defaultdict

import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split

import networkx as nx
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.rdmolfiles import MolFromXYZFile

# Load the data and couple the SMILES to the yields and remove nan's

In [None]:
# --- 1. Bestanden inlezen als ruwe tekst ---
yields_path = "data/compounds_yield.csv"
smiles_path = "data/compounds_smiles.csv"

# --- 2. Parser voor yields: hoogste percentage extraheren ---
yield_data = []
with open(yields_path, "r") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        parts = line.split(" ", 1)
        if len(parts) == 2:
            compound_id, yield_info = parts
            percentages = re.findall(r'(\d+)%', yield_info)
            if percentages:
                max_yield = max(map(int, percentages))
                yield_data.append((compound_id, int(max_yield)))

df_yields_clean = pd.DataFrame(yield_data, columns=["compound_id", "yield"])

# --- 3. Parser voor SMILES ---
smiles_data = []
with open(smiles_path, "r") as f:
    for line in f:
        parts = [p.strip() for p in line.strip().split(",")]
        if len(parts) == 4:
            compound_id, smiles_raw, number, _ = parts  # ignore smiles_normalized
            smiles_data.append((compound_id, smiles_raw, number))

df_smiles_clean = pd.DataFrame(
    smiles_data,
    columns=["compound_id", "smiles_raw", "borylation_site"]
)

# --- 4. Merge op compound_id ---
df_merged = pd.merge(df_smiles_clean, df_yields_clean, on="compound_id", how="inner")

print("Merged DataFrame:")
print(df_merged)


Convert the SMILES to Graphs

In [None]:

ALLOWED_ATOMS = ["H", "C", "N", "O", "S", "Br", "F", "Cl", "I", "Si", "B"]
ELECTRONEGATIVITY = {"H": 2.20, "C": 2.55, "N": 3.04, "O": 3.44, "S": 2.58, "Br": 2.96, 
                    "F": 3.98, "Cl": 3.16, "I": 2.66, "Si": 1.90, "B": 2.04}
BOND_ORDER_MAP = {
    Chem.rdchem.BondType.SINGLE: 1,
    Chem.rdchem.BondType.DOUBLE: 2,
    Chem.rdchem.BondType.TRIPLE: 3,
    Chem.rdchem.BondType.AROMATIC: 1.5,
}
BOND_TYPE_ENCODING = {"covalent": 0, "polar": 1, "ionic": 2}

class MolecularGraphFromSMILES:
    def __init__(self, smiles: str):
        self.smiles = smiles
        self.mol = Chem.MolFromSmiles(smiles)
        self.atoms = [atom.GetSymbol() for atom in self.mol.GetAtoms()]
        self.atom_objects = [atom for atom in self.mol.GetAtoms()]
        self.bond_objects = [bond for bond in self.mol.GetBonds()]

    def _one_hot(self, value, choices):
        encoding = [0] * len(choices)
        if value in choices:
            encoding[choices.index(value)] = 1
        return encoding

    def to_pyg_data(self) -> Data:
        x = []
        for atom in self.atom_objects:
            symbol = atom.GetSymbol()
            one_hot_symbol = self._one_hot(symbol, ALLOWED_ATOMS)
            one_hot_aromatic = [int(atom.GetIsAromatic()), int(not atom.GetIsAromatic())]
            feature_vector = (
                one_hot_symbol +
                one_hot_aromatic +
                [atom.GetFormalCharge()] +
                [int(atom.IsInRing())] +
                [ELECTRONEGATIVITY.get(symbol, 0.0)]
            )
            x.append(feature_vector)
        x = torch.tensor(x, dtype=torch.float)

        edge_index = []
        edge_attr = []
        for bond in self.bond_objects:
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            sym_i = self.atoms[i]
            sym_j = self.atoms[j]

            # electronegativiteitsschatting voor bindingstype
            diff = abs(ELECTRONEGATIVITY.get(sym_i, 0) - ELECTRONEGATIVITY.get(sym_j, 0))
            if diff > 1.7:
                bond_type = "ionic"
            elif diff > 0.4:
                bond_type = "polar"
            else:
                bond_type = "covalent"

            attr = [
                BOND_ORDER_MAP.get(bond.GetBondType(), 1),
                int(bond.GetIsAromatic()),
                BOND_TYPE_ENCODING[bond_type]
            ]

            edge_index += [[i, j], [j, i]]
            edge_attr += [attr, attr]

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    def visualize(self, with_labels=True):
        G = nx.Graph()
        for i, el in enumerate(self.atoms):
            G.add_node(i, label=el)
        for bond in self.bond_objects:
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            order = BOND_ORDER_MAP.get(bond.GetBondType(), 1)
            G.add_edge(i, j, label=str(order))

        pos = nx.spring_layout(G)
        nx.draw(G, pos, with_labels=with_labels,
                labels=nx.get_node_attributes(G, 'label'),
                node_color='lightblue', node_size=700, font_size=10)
        edge_labels = nx.get_edge_attributes(G, 'label')
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        plt.show()


In [None]:
# Create a molecule from a SMILES string
smiles = "CC(=O)O"  
mol_graph = MolecularGraphFromSMILES(smiles)

# Visualize the molecule with bond orders
mol_graph.visualize()

# Convert to PyTorch Geometric format
pyg_data = mol_graph.to_pyg_data()
print(pyg_data)

## Zet de SMILES om naar graphs

In [None]:

# Je class moet al geïmporteerd zijn:
# from jouw_bestand import MolecularGraphFromSMILES

graphs = []
for _, row in tqdm(df_merged.iterrows(), total=len(df_merged), desc="Converting SMILES to graphs"):
    try:
        graph = MolecularGraphFromSMILES(row['smiles_raw']).to_pyg_data()
        graph.y = torch.tensor([row['yield']], dtype=torch.float)
        graphs.append(graph)
    except Exception as e:
        print(f"Fout bij SMILES: {row['smiles_raw']}, error: {e}")


## Zet de graphs in een dataloader zodat het de GNN in kan

In [None]:
# Split de lijst met graphs (die je eerder hebt gegenereerd)
train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)

# Maak DataLoaders aan voor training en evaluatie
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

In [None]:
for batch in train_loader:
    x = batch.x
    edge_index = batch.edge_index
    edge_attr = batch.edge_attr
    y = batch.y
    batch_vector = batch.batch