In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.5.3


In [None]:
import torch
import time
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch_geometric.nn import GATConv
from torch_geometric.datasets.mnist_superpixels import MNISTSuperpixels
from torch_geometric.data import Data
from torch.utils.data import DataLoader, random_split
import torchvision
from tqdm import tqdm

In [None]:
in_chls = 3
num_classes = 10
num_epochs = 2
learning_rate = 3e-3
batch_size = 64

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def build_collate_fn(device: str | torch.device):
    def collate_fn(original_batch: list[Data]):
        batch_node_features: list[torch.Tensor] = []
        batch_edge_indices: list[torch.Tensor] = []
        classes: list[int] = []

        for data in original_batch:
            node_features = torch.cat((data.x, data.pos), dim=-1).to(device)
            edge_indices = data.edge_index.to(device)
            class_ = int(data.y)
            batch_node_features.append(node_features)
            batch_edge_indices.append(edge_indices)
            classes.append(class_)

        collated = {"batch_node_features": batch_node_features, "batch_edge_indices": batch_edge_indices, "classes": torch.LongTensor(classes).to(device)}

        return collated
    return collate_fn

graph_dataset = MNISTSuperpixels(root="mnist-superpixels-dataset", train=False)

train_dataset = MNISTSuperpixels(root="mnist-superpixels-dataset", train=True)
test_dataset = MNISTSuperpixels(root="mnist-superpixels-dataset", train=False)
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.8, 0.2])

Downloading https://data.pyg.org/datasets/MNISTSuperpixels.zip
Extracting mnist-superpixels-dataset/raw/MNISTSuperpixels.zip
Processing...
Done!


In [None]:
gnn_train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=build_collate_fn(device=device))
gnn_val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=build_collate_fn(device=device))
gnn_test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=build_collate_fn(device=device))

# GNN Architecture

In [None]:
class GNN(nn.Module):
    def __init__(self, in_chls, hidden_dim, num_classes):
        super().__init__()
        self.in_channels = in_chls
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        self.conv1 = GATConv(in_chls, hidden_dim)
        self.conv2 = GATConv(hidden_dim, hidden_dim)
        self.conv3 = GATConv(hidden_dim, hidden_dim)
        self.conv4 = GATConv(hidden_dim, hidden_dim)
        self.l1 = nn.Linear(in_chls+4*hidden_dim, 256)
        self.l2 = nn.Linear(256, 128)
        self.l3 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU(True)

    def forward_one_base(self, node_features, edge_indices):
        assert node_features.ndim == 2 and node_features.shape[1] == self.in_channels
        assert edge_indices.ndim == 2 and edge_indices.shape[0] == 2

        s0 = node_features
        s1 = self.conv1(s0, edge_indices)
        s2 = self.conv2(s1, edge_indices)
        s3 = self.conv3(s2, edge_indices)
        s4 = self.conv4(s3, edge_indices)
        s0_s1_s2_s3_s4 = torch.cat((s0, s1, s2, s3, s4), dim=-1)
        return s0_s1_s2_s3_s4

    def forward(self, batch_node_features, batch_edge_indices):
        assert len(batch_node_features) == len(batch_edge_indices)
        features_list = []
        for node_features, edge_indices in zip(batch_node_features, batch_edge_indices):
            features_list.append(self.forward_one_base(node_features=node_features, edge_indices=edge_indices))

        features = torch.stack(features_list, dim=0)
        features = features.mean(dim=1)
        logits = nn.ReLU()(self.l1(features))
        logits = nn.ReLU()(self.l2(logits))
        logits = self.l3(logits)
        return logits

In [None]:
GNNmodel = GNN(in_chls, 152, num_classes).to(device)

# Training Loop GNN

