In [None]:
import torch
import torch.nn.functional as F

from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool

import pandas as pd
import numpy as np
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold

from sklearn.metrics import roc_auc_score, average_precision_score, f1_score


In [None]:
graphs = torch.load("qs_graphs.pt")
df = pd.read_csv("qs_inhibitors_cleaned.csv")

assert len(graphs) == len(df)
print("Loaded graphs:", len(graphs))


Loaded graphs: 168


In [None]:
def get_murcko_scaffold(smiles):
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    return Chem.MolToSmiles(scaffold)

df["scaffold"] = df["smiles_canonical"].apply(get_murcko_scaffold)


In [None]:
def scaffold_split(df, test_fraction=0.2, seed=42):
    rng = np.random.default_rng(seed)

    scaffold_to_indices = {}
    for idx, scaffold in zip(df.index, df["scaffold"]):
        scaffold_to_indices.setdefault(scaffold, []).append(idx)

    scaffolds = list(scaffold_to_indices.keys())
    rng.shuffle(scaffolds)

    train_idx = []
    test_idx = []
    n_total = len(df)
    n_test = 0
    target_test = int(test_fraction * n_total)

    for scaffold in scaffolds:
        idxs = scaffold_to_indices[scaffold]
        if n_test < target_test:
            test_idx.extend(idxs)
            n_test += len(idxs)
        else:
            train_idx.extend(idxs)

    return train_idx, test_idx


In [None]:
def make_loaders(graphs, train_idx, test_idx, batch_size=32):
    train_graphs = [graphs[i] for i in train_idx]
    test_graphs = [graphs[i] for i in test_idx]

    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_dim=64):
        super().__init__()

        self.conv1 = GCNConv(in_channels, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        self.lin = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)

        x = self.conv2(x, edge_index)
        x = F.relu(x)

        x = global_mean_pool(x, batch)
        x = self.lin(x)

        return x.view(-1)


In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        logits = model(data.x, data.edge_index, data.batch)
        loss = F.binary_cross_entropy_with_logits(logits, data.y.float())

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:
def eval_model(model, loader, device):
    model.eval()

    y_true, y_prob = [], []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            logits = model(data.x, data.edge_index, data.batch)
            probs = torch.sigmoid(logits)

            y_true.extend(data.y.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())

    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob >= 0.5).astype(int)

    return {
        "ROC_AUC": roc_auc_score(y_true, y_prob),
        "PR_AUC": average_precision_score(y_true, y_prob),
        "F1": f1_score(y_true, y_pred)
    }


In [None]:
def run_gcn_experiment(
    graphs,
    df,
    n_splits=5,
    epochs=100,
    seed_start=42
):
    all_results = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for i in range(n_splits):
        print(f"\n=== Scaffold split {i+1}/{n_splits} ===")

        train_idx, test_idx = scaffold_split(df, seed=seed_start + i)
        train_scaffolds = set(df.loc[train_idx, "scaffold"])
        test_scaffolds = set(df.loc[test_idx, "scaffold"])

        assert train_scaffolds.isdisjoint(test_scaffolds), \
          "Scaffold leakage detected!"

        train_loader, test_loader = make_loaders(graphs, train_idx, test_idx)

        model = GCN(in_channels=graphs[0].x.shape[1]).to(device)
        optimizer = torch.optim.Adam(
            model.parameters(), lr=1e-3, weight_decay=1e-4
        )

        for epoch in range(epochs):
            train_epoch(model, train_loader, optimizer, device)

        metrics = eval_model(model, test_loader, device)
        print(metrics)

        all_results.append(metrics)

    return pd.DataFrame(all_results)
print(f"Train size: {len(train_idx)}, Test size: {len(test_idx)}")



Train size: 131, Test size: 37


In [None]:
gcn_results = run_gcn_experiment(
    graphs,
    df,
    n_splits=5,
    epochs=100
)

gcn_results


=== Scaffold split 1/5 ===
{'ROC_AUC': 0.6339285714285714, 'PR_AUC': 0.6887158689637582, 'F1': 0.7241379310344828}

=== Scaffold split 2/5 ===
{'ROC_AUC': 0.444078947368421, 'PR_AUC': 0.5586653306913212, 'F1': 0.5909090909090909}

=== Scaffold split 3/5 ===
{'ROC_AUC': 0.5277777777777778, 'PR_AUC': 0.6730323240141669, 'F1': 0.47058823529411764}

=== Scaffold split 4/5 ===
{'ROC_AUC': 0.7708333333333333, 'PR_AUC': 0.823707486013098, 'F1': 0.6923076923076923}

=== Scaffold split 5/5 ===
{'ROC_AUC': 0.6283333333333333, 'PR_AUC': 0.6579250198503072, 'F1': 0.5769230769230769}


Unnamed: 0,ROC_AUC,PR_AUC,F1
0,0.633929,0.688716,0.724138
1,0.444079,0.558665,0.590909
2,0.527778,0.673032,0.470588
3,0.770833,0.823707,0.692308
4,0.628333,0.657925,0.576923


In [None]:
gcn_summary = gcn_results.agg(["mean", "std"])
gcn_summary

Unnamed: 0,ROC_AUC,PR_AUC,F1
mean,0.60099,0.680409,0.610973
std,0.123147,0.094848,0.100861
