# HW2 part 2: GNN 

## Graph Attention Network (GAT) Implementation for Node Classification

**Objective:** Implement a GAT from scratch to perform node classification using an OGB dataset. Develop neural components, including the forward pass, as well as training and testing routines. There are 17 `todo`s and 2 questions for GAT.

A Graph Attention Network (GAT) applies the concept of self-attention to graph-structured data. Unlike traditional Graph Convolutional Networks (GCNs) that rely on fixed or uniform weights derived from adjacency structures, GAT learns to assign different weights (attention scores) to different edges. This allows the model to focus more on important neighbors while possibly ignoring less relevant ones. By doing so, GAT effectively captures how each node interacts with its neighbors in a more flexible and adaptive way.

In [1]:
# This notebook's first part demonstrates how to train and evaluate a GAT model on the ogbn-arxiv dataset for node classification (To predict the category of each paper). Make sure you know the basics of GNNs, PyG and pytorch before starting this notebook. If not, please check the corresponding tutorials first.

# If you use Google Colab, you can uncomment and run the following command to install the required packages.
# For this assigment, we recommend using your own local computer since the RAM usage is high, Colab may crash due to the high RAM usage.
# We passed our assignment on our local computers using python version 3.10.16.
# !pip install torch==2.5.0
# !pip install torch-geometric==2.6.1
# !pip install ogb==1.3.6

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GATConv
from ogb.nodeproppred import PygNodePropPredDataset,Evaluator

# We recommend to use GPU for training if possible. Because GAT is a relatively large model, training on CPU can be slow. But after testing, we find that training GAT on CPU is also acceptable.


# if torch.cuda.is_available():
#     device = torch.device("cuda")  # NVIDIA GPU
# else:
#     device = torch.device("cpu")   # CPU fallback
device = torch.device("cpu")   # CPU fallback
print(f"Using device: {device}")


import random
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # make cudnn deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# set a random seed for reproducibility
set_seed(42)

Using device: cpu


In [2]:
# ## Define a GAT Model for Node Classification

#### Model Structure

## The model is divided into two GATConv layers:
##   The first layer uses multi-head attention (specified by num_heads) to produce richer representations. Here, each head outputs hidden_channels features, and they are concatenated, resulting in a total output dimension of hidden_channels * num_heads.
##   The second layer is a single-head output layer that directly predicts the final node embeddings or categories. Its output dimension corresponds to the number of classes (out_channels).

#### Layer Normalization

## After the first GAT layer, the output is passed through nn.LayerNorm(hidden_channels * num_heads). This helps stabilize training by normalizing feature distributions across different nodes.

#### Forward Pass

## First apply dropout to the input features (x) to reduce overfitting.
## Pass the data through the first GAT layer (gat1).
## Apply an ELU activation, then layer normalization, and another dropout.
## Finally, pass the features through the second GAT layer (gat2) to get the predictions (the paper category with the highest score).

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8, dropout=0.6,add_self_loops=True):
        super(GAT, self).__init__()
        self.dropout = dropout
        # The first layer: multi-head attention, each head outputs hidden_channels, and the total output dimension = hidden_channels * num_heads (i.e. the input dimension of the second layer), you don't need to care the concatenation in this layer
        # The usage of GATConv is: GATConv(in_channels, out_channels, heads=XXX, concat=XXX, dropout=dropout, add_self_loops=XXX)
        ## TODO 1: Define the first GAT layer (2 points)

        self.gat1 = GATConv(in_channels, hidden_channels, heads=num_heads, concat=True, dropout=dropout, add_self_loops=add_self_loops)
        
        ## TODO 1: Define the first GAT layer

        # The second layer: single-head output, not concatenated, directly output the dimension of the node categories
        ## TODO 2: Define the second GAT layer (2 points)

        self.gat2 = GATConv(hidden_channels * num_heads, out_channels, heads=1, concat=False, dropout=dropout, add_self_loops=add_self_loops)
        
        ## TODO 2: Define the second GAT layer

        # Layer normalization for stable training, use nn.LayerNorm, the input is the hidden_channels * num_heads
        ## TODO 3: Define a LayerNorm layer (2 points)
        
        self.layer_norm = nn.LayerNorm(hidden_channels * num_heads)

        ## TODO 3: Define a LayerNorm layer

    def forward(self, data):
        # data contains x and edge_index
        x, edge_index = data.x, data.edge_index

        ## TODO 4: use a dropout layer for the input features, then apply the first GAT layer (2 points)
        # use F.dropout to perform dropout, the input is x, and the dropout rate is self.dropout, set training=self.training

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat1(x, edge_index)

        ## TODO 4: use a dropout layer for the input features, then apply the first GAT layer

        ## TODO 5: apply ELU activation, layer normalization and dropout (2 points)
        # use F.elu for ELU activation, the input is x

        x = F.elu(x)
        x = self.layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        ## TODO 5: apply ELU activation, layer normalization and dropout

        ## TODO 6: apply the second GAT layer (2 points)

        x = self.gat2(x, edge_index)

        ## TODO 6: apply the second GAT layer
        return x

# Load the ogbn-arxiv dataset
dataset_arxiv = PygNodePropPredDataset(name='ogbn-arxiv')

# Let's see some properties of the dataset
## TODO 7: print the number of graphs in the dataset (2 points)

print(f"Number of graphs in dataset: {len(dataset_arxiv)}")

## TODO 7: print the number of graphs in the dataset

# Let's see how many features and classes are in the dataset
## TODO 8: print the number of features and classes in the dataset (2 points)

print(f"Number of features: {dataset_arxiv.num_features}, Number of classes: {dataset_arxiv.num_classes}")

## TODO 8: print the number of features and classes in the dataset


data_arxiv = dataset_arxiv[0]
# Let's see the shape of the node features and the target labels
print(data_arxiv.x.shape)

# Get the data split index
split_idx_arxiv = dataset_arxiv.get_idx_split()
train_idx_arxiv = split_idx_arxiv['train']
valid_idx_arxiv = split_idx_arxiv['valid']
test_idx_arxiv  = split_idx_arxiv['test']

# Model, optimizer, scheduler, and evaluator settings
model_arxiv = GAT(in_channels=dataset_arxiv.num_features,
                         hidden_channels=64,
                         out_channels=dataset_arxiv.num_classes,
                         num_heads=8,
                         dropout=0.6).to(device)
data_arxiv = data_arxiv.to(device)
optimizer_arxiv = torch.optim.Adam(model_arxiv.parameters(), lr=0.005, weight_decay=5e-4)
scheduler_arxiv = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_arxiv, mode='min', factor=0.7, patience=10, verbose=True)
evaluator_arxiv = Evaluator(name='ogbn-arxiv')

def train_arxiv():
    model_arxiv.train()
    optimizer_arxiv.zero_grad()
    ## TODO 9: forward propagation, loss calculation and backward propagation (2 points)
    # use F.cross_entropy as the loss function, the input is the output of the model and the target labels, and only use the training set for loss calculation, then call loss.backward()
    
    out = model_arxiv(data_arxiv)
    loss = F.cross_entropy(out[train_idx_arxiv], data_arxiv.y[train_idx_arxiv].squeeze())
    loss.backward()

    ## TODO 9: forward propagation, loss calculation and backward propagation
    optimizer_arxiv.step()
    return loss.item()

@torch.no_grad()
def evaluate_arxiv():
    model_arxiv.eval()
    ## TODO 10: get the prediction and use argmax to get the final category (2 points)

    out = model_arxiv(data_arxiv)
    y_pred = out.argmax(dim=1).unsqueeze(1)

    ## TODO 10: get the prediction and use argmax to get the final category
    train_acc = evaluator_arxiv.eval({'y_true': data_arxiv.y[train_idx_arxiv],
                                      'y_pred': y_pred[train_idx_arxiv]})['acc']
    valid_acc = evaluator_arxiv.eval({'y_true': data_arxiv.y[valid_idx_arxiv],
                                      'y_pred': y_pred[valid_idx_arxiv]})['acc']
    test_acc  = evaluator_arxiv.eval({'y_true': data_arxiv.y[test_idx_arxiv],
                                      'y_pred': y_pred[test_idx_arxiv]})['acc']
    return train_acc, valid_acc, test_acc

Number of graphs in dataset: 1
Number of features: 128, Number of classes: 40
torch.Size([169343, 128])


  self.data, self.slices = torch.load(self.processed_paths[0])


In [3]:
## You can change these hyperparameters to see if you can get better results, but the default hyperparameters should work. And also make sure the three sets for hyperparameters are the same.
num_epochs = 30
best_valid_acc_arxiv = 0
patience_arxiv = 30
trigger_times_arxiv = 0
best_model_state_arxiv = None

In [None]:
# import os
# import gc

# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# torch.cuda.empty_cache()
# gc.collect()

In [4]:
print("---- ogbn-arxiv training start ----")
for epoch in range(1, num_epochs + 1):
    loss = train_arxiv()
    scheduler_arxiv.step(loss)
    train_acc, valid_acc, test_acc = evaluate_arxiv()
    ## TODO 11: early stopping, in an if-else block (2 points)
    
    if valid_acc > best_valid_acc_arxiv:
        best_valid_acc_arxiv = valid_acc
        best_model_state_arxiv = model_arxiv.state_dict()
        trigger_times_arxiv = 0  # Reset trigger count when improvement is observed
    else:
        trigger_times_arxiv += 1  # Increment trigger count when no improvement
       
    ## TODO 11: early stopping, in an if-else block

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, Test: {test_acc:.4f}')

    if trigger_times_arxiv >= patience_arxiv:
        print("Early stopping triggered!")
        break