In [None]:
def trainGNN(model, learning_rate, num_epochs):
    train_losses = []
    val_losses = []
    optimizer = optim.Adam(model.parameters(), learning_rate)
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()
    for epoch in range(num_epochs):
        model.train()
        tr_loss = 0.0
        for batch in gnn_train_loader:
            node_features, edge_features = batch['batch_node_features'], batch['batch_edge_indices']
            logits = model(node_features, edge_features)
            optimizer.zero_grad()
            loss = criterion(logits, batch['classes'])
            loss.backward()
            optimizer.step()
            tr_loss += loss.item()

        model.eval()
        with torch.no_grad():
            vl_loss = 0.0
            for batch in gnn_val_loader:
                node_features, edge_features = batch['batch_node_features'], batch['batch_edge_indices']
                logits = model(node_features, edge_features)
                loss = criterion(logits, batch['classes'])
                vl_loss += loss.item()

        train_losses.append(tr_loss/len(gnn_train_loader))
        val_losses.append(vl_loss/len(gnn_val_loader))
        print(f'epoch : {epoch+1}/{num_epochs} || train loss : {train_losses[-1]} || validation loss : {val_losses[-1]}')

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Training completed in {elapsed_time} seconds")

    return model, train_losses, val_losses, elapsed_time_gnn

In [None]:
trained_gnn_model, train_losses_gnn, val_losses_gnn, elapsed_time_gnn = trainGNN(GNNmodel, learning_rate, num_epochs)

epoch : 1/2 || train loss : 2.304510418256124 || validation loss : 2.3003661569128644
epoch : 2/2 || train loss : 2.3018108517328897 || validation loss : 2.300649809076431
Training completed in 850.387220621109 seconds


NameError: name 'elapsed_time_gnn' is not defined

# Hyperparameter Tuning for GNN

In [None]:
import itertools
learning_rates = [0.003, 0.001]
hidden_dims = [64, 128]
best_val_accuracy = 0.0
best_model_gnn = None

for learning_rate, hidden_dim in itertools.product(learning_rates, hidden_dims):

    model = GNN(3, hidden_dim, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    print(f"\nTraining with learning rate: {learning_rate}, hidden dimension: {hidden_dim}")

    start_time = time.time()
    for epoch in range(2):
        model.train()
        for batch in tqdm(gnn_train_loader, desc=f"Epoch {epoch+1} - Training"):
            node_features, edge_features = batch['batch_node_features'], batch['batch_edge_indices']
            optimizer.zero_grad()
            output = model(node_features, edge_features)
            loss = criterion(output, batch['classes'])
            loss.backward()
            optimizer.step()

    # Evaluate on validation set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(gnn_val_loader, desc=f"Epoch {epoch+1} - Validation"):
            node_features, edge_features = batch['batch_node_features'], batch['batch_edge_indices']
            output = model(node_features, edge_features)
            _, predicted = torch.max(output, 1)
            total += len(node_features)
            correct += (predicted == batch['classes']).sum().item()

    val_accuracy = correct / total

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        best_model_gnn = model

    print(f"Validation Accuracy: {val_accuracy}, Time: {time.time() - start_time}s")


In [None]:
plt.plot(train_losses_gnn, label='Training Loss')
plt.plot(val_losses_gnn, label='Validation Loss')
plt.title('Loss Curves GNN', weight='bold')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Dataset Split CNN

In [None]:
train_dataset_cnn = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset_cnn = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

In [None]:
train_size = int(0.8 * len(train_dataset_cnn))
val_size = len(train_dataset_cnn) - train_size
train_dataset_cnn, val_dataset_cnn = random_split(train_dataset_cnn, [train_size, val_size])

In [None]:
cnn_train_loader = DataLoader(train_dataset_cnn, batch_size=batch_size, shuffle=True)
cnn_val_loader = DataLoader(val_dataset_cnn, batch_size=batch_size, shuffle=False)
cnn_test_loader = DataLoader(test_dataset_cnn, batch_size=batch_size, shuffle=False)

# CNN Architecture

In [None]:
class CNN_block(nn.Module):
    def __init__(self, in_chls, out_chls, kernel_size):
        super(CNN_block, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_chls, out_chls, kernel_size),
            nn.BatchNorm2d(out_chls),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)

