## Coding assignment: Graph Neural Networks (GNN)

Graph structures are ubiquitous in various domains, from social networks to molecular interactions. Understanding these complex relationships requires advanced analytical tools, and Graph Neural Networks (GNNs) provide a powerful framework for extracting meaningful insights from graph-structured data.

In this assignment, you will:

- Gain a foundational understanding of GNNs and their underlying principles,
- Explore their applications in graph analysis, and
- Implement GNNs using state-of-the-art deep learning frameworks.
-
By the end of this assignment, you will have acquired both theoretical knowledge and hands-on experience in applying GNNs to real-world graph data.

## Environment Setup
For a seamless execution of this notebook, ensure your Python environment is properly set up. Here's what you'll need:

Python Version: We recommend using Python 3.8 or higher.

Required Packages: Install the following libraries to delve into GNNs:

```
torch
torch_geometric
torch_scatter
torch_sparse
torchmetrics
networkx
numpy
jupyter
```

In [4]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.6.0+cu124


In this assignment, we will utilize the **CLUSTER** dataset from the GNNBenchmarkDataset, as introduced in the paper Benchmarking Graph Neural Networks. This dataset is a subset of the Stochastic Block Model (SBM) datasets, which focus on node-level graph pattern recognition tasks, originally explored by Scarselli et al. (2009). Specifically, it addresses the following tasks:
- Graph Pattern Recognition (PATTERN)
- Semi-supervised Graph Clustering (CLUSTER)
The Stochastic Block Model (SBM), as described by Abbe (2017), serves as the foundation of these datasets. SBM is widely used for modeling community structures in social networks, where intra- and inter-community connections are probabilistically controlled. In particular:

- Two nodes within the same community are connected with probability p.
- Nodes belonging to different communities are connected with probability q, which acts as a noise parameter.

The CLUSTER dataset has the following properties:

- Each node is represented by a 7-dimensional feature vector.
- The dataset contains 6 distinct node classes.
- The primary learning objective is multi-class classification at the node level.

This dataset provides a challenging yet insightful benchmark for evaluating the performance of Graph Neural Networks in community detection and clustering tasks. We will start by loading the dataset and take a look at the graphs in the dataset.

In [6]:
import torch
import torch.nn as nn
from torch_geometric.datasets import GNNBenchmarkDataset
from torch_geometric.loader import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# Load the Cluster dataset
data_root = './data'
dataset_train = GNNBenchmarkDataset(root=data_root, name='CLUSTER', split='train')
dataset_val = GNNBenchmarkDataset(root=data_root, name='CLUSTER', split='val')

# Print dataset statistics
print(f'Training dataset size: {len(dataset_train)}')
print(f'Validation dataset size: {len(dataset_val)}')

# Print a sample graph's details
def print_graph_info(graph, index):
    print(f'Graph {index}:')
    print(f'  - Number of nodes: {graph.num_nodes}')
    print(f'  - Number of edges: {graph.num_edges}')
    print(f'  - Node features shape: {graph.x.shape}')
    print(f'  - Edge index shape: {graph.edge_index.shape}')
    print(f'  - Labels shape: {graph.y.shape}')
    print('-' * 40)

# Display information about the first few graphs
for i in range(min(1, len(dataset_train))):
    print_graph_info(dataset_train[i], i)


Training dataset size: 10000
Validation dataset size: 1000
Graph 0:
  - Number of nodes: 117
  - Number of edges: 4104
  - Node features shape: torch.Size([117, 7])
  - Edge index shape: torch.Size([2, 4104])
  - Labels shape: torch.Size([117])
----------------------------------------


Then we use torch_geometric's data loader class to wrap the dataset to batches:

In [7]:
# Create DataLoaders
batch_size = 32
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