---- ogbn-arxiv training start ----
Epoch: 001, Loss: 6.1625, Train: 0.1791, Valid: 0.0764, Test: 0.0588
Epoch: 002, Loss: 5.5791, Train: 0.1312, Valid: 0.2622, Test: 0.2769
Epoch: 003, Loss: 5.3758, Train: 0.1888, Valid: 0.2703, Test: 0.2310
Epoch: 004, Loss: 5.0632, Train: 0.2032, Valid: 0.2394, Test: 0.2043
Epoch: 005, Loss: 4.7448, Train: 0.3108, Valid: 0.2866, Test: 0.2442
Epoch: 006, Loss: 4.3284, Train: 0.2624, Valid: 0.2499, Test: 0.2101
Epoch: 007, Loss: 4.2126, Train: 0.2613, Valid: 0.2461, Test: 0.2066
Epoch: 008, Loss: 4.2357, Train: 0.2906, Valid: 0.2955, Test: 0.2542
Epoch: 009, Loss: 4.0538, Train: 0.3509, Valid: 0.3923, Test: 0.3838
Epoch: 010, Loss: 3.8687, Train: 0.2964, Valid: 0.3573, Test: 0.3918
Epoch: 011, Loss: 3.7791, Train: 0.3802, Valid: 0.4532, Test: 0.4518
Epoch: 012, Loss: 3.6805, Train: 0.4049, Valid: 0.4212, Test: 0.3706
Epoch: 013, Loss: 3.6393, Train: 0.4142, Valid: 0.4163, Test: 0.3546
Epoch: 014, Loss: 3.5714, Train: 0.4137, Valid: 0.4062, Test: 0.350

In [5]:
# Load the best model state
model_arxiv.load_state_dict(best_model_state_arxiv)
final_train, final_valid, final_test = evaluate_arxiv()
print(f"[ogbn-arxiv] Best validation accuracy: {final_valid:.4f}, corresponding test accuracy: {final_test:.4f}")

[ogbn-arxiv] Best validation accuracy: 0.5099, corresponding test accuracy: 0.4793


* What's the best validation accuracy you can get? (2 points) What's the corresponding test accuracy? (2 points) Please report the results in this markdown cell.
<br><br>**ANS:**
Best Validation Accuracy: 0.5099
Corresponding Test Accuracy: 0.4793
These are the best results from the ogbn-arxiv dataset after training Graph Attention Network (GAT) model.

* Let's see if we remove the graph structure and only use the node features, how well the model can perform.

In [7]:
import copy

# construct a dummy edge_index with self-loops only
num_nodes = data_arxiv.num_nodes
dummy_edge_index = torch.arange(num_nodes, device=data_arxiv.x.device).unsqueeze(0).repeat(2, 1)

# copy the original data and replace the edge_index with the dummy one
data_arxiv_no_graph = copy.deepcopy(data_arxiv)
data_arxiv_no_graph.edge_index = dummy_edge_index

# define a new model for the data without graph structure, with the same hyperparameters
model_no_graph = GAT(
    in_channels=dataset_arxiv.num_features,
    hidden_channels=64,
    out_channels=dataset_arxiv.num_classes,
    num_heads=8,
    dropout=0.6
).to(device)


optimizer_no_graph = torch.optim.Adam(model_no_graph.parameters(), lr=0.005, weight_decay=5e-4)
scheduler_no_graph = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_no_graph, mode='min', factor=0.7, patience=10, verbose=True
)

def train_no_graph():
    ## TODO 12: training code for the model without graph structure, should be the same as the original training code (2 points)

    # use the same model but with data without graph structure
    model_no_graph.train()
    optimizer_no_graph.zero_grad()

    # Forward pass with data without the graph structure
    out = model_no_graph(data_arxiv_no_graph)
    
    # Compute loss
    loss = F.cross_entropy(out[train_idx_arxiv], data_arxiv_no_graph.y[train_idx_arxiv].squeeze())
    
    # Backward pass and optimization
    loss.backward()
    optimizer_no_graph.step()

    ## TODO 12: training code for the model without graph structure, should be the same as the original training code
    return loss.item()

@torch.no_grad()
def evaluate_no_graph():
    ## TODO 13: evaluation code for the model without graph structure, should be the same as the original evaluation code (2 points)

    model_no_graph.eval()
    
    # Forward pass with data without the graph structure
    out = model_no_graph(data_arxiv_no_graph)
    y_pred = out.argmax(dim=1).unsqueeze(1) 
    
    # Calculate accuracy for train, valid, and test sets
    train_acc = evaluator_arxiv.eval({'y_true': data_arxiv_no_graph.y[train_idx_arxiv], 'y_pred': y_pred[train_idx_arxiv]})['acc']
    valid_acc = evaluator_arxiv.eval({'y_true': data_arxiv_no_graph.y[valid_idx_arxiv], 'y_pred': y_pred[valid_idx_arxiv]})['acc']
    test_acc  = evaluator_arxiv.eval({'y_true': data_arxiv_no_graph.y[test_idx_arxiv], 'y_pred': y_pred[test_idx_arxiv]})['acc']

    ## TODO 13: evaluation code for the model without graph structure, should be the same as the original evaluation code
    return train_acc, valid_acc, test_acc


## You can change these hyperparameters to see if you can get better results, but the default hyperparameters should work. And also make sure the three sets for hyperparameters are the same.
num_epochs_no_graph = 30
best_valid_acc_arxiv_no_graph = 0
patience_arxiv_no_graph = 30
trigger_times_arxiv_no_graph = 0
best_model_state_arxiv_no_graph = None
# start training and evaluation
for epoch in range(1, num_epochs_no_graph+1):
    loss = train_no_graph()
    scheduler_no_graph.step(loss)
    train_acc, valid_acc, test_acc = evaluate_no_graph()
    ## TODO 14: early stopping, in an if-else block (2 points)

    if valid_acc > best_valid_acc_arxiv_no_graph:
        best_valid_acc_arxiv_no_graph = valid_acc
        best_model_state_arxiv_no_graph = model_no_graph.state_dict()
        trigger_times_arxiv_no_graph = 0  # Reset trigger count when improvement is observed
    else:
        trigger_times_arxiv_no_graph += 1  # Increment trigger count when no improvement
        
    ## TODO 14: early stopping, in an if-else block

    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, Test: {test_acc:.4f}')

    if trigger_times_arxiv_no_graph >= patience_arxiv_no_graph:
        print("Early stopping triggered!")
        break
# Load the best model state
model_no_graph.load_state_dict(best_model_state_arxiv_no_graph)
final_train, final_valid, final_test = evaluate_no_graph()
print(f"[ogbn-arxiv_no_graph] Best validation accuracy: {final_valid:.4f}, corresponding test accuracy: {final_test:.4f}")


Epoch: 001, Loss: 7.0533, Train: 0.1792, Valid: 0.0766, Test: 0.0587
Epoch: 002, Loss: 6.5416, Train: 0.1429, Valid: 0.2423, Test: 0.2257
Epoch: 003, Loss: 5.9118, Train: 0.2650, Valid: 0.3299, Test: 0.3116
Epoch: 004, Loss: 5.6113, Train: 0.3427, Valid: 0.3552, Test: 0.3180
Epoch: 005, Loss: 5.2832, Train: 0.3241, Valid: 0.3148, Test: 0.2801
Epoch: 006, Loss: 5.0222, Train: 0.2780, Valid: 0.2545, Test: 0.2253
Epoch: 007, Loss: 4.8658, Train: 0.2747, Valid: 0.2611, Test: 0.2304
Epoch: 008, Loss: 4.7486, Train: 0.3013, Valid: 0.3079, Test: 0.2811
Epoch: 009, Loss: 4.5691, Train: 0.3454, Valid: 0.3669, Test: 0.3491
Epoch: 010, Loss: 4.4079, Train: 0.3682, Valid: 0.3999, Test: 0.3833
Epoch: 011, Loss: 4.3124, Train: 0.3751, Valid: 0.4112, Test: 0.3891
Epoch: 012, Loss: 4.2166, Train: 0.4078, Valid: 0.4349, Test: 0.4084
Epoch: 013, Loss: 4.0907, Train: 0.4302, Valid: 0.4444, Test: 0.4163
Epoch: 014, Loss: 3.9909, Train: 0.4111, Valid: 0.4213, Test: 0.3970
Epoch: 015, Loss: 3.9094, Train: 0

* What's the best validation accuracy you can get? (2 points) What's the corresponding test accuracy? (2 points) Please report the results in this markdown cell.
<br><br>**ANS:**
Best Validation Accuracy: 0.4428
Corresponding Test Accuracy: 0.4226

* Let's see if we further remove the edge structure and only use the node features, how well the model can perform.

In [None]:
import copy

# an empty edge_index, even no self-loops
dummy_edge_index = torch.empty((2, 0), dtype=torch.long, device=data_arxiv.x.device)

# duplicate the original data and replace the edge_index with the dummy one
data_arxiv_no_edge = copy.deepcopy(data_arxiv)
data_arxiv_no_edge.edge_index = dummy_edge_index

# define a new model for the data without graph structure, with the same hyperparameters
model_no_edge = GAT(
    in_channels=dataset_arxiv.num_features,
    hidden_channels=64,
    out_channels=dataset_arxiv.num_classes,
    num_heads=8,
    dropout=0.6
).to(device)

optimizer_no_edge = torch.optim.Adam(model_no_edge.parameters(), lr=0.005, weight_decay=5e-4)
scheduler_no_edge = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_no_edge, mode='min', factor=0.7, patience=10, verbose=True
)

