### This file contains code for baseline GNNs

In [7]:
from torch_geometric.loader import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
from sklearn.metrics import roc_auc_score
import numpy as np

from loadBBBP import train_set, val_set, test_set, dataset

In [8]:
batch_size = 256
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False)


In [9]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128, num_layers=3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers-2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = 0.2
        self.lin = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = x.float()
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            g = global_mean_pool(x, batch)

        return self.lin(g).view(-1)

In [10]:
from torch_geometric.nn import GINEConv
class GINE(torch.nn.Module):
    def __init__(self, in_channels, edge_dim, hidden=128, layers=3):
        super().__init__()
        def mlp():
            return nn.Sequential(nn.Linear(in_channels if _==0 else hidden, hidden),
                                 nn.ReLU(), nn.Linear(hidden, hidden))
        self.convs = nn.ModuleList()
        for _ in range(layers):
            self.convs.append(GINEConv(mlp(), edge_dim=edge_dim))
        self.lin = nn.Linear(hidden, 1)
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = x.float()
        edge_attr = edge_attr.float()  # Convert edge attributes to float
        for conv in self.convs:
            x = conv(x, edge_index, edge_attr)
            x = F.relu(x)
        g = global_mean_pool(x, batch)
        return self.lin(g).view(-1)



In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = GCN(in_channels=dataset.num_features, hidden_channels=128, num_layers=3).to(device)

model = GINE(in_channels=dataset.num_features, edge_dim=dataset[0].edge_attr.size(-1)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss()

def do_epoch(loader, training: bool, model):
    if training:
        model.train()
    else:
        model.eval()

    all_logits, all_labels, total_loss, n = [], [], 0.0, 0
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        labels = batch.y.view(-1).float() 
        loss = criterion(logits, labels)

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * labels.size(0)
        n += labels.size(0)
        all_logits.append(logits.detach().cpu())
        all_labels.append(labels.detach().cpu())

    all_logits = torch.cat(all_logits).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Some splits can be all-one or all-zero; handle safely for ROC-AUC
    try:
        auc = roc_auc_score(all_labels, 1/(1+np.exp(-all_logits)))
    except ValueError:
        auc = float("nan")
    return total_loss / n, auc

best_val_auc, best_test_auc = -1, -1
epochs = 100
for epoch in range(1, epochs+1):
    train_loss, train_auc = do_epoch(train_loader, training=True, model=model)
    val_loss,   val_auc   = do_epoch(val_loader,   training=False, model=model)
    test_loss,  test_auc  = do_epoch(test_loader,  training=False, model=model)

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), "baselineGINE.pt")

    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:02d} | "
              f"Train loss {train_loss:.4f} AUC {train_auc:.3f} | "
              f"Val AUC {val_auc:.3f}")

print(f"\nBest Val AUC: {best_val_auc:.3f}")

ckpt = torch.load("baselineGINE.pt", map_location=device)
model.load_state_dict(ckpt)
_, test_auc = do_epoch(test_loader, training=False, model=model)
print(f"Test AUC (best-val checkpoint): {test_auc:.3f}")
 
# TODO: Address test val AUC gap
# Issue with scaffold split? Seems unlikely
# Issue with training? switch to cross-validation with scaffold-based splits?

Epoch 01 | Train loss 0.4829 AUC 0.571 | Val AUC 0.713
Epoch 05 | Train loss 0.4554 AUC 0.658 | Val AUC 0.808
Epoch 10 | Train loss 0.4131 AUC 0.735 | Val AUC 0.863
Epoch 15 | Train loss 0.3995 AUC 0.752 | Val AUC 0.879
Epoch 20 | Train loss 0.3898 AUC 0.768 | Val AUC 0.890
Epoch 25 | Train loss 0.3807 AUC 0.770 | Val AUC 0.885
Epoch 30 | Train loss 0.3716 AUC 0.786 | Val AUC 0.889
Epoch 35 | Train loss 0.3627 AUC 0.793 | Val AUC 0.894
Epoch 40 | Train loss 0.3510 AUC 0.806 | Val AUC 0.897
Epoch 45 | Train loss 0.3582 AUC 0.801 | Val AUC 0.908
Epoch 50 | Train loss 0.3522 AUC 0.810 | Val AUC 0.916
Epoch 55 | Train loss 0.3332 AUC 0.826 | Val AUC 0.906
Epoch 60 | Train loss 0.3227 AUC 0.841 | Val AUC 0.905
Epoch 65 | Train loss 0.3323 AUC 0.830 | Val AUC 0.884
Epoch 70 | Train loss 0.3159 AUC 0.853 | Val AUC 0.895
Epoch 75 | Train loss 0.3063 AUC 0.863 | Val AUC 0.911
Epoch 80 | Train loss 0.3304 AUC 0.830 | Val AUC 0.923
Epoch 85 | Train loss 0.3000 AUC 0.867 | Val AUC 0.913
Epoch 90 |