### This file contains code for baseline GNNs

In [4]:
from torch_geometric.loader import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool, global_mean_pool
from sklearn.metrics import roc_auc_score
import numpy as np
from torch_geometric.nn import AttentiveFP

from loadBBBP import train_set, val_set, test_set, dataset

In [5]:
batch_size = 256
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False)


In [17]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=128, num_layers=3, dropout=0.2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers-2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, hidden_channels))

        self.dropout = dropout
        self.lin = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = x.float()
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            g = global_mean_pool(x, batch)

        return self.lin(g).view(-1)

In [18]:
from torch_geometric.nn import GINEConv
class GINE(torch.nn.Module):
    def __init__(self, in_channels, edge_dim, hidden=128, layers=3):
        super().__init__()
        def mlp():
            return nn.Sequential(nn.Linear(in_channels if _==0 else hidden, hidden),
                                 nn.ReLU(), nn.Linear(hidden, hidden))
        self.convs = nn.ModuleList()
        for _ in range(layers):
            self.convs.append(GINEConv(mlp(), edge_dim=edge_dim))
        self.lin = nn.Linear(hidden, 1)
    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = x.float()
        edge_attr = edge_attr.float()
        for conv in self.convs:
            x = conv(x, edge_index, edge_attr)
            x = F.relu(x)
        g = global_mean_pool(x, batch)
        return self.lin(g).view(-1)




In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(in_channels=dataset.num_features, hidden_channels=128, num_layers=3).to(device)
# model = GINE(in_channels=dataset.num_features, edge_dim=dataset[0].edge_attr.size(-1)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss()

def do_epoch(loader, training: bool, model):
    if training:
        model.train()
    else:
        model.eval()

    all_logits, all_labels, total_loss, n = [], [], 0.0, 0
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch)
        
        logits = logits.view(-1)
        labels = batch.y.view(-1).float() 
        loss = criterion(logits, labels)

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        total_loss += loss.item() * labels.size(0)
        n += labels.size(0)
        all_logits.append(logits.detach().cpu())
        all_labels.append(labels.detach().cpu())

    all_logits = torch.cat(all_logits).numpy()
    all_labels = torch.cat(all_labels).numpy()

    # Some splits can be all-one or all-zero; handle safely for ROC-AUC
    try:
        auc = roc_auc_score(all_labels, 1/(1+np.exp(-all_logits)))
    except ValueError:
        auc = float("nan")
    return total_loss / n, auc

best_val_auc, best_test_auc = -1, -1
epochs = 100
for epoch in range(1, epochs+1):
    train_loss, train_auc = do_epoch(train_loader, training=True, model=model)
    val_loss,   val_auc   = do_epoch(val_loader,   training=False, model=model)
    test_loss,  test_auc  = do_epoch(test_loader,  training=False, model=model)

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), "baselineGINE.pt")

    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:02d} | "
              f"Train loss {train_loss:.4f} AUC {train_auc:.3f} | "
              f"Val AUC {val_auc:.3f}")

print(f"\nBest Val AUC: {best_val_auc:.3f}")

ckpt = torch.load("baselineGINE.pt", map_location=device)
model.load_state_dict(ckpt)
_, test_auc = do_epoch(test_loader, training=False, model=model)
print(f"Test AUC (best-val checkpoint): {test_auc:.3f}")
 
# TODO: Address test val AUC gap
# Issue with scaffold split? Seems unlikely
# Issue with training? switch to cross-validation with scaffold-based splits?

Epoch 01 | Train loss 0.5884 AUC 0.482 | Val AUC 0.611
Epoch 05 | Train loss 0.4684 AUC 0.530 | Val AUC 0.761
Epoch 10 | Train loss 0.4647 AUC 0.589 | Val AUC 0.860
Epoch 15 | Train loss 0.4498 AUC 0.674 | Val AUC 0.863
Epoch 20 | Train loss 0.4240 AUC 0.724 | Val AUC 0.858
Epoch 25 | Train loss 0.4128 AUC 0.744 | Val AUC 0.866
Epoch 30 | Train loss 0.4073 AUC 0.749 | Val AUC 0.876
Epoch 35 | Train loss 0.3959 AUC 0.771 | Val AUC 0.882
Epoch 40 | Train loss 0.3948 AUC 0.777 | Val AUC 0.880
Epoch 45 | Train loss 0.3862 AUC 0.790 | Val AUC 0.889
Epoch 50 | Train loss 0.3832 AUC 0.790 | Val AUC 0.889
Epoch 55 | Train loss 0.3739 AUC 0.803 | Val AUC 0.891
Epoch 60 | Train loss 0.3708 AUC 0.807 | Val AUC 0.900
Epoch 65 | Train loss 0.3713 AUC 0.804 | Val AUC 0.895
Epoch 70 | Train loss 0.3623 AUC 0.814 | Val AUC 0.901
Epoch 75 | Train loss 0.3649 AUC 0.809 | Val AUC 0.897
Epoch 80 | Train loss 0.3546 AUC 0.822 | Val AUC 0.907
Epoch 85 | Train loss 0.3683 AUC 0.802 | Val AUC 0.903
Epoch 90 |