def train_no_edge():
    ## TODO 15: training code for the model without edge structure, should be the same as the original training code (2 points)

    # use the same model but with data without graph structure
    model_no_edge.train()
    optimizer_no_edge.zero_grad()

    # Forward pass with data without edge structure
    out = model_no_edge(data_arxiv_no_edge)
    
    # Compute the loss
    loss = F.cross_entropy(out[train_idx_arxiv], data_arxiv_no_edge.y[train_idx_arxiv].squeeze())
    
    # Backward pass and optimization
    loss.backward()
    optimizer_no_edge.step()

    ## TODO 15: training code for the model without edge structure, should be the same as the original training code
    return loss.item()

@torch.no_grad()
def evaluate_no_edge():
    ## TODO 16: evaluation code for the model without edge structure, should be the same as the original evaluation code (2 points)

    model_no_edge.eval()

    # Forward pass with data without edge structure
    out = model_no_edge(data_arxiv_no_edge)
    y_pred = out.argmax(dim=1).unsqueeze(1)  # Reshape y_pred to match the shape of y_true

    # Calculate accuracy for train, valid, and test sets
    train_acc = evaluator_arxiv.eval({'y_true': data_arxiv_no_edge.y[train_idx_arxiv], 'y_pred': y_pred[train_idx_arxiv]})['acc']
    valid_acc = evaluator_arxiv.eval({'y_true': data_arxiv_no_edge.y[valid_idx_arxiv], 'y_pred': y_pred[valid_idx_arxiv]})['acc']
    test_acc  = evaluator_arxiv.eval({'y_true': data_arxiv_no_edge.y[test_idx_arxiv], 'y_pred': y_pred[test_idx_arxiv]})['acc']

    ## TODO 16: evaluation code for the model without edge structure, should be the same as the original evaluation code
    return train_acc, valid_acc, test_acc


## You can change these hyperparameters to see if you can get better results, but the default hyperparameters should work. And also make sure the three sets for hyperparameters are the same.
num_epochs_no_edge = 30
best_valid_acc_no_edge = 0
patience_no_edge = 30
trigger_times_no_edge = 0
best_model_state_no_edge = None

# start training and evaluation
for epoch in range(1, num_epochs_no_edge+1):
    loss = train_no_edge()
    scheduler_no_edge.step(loss)
    train_acc, valid_acc, test_acc = evaluate_no_edge()
    ## TODO 17: early stopping, in an if-else block (2 points)

    if valid_acc > best_valid_acc_no_edge:
        best_valid_acc_no_edge = valid_acc
        best_model_state_no_edge = model_no_edge.state_dict()
        trigger_times_no_edge = 0  # Reset trigger count when improvement is observed
    else:
        trigger_times_no_edge += 1  # Increment trigger count when no improvement

    ## TODO 17: early stopping, in an if-else block

    print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, Test: {test_acc:.4f}")

    if trigger_times_no_edge >= patience_no_edge:
        print("Early stopping triggered!")
        break

# load the best model state
model_no_edge.load_state_dict(best_model_state_no_edge)
final_train, final_valid, final_test = evaluate_no_edge()
print(f"[ogbn-arxiv_no_edge] Best validation accuracy: {final_valid:.4f}, corresponding test accuracy: {final_test:.4f}")

Epoch: 001, Loss: 6.9605, Train: 0.1791, Valid: 0.0763, Test: 0.0586
Epoch: 002, Loss: 6.5561, Train: 0.1176, Valid: 0.2347, Test: 0.2204
Epoch: 003, Loss: 6.1048, Train: 0.1989, Valid: 0.2981, Test: 0.3183
Epoch: 004, Loss: 5.7291, Train: 0.2972, Valid: 0.2876, Test: 0.2595
Epoch: 005, Loss: 5.3743, Train: 0.3102, Valid: 0.2910, Test: 0.2552
Epoch: 006, Loss: 5.0659, Train: 0.2447, Valid: 0.2001, Test: 0.1749
Epoch: 007, Loss: 4.9449, Train: 0.2503, Valid: 0.2198, Test: 0.1909
Epoch: 008, Loss: 4.8078, Train: 0.2863, Valid: 0.2869, Test: 0.2582
Epoch: 009, Loss: 4.6061, Train: 0.3318, Valid: 0.3591, Test: 0.3436
Epoch: 010, Loss: 4.4422, Train: 0.3622, Valid: 0.4134, Test: 0.4135
Epoch: 011, Loss: 4.3661, Train: 0.3767, Valid: 0.4285, Test: 0.4242
Epoch: 012, Loss: 4.2621, Train: 0.3887, Valid: 0.4230, Test: 0.3997
Epoch: 013, Loss: 4.1627, Train: 0.4141, Valid: 0.4259, Test: 0.3943
Epoch: 014, Loss: 4.0409, Train: 0.4204, Valid: 0.4268, Test: 0.3930
Epoch: 015, Loss: 3.9257, Train: 0

* What's the best validation accuracy you can get? What's the corresponding test accuracy? Please report the results in this markdown cell (2 points).
<br><br>**ANS:**
Best Validation Accuracy: 0.4354
Corresponding Test Accuracy: 0.4122<br><br>
* What's your observation of the results (all the three situations)? Please write down your observation in this markdown cell (2 points).
<br><br>**ANS:**<br>

**Observation of the Results:**

**Full Graph Structure (GAT model):**

The best performance is achieved when the model utilizes both node features and graph edges. The GAT model captures the relationships between nodes effectively, leading to the highest validation and test accuracy.

**No Graph Structure (Only Node Features):**

When the graph structure is removed and the model only has access to node features, performance declines significantly. The model loses the ability to capture inter-node relationships, resulting in a noticeable drop in accuracy on both validation and test sets.

**No Edge Structure (Self-Loops Only):**

When the edge structure is further removed, leaving only self-loops (i.e., each node is connected only to itself), the performance drops even further. The minimal structure provided by self-loops doesn't offer enough information for the model to make accurate predictions, causing the test accuracy to be even lower than the model using only node features.

**Summary:**

The graph structure (edges) plays a crucial role in node classification tasks, as it enables the model to learn relationships between nodes. As we progressively remove the graph structure (first removing all edges, then using only self-loops), the model's performance declines. This highlights how important the topology of the graph is for capturing node dependencies, which is key for accurate predictions in graph-based tasks.



## Graph Convolutional Network (GCN) Implementation for Graph Classification

**Objective:** Implement a GCN from scratch to perform graph classification using an OGB dataset. Develop neural components, including the forward pass, as well as training and testing routines. There are 11 `todo`s and 7 questions for GCN.

## 1. Introduction

Provide an overview of Graph Neural Networks (GNNs) and the significance of GCNs in processing graph-structured data. Discuss the relevance of graph classification tasks and the role of the OGB datasets in benchmarking.

## 2. Dataset Exploration

### 2.1. Importing Libraries



In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import torch_geometric

### 2.2. Loading Dataset

In [32]:
# Import necessary modules
from ogb.graphproppred import PygGraphPropPredDataset  # OGB dataset loader for graph property prediction
from torch_geometric.data import DataLoader  # PyG DataLoader for handling batches of graphs

# Load the OGB dataset for molecular property prediction.
# 'ogbg-molhiv' is a graph-level dataset where each graph represents a molecular structure,
# and the task is binary classification to predict whether the molecule inhibits HIV replication.
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')

# Split the dataset into training, validation, and test sets.
# The dataset provides predefined splits to ensure consistency in evaluation.
split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx['train']]  # Training set
valid_dataset = dataset[split_idx['valid']]  # Validation set
test_dataset = dataset[split_idx['test']]    # Test set

# Create DataLoaders for batch processing.
# The DataLoader enables efficient loading of graphs in batches for training and evaluation.
# - batch_size: Number of graphs per batch.
# - shuffle: Whether to shuffle the dataset at each epoch (only done for training).
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)   # Shuffle for training
valid_loader = DataLoader(valid_dataset, batch_size=128, shuffle=False)  # No shuffle for validation
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)    # No shuffle for testing

  self.data, self.slices = torch.load(self.processed_paths[0])


In [33]:
# Display dataset information
print(f'Dataset name: {dataset.name}')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of classes: {dataset.num_tasks}')
print(f'Number of node features: {dataset.num_node_features}')

Dataset name: ogbg-molhiv
Number of graphs: 41127
Number of classes: 1
Number of node features: 9


## 3. Model Implementation
### 3.1. GCN Convolution Layer Implementation