class CNN(nn.Module):
    def __init__(self, in_chls, num_classes):
        super(CNN, self).__init__()
        self.in_chls = in_chls
        self.num_classes = num_classes
        self.nblocks = nn.Sequential(
            CNN_block(in_chls, 16, 3),
            CNN_block(16, 32, 3),
            CNN_block(32, 64, 3),
            CNN_block(64, 128, 3)
        )

        self.fc = nn.Linear(128*20*20, num_classes)

    def forward(self, x):
        x = self.nblocks(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
CNNmodel = CNN(1, num_classes).to(device)

# Training Loop CNN

In [None]:
def trainCNN(model, learning_rate, num_epochs):
    train_losses_cnn = []
    val_losses_cnn = []
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()
    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        tr_loss = 0.0
        for batch in cnn_train_loader:
            img, label = batch
            img = img.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            logits = model(img)
            loss = criterion(logits, label)
            loss.backward()
            optimizer.step()
            tr_loss += loss.item()

        model.eval()
        with torch.no_grad():
            vl_loss = 0.0
            for batch in cnn_val_loader:
                img, label = batch
                img = img.to(device)
                label = label.to(device)
                logits = model(img)
                loss = criterion(logits, label)
                vl_loss += loss.item()

        train_losses_cnn.append(tr_loss/len(cnn_train_loader))
        val_losses_cnn.append(vl_loss/len(cnn_val_loader))
        print(f'epoch : {epoch+1}/{num_epochs} || train loss : {train_losses_cnn[-1]} || validation loss : {val_losses_cnn[-1]}')

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Training completed in {elapsed_time} seconds")

    return model, train_losses_cnn, val_losses_cnn, elapsed_time_cnn


In [None]:
trained_cnn, train_losses_cnn, val_losses_cnn, elapsed_time_cnn = trainCNN(CNNmodel, learning_rate, num_epochs)

# Vizualising Images in Graph Format

In [None]:
def create_superpixel_image(record, scale=30, edge_width=1):
    pos = (record.pos.clone() * scale).int()
    image = np.zeros((scale * 26, scale * 26, 1), dtype=np.uint8)

    # Draw rectangles for each superpixel
    for color, (x, y) in zip(record.x, pos):
        x0, y0 = int(x), int(y)
        x1, y1 = x0 - scale, y0 - scale
        color = min(int(float(color + 0.15) * 255), 255)
        cv2.rectangle(image, (x0, y0), (x1, y1), color, -1)

    # Draw edges between superpixels
    for node_ix_0, node_ix_1 in record.edge_index.T:
        x0, y0 = list(map(int, pos[node_ix_0]))
        x1, y1 = list(map(int, pos[node_ix_1]))
        x0 -= scale // 2
        y0 -= scale // 2
        x1 -= scale // 2
        y1 -= scale // 2
        cv2.line(image, (x0, y0), (x1, y1), 125, edge_width)

    return image

def visualize_superpixels(dataset, examples_per_class=5, classes=tuple(range(10)), figsize=(30, 50), edge_width=1):
    class_to_examples = {class_ix: [] for class_ix in classes}

    # Collect examples for each class
    for record in dataset:
        class_ix = int(record.y)
        if class_ix not in class_to_examples or len(class_to_examples[class_ix]) >= examples_per_class:
            continue
        class_to_examples[class_ix].append(create_superpixel_image(record, edge_width=edge_width))

    # Plot the collected examples
    plt.figure(figsize=figsize)
    for i, class_ix in enumerate(classes):
        for j, example in enumerate(class_to_examples[class_ix]):
            plt.subplot(len(classes), examples_per_class, i * examples_per_class + j + 1)
            plt.imshow(example, cmap=plt.cm.binary)
    plt.show()


In [None]:
visualize_superpixels(graph_dataset)

# Comparing Performance of GNN and CNN

In [None]:
def evaluate_gnn(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for batch in gnn_test_loader:
            node_features, edge_features = batch['batch_node_features'], batch['batch_edge_indices']
            logits = model(node_features, edge_features)
            _, predicted = torch.max(logits.data, 1)
            total += len(node_features)
            correct += (predicted == batch['classes']).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:
def evaluate_cnn(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:
gnn_accuracy = evaluate_gnn(best_model_gnn, gnn_test_loader, device)
print(f"The accuracy of GNN is : {gnn_accuracy}")

In [None]:
cnn_accuracy = evaluate_cnn(CNNmodel, cnn_test_loader, device)
print(f"The accuracy of CNN is : {cnn_accuracy}")