# Graph Attention Network (GAT) Implementation

This notebook implements Graph Attention Networks (GATs) for node classification tasks using the Cora citation network dataset.

## Setup and Data Loading

Import necessary libraries and load the Cora citation network dataset.

In [None]:
# Core libraries
import torch
import torch.nn.functional as F
from torch.nn import Linear, Dropout

# PyTorch Geometric for graph neural networks
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATv2Conv

# Set random seed for reproducibility
torch.manual_seed(42)

In [None]:
# Load Cora citation network dataset
dataset = Planetoid(root='.', name='Cora')
data = dataset[0]

In [None]:
# Display dataset statistics
print("\n" + "="*50)
print(f"{('CORA DATASET STATISTICS'):^50}")
print("="*50)
print(f"Dataset Name:      {dataset.name}")
print(f"Number of Graphs:  {len(dataset):,}")
print(f"Number of Nodes:   {data.x.shape[0]:,}")
print(f"Number of Edges:   {data.edge_index.shape[1]:,}")
print(f"Node Features:     {dataset.num_features}")
print(f"Number of Classes: {dataset.num_classes}")
print(f"Train Nodes:       {data.train_mask.sum().item():,}")
print(f"Validation Nodes:  {data.val_mask.sum().item():,}")
print(f"Test Nodes:        {data.test_mask.sum().item():,}")
print("="*50)

## Utility Functions

Helper functions for model training and evaluation.

In [None]:
def accuracy(y_pred, y_true):
    """Calculate classification accuracy"""
    return torch.sum(y_pred == y_true) / len(y_true)

## Graph Attention Network (GAT)

Implementation of Graph Attention Network using PyTorch Geometric with multi-head attention.

In [None]:
class GAT(torch.nn.Module):
    """Graph Attention Network for node classification"""
    
    def __init__(self, dim_in, dim_h, dim_out, heads=8):
        super().__init__()
        self.gat1 = GATv2Conv(dim_in, dim_h, heads=heads)  # First GAT layer with multi-head attention
        self.gat2 = GATv2Conv(dim_h * heads, dim_out, heads=1)  # Second GAT layer (single head for output)
        self.dropout = Dropout(0.6)  # Higher dropout for GAT
    
    def forward(self, x, edge_index):
        # Apply dropout to input features
        h = self.dropout(x)
        
        # First GAT layer with ELU activation
        h = self.gat1(h, edge_index)
        h = F.elu(h)
        h = self.dropout(h)
        
        # Second GAT layer
        h = self.gat2(h, edge_index)
        return F.log_softmax(h, dim=1)
    
    def fit(self, data, epochs):
        """Train the GAT model"""
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=0.01)
        
        self.train()
        print("\n" + "="*60)
        print(f"{('Training Graph Attention Network'):^60}")
        print("="*60)
        print(f"{('Epoch'):>5} {('Train Loss'):>12} {('Train Acc'):>12} {('Val Loss'):>12} {('Val Acc'):>12}")
        print("-"*60)
        
        for epoch in range(epochs + 1):
            optimizer.zero_grad()
            out = self(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            acc = accuracy(out[data.train_mask].argmax(dim=1), data.y[data.train_mask])
            loss.backward()
            optimizer.step()
            
            if epoch % 20 == 0:
                val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
                val_acc = accuracy(out[data.val_mask].argmax(dim=1), data.y[data.val_mask])
                print(f"{epoch:5d} {loss.item():12.4f} {acc.item()*100:11.2f}% {val_loss.item():12.4f} {val_acc.item()*100:11.2f}%")
        
        print("-"*60)
    
    @torch.no_grad()
    def test(self, data):
        """Evaluate on test set"""
        self.eval()
        out = self(data.x, data.edge_index)
        acc = accuracy(out[data.test_mask].argmax(dim=1), data.y[data.test_mask])
        return acc

## Model Training and Evaluation

Initialize, train and evaluate the GAT model.

In [None]:
# Initialize GAT model
gat = GAT(dataset.num_features, 32, dataset.num_classes, heads=8)

print("\n" + "="*40)
print(f"{('GAT ARCHITECTURE'):^40}")
print("="*40)
print(gat)
print(f"Total Parameters: {sum(p.numel() for p in gat.parameters()):,}")
print(f"Attention Heads (Layer 1): 8")
print(f"Attention Heads (Layer 2): 1")
print("="*40)

In [None]:
# Train GAT model
gat.fit(data, 100)

# Test the model
gat_test_acc = gat.test(data)

print(f"\n{('='*30)}")
print(f"{('GAT FINAL RESULTS'):^30}")
print(f"{('='*30)}")
print(f"Test Accuracy: {gat_test_acc.item()*100:6.2f}%")
print(f"{('='*30)}")