In [34]:
import torch.nn as nn
from torch_geometric.nn import MessagePassing  # Base class for defining message-passing layers
from torch_geometric.utils import add_self_loops, degree  # Utilities for graph processing

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        """
        Custom implementation of a Graph Convolutional Network (GCN) layer.

        Args:
            in_channels (int): Number of input node features.
            out_channels (int): Number of output node features.

        The layer follows the formulation:
            H' = σ(D^(-1/2) A D^(-1/2) H W)
        where:
            - A is the adjacency matrix with self-loops.
            - D is the degree matrix.
            - H is the input node feature matrix.
            - W is the trainable weight matrix.
            - σ is a non-linearity (like ReLU).
        """
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation means summing messages from neighbors.

        # TODO 1: Define a linear transformation layer for node features (2 points)

        self.linear = nn.Linear(in_channels, out_channels)

        # TODO 1: Define a linear transformation layer for node features

    def forward(self, x, edge_index):
        """
        Forward pass of the GCN layer.

        Args:
            x (Tensor): Node feature matrix of shape [num_nodes, in_channels].
            edge_index (Tensor): Graph connectivity in COO format, shape [2, num_edges].

        Returns:
            Tensor: Updated node features of shape [num_nodes, out_channels].
        """

        # TODO 2: Add self-loops to the adjacency matrix (2 points)

        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # TODO 2: Add self-loops to the adjacency matrix

        # TODO 3: Compute node degrees (2 points)

        row, col = edge_index  # Extract row (source nodes) and col (destination nodes)
        deg = degree(row, x.size(0), dtype=x.dtype)  # Compute degree of each node
        deg_inv_sqrt = deg.pow(-0.5)  # Compute D^(-1/2)

        # TODO 3: Compute node degrees 

        # Prevent division by zero for isolated nodes
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0  

        # TODO 4: Compute normalized adjacency matrix (2 points)

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # Normalize using degrees

        # TODO 4: Compute normalized adjacency matrix 

        # TODO 5: Perform message passing using PyTorch Geometric's propagate function (2 points)

        out = self.propagate(edge_index, x=x, norm=norm)

        # TODO 5: Perform message passing using PyTorch Geometric's propagate function 

        # TODO 6: Apply the linear transformation after aggregation (2 points)

        out = self.linear(out)

        # TODO 6: Apply the linear transformation after aggregation 

        return out

    def message(self, x_j, norm):
        """
        Message function: Defines how information is aggregated from neighboring nodes.

        Args:
            x_j (Tensor): Features of neighboring nodes.
            norm (Tensor): Normalization coefficients computed from node degrees.

        Returns:
            Tensor: Normalized node features.
        """
        # TODO 7: Apply normalization to the node features (2 points)

        return norm.view(-1, 1) * x_j 
    
        # TODO 7: Apply normalization to the node features

## TODO:
### Why is it important to add self-loops in the graph convolution process, and how does the normalization strategy here help stabilize training in graph neural networks? (2 points)

**ANS:**

**Importance of Adding Self-Loops:**

Self-loops are crucial in graph convolutions as they allow each node to retain its own information during the aggregation process. Without self-loops, nodes could lose their identity across layers, leading to ineffective learning.

**Normalization Strategy for Stabilizing Training:**

The normalization strategy, using the inverse square root of node degrees, helps stabilize training by preventing nodes with high degrees from dominating the message aggregation. This ensures balanced information flow, avoids exploding gradients, and promotes stable learning in deep networks.

### 3.2. Implement GCN

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, dropout_rate):
        """
        Graph Convolutional Network (GCN) for graph classification.

        Args:
            num_node_features (int): Number of input node features.
            hidden_channels (int): Number of hidden units in the GCN layers.
            num_classes (int): Number of output classes (for classification tasks).
            dropout_rate (float): Dropout probability for regularization.

        This model consists of:
        - Two GCNConv layers to extract graph structure-aware features.
        - ReLU activation and dropout for regularization.
        - A global mean pooling layer to aggregate node features into graph-level representations.
        - A final linear layer for classification.
        """
        super(GCN, self).__init__()

        self.dropout_rate = dropout_rate  # Store dropout rate

        # TODO 8: Complete the network layer definitions (2 points)

        # First graph convolution layer: transforms node features to hidden dimension
        self.gcn1 = GCNConv(num_node_features, hidden_channels)

        # Second graph convolution layer: refines node embeddings
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)

        # Final linear layer to produce class predictions from the pooled graph representation
        self.fc = nn.Linear(hidden_channels, num_classes)

        # TODO 8: Complete the network layer definitions (2 points)

    def forward(self, x, edge_index, batch):
        """
        Forward pass through the GCN.

        Args:
            x (Tensor): Node feature matrix of shape [num_nodes, num_node_features].
            edge_index (Tensor): Graph connectivity in COO format [2, num_edges].
            batch (Tensor): Batch index for each node, used for global pooling.

        Returns:
            Tensor: Output class logits of shape [num_graphs, num_classes].
        """

        # TODO 9: Complete the implementation of the forward function (2 points)

        # First GCN layer: apply convolution, activation, and dropout
        x = self.gcn1(x, edge_index)
        x = F.relu(x)  # ReLU activation layer
        x = F.dropout(x, p=self.dropout_rate, training=self.training)  # Dropout layer for regularization

        # Second GCN layer
        x = self.gcn2(x, edge_index)
        x = F.relu(x)  # ReLU activation layer
        x = F.dropout(x, p=self.dropout_rate, training=self.training)  # Dropout layer for regularization

        # Global mean pooling to aggregate node features into a graph representation
        x = global_mean_pool(x, batch)  # Output shape: [num_graphs, hidden_channels]

        # Final linear layer for classification
        x = self.fc(x)  # Output shape: [num_graphs, num_classes]

        # TODO 9: Complete the implementation of the forward function (2 points)
        return x

## 4. Model Training
### 4.1. Implement Training and Evalutation Function

In [36]:
def train(model, loader, device):
    """
    Training function for the GCN model.

    Args:
        model (torch.nn.Module): The GCN model.
        loader (DataLoader): DataLoader for the training set.
        criterion (torch.nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer for training.
        device (torch.device): Device to run computations on (CPU or GPU).

    Returns:
        float: Average training loss per graph.
    """

    model.train()
    total_loss = 0
    # Iterate over batches in the DataLoader
    for data in loader:
        data = data.to(device)

        # TODO 10: Complete the training function (2 points)

        #  Zero the gradients of the Adam optimizer from `torch.optim`
        optimizer.zero_grad()

        # Perform a forward pass through the model
        out = model(data.x, data.edge_index, data.batch)

        # Compute the loss use `BCEWithLogitsLoss()``

        # This loss is suitable for binary classification tasks (e.g., graph classification)
        loss = F.binary_cross_entropy_with_logits(out, data.y.float())

        #  Perform backpropagation
        loss.backward()

        # Update model parameters
        optimizer.step()

        # TODO 10: Complete the training function

        # Accumulate loss
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(loader.dataset)

## TODO:
### How does the choice of the BCEWithLogitsLoss and the Adam optimizer influence the training dynamics in this setting? (2 points)

**ANS:**

1. **BCEWithLogitsLoss**:
   - **BCEWithLogitsLoss** is a loss function specifically designed for binary classification tasks, where the outputs are logits (raw, unnormalized scores) rather than probabilities. It combines a **sigmoid activation** with **binary cross-entropy loss**, which makes it numerically stable by avoiding potential issues with small values from separate sigmoid and loss operations.
   - **Effect on Training**: Since the GCN model produces logits (not probabilities), BCEWithLogitsLoss directly compares the raw logits with the target labels. This is ideal for graph-based binary classification tasks, as it allows the model to learn and update its parameters efficiently while avoiding computational issues related to individual sigmoid and cross-entropy steps.

2. **Adam Optimizer**:
   - **Adam** (short for Adaptive Moment Estimation) is an adaptive learning rate optimization algorithm that computes individual learning rates for different parameters based on estimates of first and second moments (mean and uncentered variance). It combines the benefits of **Momentum** and **RMSprop**, making it well-suited for training deep learning models.
   - **Effect on Training**: Adam adapts the learning rate for each parameter, which helps speed up training and provides robustness against poor initialization. This is particularly useful in graph-based models where the structure of the data can vary significantly. Adam's ability to adapt to the landscape of the loss function helps the model converge faster and more reliably, especially when dealing with complex graph structures.


### What are potential alternative choices and their pros/cons for graph-based binary classification tasks? (2 points)

**ANS:**
1. **Alternative Loss Functions**:
   
   - **CrossEntropyLoss (for multi-class classification)**:
     - **Pros**: Useful if the binary classification task is extended to multi-class problems, as it can handle multi-class classification directly.
     - **Cons**: If the task is strictly binary, BCEWithLogitsLoss would be more appropriate since CrossEntropyLoss requires one-hot encoded labels and can be less efficient for binary tasks.
   
   - **Hinge Loss (SVM-based)**:
     - **Pros**: Effective for margin-based classification tasks where you care about not just correct classification but also the confidence in classification. It can lead to better generalization.
     - **Cons**: Not as commonly used in graph neural networks, especially for tasks like graph classification, where BCEWithLogitsLoss typically provides a smoother and more efficient training process.

2. **Alternative Optimizers**:
   
   - **Stochastic Gradient Descent (SGD)**:
     - **Pros**: One of the simplest and most widely used optimizers, suitable for large-scale datasets and relatively simple models. It is a good option if computational resources are limited.
     - **Cons**: Does not adapt learning rates, making it slower to converge in deep learning tasks. Requires careful tuning of the learning rate and momentum parameters, which can be cumbersome.

   - **RMSprop**:
     - **Pros**: Like Adam, it adapts the learning rate based on recent gradient information, but it does so by maintaining a moving average of squared gradients. Works well with non-stationary objectives.
     - **Cons**: While it adapts learning rates like Adam, it can be less robust for certain tasks and can require tuning of additional hyperparameters like the decay term.

### **Summary**:
- **BCEWithLogitsLoss** is ideal for binary classification tasks where the outputs are logits, and **Adam** is well-suited for the dynamic and complex landscape of graph data, providing fast convergence and reliable training.
- **Alternative loss functions** like **CrossEntropyLoss** or **Hinge Loss** may be useful for different types of classification tasks, but they are generally less efficient or appropriate for binary graph classification tasks.
- **Alternative optimizers** like **SGD** or **RMSprop** could work but may require more hyperparameter tuning and could converge more slowly compared to Adam in graph-based models.

