In [1]:


import os
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_add_pool



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Reproducibility helpers

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)



Device: cuda


In [3]:

# Load MUTAG

dataset = TUDataset(
    root="/data/b23_chiranjeevi/GIN_MUTAG/data/TUDataset",
    name="MUTAG"
)


if dataset.num_node_features == 0:
    raise ValueError("Dataset has no node features ")

# Shuffle and split (80/10/10)
dataset = dataset.shuffle()
n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
train_ds = dataset[:n_train]
val_ds = dataset[n_train:n_train + n_val]
test_ds = dataset[n_train + n_val:]

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

print(dataset)
print("Num graphs:", len(dataset))
print("Num node features:", dataset.num_node_features)
print("Num classes:", dataset.num_classes)



MUTAG(188)
Num graphs: 188
Num node features: 7
Num classes: 2


In [None]:
# GIN model

class GIN(nn.Module):

    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int = 3, dropout: float = 0.5):
        super().__init__()
        assert num_layers >= 2, "Use at least 2 layers for GIN."

        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        # First GIN layer
        mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.convs.append(GINConv(mlp))
        self.bns.append(nn.BatchNorm1d(hidden_dim))

        # Hidden GIN layers
        for _ in range(num_layers - 1):
            mlp = nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
            )
            self.convs.append(GINConv(mlp))
            self.bns.append(nn.BatchNorm1d(hidden_dim))

        self.dropout = dropout
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x, edge_index, batch):
        # Node updates
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Graph pooling (graph-level embedding)
        g = global_add_pool(x, batch)

        # Classifier
        out = self.classifier(g)
        return out

model = GIN(
    in_dim=dataset.num_node_features,
    hidden_dim=64,
    out_dim=dataset.num_classes,
    num_layers=3,
    dropout=0.5,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)



In [None]:
# Train / Eval loops

@torch.no_grad()
def evaluate(loader: DataLoader):
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    for data in loader:
        data = data.to(device)
        logits = model(data.x.float(), data.edge_index, data.batch)

        # MUTAG labels are shape [batch_size] or [batch_size, 1] depending on dataset formatting
        y = data.y.view(-1).long()
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=-1)
        correct += int((pred == y).sum().item())
        total += y.numel()
        loss_sum += float(loss.item()) * y.numel()

    return loss_sum / max(total, 1), correct / max(total, 1)

def train_one_epoch(loader: DataLoader):
    model.train()
    total = 0
    loss_sum = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        logits = model(data.x.float(), data.edge_index, data.batch)
        y = data.y.view(-1).long()
        loss = F.cross_entropy(logits, y)

        loss.backward()
        optimizer.step()

        total += y.numel()
        loss_sum += float(loss.item()) * y.numel()

    return loss_sum / max(total, 1)



In [6]:

# Training

best_val_acc = 0.0
best_state = None

epochs = 100
for epoch in range(1, epochs + 1):
    train_loss = train_one_epoch(train_loader)
    val_loss, val_acc = evaluate(val_loader)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")

# Load best model 
if best_state is not None:
    model.load_state_dict(best_state)

test_loss, test_acc = evaluate(test_loader)
print(f"\nBest val_acc={best_val_acc:.4f}")
print(f"Test: loss={test_loss:.4f} | acc={test_acc:.4f}")


Epoch 001 | train_loss=2.1231 | val_loss=0.9108 | val_acc=0.5000
Epoch 010 | train_loss=0.5635 | val_loss=0.4107 | val_acc=0.7222
Epoch 020 | train_loss=0.3513 | val_loss=0.4434 | val_acc=0.8889
Epoch 030 | train_loss=0.3616 | val_loss=0.4631 | val_acc=0.8333
Epoch 040 | train_loss=0.3742 | val_loss=0.4431 | val_acc=0.8333
Epoch 050 | train_loss=0.3726 | val_loss=0.4302 | val_acc=0.8333
Epoch 060 | train_loss=0.3332 | val_loss=0.5142 | val_acc=0.7778
Epoch 070 | train_loss=0.3015 | val_loss=0.4217 | val_acc=0.8889
Epoch 080 | train_loss=0.3215 | val_loss=0.4977 | val_acc=0.6667
Epoch 090 | train_loss=0.3656 | val_loss=0.4559 | val_acc=0.7222
Epoch 100 | train_loss=0.2955 | val_loss=0.5220 | val_acc=0.8889

Best val_acc=0.9444
Test: loss=0.4062 | acc=0.8000
