I've modified the architecture into a node-level predictor: it now outputs a score per atom indicating the likelihood of being the borylation site. This version replaces the graph-level readout with a node_predictor head suitable for binary classification per node.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MPNN(nn.Module):
    def __init__(self, node_in_feats, edge_in_feats, hidden_feats=64,
                 num_step_message_passing=3, num_step_set2set=3, num_layer_set2set=1,
                 readout_feats=1024):
        super(MPNN, self).__init__()

        self.project_node_feats = nn.Sequential(
            nn.Linear(node_in_feats, hidden_feats), nn.ReLU()
        )

        self.num_step_message_passing = num_step_message_passing

        edge_network = nn.Linear(edge_in_feats, hidden_feats * hidden_feats)

        self.gnn_layer = NNConv(
            in_feats=hidden_feats,
            out_feats=hidden_feats,
            edge_func=edge_network,
            aggregator_type='sum'
        )

        self.activation = nn.ReLU()
        self.gru = nn.GRU(hidden_feats, hidden_feats)

        # Node-level heads
        self.node_classifier = nn.Sequential(
            nn.Linear(hidden_feats, 1),  # Sigmoid komt pas bij loss
        )
        self.node_regressor = nn.Linear(hidden_feats, 1)

        # Graph-level readout
        self.readout = Set2Set(input_dim=hidden_feats * 2,
                               n_iters=num_step_set2set,
                               n_layers=num_layer_set2set)

        self.sparsify = nn.Sequential(
            nn.Linear(hidden_feats * 4, readout_feats), nn.PReLU()
        )

        self.yield_regressor = nn.Sequential(
            nn.Linear(readout_feats, 128), nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x, edge_index, edge_attr, batch):
        # Project node features
        node_feats = self.project_node_feats(x)
        hidden_feats = node_feats.unsqueeze(0)

        node_aggr = [node_feats]
        for _ in range(self.num_step_message_passing):
            node_feats = self.activation(self.gnn_layer(node_feats, edge_index, edge_attr)).unsqueeze(0)
            node_feats, hidden_feats = self.gru(node_feats, hidden_feats)
            node_feats = node_feats.squeeze(0)

        node_aggr.append(node_feats)
        node_aggr_cat = torch.cat(node_aggr, dim=1)

        # Node-level outputs
        p_borylation = self.node_classifier(node_feats).squeeze(-1)
        reactivity_score = self.node_regressor(node_feats).squeeze(-1)

        # Graph-level output
        readout = self.readout(node_aggr_cat, batch)
        graph_feats = self.sparsify(readout)
        predicted_yield = self.yield_regressor(graph_feats).squeeze(-1)

        return p_borylation, reactivity_score, predicted_yield



In [2]:
def compute_loss(p_borylation, borylation_mask, reactivity_score, reactivity_target, predicted_yield, true_yield,
                 alpha=1.0, beta=1.0, gamma=0.1):
    # Borylation: Binary classification (sigmoid niet in model maar in BCEWithLogits)
    loss_site = nn.BCEWithLogitsLoss()(p_borylation, borylation_mask)

    # Reactivity: Regressie per node
    loss_react = nn.MSELoss()(reactivity_score, reactivity_target)

    # Yield: Regressie per graaf
    loss_yield = nn.MSELoss()(predicted_yield, true_yield)

    total_loss = alpha * loss_site + beta * loss_react + gamma * loss_yield
    return total_loss, loss_site, loss_react, loss_yield


In [3]:
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)

        optimizer.zero_grad()
        p_borylation, reactivity_score, predicted_yield = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        loss, l_site, l_react, l_yield = compute_loss(
            p_borylation, batch.borylation_mask,
            reactivity_score, batch.reactivity,
            predicted_yield, batch.y
        )

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)


In [4]:
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

def evaluate_yield(y_true, y_pred):
    mse = mean_squared_error(y_true, y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    return {"yield_MSE": mse, "yield_MAE": mae, "yield_R2": r2}


In [5]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

def evaluate_borylation_site(pred_logits, true_mask):
    pred_probs = torch.sigmoid(pred_logits).detach().cpu().numpy()
    true_mask = true_mask.detach().cpu().numpy()
    
    pred_binary = (pred_probs >= 0.5).astype(int)

    return {
        "site_Accuracy": accuracy_score(true_mask, pred_binary),
        "site_Precision": precision_score(true_mask, pred_binary, zero_division=0),
        "site_Recall": recall_score(true_mask, pred_binary, zero_division=0),
        "site_F1": f1_score(true_mask, pred_binary, zero_division=0),
        "site_AUC": roc_auc_score(true_mask, pred_probs) if len(set(true_mask)) > 1 else float("nan")
    }


In [6]:
from scipy.stats import spearmanr, pearsonr

def evaluate_reactivity(pred_score, true_score):
    pred = pred_score.detach().cpu().numpy()
    true = true_score.detach().cpu().numpy()
    
    mse = mean_squared_error(true, pred)
    spearman_corr = spearmanr(true, pred).correlation
    pearson_corr = pearsonr(true, pred)[0]
    
    return {
        "react_MSE": mse,
        "react_Spearman": spearman_corr,
        "react_Pearson": pearson_corr
    }


In [7]:
def evaluate_model(model, dataloader, device):
    model.eval()
    
    all_y_true = []
    all_y_pred = []
    
    all_site_logits = []
    all_site_masks = []

    all_reactivity_pred = []
    all_reactivity_true = []

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)

            p_borylation, reactivity_score, predicted_yield = model(
                batch.x, batch.edge_index, batch.edge_attr, batch.batch
            )

            # Yield
            all_y_true.append(batch.y.cpu())
            all_y_pred.append(predicted_yield.cpu())

            # Borylation mask
            all_site_logits.append(p_borylation)
            all_site_masks.append(batch.borylation_mask)

            # Reactivity
            all_reactivity_pred.append(reactivity_score)
            all_reactivity_true.append(batch.reactivity)

    # concat
    y_true = torch.cat(all_y_true).numpy()
    y_pred = torch.cat(all_y_pred).numpy()
    
    site_logits = torch.cat(all_site_logits)
    site_masks = torch.cat(all_site_masks)
    
    react_pred = torch.cat(all_reactivity_pred)
    react_true = torch.cat(all_reactivity_true)

    # evaluate
    metrics = {}
    metrics.update(evaluate_yield(y_true, y_pred))
    metrics.update(evaluate_borylation_site(site_logits, site_masks))
    metrics.update(evaluate_reactivity(react_pred, react_true))

    return metrics