In [37]:
def evaluate(model, loader, evaluator, device):
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            y_true.append(data.y.view(-1, 1))
            y_pred.append(out)
    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    return evaluator.eval({'y_true': y_true, 'y_pred': y_pred})

## TODO:
### What potential pitfalls can arise during the evaluation phase of graph neural networks? (2 points)

**ANS:**

1. **Data Leakage**:
   - One common pitfall in evaluation is **data leakage**, where information from the validation or test set leaks into the model during training, leading to overly optimistic evaluation results. In graph-based models, leakage can occur if the graph structures or node features are inadvertently shared between training and test data (e.g., improper splitting of data, shared node identifiers across splits, etc.).
   - **Solution**: Always ensure proper data splits and that the validation/test graphs are completely independent from training graphs. In graph classification tasks, using **graph-level splits** ensures no leakage.

2. **Overestimating Model Performance**:
   - **Evaluation on the same dataset used for training** or performing **evaluation without proper validation** can result in overly optimistic performance metrics. For example, running evaluations on the training set or testing on graphs that have similar features can lead to inflated metrics.
   - **Solution**: Evaluate on a **separate validation/test set** that has not been seen by the model during training. Also, ensure the evaluation follows proper cross-validation procedures where applicable.


### How might one detect issues like overfitting or underfitting using evaluation metrics such as ROC-AUC? (2 points)

**ANS:**

1. **Overfitting**:
   - Overfitting occurs when a model performs well on the training data but poorly on the validation/test set. It can be detected by monitoring the **ROC-AUC** (Receiver Operating Characteristic - Area Under Curve) or any other performance metric during training and evaluation.
   - **Detection**: If the ROC-AUC on the training set is **significantly higher** than on the validation/test set, this is a clear indication of overfitting. The model has learned the specifics of the training data, but it fails to generalize to unseen data.
   - **Solution**: To mitigate overfitting, use techniques like **regularization**, **dropout**, **early stopping**, or **cross-validation** to ensure the model generalizes well.

2. **Underfitting**:
   - Underfitting occurs when a model fails to learn the underlying patterns in the data, typically due to an overly simplistic model or inadequate training. It can be detected if both the training and validation/test ROC-AUC scores are **significantly low**.
   - **Detection**: If the ROC-AUC score is low on both the training and test sets, it indicates that the model has not learned the relevant features, possibly due to insufficient complexity (e.g., too few layers or features) or inadequate training.
   - **Solution**: To address underfitting, increase the model complexity (e.g., adding more layers, increasing hidden units), provide better features, or train for more epochs to allow the model to learn more effectively.

### **Summary**:
- **Pitfalls**: Issues like data leakage or overestimating performance can arise if proper data splits and independent evaluation sets are not used.
- **Detection of Overfitting/Underfitting**: Using **ROC-AUC**, overfitting is detected when performance on the training set is much better than on the validation/test set, and underfitting is detected when performance is low on both.

### 4.2. Training Loop (Hyper-parameter Tuning to reach `Validation AUC` $\ge$ 0.7 ) 

1 pt: `Validation AUC` $\ge$ 0.69


1.5 pts: `Validation AUC` $\ge$ 0.71


2 pts: `Validation AUC` $\ge$ 0.73

In [38]:
# TODO 11: Tune your hyper-parameter to achieve high performance (2 points)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
    num_node_features=dataset.num_node_features, 
    hidden_channels=128,  # Increased hidden units
    num_classes=dataset.num_tasks, 
    dropout_rate=0.3
).to(device)
evaluator = Evaluator(name='ogbg-molhiv')

# Initialize the Adam optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, betas=(0.85, 0.999), weight_decay=1e-5)

# Learning Rate Scheduler (Cosine Annealing for smooth decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# TODO 11: Tune your hyper-parameter to achieve high performance

### 4.2. Training Loop (Hyper-parameter Tuning to reach `Validation AUC` $\ge$ 0.7 ) 

1 pt: `Validation AUC` $\ge$ 0.69


1.5 pts: `Validation AUC` $\ge$ 0.71


2 pts: `Validation AUC` $\ge$ 0.73

In [39]:
from tqdm import tqdm
num_epochs = 150
best_valid_auc = 0
patience = 20      # Early stopping patience to avoid overfitting
trigger_times = 0  # Counter for early stopping

for epoch in tqdm(range(1, num_epochs + 1)):
    # Training
    train_loss = train(model, train_loader, device)

    # Evaluation
    train_result = evaluate(model, train_loader, evaluator, device)
    valid_result = evaluate(model, valid_loader, evaluator, device)

    train_auc = train_result['rocauc']
    valid_auc = valid_result['rocauc']

    # Print metrics
    print(f'Epoch: {epoch:03d}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train AUC: {train_auc:.4f}, '
          f'Validation AUC: {valid_auc:.4f}')

    # Save the best model
    if valid_auc > best_valid_auc:
        best_valid_auc = valid_auc
        # Save the model with the best validation AUC
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Saved best model with Validation AUC: {best_valid_auc:.4f}')
        trigger_times = 0  # Reset early stopping trigger
    else:
        trigger_times += 1  # Increment early stopping trigger
    
    # Early stopping if validation AUC does not improve for `patience` epochs
    if trigger_times >= patience:
        print(f"Early stopping at epoch {epoch} due to no improvement in Validation AUC.")
        break

  1%|          | 1/150 [00:13<33:54, 13.65s/it]

Epoch: 001, Train Loss: 0.1781, Train AUC: 0.4759, Validation AUC: 0.5613
Saved best model with Validation AUC: 0.5613


  1%|▏         | 2/150 [00:21<25:40, 10.41s/it]

Epoch: 002, Train Loss: 0.1600, Train AUC: 0.5736, Validation AUC: 0.6429
Saved best model with Validation AUC: 0.6429


  2%|▏         | 3/150 [00:29<22:40,  9.25s/it]

Epoch: 003, Train Loss: 0.1597, Train AUC: 0.5815, Validation AUC: 0.6460
Saved best model with Validation AUC: 0.6460


  3%|▎         | 4/150 [00:36<20:01,  8.23s/it]

Epoch: 004, Train Loss: 0.1589, Train AUC: 0.5884, Validation AUC: 0.6470
Saved best model with Validation AUC: 0.6470


  3%|▎         | 5/150 [00:42<18:20,  7.59s/it]

Epoch: 005, Train Loss: 0.1588, Train AUC: 0.5996, Validation AUC: 0.6413


  4%|▍         | 6/150 [00:50<18:01,  7.51s/it]

Epoch: 006, Train Loss: 0.1588, Train AUC: 0.6090, Validation AUC: 0.6439


  5%|▍         | 7/150 [00:56<17:03,  7.16s/it]

Epoch: 007, Train Loss: 0.1585, Train AUC: 0.6184, Validation AUC: 0.6440


  5%|▌         | 8/150 [01:03<16:46,  7.09s/it]

Epoch: 008, Train Loss: 0.1586, Train AUC: 0.6205, Validation AUC: 0.6313


  6%|▌         | 9/150 [01:09<16:10,  6.88s/it]

Epoch: 009, Train Loss: 0.1583, Train AUC: 0.6212, Validation AUC: 0.6313


  7%|▋         | 10/150 [01:16<15:56,  6.83s/it]

Epoch: 010, Train Loss: 0.1586, Train AUC: 0.6294, Validation AUC: 0.6340


  7%|▋         | 11/150 [01:23<15:40,  6.77s/it]

Epoch: 011, Train Loss: 0.1578, Train AUC: 0.6362, Validation AUC: 0.6362


  8%|▊         | 12/150 [01:29<15:28,  6.73s/it]

Epoch: 012, Train Loss: 0.1580, Train AUC: 0.6468, Validation AUC: 0.6500
Saved best model with Validation AUC: 0.6500


  9%|▊         | 13/150 [01:36<15:04,  6.60s/it]

Epoch: 013, Train Loss: 0.1577, Train AUC: 0.6380, Validation AUC: 0.6336


  9%|▉         | 14/150 [01:43<15:25,  6.81s/it]

Epoch: 014, Train Loss: 0.1579, Train AUC: 0.6345, Validation AUC: 0.6271


 10%|█         | 15/150 [01:51<16:07,  7.17s/it]

Epoch: 015, Train Loss: 0.1569, Train AUC: 0.6496, Validation AUC: 0.6424


 11%|█         | 16/150 [01:58<16:01,  7.17s/it]

Epoch: 016, Train Loss: 0.1566, Train AUC: 0.6422, Validation AUC: 0.6319


 11%|█▏        | 17/150 [02:05<15:30,  7.00s/it]

Epoch: 017, Train Loss: 0.1570, Train AUC: 0.6477, Validation AUC: 0.6344


 12%|█▏        | 18/150 [02:11<14:55,  6.79s/it]

Epoch: 018, Train Loss: 0.1574, Train AUC: 0.6405, Validation AUC: 0.6273


 13%|█▎        | 19/150 [02:17<14:26,  6.61s/it]

Epoch: 019, Train Loss: 0.1569, Train AUC: 0.6496, Validation AUC: 0.6351


 13%|█▎        | 20/150 [02:24<14:06,  6.51s/it]

Epoch: 020, Train Loss: 0.1566, Train AUC: 0.6467, Validation AUC: 0.6277


 14%|█▍        | 21/150 [02:30<13:56,  6.49s/it]

Epoch: 021, Train Loss: 0.1564, Train AUC: 0.6611, Validation AUC: 0.6486


 15%|█▍        | 22/150 [02:37<14:00,  6.57s/it]

