Name - Nidhi Shukla Roll no - 24AI60R47

Group- 17

In [10]:
# install required packages here
!pip3 install numpy
!pip3 install torch
!pip3 install networkx
!pip3 install matplotlib
!pip install torch_geometric



In [12]:

# Import required packages
import os
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn.pool import global_mean_pool
from torch_geometric.nn import GCNConv, TopKPooling


# Implementation of GCN Model
class GCN(nn.Module):
    '''
    Implementation of GCN [Kipf et al.] as Basic GNN module.
    '''
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        x = F.relu(x)
        return x


# Implementation of DownSamplePool Model
class DownSamplePool(nn.Module):
    '''
    Implementation of DownSample & Pool Module.
    '''
    def __init__(self, in_channels, out_channels, k):
        super(DownSamplePool, self).__init__()
        # Top-K pooling layer, which performs adaptive node selection
        self.pool = TopKPooling(in_channels, ratio=k)

    def forward(self, x, edge_index, batch):
        # Down-sampling important nodes
        x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch)
        return x, edge_index, batch


# Hierarchical Pooling Class
class HierarchicalPooling(nn.Module):
    '''
    Adds multiple down-sample pooling layers based on "m" parameter.
    '''
    def __init__(self, in_channels, hidden_channels, num_pools):
        super(HierarchicalPooling, self).__init__()
        self.pools = nn.ModuleList([DownSamplePool(in_channels, hidden_channels, k=0.5) for _ in range(num_pools)])

    def forward(self, x, edge_index, batch):
        for pool in self.pools:
            x, edge_index, batch = pool(x, edge_index, batch)
        return x, edge_index, batch


# Complete Model Architecture
class Model(nn.Module):
    '''
    Architecture of Overall Graph Classifier.
    '''
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, k1, k2, m1, m2):
        super(Model, self).__init__()

        # GNN layers before the first down-sample & pool
        self.gnn1 = GCN(in_channels, hidden_channels)
        self.gnn2 = GCN(hidden_channels, hidden_channels)

        # First Down-Sample & Pool
        self.pool1 = DownSamplePool(hidden_channels, hidden_channels, k=k1)

        # Hierarchical Pooling (after first pool)
        self.hier_pool1 = HierarchicalPooling(hidden_channels, hidden_channels, num_pools=m1)

        # GNN layers before the second down-sample & pool
        self.gnn3 = GCN(hidden_channels, hidden_channels)
        self.gnn4 = GCN(hidden_channels, hidden_channels)

        # Second Down-Sample & Pool
        self.pool2 = DownSamplePool(hidden_channels, hidden_channels, k=k2)

        # Hierarchical Pooling (after second pool)
        self.hier_pool2 = HierarchicalPooling(hidden_channels, hidden_channels, num_pools=m2)

        # Final Classification Head
        self.fc = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch):
        # GNN Block 1
        x = self.gnn1(x, edge_index)
        x = self.gnn2(x, edge_index)

        # Down-Sample & Pool 1
        x, edge_index, batch = self.pool1(x, edge_index, batch)

        # Hierarchical Pooling 1
        x, edge_index, batch = self.hier_pool1(x, edge_index, batch)

        # GNN Block 2
        x = self.gnn3(x, edge_index)
        x = self.gnn4(x, edge_index)

        # Down-Sample & Pool 2
        x, edge_index, batch = self.pool2(x, edge_index, batch)

        # Hierarchical Pooling 2
        x, edge_index, batch = self.hier_pool2(x, edge_index, batch)

        # Global Pooling (Graph Level Readout)
        x = global_mean_pool(x, batch)

        # Final classification
        x = self.fc(x)

        return F.log_softmax(x, dim=1)


# Utility functions for dataset loading and training
def load_dataset(name, split_ratio=(0.8, 0.1, 0.1)):
    dataset = TUDataset(root='/tmp/' + name, name=name)
    train_size = int(split_ratio[0] * len(dataset))
    val_size = int(split_ratio[1] * len(dataset))
    test_size = len(dataset) - train_size - val_size
    return random_split(dataset, [train_size, val_size, test_size])


