# GraphSAINT on ogbn-products (500K Subsample)
## ML4G Course Project - Scalability Model Implementation
### Team: Abhishek Indupally, Pranav Bhimrao Kapadne, Gaurav Suvarna

This notebook implements GraphSAINT with three sampling strategies:
- Random Walk Sampling (recommended)
- Node Sampling
- Edge Sampling

In [10]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import time
from datetime import datetime
from torch.serialization import add_safe_globals

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cuda
GPU: Tesla P100-PCIE-12GB
Available GPU memory: 12.78 GB


## 1. Load and Prepare Data (Consistent with Previous Models)

In [11]:
# Allowlist required torch_geometric classes for safe unpickling
add_safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage, NodeStorage, EdgeStorage, Data, Batch])

NameError: name 'DataEdgeAttr' is not defined

In [None]:
# Load dataset
print("Loading ogbn-products dataset...")
dataset = PygNodePropPredDataset(name="ogbn-products", root="data")
data = dataset[0]

print("\n" + "="*60)
print("ORIGINAL DATASET INFO")
print("="*60)
print(f"Total nodes: {data.x.shape[0]:,}")
print(f"Node features: {data.x.shape[1]}")
print(f"Total edges: {data.edge_index.shape[1]:,}")
print(f"Labels shape: {data.y.shape}")
print(f"All unique labels: {torch.unique(data.y).numel()}")