Epoch: 022, Train Loss: 0.1570, Train AUC: 0.6638, Validation AUC: 0.6481


 15%|█▌        | 23/150 [02:43<13:56,  6.59s/it]

Epoch: 023, Train Loss: 0.1562, Train AUC: 0.6609, Validation AUC: 0.6438


 16%|█▌        | 24/150 [02:50<13:33,  6.46s/it]

Epoch: 024, Train Loss: 0.1563, Train AUC: 0.6537, Validation AUC: 0.6320


 17%|█▋        | 25/150 [02:56<13:22,  6.42s/it]

Epoch: 025, Train Loss: 0.1556, Train AUC: 0.6530, Validation AUC: 0.6320


 17%|█▋        | 26/150 [03:02<13:07,  6.35s/it]

Epoch: 026, Train Loss: 0.1557, Train AUC: 0.6634, Validation AUC: 0.6428


 18%|█▊        | 27/150 [03:09<13:27,  6.56s/it]

Epoch: 027, Train Loss: 0.1558, Train AUC: 0.6704, Validation AUC: 0.6564
Saved best model with Validation AUC: 0.6564


 19%|█▊        | 28/150 [03:16<13:45,  6.77s/it]

Epoch: 028, Train Loss: 0.1549, Train AUC: 0.6666, Validation AUC: 0.6421


 19%|█▉        | 29/150 [03:23<13:27,  6.67s/it]

Epoch: 029, Train Loss: 0.1555, Train AUC: 0.6703, Validation AUC: 0.6448


 20%|██        | 30/150 [03:29<13:05,  6.54s/it]

Epoch: 030, Train Loss: 0.1547, Train AUC: 0.6771, Validation AUC: 0.6507


 21%|██        | 31/150 [03:36<13:08,  6.63s/it]

Epoch: 031, Train Loss: 0.1547, Train AUC: 0.6795, Validation AUC: 0.6561


 21%|██▏       | 32/150 [03:42<12:51,  6.53s/it]

Epoch: 032, Train Loss: 0.1541, Train AUC: 0.6763, Validation AUC: 0.6514


 22%|██▏       | 33/150 [03:48<12:36,  6.46s/it]

Epoch: 033, Train Loss: 0.1543, Train AUC: 0.6705, Validation AUC: 0.6418


 23%|██▎       | 34/150 [03:55<12:20,  6.39s/it]

Epoch: 034, Train Loss: 0.1536, Train AUC: 0.6842, Validation AUC: 0.6510


 23%|██▎       | 35/150 [04:02<12:46,  6.67s/it]

Epoch: 035, Train Loss: 0.1540, Train AUC: 0.6839, Validation AUC: 0.6614
Saved best model with Validation AUC: 0.6614


 24%|██▍       | 36/150 [04:08<12:25,  6.54s/it]

Epoch: 036, Train Loss: 0.1532, Train AUC: 0.6871, Validation AUC: 0.6642
Saved best model with Validation AUC: 0.6642


 25%|██▍       | 37/150 [04:14<12:07,  6.44s/it]

Epoch: 037, Train Loss: 0.1539, Train AUC: 0.6921, Validation AUC: 0.6699
Saved best model with Validation AUC: 0.6699


 25%|██▌       | 38/150 [04:21<11:48,  6.33s/it]

Epoch: 038, Train Loss: 0.1530, Train AUC: 0.6932, Validation AUC: 0.6604


 26%|██▌       | 39/150 [04:27<11:51,  6.41s/it]

Epoch: 039, Train Loss: 0.1526, Train AUC: 0.6886, Validation AUC: 0.6677


 27%|██▋       | 40/150 [04:33<11:41,  6.38s/it]

Epoch: 040, Train Loss: 0.1532, Train AUC: 0.6907, Validation AUC: 0.6559


 27%|██▋       | 41/150 [04:41<12:03,  6.64s/it]

Epoch: 041, Train Loss: 0.1525, Train AUC: 0.6959, Validation AUC: 0.6612


 28%|██▊       | 42/150 [04:47<12:01,  6.68s/it]

Epoch: 042, Train Loss: 0.1524, Train AUC: 0.6956, Validation AUC: 0.6648


 29%|██▊       | 43/150 [04:54<11:42,  6.56s/it]

Epoch: 043, Train Loss: 0.1519, Train AUC: 0.7001, Validation AUC: 0.6737
Saved best model with Validation AUC: 0.6737


 29%|██▉       | 44/150 [05:00<11:23,  6.45s/it]

Epoch: 044, Train Loss: 0.1519, Train AUC: 0.6987, Validation AUC: 0.6827
Saved best model with Validation AUC: 0.6827


 30%|███       | 45/150 [05:06<11:13,  6.42s/it]

Epoch: 045, Train Loss: 0.1519, Train AUC: 0.7014, Validation AUC: 0.6809


 31%|███       | 46/150 [05:14<11:42,  6.76s/it]

Epoch: 046, Train Loss: 0.1510, Train AUC: 0.6951, Validation AUC: 0.6635


 31%|███▏      | 47/150 [05:20<11:22,  6.62s/it]

Epoch: 047, Train Loss: 0.1512, Train AUC: 0.6979, Validation AUC: 0.6694


 32%|███▏      | 48/150 [05:27<11:18,  6.65s/it]

Epoch: 048, Train Loss: 0.1513, Train AUC: 0.7065, Validation AUC: 0.6838
Saved best model with Validation AUC: 0.6838


 33%|███▎      | 49/150 [05:33<10:58,  6.52s/it]

Epoch: 049, Train Loss: 0.1506, Train AUC: 0.7057, Validation AUC: 0.6707


 33%|███▎      | 50/150 [05:39<10:42,  6.42s/it]

Epoch: 050, Train Loss: 0.1502, Train AUC: 0.7054, Validation AUC: 0.6821


 34%|███▍      | 51/150 [05:45<10:27,  6.34s/it]

Epoch: 051, Train Loss: 0.1502, Train AUC: 0.7039, Validation AUC: 0.6696


 35%|███▍      | 52/150 [05:52<10:16,  6.29s/it]

Epoch: 052, Train Loss: 0.1499, Train AUC: 0.7072, Validation AUC: 0.6897
Saved best model with Validation AUC: 0.6897


 35%|███▌      | 53/150 [05:58<10:09,  6.28s/it]

Epoch: 053, Train Loss: 0.1503, Train AUC: 0.7096, Validation AUC: 0.6855


 36%|███▌      | 54/150 [06:04<10:12,  6.38s/it]

Epoch: 054, Train Loss: 0.1502, Train AUC: 0.7110, Validation AUC: 0.6786


 37%|███▋      | 55/150 [06:11<10:22,  6.55s/it]

Epoch: 055, Train Loss: 0.1493, Train AUC: 0.7077, Validation AUC: 0.7004
Saved best model with Validation AUC: 0.7004


 37%|███▋      | 56/150 [06:18<10:11,  6.50s/it]

Epoch: 056, Train Loss: 0.1492, Train AUC: 0.7127, Validation AUC: 0.6960


 38%|███▊      | 57/150 [06:25<10:10,  6.57s/it]

Epoch: 057, Train Loss: 0.1491, Train AUC: 0.7080, Validation AUC: 0.6953


 39%|███▊      | 58/150 [06:31<09:55,  6.47s/it]

Epoch: 058, Train Loss: 0.1492, Train AUC: 0.7129, Validation AUC: 0.7012
Saved best model with Validation AUC: 0.7012


 39%|███▉      | 59/150 [06:37<09:42,  6.40s/it]

Epoch: 059, Train Loss: 0.1489, Train AUC: 0.7145, Validation AUC: 0.7015
Saved best model with Validation AUC: 0.7015


 40%|████      | 60/150 [06:43<09:32,  6.36s/it]

Epoch: 060, Train Loss: 0.1486, Train AUC: 0.7147, Validation AUC: 0.6943


 41%|████      | 61/150 [06:49<09:22,  6.32s/it]

Epoch: 061, Train Loss: 0.1493, Train AUC: 0.7122, Validation AUC: 0.6957


 41%|████▏     | 62/150 [06:56<09:32,  6.51s/it]

Epoch: 062, Train Loss: 0.1485, Train AUC: 0.7172, Validation AUC: 0.7013


 42%|████▏     | 63/150 [07:03<09:37,  6.64s/it]

Epoch: 063, Train Loss: 0.1486, Train AUC: 0.7142, Validation AUC: 0.6843


 43%|████▎     | 64/150 [07:10<09:21,  6.52s/it]

Epoch: 064, Train Loss: 0.1482, Train AUC: 0.7126, Validation AUC: 0.7041
Saved best model with Validation AUC: 0.7041


 43%|████▎     | 65/150 [07:16<09:16,  6.55s/it]

Epoch: 065, Train Loss: 0.1485, Train AUC: 0.7165, Validation AUC: 0.7138
Saved best model with Validation AUC: 0.7138


 44%|████▍     | 66/150 [07:23<09:03,  6.47s/it]

Epoch: 066, Train Loss: 0.1485, Train AUC: 0.7199, Validation AUC: 0.7053


 45%|████▍     | 67/150 [07:29<08:49,  6.38s/it]

Epoch: 067, Train Loss: 0.1483, Train AUC: 0.7194, Validation AUC: 0.7027


 45%|████▌     | 68/150 [07:35<08:37,  6.32s/it]

Epoch: 068, Train Loss: 0.1474, Train AUC: 0.7212, Validation AUC: 0.7091


 46%|████▌     | 69/150 [07:41<08:29,  6.29s/it]