def get_data_loaders(dataset_splits, batch_size=32):
    train_loader = DataLoader(dataset_splits[0], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(dataset_splits[1], batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset_splits[2], batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader


# Accuracy calculation function
def calculate_accuracy(loader, model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
            out = model(x, edge_index, batch_idx)
            _, predicted = out.max(1)
            total += batch.y.size(0)
            correct += (predicted == batch.y).sum().item()
    accuracy = (correct / total) * 100
    return accuracy


# Main training function
def main(dataset_name, num_classes, epochs, k1, k2, m1, m2):
    # Hyperparameters
    # Get the number of features from the dataset
    dataset = TUDataset(root='/tmp/' + dataset_name, name=dataset_name)
    in_channels = dataset.num_node_features  # Get the actual number of features
    hidden_channels = 64
    out_channels = 64
    learning_rate = 0.001

    # Load and prepare dataset
    dataset_splits = load_dataset(dataset_name)
    train_loader, val_loader, test_loader = get_data_loaders(dataset_splits)

    # Initialize model, loss, optimizer
    model = Model(in_channels, hidden_channels, out_channels, num_classes, k1=k1, k2=k2, m1=m1, m2=m2)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        t = 0
        for batch in train_loader:
            x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
            optimizer.zero_grad()
            out = model(x, edge_index, batch_idx)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            t += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {t/10:.4f}')

        # Validation accuracy
        val_accuracy = calculate_accuracy(val_loader, model)
        print(f'Validation Accuracy: {val_accuracy:.4f}')

    # Test accuracy
    test_accuracy = calculate_accuracy(test_loader, model)
    print(f'Test Accuracy: {test_accuracy:.4f}')
    return test_accuracy


# Experiment configurations
if __name__ == "__main__":
    # Experiment configurations
    k1_values = [0.9, 0.8, 0.6]  # First downsampling ratio
    k2_values = [0.9, 0.8, 0.6]  # Second downsampling ratio
    m1_values = [6, 3]  # Number of hierarchical pooling layers after first down-sample
    m2_values = [6, 3]  # Number of hierarchical pooling layers after second down-sample

    # Dataset configurations
    datasets = {
        'DD': 2,  # Binary Classification
        'ENZYMES': 6  # 6-Class Classification
    }

    results = {}

    # Running experiments for both datasets
    for dataset_name, num_classes in datasets.items():
        print(f"\nTraining on {dataset_name} Dataset...")
        for k1 in k1_values:
            for k2 in k2_values:
                for m1 in m1_values:
                    for m2 in m2_values:
                        print(f"Running with k1={k1}, k2={k2}, m1={m1}, m2={m2}")
                        test_accuracy = main(dataset_name, num_classes, epochs=10, k1=k1, k2=k2, m1=m1, m2=m2)
                        results[f'{dataset_name}_k1={k1}_k2={k2}_m1={m1}_m2={m2}'] = test_accuracy

    # Print all results
    print("\nFinal Results:")
    for config, accuracy in results.items():
        print(f"{config}: Test Accuracy = {accuracy:.4f}")



Training on DD Dataset...
Running with k1=0.9, k2=0.9, m1=6, m2=6


Downloading https://www.chrsmrrs.com/graphkerneldatasets/DD.zip
Processing...
Done!


Epoch 1/10, Loss: 2.0304
Validation Accuracy: 54.7009
Epoch 2/10, Loss: 2.0260
Validation Accuracy: 54.7009
Epoch 3/10, Loss: 2.0223
Validation Accuracy: 54.7009
Epoch 4/10, Loss: 2.0252
Validation Accuracy: 54.7009
Epoch 5/10, Loss: 2.0188
Validation Accuracy: 54.7009
Epoch 6/10, Loss: 2.0176
Validation Accuracy: 54.7009
Epoch 7/10, Loss: 2.0165
Validation Accuracy: 54.7009
Epoch 8/10, Loss: 2.0202
Validation Accuracy: 54.7009
Epoch 9/10, Loss: 2.0183
Validation Accuracy: 54.7009
Epoch 10/10, Loss: 2.0166
Validation Accuracy: 54.7009
Test Accuracy: 50.4202
Running with k1=0.9, k2=0.9, m1=6, m2=3
Epoch 1/10, Loss: 2.0483
Validation Accuracy: 56.4103
Epoch 2/10, Loss: 2.0447
Validation Accuracy: 56.4103
Epoch 3/10, Loss: 2.0416
Validation Accuracy: 56.4103
Epoch 4/10, Loss: 2.0415
Validation Accuracy: 56.4103
Epoch 5/10, Loss: 2.0411
Validation Accuracy: 56.4103
Epoch 6/10, Loss: 2.0370
Validation Accuracy: 56.4103
Epoch 7/10, Loss: 2.0396
Validation Accuracy: 56.4103
Epoch 8/10, Loss: 

Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!


Epoch 1/10, Loss: 2.6927
Validation Accuracy: 23.3333
Epoch 2/10, Loss: 2.6920
Validation Accuracy: 23.3333
Epoch 3/10, Loss: 2.6917
Validation Accuracy: 23.3333
Epoch 4/10, Loss: 2.6913
Validation Accuracy: 23.3333
Epoch 5/10, Loss: 2.6910
Validation Accuracy: 23.3333
Epoch 6/10, Loss: 2.6906
Validation Accuracy: 23.3333
Epoch 7/10, Loss: 2.6904
Validation Accuracy: 23.3333
Epoch 8/10, Loss: 2.6901
Validation Accuracy: 23.3333
Epoch 9/10, Loss: 2.6898
Validation Accuracy: 23.3333
Epoch 10/10, Loss: 2.6895
Validation Accuracy: 23.3333
Test Accuracy: 15.0000
Running with k1=0.9, k2=0.9, m1=6, m2=3
Epoch 1/10, Loss: 2.6892
Validation Accuracy: 10.0000
Epoch 2/10, Loss: 2.6889
Validation Accuracy: 10.0000
Epoch 3/10, Loss: 2.6888
Validation Accuracy: 10.0000
Epoch 4/10, Loss: 2.6886
Validation Accuracy: 10.0000
Epoch 5/10, Loss: 2.6885
Validation Accuracy: 10.0000
Epoch 6/10, Loss: 2.6885
Validation Accuracy: 10.0000
Epoch 7/10, Loss: 2.6883
Validation Accuracy: 10.0000
Epoch 8/10, Loss: 