# GraphSage Implementation

This notebook implements and evaluates GraphSage (Graph Sample and Aggregate) for node classification tasks using the Pubmed citation network dataset. GraphSage enables inductive learning on large graphs through neighbor sampling and aggregation.

## Setup and Data Loading

Import necessary libraries and load the Pubmed citation network dataset for node classification.

In [None]:
# Core libraries
import torch
import torch.nn.functional as F
from torch.nn import Linear, Dropout

# PyTorch Geometric for graph neural networks
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx

# Visualization and utilities
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from matplotlib.patches import Rectangle

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set matplotlib style
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

In [None]:
# Load Pubmed citation network dataset
dataset = Planetoid(root='.', name='Pubmed')
data = dataset[0]

In [None]:
# Display dataset statistics
print("\n" + "="*50)
print(f"{'PUBMED DATASET STATISTICS':^50}")
print("="*50)
print(f"Dataset Name:      {dataset.name}")
print(f"Number of Graphs:  {len(dataset):,}")
print(f"Number of Nodes:   {data.x.shape[0]:,}")
print(f"Number of Edges:   {data.edge_index.shape[1]:,}")
print(f"Node Features:     {dataset.num_features}")
print(f"Number of Classes: {dataset.num_classes}")
print(f"Train Nodes:       {data.train_mask.sum().item():,}")
print(f"Validation Nodes:  {data.val_mask.sum().item():,}")
print(f"Test Nodes:        {data.test_mask.sum().item():,}")
print("="*50)

## Neighbor Sampling

GraphSage uses neighbor sampling to create mini-batches for scalable training on large graphs. This enables inductive learning by sampling fixed-size neighborhoods.

In [None]:
# Create neighbor sampling loaders
train_loader = NeighborLoader(
    data,
    num_neighbors=[5, 10],  # Sample 5 neighbors in first layer, 10 in second
    batch_size=16,
    input_nodes=data.train_mask,
    shuffle=True
)

# Validation and test loaders (smaller batches for memory efficiency)
val_loader = NeighborLoader(
    data,
    num_neighbors=[5, 10],
    batch_size=16,
    input_nodes=data.val_mask,
    shuffle=False
)

test_loader = NeighborLoader(
    data,
    num_neighbors=[5, 10],
    batch_size=16,
    input_nodes=data.test_mask,
    shuffle=False
)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Visualize neighbor sampling structure
print("\n" + "="*60)
print(f"{'NEIGHBOR SAMPLING VISUALIZATION':^60}")
print("="*60)

# Print first few subgraphs to understand sampling
for i, subgraph in enumerate(train_loader):
    if i >= 4:  # Show only first 4 batches
        break
    print(f"Subgraph {i}: {subgraph}")
    
print("\n" + "-"*60)
print("Each subgraph contains:")
print("- x: Node features for sampled neighborhood")
print("- edge_index: Edges within the sampled subgraph")
print("- y: Labels for all nodes in subgraph")
print("- batch_size: Number of target nodes for this batch")
print("- input_id: Original node IDs of target nodes")
print("="*60)

## Utility Functions

Helper functions for model training and evaluation with mini-batch processing.

In [None]:
def accuracy(y_pred, y_true):
    """Calculate classification accuracy"""
    return torch.sum(y_pred == y_true) / len(y_true)

@torch.no_grad()
def evaluate_model(model, loader, device):
    """Evaluate model on given data loader"""
    model.eval()
    total_correct = 0
    total_samples = 0
    
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index)
        pred = out[:batch.batch_size].argmax(dim=1)
        y = batch.y[:batch.batch_size]
        
        total_correct += (pred == y).sum().item()
        total_samples += batch.batch_size
    
    return total_correct / total_samples