Epoch: 069, Train Loss: 0.1482, Train AUC: 0.7127, Validation AUC: 0.6868


 47%|████▋     | 70/150 [07:48<08:31,  6.39s/it]

Epoch: 070, Train Loss: 0.1485, Train AUC: 0.7220, Validation AUC: 0.6929


 47%|████▋     | 71/150 [07:54<08:20,  6.33s/it]

Epoch: 071, Train Loss: 0.1475, Train AUC: 0.7230, Validation AUC: 0.6951


 48%|████▊     | 72/150 [08:00<08:13,  6.33s/it]

Epoch: 072, Train Loss: 0.1477, Train AUC: 0.7229, Validation AUC: 0.7011


 49%|████▊     | 73/150 [08:07<08:11,  6.38s/it]

Epoch: 073, Train Loss: 0.1473, Train AUC: 0.7243, Validation AUC: 0.6997


 49%|████▉     | 74/150 [08:14<08:19,  6.57s/it]

Epoch: 074, Train Loss: 0.1472, Train AUC: 0.7255, Validation AUC: 0.7063


 50%|█████     | 75/150 [08:20<08:16,  6.61s/it]

Epoch: 075, Train Loss: 0.1470, Train AUC: 0.7225, Validation AUC: 0.7031


 51%|█████     | 76/150 [08:27<08:13,  6.67s/it]

Epoch: 076, Train Loss: 0.1475, Train AUC: 0.7250, Validation AUC: 0.6973


 51%|█████▏    | 77/150 [08:34<08:01,  6.59s/it]

Epoch: 077, Train Loss: 0.1466, Train AUC: 0.7264, Validation AUC: 0.7011


 52%|█████▏    | 78/150 [08:40<07:59,  6.66s/it]

Epoch: 078, Train Loss: 0.1470, Train AUC: 0.7248, Validation AUC: 0.7042


 53%|█████▎    | 79/150 [08:47<07:43,  6.53s/it]

Epoch: 079, Train Loss: 0.1463, Train AUC: 0.7244, Validation AUC: 0.7117


 53%|█████▎    | 80/150 [08:53<07:32,  6.46s/it]

Epoch: 080, Train Loss: 0.1461, Train AUC: 0.7250, Validation AUC: 0.7155
Saved best model with Validation AUC: 0.7155


 54%|█████▍    | 81/150 [08:59<07:20,  6.38s/it]

Epoch: 081, Train Loss: 0.1472, Train AUC: 0.7282, Validation AUC: 0.6988


 55%|█████▍    | 82/150 [09:09<08:24,  7.42s/it]

Epoch: 082, Train Loss: 0.1473, Train AUC: 0.7262, Validation AUC: 0.7094


 55%|█████▌    | 83/150 [09:18<08:41,  7.78s/it]

Epoch: 083, Train Loss: 0.1470, Train AUC: 0.7280, Validation AUC: 0.6991


 56%|█████▌    | 84/150 [09:24<08:14,  7.49s/it]

Epoch: 084, Train Loss: 0.1460, Train AUC: 0.7277, Validation AUC: 0.7151


 57%|█████▋    | 85/150 [09:32<07:58,  7.36s/it]

Epoch: 085, Train Loss: 0.1466, Train AUC: 0.7296, Validation AUC: 0.7091


 57%|█████▋    | 86/150 [09:41<08:23,  7.87s/it]

Epoch: 086, Train Loss: 0.1460, Train AUC: 0.7283, Validation AUC: 0.7068


 58%|█████▊    | 87/150 [09:49<08:31,  8.12s/it]

Epoch: 087, Train Loss: 0.1458, Train AUC: 0.7305, Validation AUC: 0.7147


 59%|█████▊    | 88/150 [09:59<08:46,  8.49s/it]

Epoch: 088, Train Loss: 0.1457, Train AUC: 0.7308, Validation AUC: 0.7163
Saved best model with Validation AUC: 0.7163


 59%|█████▉    | 89/150 [10:07<08:40,  8.54s/it]

Epoch: 089, Train Loss: 0.1461, Train AUC: 0.7288, Validation AUC: 0.7143


 60%|██████    | 90/150 [10:17<09:01,  9.03s/it]

Epoch: 090, Train Loss: 0.1458, Train AUC: 0.7310, Validation AUC: 0.7047


 61%|██████    | 91/150 [10:27<09:01,  9.18s/it]

Epoch: 091, Train Loss: 0.1460, Train AUC: 0.7278, Validation AUC: 0.7188
Saved best model with Validation AUC: 0.7188


 61%|██████▏   | 92/150 [10:39<09:38,  9.97s/it]

Epoch: 092, Train Loss: 0.1463, Train AUC: 0.7298, Validation AUC: 0.7187


 62%|██████▏   | 93/150 [10:51<10:06, 10.64s/it]

Epoch: 093, Train Loss: 0.1454, Train AUC: 0.7326, Validation AUC: 0.7143


 63%|██████▎   | 94/150 [11:02<09:55, 10.64s/it]

Epoch: 094, Train Loss: 0.1456, Train AUC: 0.7326, Validation AUC: 0.7180


 63%|██████▎   | 95/150 [11:11<09:29, 10.35s/it]

Epoch: 095, Train Loss: 0.1457, Train AUC: 0.7340, Validation AUC: 0.7214
Saved best model with Validation AUC: 0.7214


 64%|██████▍   | 96/150 [11:23<09:37, 10.69s/it]

Epoch: 096, Train Loss: 0.1460, Train AUC: 0.7330, Validation AUC: 0.7182


 65%|██████▍   | 97/150 [11:34<09:33, 10.83s/it]

Epoch: 097, Train Loss: 0.1457, Train AUC: 0.7357, Validation AUC: 0.7133


 65%|██████▌   | 98/150 [11:44<09:17, 10.73s/it]

Epoch: 098, Train Loss: 0.1455, Train AUC: 0.7348, Validation AUC: 0.7193


 66%|██████▌   | 99/150 [11:53<08:37, 10.15s/it]

Epoch: 099, Train Loss: 0.1459, Train AUC: 0.7312, Validation AUC: 0.7103


 67%|██████▋   | 100/150 [12:04<08:35, 10.31s/it]

Epoch: 100, Train Loss: 0.1460, Train AUC: 0.7353, Validation AUC: 0.7151


 67%|██████▋   | 101/150 [12:15<08:37, 10.57s/it]

Epoch: 101, Train Loss: 0.1447, Train AUC: 0.7348, Validation AUC: 0.7198


 68%|██████▊   | 102/150 [12:24<08:07, 10.15s/it]

Epoch: 102, Train Loss: 0.1449, Train AUC: 0.7333, Validation AUC: 0.7126


 69%|██████▊   | 103/150 [12:32<07:20,  9.38s/it]

Epoch: 103, Train Loss: 0.1454, Train AUC: 0.7335, Validation AUC: 0.7056


 69%|██████▉   | 104/150 [12:39<06:41,  8.73s/it]

Epoch: 104, Train Loss: 0.1458, Train AUC: 0.7352, Validation AUC: 0.7219
Saved best model with Validation AUC: 0.7219


 70%|███████   | 105/150 [12:46<06:03,  8.07s/it]

Epoch: 105, Train Loss: 0.1450, Train AUC: 0.7374, Validation AUC: 0.7236
Saved best model with Validation AUC: 0.7236


 71%|███████   | 106/150 [12:53<05:49,  7.93s/it]

Epoch: 106, Train Loss: 0.1454, Train AUC: 0.7387, Validation AUC: 0.7236
Saved best model with Validation AUC: 0.7236


 71%|███████▏  | 107/150 [13:01<05:39,  7.90s/it]

Epoch: 107, Train Loss: 0.1448, Train AUC: 0.7388, Validation AUC: 0.7207


 72%|███████▏  | 108/150 [13:09<05:26,  7.76s/it]

Epoch: 108, Train Loss: 0.1447, Train AUC: 0.7379, Validation AUC: 0.7192


 73%|███████▎  | 109/150 [13:18<05:36,  8.21s/it]

Epoch: 109, Train Loss: 0.1447, Train AUC: 0.7369, Validation AUC: 0.7184


 73%|███████▎  | 110/150 [13:26<05:23,  8.08s/it]

Epoch: 110, Train Loss: 0.1448, Train AUC: 0.7384, Validation AUC: 0.7226


 74%|███████▍  | 111/150 [13:33<05:05,  7.82s/it]

Epoch: 111, Train Loss: 0.1447, Train AUC: 0.7381, Validation AUC: 0.7249
Saved best model with Validation AUC: 0.7249


 75%|███████▍  | 112/150 [13:40<04:52,  7.70s/it]

Epoch: 112, Train Loss: 0.1444, Train AUC: 0.7410, Validation AUC: 0.7199


 75%|███████▌  | 113/150 [13:48<04:41,  7.62s/it]

Epoch: 113, Train Loss: 0.1448, Train AUC: 0.7388, Validation AUC: 0.7145


 76%|███████▌  | 114/150 [13:55<04:34,  7.64s/it]

Epoch: 114, Train Loss: 0.1445, Train AUC: 0.7402, Validation AUC: 0.7250
Saved best model with Validation AUC: 0.7250


 77%|███████▋  | 115/150 [14:03<04:27,  7.63s/it]

Epoch: 115, Train Loss: 0.1444, Train AUC: 0.7401, Validation AUC: 0.7229


 77%|███████▋  | 116/150 [14:14<04:52,  8.60s/it]