# Display batched graph information
for batch in dataloader_train:
    print('Batched graph:')
    print(f'  - Batch size: {batch_size}')
    print(f'  - Total nodes in batch: {batch.x.shape[0]}')
    print(f'  - Total edges in batch: {batch.edge_index.shape[1]}')
    print(f'  - Labels shape: {batch.y.shape}')
    break  # Only show one batch

Batched graph:
  - Batch size: 32
  - Total nodes in batch: 3561
  - Total edges in batch: 123310
  - Labels shape: torch.Size([3561])


## Task 1: MLP for node classification

To start with, we will build an MLP model as a baseline. The MLP should directly take the node features as input and output the predictions for each class. In this task, you need to complete the MLP model and the training function.

In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

class MLPNodeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        #####
        # TODO:
        super(MLPNodeClassifier, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        #####

    def forward(self, x):
        #####
        # TODO:
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
        #####

Using device: cuda:0


The training and evaluation functions for the MLP model:

In [10]:
# Evaluation function
def evaluate(model, dataloader, device=device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            outputs = model(batch.x)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)
            #####
    # return the accuracy
    return correct / total if total > 0 else 0

# Training function
def train(model, dataloader_train, dataloader_val, epochs=50, lr=0.001, patience=5, device=device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    best_val_acc = 0
    patience_counter = 0

    model.to(device)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0 # number of correct predicted nodes
        total = 0 # number of nodes

        for batch in dataloader_train:
            #####
            # TODO:
            batch = batch.to(device)
            optimizer.zero_grad()
            outputs = model(batch.x)
            loss = criterion(outputs, batch.y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch.y.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)
            #####

        train_acc = correct / total if total > 0 else 0
        val_acc = evaluate(model, dataloader_val, device)
        print(f'Epoch {epoch+1}: Loss={total_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            best_ckpt = model.state_dict()
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print('Early stopping triggered')
            break
    return best_ckpt

We will train the MLP for the node classification task. There are several hyperparameters that might need your attention, including the batch size, hidden dimension of the MLP, number of epochs, learning rate, and early stop patience.

In [11]:
# Initialize model
input_dim = dataset_train.num_node_features
hidden_dim = 32
output_dim = dataset_train.num_classes

model = MLPNodeClassifier(input_dim, hidden_dim, output_dim)
print(model)

# Train the model
best_ckpt = train(model, dataloader_train, dataloader_val, epochs=5, lr=1e-2, patience=7)

MLPNodeClassifier(
  (fc1): Linear(in_features=7, out_features=32, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=32, out_features=6, bias=True)
)
Epoch 1: Loss=2005784.9551, Train Acc=0.2088, Val Acc=0.2133
Epoch 2: Loss=1995162.7075, Train Acc=0.2095, Val Acc=0.2096
Epoch 3: Loss=1994486.4846, Train Acc=0.2096, Val Acc=0.2096
Epoch 4: Loss=1993541.8384, Train Acc=0.2093, Val Acc=0.2133
Epoch 5: Loss=1993437.9511, Train Acc=0.2080, Val Acc=0.2072


After the training, we will make predictions on the test set and save the prediction results. The saved predictions will be used for grading.

In [12]:
dataset_test = GNNBenchmarkDataset(root=data_root, name='CLUSTER', split='test')
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

# Test set predictions and save results
def predict(model, dataloader, filename='predictions.txt', device=device):
    model.eval()
    predictions = [] # a list of predicted labels
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            outputs = model(batch.x)
            preds = outputs.argmax(dim=1)
            predictions.extend(preds.cpu().tolist())
            #####

    return predictions

model.load_state_dict(best_ckpt)
predictions = predict(model, dataloader_test)
np.savetxt('predictions_mlp_cluster.txt', predictions, fmt='%d')

## Task 2: GCN for node classification

Next we will leverage the graph convolutional layers to construct a GNN model for the node classification task. You can check pyg's [documentation](https://pytorch-geometric.readthedocs.io/en/latest/) to learn the usage of the `GCNConv` module and use it to build a GNN model. In the following, you need to complete the `GNNNodeClassifier` class as well as the training function.

In [13]:
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
# Define a GNN model using GCNConv
class GNNNodeClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_gcn_layers=2):
        #####
        # TODO:
        super(GNNNodeClassifier, self).__init__()
        dims = [input_dim] + [hidden_dim] * (num_gcn_layers - 1) + [output_dim]
        self.convs = nn.ModuleList(
            [GCNConv(dims[i], dims[i + 1]) for i in range(len(dims) - 1)]
        )
        #####

    def forward(self, x, edge_index):
        #####
        # TODO:
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1:
                x = F.relu(x)
        return x
        #####

In [14]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Training function
def train(model, dataloader_train, dataloader_val, epochs=50, lr=0.001, patience=5, device=device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in dataloader_train:
            #####
            # TODO:
            batch = batch.to(device)
            optimizer.zero_grad()
            outputs = model(batch.x, batch.edge_index)
            loss = criterion(outputs, batch.y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch.y.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)

            #####

        train_acc = correct / total
        val_acc = evaluate(model, dataloader_val, device)
        scheduler.step(val_acc)
        print(f'Epoch {epoch+1}: Loss={total_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print('Early stopping triggered')
            break

# Evaluation function
def evaluate(model, dataloader, device=device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            outputs = model(batch.x, batch.edge_index)
            preds = outputs.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)
            #####
    return correct / total

Then we can train the GCN model for the node classification. You should carefully choose the hyperparameters: batch size, hidden dimension, number of GCN layers, number of epochs, learning rate, and early stop patience.

In [15]:
# Initialize model
input_dim = dataset_train.num_node_features
hidden_dim = 128
output_dim = dataset_train.num_classes

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = GNNNodeClassifier(input_dim, hidden_dim, output_dim, num_gcn_layers=2)
print(model)
print(f'Number of trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
model.to(device)

# Train the model
train(model, dataloader_train, dataloader_val, epochs=5, lr=1e-2, patience=3, device=device)

GNNNodeClassifier(
  (convs): ModuleList(
    (0): GCNConv(7, 128)
    (1): GCNConv(128, 6)
  )
)
Number of trainable params: 1798




Epoch 1: Loss=2008820.6207, Train Acc=0.3038, Val Acc=0.4442
Epoch 2: Loss=1759370.1928, Train Acc=0.4398, Val Acc=0.4139
Epoch 3: Loss=1682736.5758, Train Acc=0.4541, Val Acc=0.4610
Epoch 4: Loss=1660280.6751, Train Acc=0.4581, Val Acc=0.4562
Epoch 5: Loss=1652592.4957, Train Acc=0.4595, Val Acc=0.4603


After training, we can make predictions on the test set and save the results. The results will be used for the grading of the task.

In [16]:
# Test set predictions and save results
def predict(model, dataloader, filename='gnn_predictions.txt', device=device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            outputs = model(batch.x, batch.edge_index)
            preds = outputs.argmax(dim=1)
            predictions.extend(preds.cpu().tolist())
            #####

    return predictions


predictions = predict(model, dataloader_test)
np.savetxt('predictions_gcn_cluster.txt', predictions, fmt='%d')

## Graph classification task

In this section, we explore graph classification using Graph Neural Networks (GNNs). Unlike node classification, which focuses on predicting labels for individual nodes within a graph, graph classification aims to assign labels to entire graphs based on their structural and feature-based attributes. The primary challenge lies in effectively embedding entire graphs into a feature space where they become linearly separable for classification tasks.

One notable application of graph classification is the representation of image data as graphs, an approach demonstrated in super-pixel datasets. These datasets provide a novel way to transform traditional image classification tasks into graph learning problems. Prominent image datasets such as MNIST and CIFAR10 have been adapted into graph structures using this methodology. The motivation for utilizing these datasets is twofold:

- Benchmarking and Sanity-Checking – These datasets serve as standard benchmarks for evaluating the performance of GNN architectures. Most GNN models are expected to achieve near-perfect accuracy on MNIST and competitive performance on CIFAR10.
- Extending Image-Based Learning to Graphs – Super-pixel representations offer valuable insights into how conventional image datasets can be leveraged for graph-based learning and analysis.

### CIFAR10 Super-Pixel Dataset
In this assignment, we will work with the CIFAR10 super-pixel dataset for a graph classification task. The CIFAR10 images are transformed into graphs using super-pixel segmentation, where each super-pixel represents a small, homogeneous region of the image. This transformation is performed using the Simple Linear Iterative Clustering (SLIC) algorithm, introduced by Achanta et al. (2012).

By leveraging super-pixel representations, we can analyze the effectiveness of GNNs in graph classification while drawing connections between traditional computer vision tasks and graph-based learning techniques.

In [17]:
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool

# Load the CIFAR10 dataset
data_root = './data/GNNBenchmark'
dataset_train = GNNBenchmarkDataset(root=data_root, name='CIFAR10', split='train')
dataset_val = GNNBenchmarkDataset(root=data_root, name='CIFAR10', split='val')

print(f'Training dataset size: {len(dataset_train)}')
print(f'Validation dataset size: {len(dataset_val)}')
print_graph_info(dataset_train[0], 0)

Downloading https://data.pyg.org/datasets/benchmarking-gnns/CIFAR10_v2.zip
Extracting data/GNNBenchmark/CIFAR10/raw/CIFAR10_v2.zip
Processing...
Done!


Training dataset size: 45000
Validation dataset size: 5000
Graph 0:
  - Number of nodes: 110
  - Number of edges: 880
  - Node features shape: torch.Size([110, 3])
  - Edge index shape: torch.Size([2, 880])
  - Labels shape: torch.Size([1])
----------------------------------------


## Task 3: GCN for graph classification

In this task, you need to build a GNN model using the `GCNConv` module for the graph classification task. Note that in order to do graph classification, you need to get the graph embedding by pooling the node embeddings. You can refer to pyg's [documentation](https://pytorch-geometric.readthedocs.io/en/latest/) for different pooling functions (e.g., mean pooling, max pooling, sum pooling).

In [18]:
from torch_geometric.nn import global_mean_pool

# Define a GCN model for graph classification
class GCNGraphClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_gcn_layers=2, embedding_dim=512):
        #####
        # TODO:
        super(GCNGraphClassifier, self).__init__()
        self.embedding = nn.Sequential(
            nn.Linear(input_dim, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(embedding_dim, hidden_dim))
        self.batch_norms = nn.ModuleList()
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        for _ in range(num_gcn_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, output_dim)
        )
        #####

    def forward(self, x, edge_index, batch):
        #####
        # TODO:
        x = self.embedding(x)
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=0.2, training=self.training)
        x = global_mean_pool(x, batch)
        x = self.mlp(x)
        return x
        #####

In [19]:
# Training function
def train(model, dataloader_train, dataloader_val, epochs=50, lr=0.001, patience=5, device=device):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=3,
        verbose=True
    )
    best_val_acc = 0
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch in dataloader_train:
            #####
            # TODO:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)

            # Fix for y shape mismatch: get graph level labels
            if batch.y.shape[0] != out.shape[0]:
                graph_y = []
                for i in range(out.shape[0]):
                    graph_indices = (batch.batch == i).nonzero(as_tuple=True)[0]
                    graph_y.append(batch.y[graph_indices[0]])
                target = torch.tensor(graph_y, device=device)
            else:
                target = batch.y

            loss = criterion(out, target)
            loss.backward()
            optimizer.step()
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
            total_loss += loss.item() * target.size(0)
            #####

        train_acc = correct / total
        val_acc = evaluate(model, dataloader_val, device)
        print(f'Epoch {epoch+1}: Loss={total_loss:.4f}, Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print('Early stopping triggered')
            break

# Evaluation function
def evaluate(model, dataloader, device=device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            out = model(batch.x, batch.edge_index, batch.batch)

            # Fix for y shape mismatch: get graph level labels
            if batch.y.shape[0] != out.shape[0]:
                graph_y = []
                for i in range(out.shape[0]):
                    graph_indices = (batch.batch == i).nonzero(as_tuple=True)[0]
                    graph_y.append(batch.y[graph_indices[0]])
                target = torch.tensor(graph_y, device=device)
            else:
                target = batch.y

            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
            #####
    return correct / total

Train the GCN model for graph classification:

In [21]:
# Create DataLoaders
batch_size = 60
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False)

# Initialize models
input_dim = dataset_train.num_node_features
hidden_dim = 256
embedding_dim = 512
output_dim = dataset_train.num_classes

gcn_model = GCNGraphClassifier(input_dim, hidden_dim, output_dim, num_gcn_layers=4)
gcn_model.to(device)
print(gcn_model)
print(f'Number of trainable params: {sum(p.numel() for p in gcn_model.parameters() if p.requires_grad)}')

# Train the GCN model
train(gcn_model, dataloader_train, dataloader_val, epochs=12, lr=5e-3, patience=5, device=device)

GCNGraphClassifier(
  (embedding): Sequential(
    (0): Linear(in_features=3, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (convs): ModuleList(
    (0): GCNConv(512, 256)
    (1-3): 3 x GCNConv(256, 256)
  )
  (batch_norms): ModuleList(
    (0-3): 4 x BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.3, inplace=False)
    (8): Linear(in_features=256, out_features=10, bias=True)
  )
)
Number of trainable params

After training, you can make predictions using the GCN model. Save the prediction results in a `.txt` file, where each line contains one prediction for one test data point.

In [None]:
def predict(model, dataloader, device=device):
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            #####
            # TODO:
            batch = batch.to(device)
            
            # Forward pass with the enhanced model
            out = model(batch.x, batch.edge_index, batch.batch)
            
            # Handle potential shape mismatch (as in training)
            if hasattr(batch, 'y') and batch.y.shape[0] != out.shape[0]:
                graph_y = []
                for i in range(out.shape[0]):
                    graph_indices = (batch.batch == i).nonzero(as_tuple=True)[0]
                    if len(graph_indices) > 0:  # Make sure there are nodes in this graph
                        graph_y.append(batch.y[graph_indices[0]])
                # This is only needed for verification if you want to check accuracy
                # target = torch.tensor(graph_y, device=device)
            
            # Get predictions - one per graph
            pred = out.argmax(dim=1).cpu().numpy()
            predictions.extend(pred)
            #####

    return predictions

dataset_test = GNNBenchmarkDataset(root=data_root, name='CIFAR10', split='test')
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

predictions_gcn = predict(gcn_model, dataloader_test)
np.savetxt('predictions_gcn_cifar10.txt', predictions_gcn, fmt='%d')

## Grading

All the tasks will be graded by the accuracy on the test set.

### Task 1 (5 points)
- Accuracy >= 0.2: 5 points
- Accuracy < 0.2: 0 points

### Task 2 (10 points)
- Accuracy >= 0.32: 10 points
- Accuracy < 0.32: 0 points

### Task 3 (10 points)
- Accuracy >= 0.4: 10 points
- Accuracy < 0.4: 0 points


## Submission

After completing all the tasks, you should submit the following four files to Gradescope:

- `hw4_gnn.ipynb`: The notebook with all tasks completed.
- `predictions_mlp_cluster.txt`: prediction results of the MLP model on CLUSTER dataset.
- `predictions_gcn_cluster.txt`: prediction results of the GCN model on CLUSTER dataset.
- `predictions_gcn_cifar10.txt`: prediction results of the GCN model on cifar10 dataset.

Note that you need to submit the files individually, **DO NOT** submit a zip file.