In [None]:
# Create 500K subsample with seed=42 (same as other models)
def create_subsample(data, num_nodes=500000, seed=42):
    """
    Create a consistent subsample of the graph.
    This ensures fair comparison across all models.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Random sample of nodes
    perm = torch.randperm(data.num_nodes)
    sampled_nodes = perm[:num_nodes]
    
    # Create node mapping
    node_mapping = torch.full((data.num_nodes,), -1, dtype=torch.long)
    node_mapping[sampled_nodes] = torch.arange(len(sampled_nodes))
    
    # Filter edges
    edge_index = data.edge_index
    mask = (node_mapping[edge_index[0]] >= 0) & (node_mapping[edge_index[1]] >= 0)
    new_edge_index = node_mapping[edge_index[:, mask]]
    
    # Create new data object
    new_data = data.__class__()
    new_data.x = data.x[sampled_nodes]
    new_data.edge_index = new_edge_index
    new_data.y = data.y[sampled_nodes]
    
    return new_data, sampled_nodes

print("Creating 500K subsample...")
subsampled_data, sampled_nodes = create_subsample(data, num_nodes=500000, seed=SEED)

print(f"\nSubsampled dataset:")
print(f"  Nodes: {subsampled_data.num_nodes:,}")
print(f"  Edges: {subsampled_data.num_edges:,}")
print(f"  Edge density: {subsampled_data.num_edges / (subsampled_data.num_nodes ** 2) * 100:.6f}%")

In [None]:
# Filter to labels 0-15 (excluding 4) - consistent with previous models
def filter_labels(data, valid_labels):
    """
    Filter data to only include specified labels.
    This maintains consistency with MLP, GCN, and GraphSAGE models.
    """
    y = data.y.squeeze()
    mask = torch.zeros(len(y), dtype=torch.bool)
    for label in valid_labels:
        mask |= (y == label)
    
    # Filter nodes
    valid_node_idx = torch.where(mask)[0]
    node_mapping = torch.full((len(y),), -1, dtype=torch.long)
    node_mapping[valid_node_idx] = torch.arange(len(valid_node_idx))
    
    # Filter edges
    edge_mask = mask[data.edge_index[0]] & mask[data.edge_index[1]]
    new_edge_index = node_mapping[data.edge_index[:, edge_mask]]
    
    # Remap labels to 0-14
    label_mapping = {label: idx for idx, label in enumerate(valid_labels)}
    new_y = torch.tensor([label_mapping[y[i].item()] for i in valid_node_idx])
    
    # Create filtered data
    filtered_data = data.__class__()
    filtered_data.x = data.x[valid_node_idx]
    filtered_data.edge_index = new_edge_index
    filtered_data.y = new_y.unsqueeze(1)
    
    return filtered_data, valid_node_idx

# Use labels 0-15 except 4 (same as other models)
valid_labels = [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
print(f"Filtering to {len(valid_labels)} labels...")
filtered_data, valid_nodes = filter_labels(subsampled_data, valid_labels)

print(f"\nFiltered dataset:")
print(f"  Nodes: {filtered_data.num_nodes:,}")
print(f"  Edges: {filtered_data.num_edges:,}")
print(f"  Classes: {len(valid_labels)}")

In [None]:
# Create train/val/test splits using sales_ranking (same as other models)
split_idx = dataset.get_idx_split()

# Map original indices to subsampled and filtered indices
def map_split_indices(original_indices, sampled_nodes, valid_nodes):
    """Map original dataset indices to filtered subsample indices."""
    # Find which sampled nodes are in the split
    mask = torch.isin(sampled_nodes, original_indices)
    split_in_sample = sampled_nodes[mask]
    
    # Find which of these are in valid_nodes (after label filtering)
    valid_mask = torch.isin(split_in_sample, sampled_nodes[valid_nodes])
    final_split = split_in_sample[valid_mask]
    
    # Map to new indices
    node_to_idx = {node.item(): idx for idx, node in enumerate(sampled_nodes[valid_nodes])}
    mapped_indices = torch.tensor([node_to_idx[node.item()] for node in final_split])
    
    return mapped_indices

train_idx = map_split_indices(split_idx['train'], sampled_nodes, valid_nodes)
val_idx = map_split_indices(split_idx['valid'], sampled_nodes, valid_nodes)
test_idx = map_split_indices(split_idx['test'], sampled_nodes, valid_nodes)

print(f"\nData splits:")
print(f"  Train: {len(train_idx):,} nodes ({len(train_idx)/filtered_data.num_nodes*100:.1f}%)")
print(f"  Val:   {len(val_idx):,} nodes ({len(val_idx)/filtered_data.num_nodes*100:.1f}%)")
print(f"  Test:  {len(test_idx):,} nodes ({len(test_idx)/filtered_data.num_nodes*100:.1f}%)")

# Create masks
train_mask = torch.zeros(filtered_data.num_nodes, dtype=torch.bool)
val_mask = torch.zeros(filtered_data.num_nodes, dtype=torch.bool)
test_mask = torch.zeros(filtered_data.num_nodes, dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx] = True
test_mask[test_idx] = True

filtered_data.train_mask = train_mask
filtered_data.val_mask = val_mask
filtered_data.test_mask = test_mask

## 2. GraphSAINT Model Architecture

In [None]:
class GraphSAINT(torch.nn.Module):
    """GraphSAINT model using SAGEConv layers.
    
    GraphSAINT samples subgraphs before training, then trains on complete subgraphs.
    This is more memory-efficient than GraphSAGE's per-node neighbor sampling.
    
    Architecture matches GCN and GraphSAGE for fair comparison:
    - 2 SAGEConv layers
    - 128 hidden channels
    - Dropout for regularization
    """
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

# Model parameters (consistent with other models)
in_channels = filtered_data.x.shape[1]
hidden_channels = 128
out_channels = len(valid_labels)

print(f"Model architecture:")
print(f"  Input: {in_channels} features")
print(f"  Hidden: {hidden_channels} channels")
print(f"  Output: {out_channels} classes")

## 3. GraphSAINT Samplers

GraphSAINT offers three sampling strategies:

1. **Random Walk Sampling**: Samples subgraphs via random walks (usually best performance)
2. **Node Sampling**: Randomly samples nodes and their induced subgraph
3. **Edge Sampling**: Randomly samples edges and their incident nodes

In [None]:
def create_samplers(data, sampling_strategy='random_walk'):
    """
    Create GraphSAINT sampler based on strategy.
    
    Args:
        data: PyG Data object
        sampling_strategy: 'random_walk', 'node', or 'edge'
    
    Returns:
        train_loader, val_loader, test_loader
    """
    if sampling_strategy == 'random_walk':
        print("Using Random Walk Sampling")
        train_loader = GraphSAINTRandomWalkSampler(
            data,
            batch_size=6000,      # nodes per subgraph
            walk_length=2,        # random walk length
            num_steps=30,         # number of subgraphs per epoch
            sample_coverage=100,  # number of times each node should be sampled
            save_dir=None
        )
        
    elif sampling_strategy == 'node':
        print("Using Node Sampling")
        train_loader = GraphSAINTNodeSampler(
            data,
            batch_size=6000,
            num_steps=30,
            sample_coverage=100,
            save_dir=None
        )
        
    elif sampling_strategy == 'edge':
        print("Using Edge Sampling")
        train_loader = GraphSAINTEdgeSampler(
            data,
            batch_size=6000,
            num_steps=30,
            sample_coverage=100,
            save_dir=None
        )
    
    else:
        raise ValueError(f"Unknown sampling strategy: {sampling_strategy}")
    
    # For validation and test, we don't need sampling
    # We'll evaluate on the full graph
    return train_loader

# We'll test all three strategies
SAMPLING_STRATEGIES = ['random_walk', 'node', 'edge']

## 4. Training and Evaluation Functions

In [None]:
def train_epoch(model, loader, optimizer, device):
    """Train for one epoch using GraphSAINT sampling."""
    model.train()
    total_loss = 0
    total_correct = 0
    total_nodes = 0
    
    for batch_data in loader:
        batch_data = batch_data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        out = model(batch_data.x, batch_data.edge_index)
        
        # Only compute loss on training nodes in this batch
        train_mask = batch_data.train_mask
        loss = F.cross_entropy(out[train_mask], batch_data.y.squeeze()[train_mask])
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item() * train_mask.sum().item()
        pred = out[train_mask].argmax(dim=1)
        total_correct += (pred == batch_data.y.squeeze()[train_mask]).sum().item()
        total_nodes += train_mask.sum().item()
    
    avg_loss = total_loss / total_nodes
    accuracy = total_correct / total_nodes
    return avg_loss, accuracy

@torch.no_grad()
def evaluate(model, data, mask, device):
    """Evaluate on full graph (no sampling)."""
    model.eval()
    data = data.to(device)
    
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[mask], data.y.squeeze()[mask])
    
    pred = out[mask].argmax(dim=1)
    accuracy = (pred == data.y.squeeze()[mask]).float().mean().item()
    
    return loss.item(), accuracy

@torch.no_grad()
def compute_topk_accuracy(model, data, mask, device, k_values=[1, 3, 5]):
    """Compute top-k accuracy."""
    model.eval()
    data = data.to(device)
    
    out = model(data.x, data.edge_index)
    logits = out[mask]
    targets = data.y.squeeze()[mask]
    
    topk_accs = {}
    for k in k_values:
        _, topk_pred = torch.topk(logits, k, dim=1)
        correct = (topk_pred == targets.unsqueeze(1)).any(dim=1)
        topk_accs[f'top{k}'] = correct.float().mean().item()
    
    return topk_accs

## 5. Training Loop

In [12]:
def train_graphsaint(data, sampling_strategy, num_epochs=400, lr=0.01, patience=50):
    """
    Train GraphSAINT model with specified sampling strategy.
    
    Args:
        data: PyG Data object
        sampling_strategy: 'random_walk', 'node', or 'edge'
        num_epochs: maximum number of epochs
        lr: learning rate
        patience: early stopping patience
    
    Returns:
        model, training history, results dictionary
    """
    print(f"\n{'='*70}")
    print(f"Training GraphSAINT with {sampling_strategy} sampling")
    print(f"{'='*70}\n")
    
    # Create sampler
    train_loader = create_samplers(data, sampling_strategy)
    
    # Initialize model
    model = GraphSAINT(
        in_channels=data.x.shape[1],
        hidden_channels=128,
        out_channels=len(valid_labels),
        dropout=0.5
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'test_loss': [],
        'test_acc': []
    }
    
    best_val_acc = 0
    best_epoch = 0
    patience_counter = 0
    
    start_time = time.time()
    
    for epoch in range(1, num_epochs + 1):
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
        
        # Evaluate
        val_loss, val_acc = evaluate(model, data, data.val_mask, device)
        test_loss, test_acc = evaluate(model, data, data.test_mask, device)
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Print progress
        if epoch % 10 == 0 or epoch == 1:
            print(f"Epoch {epoch:3d} | "
                  f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
                  f"Test Acc: {test_acc:.4f}")
        
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch} (no improvement for {patience} epochs)")
            break
    
    training_time = time.time() - start_time
    
    # Load best model
    model.load_state_dict(best_model_state)
    
    # Final evaluation
    _, train_acc = evaluate(model, data, data.train_mask, device)
    _, val_acc = evaluate(model, data, data.val_mask, device)
    _, test_acc = evaluate(model, data, data.test_mask, device)
    
    # Compute top-k accuracy
    topk_results = compute_topk_accuracy(model, data, data.test_mask, device)
    
    print(f"\n{'='*70}")
    print(f"Best model at epoch {best_epoch}:")
    print(f"  Train Acc: {train_acc:.4f}")
    print(f"  Val Acc:   {val_acc:.4f}")
    print(f"  Test Acc:  {test_acc:.4f}")
    print(f"  Top-1 Acc: {topk_results['top1']:.4f}")
    print(f"  Top-3 Acc: {topk_results['top3']:.4f}")
    print(f"  Top-5 Acc: {topk_results['top5']:.4f}")
    print(f"  Training time: {training_time:.2f}s")
    print(f"{'='*70}\n")
    
    # Create results dictionary
    results = {
        'model': f'GraphSAINT_{sampling_strategy}',
        'sampling_strategy': sampling_strategy,
        'test_accuracy': test_acc,
        'val_accuracy': val_acc,
        'train_accuracy': train_acc,
        'best_epoch': best_epoch,
        'training_time': training_time,
        'num_parameters': sum(p.numel() for p in model.parameters()),
        'top1_acc': topk_results['top1'],
        'top3_acc': topk_results['top3'],
        'top5_acc': topk_results['top5'],
        'hidden_channels': 128,
        'num_classes': len(valid_labels),
        'num_nodes': data.num_nodes,
        'num_edges': data.num_edges,
        'random_seed': SEED
    }
    
    return model, history, results

## 6. Train All Three Sampling Strategies

In [13]:
# Store results for all strategies
all_results = {}
all_histories = {}
all_models = {}

# Train with each sampling strategy
for strategy in SAMPLING_STRATEGIES:
    model, history, results = train_graphsaint(
        filtered_data,
        sampling_strategy=strategy,
        num_epochs=400,
        lr=0.01,
        patience=50
    )
    
    all_results[strategy] = results
    all_histories[strategy] = history
    all_models[strategy] = model
    
    # Save individual results
    with open(f'graphsaint_{strategy}_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to graphsaint_{strategy}_results.json\n")

NameError: name 'SAMPLING_STRATEGIES' is not defined

## 7. Select Best Strategy

In [None]:
# Compare all strategies
print("\n" + "="*70)
print("COMPARISON OF SAMPLING STRATEGIES")
print("="*70)

for strategy in SAMPLING_STRATEGIES:
    results = all_results[strategy]
    print(f"\n{strategy.upper()} Sampling:")
    print(f"  Test Accuracy: {results['test_accuracy']:.4f}")
    print(f"  Val Accuracy:  {results['val_accuracy']:.4f}")
    print(f"  Top-3 Accuracy: {results['top3_acc']:.4f}")
    print(f"  Training Time: {results['training_time']:.2f}s")
    print(f"  Best Epoch:    {results['best_epoch']}")

# Select best based on validation accuracy
best_strategy = max(all_results.items(), key=lambda x: x[1]['val_accuracy'])[0]
print(f"\n{'='*70}")
print(f"BEST STRATEGY: {best_strategy.upper()} (Val Acc: {all_results[best_strategy]['val_accuracy']:.4f})")
print(f"{'='*70}\n")

# Use best strategy for final results
best_model = all_models[best_strategy]
best_history = all_histories[best_strategy]
best_results = all_results[best_strategy]

# Save best results
with open('graphsaint_best_results.json', 'w') as f:
    json.dump(best_results, f, indent=2)

## 8. Comparison with Other Models

In [None]:
# Load results from other models for comparison
try:
    with open('mlp_500k_results.json', 'r') as f:
        mlp_results = json.load(f)
    mlp_test_acc = mlp_results['test_accuracy']
except:
    mlp_test_acc = 0.6192  # Use your reported value

try:
    with open('gcn_results.json', 'r') as f:
        gcn_results = json.load(f)
    gcn_test_acc = gcn_results['test_accuracy']
except:
    gcn_test_acc = 0.7668  # Use your reported value

try:
    with open('GraphSage_results.json', 'r') as f:
        sage_results = json.load(f)
    sage_test_acc = sage_results['test_accuracy']
except:
    sage_test_acc = 0.7606  # Use your reported value

# Add comparison metrics to best results
best_results['mlp_baseline'] = mlp_test_acc
best_results['gcn_baseline'] = gcn_test_acc
best_results['graphsage_baseline'] = sage_test_acc
best_results['improvement_over_mlp'] = best_results['test_accuracy'] - mlp_test_acc
best_results['improvement_over_mlp_pct'] = (best_results['improvement_over_mlp'] / mlp_test_acc) * 100
best_results['improvement_over_gcn'] = best_results['test_accuracy'] - gcn_test_acc
best_results['improvement_over_gcn_pct'] = (best_results['improvement_over_gcn'] / gcn_test_acc) * 100
best_results['improvement_over_sage'] = best_results['test_accuracy'] - sage_test_acc
best_results['improvement_over_sage_pct'] = (best_results['improvement_over_sage'] / sage_test_acc) * 100

print("\n" + "="*70)
print("MODEL COMPARISON")
print("="*70)
print(f"\nMLP Baseline:        {mlp_test_acc:.4f}")
print(f"GCN:                 {gcn_test_acc:.4f} (+{(gcn_test_acc-mlp_test_acc)/mlp_test_acc*100:.2f}% vs MLP)")
print(f"GraphSAGE:           {sage_test_acc:.4f} (+{(sage_test_acc-mlp_test_acc)/mlp_test_acc*100:.2f}% vs MLP)")
print(f"GraphSAINT ({best_strategy}): {best_results['test_accuracy']:.4f} "
      f"(+{best_results['improvement_over_mlp_pct']:.2f}% vs MLP)")
print(f"\nGraphSAINT vs GCN:       {best_results['improvement_over_gcn']:+.4f} ({best_results['improvement_over_gcn_pct']:+.2f}%)")
print(f"GraphSAINT vs GraphSAGE: {best_results['improvement_over_sage']:+.4f} ({best_results['improvement_over_sage_pct']:+.2f}%)")
print("="*70 + "\n")

# Save final results with comparisons
with open('graphsaint_best_results.json', 'w') as f:
    json.dump(best_results, f, indent=2)

## 9. Visualizations

In [None]:
# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [None]:
# Plot 1: Training curves for best strategy
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(best_history['train_loss'], label='Train Loss', linewidth=2)
axes[0].plot(best_history['val_loss'], label='Val Loss', linewidth=2)
axes[0].plot(best_history['test_loss'], label='Test Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title(f'GraphSAINT ({best_strategy}) Training Loss Curves', fontsize=14)
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(best_history['train_acc'], label='Train Acc', linewidth=2)
axes[1].plot(best_history['val_acc'], label='Val Acc', linewidth=2)
axes[1].plot(best_history['test_acc'], label='Test Acc', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy', fontsize=12)
axes[1].set_title(f'GraphSAINT ({best_strategy}) Training Accuracy Curves', fontsize=14)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('graphsaint_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: graphsaint_training_curves.png")

In [None]:
# Plot 2: Top-k accuracy comparison
fig, ax = plt.subplots(figsize=(10, 6))

models = ['MLP', 'GCN', 'GraphSAGE', f'GraphSAINT\n({best_strategy})']
top1_accs = [mlp_results.get('top1_acc', mlp_test_acc), 
             gcn_results.get('top1_acc', gcn_test_acc),
             sage_results.get('top1_acc', sage_test_acc),
             best_results['top1_acc']]
top3_accs = [mlp_results.get('top3_acc', 0.85), 
             gcn_results.get('top3_acc', 0.93),
             sage_results.get('top3_acc', 0.92),
             best_results['top3_acc']]
top5_accs = [mlp_results.get('top5_acc', 0.92), 
             gcn_results.get('top5_acc', 0.96),
             sage_results.get('top5_acc', 0.96),
             best_results['top5_acc']]

x = np.arange(len(models))
width = 0.25

bars1 = ax.bar(x - width, top1_accs, width, label='Top-1', alpha=0.8)
bars2 = ax.bar(x, top3_accs, width, label='Top-3', alpha=0.8)
bars3 = ax.bar(x + width, top5_accs, width, label='Top-5', alpha=0.8)

ax.set_xlabel('Model', fontsize=12)
ax.set_ylabel('Accuracy', fontsize=12)
ax.set_title('Top-k Accuracy Comparison: All Models', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(models)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([0.5, 1.0])

# Add value labels on bars
def add_value_labels(bars):
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=8)

add_value_labels(bars1)
add_value_labels(bars2)
add_value_labels(bars3)

plt.tight_layout()
plt.savefig('graphsaint_topk_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: graphsaint_topk_comparison.png")

In [None]:
# Plot 3: Sampling strategy comparison
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

strategies = [s.replace('_', ' ').title() for s in SAMPLING_STRATEGIES]
test_accs = [all_results[s]['test_accuracy'] for s in SAMPLING_STRATEGIES]
val_accs = [all_results[s]['val_accuracy'] for s in SAMPLING_STRATEGIES]
train_times = [all_results[s]['training_time'] for s in SAMPLING_STRATEGIES]

# Accuracy comparison
x = np.arange(len(strategies))
width = 0.35

bars1 = axes[0].bar(x - width/2, val_accs, width, label='Validation', alpha=0.8)
bars2 = axes[0].bar(x + width/2, test_accs, width, label='Test', alpha=0.8)

axes[0].set_xlabel('Sampling Strategy', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('GraphSAINT: Sampling Strategy Accuracy Comparison', fontsize=14)
axes[0].set_xticks(x)
axes[0].set_xticklabels(strategies)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3, axis='y')
axes[0].set_ylim([0.7, 0.9])

for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.4f}',
                    ha='center', va='bottom', fontsize=9)

# Training time comparison
bars = axes[1].bar(strategies, train_times, alpha=0.8, color='coral')
axes[1].set_xlabel('Sampling Strategy', fontsize=12)
axes[1].set_ylabel('Training Time (seconds)', fontsize=12)
axes[1].set_title('GraphSAINT: Training Time Comparison', fontsize=14)
axes[1].grid(True, alpha=0.3, axis='y')

for bar in bars:
    height = bar.get_height()
    axes[1].text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}s',
                ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('graphsaint_sampling_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: graphsaint_sampling_comparison.png")

In [None]:
# Plot 4: Confusion Matrix
@torch.no_grad()
def compute_confusion_matrix(model, data, mask, device):
    """Compute confusion matrix."""
    model.eval()
    data = data.to(device)
    
    out = model(data.x, data.edge_index)
    pred = out[mask].argmax(dim=1)
    targets = data.y.squeeze()[mask]
    
    # Compute confusion matrix
    num_classes = out.shape[1]
    conf_matrix = torch.zeros(num_classes, num_classes)
    
    for t, p in zip(targets, pred):
        conf_matrix[t, p] += 1
    
    return conf_matrix.cpu().numpy()

# Compute confusion matrix
conf_matrix = compute_confusion_matrix(best_model, filtered_data, filtered_data.test_mask, device)

# Plot
fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(conf_matrix, annot=False, fmt='g', cmap='Blues', ax=ax, cbar_kws={'label': 'Count'})
ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)
ax.set_title(f'GraphSAINT ({best_strategy}) Confusion Matrix on Test Set ({len(valid_labels)} Classes)', 
             fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('graphsaint_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print("Saved: graphsaint_confusion_matrix.png")

## 10. Summary and Next Steps

In [None]:
print("\n" + "="*70)
print("GRAPHSAINT IMPLEMENTATION COMPLETE")
print("="*70)

print(f"\nBest Sampling Strategy: {best_strategy.upper()}")
print(f"\nFinal Results:")
print(f"  Test Accuracy:  {best_results['test_accuracy']:.4f}")
print(f"  Top-3 Accuracy: {best_results['top3_acc']:.4f}")
print(f"  Top-5 Accuracy: {best_results['top5_acc']:.4f}")
print(f"  Training Time:  {best_results['training_time']:.2f}s")

print(f"\nComparison with Other Models:")
print(f"  MLP:       {mlp_test_acc:.4f}")
print(f"  GCN:       {gcn_test_acc:.4f}")
print(f"  GraphSAGE: {sage_test_acc:.4f}")
print(f"  GraphSAINT: {best_results['test_accuracy']:.4f} ‚≠ê")

print(f"\nFiles Generated:")
print(f"  - graphsaint_best_results.json")
for strategy in SAMPLING_STRATEGIES:
    print(f"  - graphsaint_{strategy}_results.json")
print(f"  - graphsaint_training_curves.png")
print(f"  - graphsaint_topk_comparison.png")
print(f"  - graphsaint_sampling_comparison.png")
print(f"  - graphsaint_confusion_matrix.png")

print(f"\n{'='*70}")
print("NEXT STEPS:")
print("="*70)
print("\n1. Edge Quality Sensitivity Analysis")
print("   - Test robustness to edge removal (sparsification)")
print("   - Test robustness to edge noise (random edge addition)")
print("   - Compare: GCN vs GraphSAGE vs GraphSAINT")

print("\n2. Cold-Start Resilience Analysis")
print("   - Evaluate performance by node degree")
print("   - Test classification with edge removal (new products)")
print("   - Identify threshold where GNNs become beneficial")

print("\n3. Final Report/Blog Post")
print("   - Synthesize all results")
print("   - Create comprehensive visualizations")
print("   - Write up insights and conclusions")
print("\n" + "="*70 + "\n")