Epoch: 116, Train Loss: 0.1442, Train AUC: 0.7391, Validation AUC: 0.7216


 78%|███████▊  | 117/150 [14:22<04:37,  8.41s/it]

Epoch: 117, Train Loss: 0.1449, Train AUC: 0.7402, Validation AUC: 0.7202


 79%|███████▊  | 118/150 [14:29<04:18,  8.09s/it]

Epoch: 118, Train Loss: 0.1444, Train AUC: 0.7415, Validation AUC: 0.7194


 79%|███████▉  | 119/150 [14:36<03:58,  7.70s/it]

Epoch: 119, Train Loss: 0.1451, Train AUC: 0.7437, Validation AUC: 0.7255
Saved best model with Validation AUC: 0.7255


 80%|████████  | 120/150 [14:42<03:39,  7.31s/it]

Epoch: 120, Train Loss: 0.1444, Train AUC: 0.7412, Validation AUC: 0.7209


 81%|████████  | 121/150 [14:49<03:23,  7.03s/it]

Epoch: 121, Train Loss: 0.1444, Train AUC: 0.7414, Validation AUC: 0.7347
Saved best model with Validation AUC: 0.7347


 81%|████████▏ | 122/150 [14:55<03:12,  6.86s/it]

Epoch: 122, Train Loss: 0.1433, Train AUC: 0.7416, Validation AUC: 0.7219


 82%|████████▏ | 123/150 [15:02<03:01,  6.73s/it]

Epoch: 123, Train Loss: 0.1435, Train AUC: 0.7445, Validation AUC: 0.7311


 83%|████████▎ | 124/150 [15:10<03:04,  7.10s/it]

Epoch: 124, Train Loss: 0.1441, Train AUC: 0.7423, Validation AUC: 0.7451
Saved best model with Validation AUC: 0.7451


 83%|████████▎ | 125/150 [15:18<03:07,  7.50s/it]

Epoch: 125, Train Loss: 0.1437, Train AUC: 0.7452, Validation AUC: 0.7318


 84%|████████▍ | 126/150 [15:24<02:51,  7.14s/it]

Epoch: 126, Train Loss: 0.1442, Train AUC: 0.7420, Validation AUC: 0.7413


 85%|████████▍ | 127/150 [15:31<02:41,  7.01s/it]

Epoch: 127, Train Loss: 0.1439, Train AUC: 0.7422, Validation AUC: 0.7268


 85%|████████▌ | 128/150 [15:37<02:31,  6.87s/it]

Epoch: 128, Train Loss: 0.1437, Train AUC: 0.7441, Validation AUC: 0.7287


 86%|████████▌ | 129/150 [15:44<02:24,  6.88s/it]

Epoch: 129, Train Loss: 0.1444, Train AUC: 0.7467, Validation AUC: 0.7386


 87%|████████▋ | 130/150 [15:51<02:17,  6.90s/it]

Epoch: 130, Train Loss: 0.1435, Train AUC: 0.7454, Validation AUC: 0.7282


 87%|████████▋ | 131/150 [15:59<02:15,  7.14s/it]

Epoch: 131, Train Loss: 0.1433, Train AUC: 0.7407, Validation AUC: 0.7263


 88%|████████▊ | 132/150 [16:08<02:17,  7.66s/it]

Epoch: 132, Train Loss: 0.1441, Train AUC: 0.7438, Validation AUC: 0.7284


 89%|████████▊ | 133/150 [16:16<02:13,  7.85s/it]

Epoch: 133, Train Loss: 0.1434, Train AUC: 0.7433, Validation AUC: 0.7315


 89%|████████▉ | 134/150 [16:24<02:06,  7.88s/it]

Epoch: 134, Train Loss: 0.1440, Train AUC: 0.7419, Validation AUC: 0.7359


 90%|█████████ | 135/150 [16:32<01:58,  7.89s/it]

Epoch: 135, Train Loss: 0.1448, Train AUC: 0.7469, Validation AUC: 0.7264


 91%|█████████ | 136/150 [16:41<01:56,  8.31s/it]

Epoch: 136, Train Loss: 0.1429, Train AUC: 0.7445, Validation AUC: 0.7343


 91%|█████████▏| 137/150 [16:49<01:46,  8.22s/it]

Epoch: 137, Train Loss: 0.1431, Train AUC: 0.7481, Validation AUC: 0.7387


 92%|█████████▏| 138/150 [16:58<01:38,  8.24s/it]

Epoch: 138, Train Loss: 0.1432, Train AUC: 0.7485, Validation AUC: 0.7403


 93%|█████████▎| 139/150 [17:07<01:34,  8.57s/it]

Epoch: 139, Train Loss: 0.1430, Train AUC: 0.7493, Validation AUC: 0.7322


 93%|█████████▎| 140/150 [17:17<01:29,  8.92s/it]

Epoch: 140, Train Loss: 0.1429, Train AUC: 0.7407, Validation AUC: 0.7255


 94%|█████████▍| 141/150 [17:24<01:16,  8.47s/it]

Epoch: 141, Train Loss: 0.1431, Train AUC: 0.7472, Validation AUC: 0.7427


 95%|█████████▍| 142/150 [17:31<01:03,  7.90s/it]

Epoch: 142, Train Loss: 0.1428, Train AUC: 0.7495, Validation AUC: 0.7251


 95%|█████████▌| 143/150 [17:37<00:52,  7.55s/it]

Epoch: 143, Train Loss: 0.1426, Train AUC: 0.7476, Validation AUC: 0.7363


 95%|█████████▌| 143/150 [17:44<00:52,  7.44s/it]

Epoch: 144, Train Loss: 0.1439, Train AUC: 0.7516, Validation AUC: 0.7373
Early stopping at epoch 144 due to no improvement in Validation AUC.





In [40]:
# Load the best model for final evaluation
model.load_state_dict(torch.load('best_model.pth'))

# Final evaluation on the test set
test_result = evaluate(model, test_loader, evaluator, device)
test_auc = test_result['rocauc']
print(f'Final Test AUC: {test_auc:.4f}')

  model.load_state_dict(torch.load('best_model.pth'))


Final Test AUC: 0.6865


### OPEN ENDED QUESTION:
#### Reflect on the entire implementation of the GCN model for molecular property prediction.
#### - What potential improvements or modifications could you suggest to enhance the model's performance? Consider aspects such as network architecture, hyperparameter tuning, data preprocessing, and evaluation strategies. (2 points)

#### **ANS: Potential Improvements & Modifications for Enhancing Performance**
Several aspects of the current GCN model can be improved to achieve better performance:

1. **Network Architecture Modifications:**
   - Introduce **Graph Attention Networks (GAT)** instead of GCN to allow adaptive weighting of neighboring nodes.
   - Use **Graph Isomorphism Network (GIN)**, which has been shown to be more powerful in capturing molecular graph structures.
   - Add **residual connections** between layers to prevent vanishing gradients and allow deeper architectures.

2. **Hyperparameter Tuning:**
   - Perform a **grid search** or **Bayesian optimization** over hidden dimensions, learning rate, dropout, and weight decay.
   - Try **learning rate warm-up** before applying decay to stabilize training.
   - Experiment with **larger batch sizes** (e.g., 512) to achieve smoother gradient updates.

3. **Data Preprocessing & Augmentation:**
   - Apply **feature engineering** on molecular graphs (e.g., adding edge attributes for bond types).
   - Consider **self-supervised pretraining** using contrastive learning techniques like **GraphCL** to improve node representations before fine-tuning.

4. **Evaluation Strategies:**
   - Use **stratified k-fold cross-validation** to ensure robustness in model evaluation.
   - Implement **class balancing techniques** to handle label imbalance in molecular property datasets.
   - Monitor **precision-recall curves** in addition to ROC-AUC for better insight into class performance.

---
#### - How might you leverage additional graph-specific techniques or modern architectures to push the boundaries of performance on this task? (2 points)

#### **ANS: Leveraging Additional Graph-Specific Techniques & Modern Architectures**
Beyond traditional GCNs, we can incorporate **more advanced graph neural network techniques**:

1. **Graph Transformers:**
   - Use **Graphormer** or **SAN (Structure-Aware Transformer)** to capture long-range dependencies in molecular graphs.
   - These models **outperform standard GNNs** on certain molecular property prediction tasks.

2. **Contrastive Learning for Graph Representations:**
   - Implement **Graph Contrastive Learning (GraphCL, InfoGraph)** to pretrain the model on unlabeled molecular graphs.
   - This can significantly improve feature representations and generalization.

3. **Multi-View Graph Learning:**
   - Use **multi-scale GCNs** that extract information at **different neighborhood levels**.
   - Combine **node-level and subgraph-level embeddings** for richer representations.

4. **Combining GNNs with Domain Knowledge:**
   - Incorporate **physically meaningful descriptors** (e.g., quantum chemical properties) as input features.
   - Hybrid models that integrate **GNNs with traditional cheminformatics features** (e.g., molecular fingerprints) can enhance prediction accuracy.

5. **Meta-Learning for Adaptive GNN Training:**
   - Use **meta-learning (MAML, Reptile)** to adapt the model quickly to different molecular tasks.
   - Helps in few-shot learning scenarios where labeled data is limited.


### **Summary**
To push the boundaries of performance, we can:
Adopt **advanced GNN architectures** (GAT, GIN, Graph Transformers).  
Use **contrastive learning & self-supervised pretraining**.  
Explore **multi-scale graph learning** & integrate domain knowledge.  
Improve **evaluation strategies & hyperparameter tuning**.

By combining these approaches, we can significantly enhance the model's effectiveness for molecular property prediction