def plot_training_curves(train_losses, train_accs, val_accs):
    """Plot training curves for loss and accuracy"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot training loss
    ax1.plot(train_losses, 'b-', label='Training Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Plot accuracies
    ax2.plot(train_accs, 'b-', label='Training Accuracy', linewidth=2)
    ax2.plot(val_accs, 'r-', label='Validation Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training vs Validation Accuracy')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

def plot_subgraph_structure(subgraph, title="Sampled Subgraph Structure"):
    """Visualize the structure of a sampled subgraph"""
    # Convert to NetworkX graph
    edge_index = subgraph.edge_index.cpu().numpy()
    G = nx.Graph()
    
    # Add edges
    edges = [(edge_index[0][i], edge_index[1][i]) for i in range(edge_index.shape[1])]
    G.add_edges_from(edges)
    
    # Create layout
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    plt.figure(figsize=(10, 8))
    
    # Color nodes based on whether they are target nodes or neighbors
    target_nodes = list(range(subgraph.batch_size))
    neighbor_nodes = list(range(subgraph.batch_size, subgraph.x.shape[0]))
    
    # Draw neighbor nodes (smaller, light blue)
    if neighbor_nodes:
        nx.draw_networkx_nodes(G, pos, nodelist=neighbor_nodes, 
                             node_color='lightblue', node_size=100, alpha=0.7)
    
    # Draw target nodes (larger, red)
    if target_nodes:
        nx.draw_networkx_nodes(G, pos, nodelist=target_nodes, 
                             node_color='red', node_size=300, alpha=0.8)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, alpha=0.5, width=1)
    
    # Add labels for target nodes only
    target_labels = {i: f'T{i}' for i in target_nodes}
    nx.draw_networkx_labels(G, pos, labels=target_labels, font_size=8, font_weight='bold')
    
    plt.title(f"{title}\nRed: Target Nodes ({len(target_nodes)}), Blue: Sampled Neighbors ({len(neighbor_nodes)})")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

def plot_sampling_statistics(train_loader, num_batches=10):
    """Plot statistics about the sampling process"""
    batch_sizes = []
    subgraph_sizes = []
    edge_counts = []
    
    for i, batch in enumerate(train_loader):
        if i >= num_batches:
            break
        batch_sizes.append(batch.batch_size)
        subgraph_sizes.append(batch.x.shape[0])
        edge_counts.append(batch.edge_index.shape[1])
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))
    
    # Batch sizes
    ax1.bar(range(len(batch_sizes)), batch_sizes, color='skyblue', alpha=0.7)
    ax1.set_title('Target Nodes per Batch')
    ax1.set_xlabel('Batch Index')
    ax1.set_ylabel('Number of Target Nodes')
    ax1.grid(True, alpha=0.3)
    
    # Subgraph sizes
    ax2.bar(range(len(subgraph_sizes)), subgraph_sizes, color='lightgreen', alpha=0.7)
    ax2.set_title('Total Nodes per Subgraph')
    ax2.set_xlabel('Batch Index')
    ax2.set_ylabel('Total Nodes (Target + Neighbors)')
    ax2.grid(True, alpha=0.3)
    
    # Edge counts
    ax3.bar(range(len(edge_counts)), edge_counts, color='salmon', alpha=0.7)
    ax3.set_title('Edges per Subgraph')
    ax3.set_xlabel('Batch Index')
    ax3.set_ylabel('Number of Edges')
    ax3.grid(True, alpha=0.3)
    
    # Sampling ratio
    sampling_ratios = [sg/bs for sg, bs in zip(subgraph_sizes, batch_sizes)]
    ax4.bar(range(len(sampling_ratios)), sampling_ratios, color='gold', alpha=0.7)
    ax4.set_title('Sampling Expansion Ratio')
    ax4.set_xlabel('Batch Index')
    ax4.set_ylabel('Total Nodes / Target Nodes')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\n{'='*50}")
    print(f"{'SAMPLING STATISTICS':^50}")
    print(f"{'='*50}")
    print(f"Average target nodes per batch: {np.mean(batch_sizes):.1f}")
    print(f"Average total nodes per subgraph: {np.mean(subgraph_sizes):.1f}")
    print(f"Average edges per subgraph: {np.mean(edge_counts):.1f}")
    print(f"Average sampling expansion: {np.mean(sampling_ratios):.1f}x")
    print(f"{'='*50}")

## GraphSage Implementation

Implementation of GraphSage using PyTorch Geometric's SAGEConv layers with neighbor sampling and aggregation.

In [None]:
class GraphSage(torch.nn.Module):
    """GraphSage model for inductive node classification"""
    
    def __init__(self, dim_in, dim_h, dim_out, num_layers=2, dropout=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()
        
        # First layer
        self.convs.append(SAGEConv(dim_in, dim_h))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(dim_h, dim_h))
        
        # Output layer
        self.convs.append(SAGEConv(dim_h, dim_out))
        
        self.dropout = Dropout(dropout)

    def forward(self, x, edge_index):
        # Apply GraphSage layers
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        
        # Final layer (no activation)
        x = self.convs[-1](x, edge_index)
        return F.log_softmax(x, dim=1)
    
    def fit(self, train_loader, val_loader, epochs, device='cpu', lr=0.01):
        """Train the GraphSage model using mini-batch training"""
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=5e-4)
        
        self.to(device)
        
        # Training history
        train_losses = []
        train_accs = []
        val_accs = []
        
        print("\n" + "="*60)
        print(f"{'Training GraphSage (Mini-batch)':^60}")
        print("="*60)
        print(f"{'Epoch':>5} {'Train Loss':>12} {'Train Acc':>12} {'Val Acc':>12}")
        print("-"*60)
        
        for epoch in range(epochs + 1):
            self.train()
            total_loss = 0
            total_correct = 0
            total_samples = 0
            
            for batch in train_loader:
                batch = batch.to(device)
                optimizer.zero_grad()
                
                out = self(batch.x, batch.edge_index)
                loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pred = out[:batch.batch_size].argmax(dim=1)
                total_correct += (pred == batch.y[:batch.batch_size]).sum().item()
                total_samples += batch.batch_size
            
            # Calculate metrics
            avg_loss = total_loss / len(train_loader)
            train_acc = total_correct / total_samples
            val_acc = evaluate_model(self, val_loader, device)
            
            # Store history
            train_losses.append(avg_loss)
            train_accs.append(train_acc)
            val_accs.append(val_acc)
            
            if epoch % 20 == 0:
                print(f"{epoch:5d} {avg_loss:12.4f} {train_acc*100:11.2f}% {val_acc*100:11.2f}%")
        
        print("-"*60)
        
        return train_losses, train_accs, val_accs
        
    def test(self, test_loader, device='cpu'):
        """Evaluate on test set"""
        return evaluate_model(self, test_loader, device)

## Model Training and Evaluation

Initialize, train and evaluate the GraphSage model using mini-batch training.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Initialize GraphSage model
model = GraphSage(dataset.num_features, 256, dataset.num_classes, num_layers=2)
print("\n" + "="*40)
print(f"{'GRAPHSAGE ARCHITECTURE':^40}")
print("="*40)
print(model)
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Hidden Dimensions: 256")
print(f"Number of Layers: 2")
print(f"Neighbor Sampling: [5, 10]")
print("="*40)

In [None]:
# Train GraphSage model
train_losses, train_accs, val_accs = model.fit(
    train_loader, val_loader, epochs=100, device=device, lr=0.01
)

# Test the model
test_acc = model.test(test_loader, device=device)

print(f"\n{'='*30}")
print(f"{'GRAPHSAGE FINAL RESULTS':^30}")
print(f"{'='*30}")
print(f"Test Accuracy: {test_acc*100:6.2f}%")
print(f"{'='*30}")

In [None]:
# Plot training curves
plot_training_curves(train_losses, train_accs, val_accs)

## Subgraph Structure Analysis

Analyze and visualize the sampled subgraphs to understand GraphSage's neighbor sampling mechanism in detail.

In [None]:
# Plot structure of first few subgraphs
print("\n" + "="*50)
print(f"{'SUBGRAPH STRUCTURE PLOTS':^50}")
print("="*50)

# Create a fresh loader to avoid exhausting the iterator
viz_loader = NeighborLoader(
    data,
    num_neighbors=[5, 10],
    batch_size=16,
    input_nodes=data.train_mask,
    shuffle=True
)

subgraph_list = []
for i, subgraph in enumerate(viz_loader):
    if i >= 3:  # Plot first 3 subgraphs
        break
    subgraph_list.append(subgraph)
    plot_subgraph_structure(subgraph, f"Subgraph {i+1}")
    
    print(f"Subgraph {i+1} Details:")
    print(f"  - Target nodes: {subgraph.batch_size}")
    print(f"  - Total nodes: {subgraph.x.shape[0]}")
    print(f"  - Edges: {subgraph.edge_index.shape[1]}")
    print(f"  - Expansion ratio: {subgraph.x.shape[0]/subgraph.batch_size:.1f}x\n")

In [None]:
# Plot sampling statistics across multiple batches
stats_loader = NeighborLoader(
    data,
    num_neighbors=[5, 10],
    batch_size=16,
    input_nodes=data.train_mask,
    shuffle=True
)

plot_sampling_statistics(stats_loader, num_batches=20)

### Visualization Insights

**Subgraph Structure:**
- **Red nodes**: Target nodes for which we want to compute embeddings
- **Blue nodes**: Sampled neighbors used for aggregation
- **Edges**: Connections within the sampled neighborhood
- **Expansion ratio**: How much the neighborhood grows from target nodes

**Sampling Benefits:**
- **Fixed complexity**: Each target node has bounded neighborhood size
- **Scalability**: Memory usage doesn't depend on full graph size
- **Parallelization**: Multiple subgraphs can be processed simultaneously
- **Inductive capability**: New nodes can be processed without retraining

**Key Observations:**
- Subgraph sizes vary based on local graph density
- Sampling creates diverse neighborhood structures
- Edge connectivity patterns reflect original graph topology
- Expansion ratios show efficiency of neighbor sampling strategy

## Technical Analysis

Understanding the GraphSage implementation and its advantages for large-scale graph learning.

### GraphSage (Graph Sample and Aggregate)

**Core Concepts:**
- **Inductive Learning**: Can generalize to unseen nodes and graphs during inference
- **Neighbor Sampling**: Samples fixed-size neighborhoods to enable mini-batch training
- **Aggregation Functions**: Combines neighbor information (mean, max, LSTM, etc.)
- **Scalability**: Handles large graphs through sampling and batching

**Key Advantages:**
- **Memory Efficient**: Fixed computational complexity per node regardless of graph size
- **Inductive**: Can handle dynamic graphs and new nodes without retraining
- **Parallelizable**: Mini-batch training enables GPU acceleration
- **Flexible**: Works with various aggregation functions and sampling strategies

**Architecture Details:**
- **Two SAGE Layers**: Input → Hidden (256) → Output (3 classes)
- **Neighbor Sampling**: [5, 10] neighbors per layer
- **Aggregation**: Mean aggregation (default in SAGEConv)
- **Regularization**: 50% dropout between layers
- **Optimization**: Adam optimizer with weight decay

**Sampling Strategy:**
- **Layer 1**: Sample 5 neighbors for each target node
- **Layer 2**: Sample 10 neighbors for each node from layer 1
- **Mini-batches**: Process 16 nodes at a time during training
- **Computational Graph**: Each node's representation depends on sampled neighborhood

**Expected Performance:**
GraphSage typically achieves 75-80% accuracy on Pubmed while being much more scalable than full-batch methods like GCN or GAT. The inductive capability makes it ideal for large, dynamic graphs.

### GraphSage vs Other GNNs

**Transductive Methods (GCN, GAT):**
- Require the entire graph during training
- Cannot handle new nodes without retraining
- Memory usage scales with graph size
- Better performance on fixed graphs

**GraphSage (Inductive):**
- Uses neighbor sampling for scalability
- Can generalize to unseen nodes/graphs
- Fixed memory usage regardless of graph size
- Slightly lower accuracy but much more scalable

**When to Use GraphSage:**
- Large graphs (>100K nodes)
- Dynamic graphs with new nodes
- Limited computational resources
- Industrial applications requiring scalability
- Multi-graph learning scenarios

**Aggregation Functions:**
- **Mean**: Simple average of neighbor features (default)
- **Max**: Element-wise maximum of neighbor features
- **LSTM**: Sequential processing of neighbors
- **Pool**: Max/mean pooling after linear transformation

### Performance Analysis

**Dataset Characteristics:**
- **Pubmed**: 19,717 nodes, 44,338 edges, 500 features, 3 classes
- **Domain**: Citation network of biomedical papers
- **Task**: Multi-class node classification
- **Challenge**: High-dimensional features with sparse connectivity

**Model Complexity:**
- **Parameters**: ~260K trainable parameters
- **Memory**: O(batch_size × neighbors) per forward pass
- **Computation**: Linear in sampled neighborhood size
- **Scalability**: Can handle graphs with millions of nodes

**Training Efficiency:**
- **Mini-batch Size**: 16 nodes per batch
- **Sampling**: [5, 10] neighbors per layer
- **Epochs**: 100 epochs for convergence
- **Time**: ~1-2 minutes on CPU, ~30 seconds on GPU

**Key Insights:**
- GraphSage enables scalable training on large graphs
- Neighbor sampling provides computational efficiency
- Inductive learning allows generalization to new nodes
- Performance scales well with available computational resources