In [30]:
param_grid = {
    'hidden_channels': [64, 128, 256],
    'num_layers': [2, 3, 4, 5],
    'dropout': [0.1],
    'learning_rate': [1e-2, 1e-3, 1e-4],
    'weight_decay': [1e-5, 1e-6, 1e-7]
}

In [31]:
import itertools, json, time
from copy import deepcopy

# Reuse loaders and dataset from previous cells

def train_one_config(config, max_epochs=60, patience=10, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    model = GCN(
        in_channels=dataset.num_features,
        hidden_channels=config['hidden_channels'],
        num_layers=config['num_layers'],
        dropout=config['dropout']
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']
    )
    criterion = nn.BCEWithLogitsLoss()

    def do_epoch(loader, training: bool):
        if training:
            model.train()
        else:
            model.eval()
        all_logits, all_labels, total_loss, n = [], [], 0.0, 0
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch).view(-1)
            labels = batch.y.view(-1).float()
            loss = criterion(logits, labels)
            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            total_loss += loss.item() * labels.size(0)
            n += labels.size(0)
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())
        all_logits = torch.cat(all_logits).numpy()
        all_labels = torch.cat(all_labels).numpy()
        try:
            auc = roc_auc_score(all_labels, 1/(1+np.exp(-all_logits)))
        except ValueError:
            auc = float("nan")
        return total_loss / n, auc

    best_val_auc = -1.0
    best_state = None
    best_epoch = -1
    epochs_without_improvement = 0

    for epoch in range(1, max_epochs+1):
        train_loss, train_auc = do_epoch(train_loader, training=True)
        val_loss,   val_auc   = do_epoch(val_loader,   training=False)
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_state = deepcopy(model.state_dict())
            best_epoch = epoch
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            break
    # Load best state and compute test AUC
    model.load_state_dict(best_state)
    _, test_auc = do_epoch(test_loader, training=False)
    return {
        'best_val_auc': float(best_val_auc),
        'best_epoch': int(best_epoch),
        'test_auc_at_best': float(test_auc),
        'state_dict': best_state
    }


# Grid search
keys = list(param_grid.keys())
all_results = []
best_overall = None
best_auc = -1.0
start_time = time.time()

for values in itertools.product(*(param_grid[k] for k in keys)):
    cfg = {k: v for k, v in zip(keys, values)}
    print(f"\nConfig: {cfg}")
    result = train_one_config(cfg, max_epochs=80, patience=12)
    rec = {**cfg, 'best_val_auc': result['best_val_auc'], 'best_epoch': result['best_epoch'], 'test_auc_at_best': result['test_auc_at_best']}
    all_results.append(rec)
    if result['best_val_auc'] > best_auc:
        best_auc = result['best_val_auc']
        best_overall = {'config': cfg, **result}
        torch.save(result['state_dict'], 'best_gcn_grid.pt')
        print(f"New best AUC: {best_auc:.4f} with cfg={cfg}")

# Persist results
with open('grid_results.json', 'w') as f:
    json.dump(all_results, f, indent=2)

import pandas as pd
pd.DataFrame(all_results).to_csv('grid_results.csv', index=False)

print(f"\nBest config: {best_overall['config']} | Val AUC {best_overall['best_val_auc']:.4f} | Test AUC {best_overall['test_auc_at_best']:.4f}")
print(f"Total time: {time.time() - start_time:.1f}s for {len(all_results)} runs")



Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.01, 'weight_decay': 1e-05}
New best AUC: 0.9251 with cfg={'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.01, 'weight_decay': 1e-05}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.01, 'weight_decay': 1e-06}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.01, 'weight_decay': 1e-07}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.001, 'weight_decay': 1e-05}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.001, 'weight_decay': 1e-06}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.001, 'weight_decay': 1e-07}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_rate': 0.0001, 'weight_decay': 1e-05}

Config: {'hidden_channels': 64, 'num_layers': 2, 'dropout': 0.1, 'learning_ra