# Comprehensive Cold-Start Analysis: All Models
## ML4G Course Project - Applications Research with GNNs

**Team:** Abhishek Indupally, Pranav Bhimrao Kapadne, Gaurav Suvarna

**Goal:** Compare MLP vs GCN vs GraphSAGE vs GraphSAINT on nodes stratified by connectivity

**Key Questions:**
1. Do GNNs outperform MLP when nodes have 0 edges (true cold-start)?
2. At what connectivity threshold do GNNs become worthwhile?
3. Which GNN architecture is most resilient to sparse edges?

**Degree Stratification:**
- Degree 0: Isolated nodes (true cold-start)
- Degree 1-5: Sparse connections 
- Degree 6-20: Moderate connections
- Degree 20+: Well-connected nodes

## 1. Imports and Setup

In [1]:
from ogb.nodeproppred import PygNodePropPredDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.serialization import add_safe_globals
from torch_geometric.data import Data, Batch
from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
from torch_geometric.data.storage import GlobalStorage, NodeStorage, EdgeStorage
from torch_geometric.nn import SAGEConv, GCNConv
from torch_geometric.utils import degree, subgraph
from sklearn.metrics import accuracy_score, classification_report
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import os
from collections import defaultdict

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Allowlist required torch_geometric classes for safe unpickling
add_safe_globals([DataEdgeAttr, DataTensorAttr, GlobalStorage, NodeStorage, EdgeStorage, Data, Batch])

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Create output directories
os.makedirs('images/cold_start', exist_ok=True)
os.makedirs('results/cold_start', exist_ok=True)
print("Created directories: images/cold_start/, results/cold_start/")

All imports successful!
PyTorch version: 2.7.1+cu118
CUDA available: True
Created directories: images/cold_start/, results/cold_start/


## 2. Load All Baseline Results

In [2]:
print("\n" + "="*60)
print("LOADING ALL BASELINE RESULTS")
print("="*60)

# Store all available results
baseline_results = {}

# Load MLP results
try:
    with open('mlp_500k_results.json', 'r') as f:
        mlp_results = json.load(f)
    baseline_results['MLP'] = mlp_results
    print(f"MLP baseline: {mlp_results['test_accuracy']:.4f}")
except:
    print("WARNING: mlp_500k_results.json not found")

# Load GCN results  
try:
    with open('gcn_results.json', 'r') as f:
        gcn_results = json.load(f)
    baseline_results['GCN'] = gcn_results
    print(f"GCN baseline: {gcn_results['test_accuracy']:.4f}")
except:
    print("WARNING: gcn_results.json not found")

# Load GraphSAGE results
try:
    with open('GraphSage_results.json', 'r') as f:
        graphsage_results = json.load(f)
    baseline_results['GraphSAGE'] = graphsage_results
    print(f"GraphSAGE baseline: {graphsage_results['test_accuracy']:.4f}")
except:
    print("WARNING: GraphSage_results.json not found")

# Load GraphSAINT-RW results
try:
    with open('graphsaint_random_walk_results.json', 'r') as f:
        saint_rw_results = json.load(f)
    baseline_results['GraphSAINT-RW'] = saint_rw_results
    print(f"GraphSAINT-RW baseline: {saint_rw_results['test_accuracy']:.4f}")
except:
    print("WARNING: graphsaint_random_walk_results.json not found")

# Load GraphSAINT-Node results
try:
    with open('graphsaint_node_results.json', 'r') as f:
        saint_node_results = json.load(f)
    baseline_results['GraphSAINT-Node'] = saint_node_results
    print(f"GraphSAINT-Node baseline: {saint_node_results['test_accuracy']:.4f}")
except:
    print("WARNING: graphsaint_node_results.json not found")

if len(baseline_results) < 2:
    print("\nERROR: Need at least 2 models for meaningful comparison")
    print("Available models:", list(baseline_results.keys()))
else:
    print(f"\nLoaded {len(baseline_results)} model results for comparison")
    print("Available models:", list(baseline_results.keys()))


LOADING ALL BASELINE RESULTS
MLP baseline: 0.6192
GCN baseline: 0.7668
GraphSAGE baseline: 0.7609
GraphSAINT-RW baseline: 0.7715
GraphSAINT-Node baseline: 0.7641

Loaded 5 model results for comparison
Available models: ['MLP', 'GCN', 'GraphSAGE', 'GraphSAINT-RW', 'GraphSAINT-Node']


## 3. Reproduce Dataset Preparation (Same 500K Subsample)

In [4]:
print("\n" + "="*60)
print("LOADING AND PREPARING DATASET")
print("="*60)

# Load dataset
dataset = PygNodePropPredDataset(name="ogbn-products", root="data")
data = dataset[0]

print(f"Original dataset: {data.x.shape[0]:,} nodes, {data.edge_index.shape[1]:,} edges")

# Subsample to 500K nodes (same as training notebooks)
subsample_size = 500000
total_nodes = data.x.shape[0]
sampled_indices = torch.randperm(total_nodes)[:subsample_size]
sampled_indices = sampled_indices.sort()[0]

# Create mapping from old indices to new indices
subsample_mapping = {}
for new_idx, old_idx in enumerate(sampled_indices):
    subsample_mapping[old_idx.item()] = new_idx

# Extract subgraph
from torch_geometric.utils import subgraph as pyg_subgraph
subsampled_edge_index, _ = pyg_subgraph(
    subset=sampled_indices,
    edge_index=data.edge_index,
    relabel_nodes=True,
    num_nodes=total_nodes
)

# Update data object
data.x = data.x[sampled_indices]
data.y = data.y[sampled_indices]
data.edge_index = subsampled_edge_index

print(f"After subsampling: {data.x.shape[0]:,} nodes, {subsampled_edge_index.shape[1]:,} edges")


LOADING AND PREPARING DATASET
Original dataset: 2,449,029 nodes, 123,718,280 edges
After subsampling: 500,000 nodes, 5,190,420 edges


In [None]:
# Load splits and filter to selected labels (same as training)
split_dir = "data/ogbn_products/split/sales_ranking/"
train_df = pd.read_csv(split_dir + "train.csv.gz")
valid_df = pd.read_csv(split_dir + "valid.csv.gz") 
test_df = pd.read_csv(split_dir + "test.csv.gz")

original_train = torch.tensor(train_df.iloc[:, 0].values, dtype=torch.long)
original_valid = torch.tensor(valid_df.iloc[:, 0].values, dtype=torch.long)
original_test = torch.tensor(test_df.iloc[:, 0].values, dtype=torch.long)

# Filter splits to subsample
train_in_sample = torch.isin(original_train, sampled_indices)
valid_in_sample = torch.isin(original_valid, sampled_indices)
test_in_sample = torch.isin(original_test, sampled_indices)

filtered_train_original = original_train[train_in_sample]
filtered_valid_original = original_valid[valid_in_sample]
filtered_test_original = original_test[test_in_sample]

# Map to new indices
split_idx = {
    'train': torch.tensor([subsample_mapping[idx.item()] for idx in filtered_train_original]),
    'valid': torch.tensor([subsample_mapping[idx.item()] for idx in filtered_valid_original]),
    'test': torch.tensor([subsample_mapping[idx.item()] for idx in filtered_test_original])
}

print(f"Splits - Train: {len(split_idx['train']):,}, Valid: {len(split_idx['valid']):,}, Test: {len(split_idx['test']):,}")

In [None]:
# Filter to selected 15 labels (same as training)
selected_labels = set(range(16)) - {4}
label_mask = torch.tensor([label.item() in selected_labels for label in data.y])
filtered_node_indices = torch.where(label_mask)[0]

# Filter splits to selected labels
train_mask = torch.isin(split_idx['train'], filtered_node_indices)
valid_mask = torch.isin(split_idx['valid'], filtered_node_indices)
test_mask = torch.isin(split_idx['test'], filtered_node_indices)

filtered_train_idx = split_idx['train'][train_mask]
filtered_valid_idx = split_idx['valid'][valid_mask]
filtered_test_idx = split_idx['test'][test_mask]

# Extract final features and labels
X = data.x[filtered_node_indices]
y = data.y[filtered_node_indices].squeeze()

# Remap labels to 0 to num_classes-1
label_map = {orig: new for new, orig in enumerate(sorted(selected_labels))}
y_mapped = torch.tensor([label_map[label.item()] for label in y])

# Extract subgraph for filtered nodes
remapped_edges, _ = subgraph(
    subset=filtered_node_indices,
    edge_index=data.edge_index,
    relabel_nodes=True,
    num_nodes=data.x.shape[0]
)

# Create index mapping for splits
index_mapping = {orig_idx.item(): new_idx for new_idx, orig_idx in enumerate(filtered_node_indices)}

train_idx = torch.tensor([index_mapping[idx.item()] for idx in filtered_train_idx])
valid_idx = torch.tensor([index_mapping[idx.item()] for idx in filtered_valid_idx])
test_idx = torch.tensor([index_mapping[idx.item()] for idx in filtered_test_idx])

num_features = X.shape[1]
num_classes = len(selected_labels)

print(f"Final dataset: {X.shape[0]:,} nodes, {remapped_edges.shape[1]:,} edges")
print(f"Features: {num_features}, Classes: {num_classes}")
print(f"Test set: {len(test_idx):,} nodes")

## 4. Analyze Node Degree Distribution with Stratification

In [None]:
print("\n" + "="*60)
print("NODE DEGREE ANALYSIS WITH STRATIFICATION")
print("="*60)

# Compute node degrees
node_degrees = degree(remapped_edges[0], num_nodes=X.shape[0])

# Define degree strata
degree_strata = {
    'isolated': (node_degrees == 0),
    'sparse': ((node_degrees >= 1) & (node_degrees <= 5)),
    'moderate': ((node_degrees >= 6) & (node_degrees <= 20)),
    'well_connected': (node_degrees > 20)
}

print(f"Degree stratification:")
total_nodes = X.shape[0]
for stratum_name, mask in degree_strata.items():
    count = mask.sum().item()
    percentage = (count / total_nodes) * 100
    min_deg = node_degrees[mask].min().item() if count > 0 else 0
    max_deg = node_degrees[mask].max().item() if count > 0 else 0
    print(f"  {stratum_name:<15}: {count:>6,} nodes ({percentage:>5.1f}%) [degree {min_deg}-{max_deg}]")

# Overall degree statistics
degree_stats = {
    'mean': node_degrees.mean().item(),
    'median': node_degrees.median().item(),
    'min': node_degrees.min().item(),
    'max': node_degrees.max().item(),
    'std': node_degrees.std().item()
}

print(f"\nOverall degree statistics:")
print(f"  Mean: {degree_stats['mean']:.2f}")
print(f"  Median: {degree_stats['median']:.2f}")
print(f"  Range: {degree_stats['min']:.0f} - {degree_stats['max']:.0f}")
print(f"  Std: {degree_stats['std']:.2f}")

## 5. Load All Trained Models

In [None]:
print("\n" + "="*60)
print("LOADING ALL TRAINED MODELS")
print("="*60)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define model architectures
class MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super(MLP, self).__init__()
        self.lin1 = nn.Linear(in_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = nn.Linear(hidden_channels, out_channels)
        self.dropout = dropout
        
    def forward(self, x, edge_index=None):
        # MLP ignores edge_index for consistent API
        x = self.lin1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin3(x)
        return x

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.dropout = dropout
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# Load models
hidden_channels = 128
models = {}

# Model loading configuration
model_configs = [
    ('MLP', 'models/mlp_best.pt', MLP),
    ('GCN', 'models/gcn_best.pt', GCN),
    ('GraphSAGE', 'models/graphsage_best.pt', GraphSAGE),
    ('GraphSAINT-RW', 'models/graphsaint_rw_best.pt', GraphSAGE),
    ('GraphSAINT-Node', 'models/graphsaint_node_best.pt', GraphSAGE)
]

for model_name, model_path, model_class in model_configs:
    if model_name in baseline_results and os.path.exists(model_path):
        try:
            model = model_class(num_features, hidden_channels, num_classes, dropout=0.5)
            model.load_state_dict(torch.load(model_path, map_location=device))
            model = model.to(device)
            model.eval()
            models[model_name] = model
            print(f"Loaded {model_name} model")
        except Exception as e:
            print(f"ERROR loading {model_name}: {e}")
    else:
        if model_name not in baseline_results:
            print(f"Skipping {model_name}: no baseline results")
        else:
            print(f"Skipping {model_name}: model file not found")

if len(models) < 2:
    print("\nERROR: Need at least 2 models for comparison")
    print("Loaded models:", list(models.keys()))
    raise RuntimeError("Insufficient models loaded")

print(f"\nSuccessfully loaded {len(models)} models: {list(models.keys())}")

# Move data to device
X = X.to(device)
y_mapped = y_mapped.to(device)
remapped_edges = remapped_edges.to(device)
test_idx = test_idx.to(device)
node_degrees = node_degrees.to(device)

degree_strata = {k: v.to(device) for k, v in degree_strata.items()}

## 6. Cold-Start Analysis by Degree Strata

In [None]:
print("\n" + "="*60)
print("COLD-START ANALYSIS BY DEGREE STRATA")
print("="*60)

@torch.no_grad()
def evaluate_on_stratum(model, x, edge_index, test_mask, y_test, stratum_mask, model_name):
    """Evaluate model on specific degree stratum."""
    model.eval()
    
    # Get test nodes in this stratum
    test_in_stratum = test_mask[stratum_mask[test_mask]]
    
    if len(test_in_stratum) == 0:
        return 0.0, 0
    
    # Forward pass (different for MLP vs GNN)
    if isinstance(model, MLP):
        out = model(x[test_in_stratum])
    else:
        # For GNNs, need full graph but evaluate on subset
        out = model(x, edge_index)
        out = out[test_in_stratum]
    
    y_true = y_test[test_in_stratum]
    pred = out.argmax(dim=1)
    accuracy = (pred == y_true).float().mean().item()
    
    return accuracy, len(test_in_stratum)

# Run analysis across all degree strata
results_by_stratum = {
    'stratum': [],
    'degree_range': [],
    'num_test_nodes': []
}

# Add columns for each model
for model_name in models.keys():
    results_by_stratum[f'{model_name}_accuracy'] = []

print("Testing performance across degree strata...")
print(f"{'Stratum':<15} {'Range':<12} {'Test Nodes':<10} ", end='')
for model_name in models.keys():
    print(f"{model_name:<12}", end=' ')
print()
print("-" * (15 + 12 + 10 + 12 * len(models) + len(models)))

for stratum_name, stratum_mask in degree_strata.items():
    # Skip if no test nodes in this stratum
    test_in_stratum_mask = stratum_mask[test_idx]
    num_test_nodes = test_in_stratum_mask.sum().item()
    
    if num_test_nodes == 0:
        print(f"{stratum_name:<15} {'N/A':<12} {0:<10} No test nodes")
        continue
    
    # Get degree range for this stratum
    stratum_degrees = node_degrees[stratum_mask]
    if len(stratum_degrees) > 0:
        min_deg = stratum_degrees.min().item()
        max_deg = stratum_degrees.max().item()
        if min_deg == max_deg:
            degree_range = f"{min_deg}"
        else:
            degree_range = f"{min_deg}-{max_deg}"
    else:
        degree_range = "N/A"
    
    # Store basic info
    results_by_stratum['stratum'].append(stratum_name)
    results_by_stratum['degree_range'].append(degree_range)
    results_by_stratum['num_test_nodes'].append(num_test_nodes)
    
    print(f"{stratum_name:<15} {degree_range:<12} {num_test_nodes:<10} ", end='')
    
    # Test each model on this stratum
    for model_name, model in models.items():
        accuracy, _ = evaluate_on_stratum(
            model, X, remapped_edges, test_idx, y_mapped, stratum_mask, model_name
        )
        results_by_stratum[f'{model_name}_accuracy'].append(accuracy)
        print(f"{accuracy:.4f}    ", end='')
    
    print()  # New line after each stratum

# Convert to DataFrame
results_df = pd.DataFrame(results_by_stratum)
print(f"\nCold-Start Analysis Results by Degree Strata:")
print(results_df.to_string(index=False))

## 7. Isolated Nodes Analysis (Degree = 0)

In [None]:
print("\n" + "="*60)
print("DETAILED ISOLATED NODES ANALYSIS (DEGREE = 0)")
print("="*60)

# Find isolated nodes (degree = 0)
isolated_mask = node_degrees == 0
isolated_test_mask = isolated_mask[test_idx]
isolated_test_nodes = test_idx[isolated_test_mask]

print(f"Total isolated nodes: {isolated_mask.sum().item():,}")
print(f"Isolated test nodes: {len(isolated_test_nodes):,}")

if len(isolated_test_nodes) > 0:
    print(f"\nPerformance on isolated nodes (true cold-start scenario):")
    print(f"{'Model':<15} {'Accuracy':<10} {'vs Baseline':<12}")
    print("-" * 40)
    
    isolated_results = {}
    
    # Evaluate each model on isolated nodes
    for model_name, model in models.items():
        with torch.no_grad():
            if isinstance(model, MLP):
                out = model(X[isolated_test_nodes])
            else:
                # For GNNs, use full graph but evaluate on isolated nodes
                out = model(X, remapped_edges)
                out = out[isolated_test_nodes]
            
            pred = out.argmax(dim=1)
            accuracy = (pred == y_mapped[isolated_test_nodes]).float().mean().item()
            isolated_results[model_name] = accuracy
            
            # Compare to baseline result
            baseline_acc = baseline_results[model_name]['test_accuracy']
            difference = accuracy - baseline_acc
            
            print(f"{model_name:<15} {accuracy:.4f}     {difference:+.4f}")
    
    # Find best and worst performers on isolated nodes
    best_model = max(isolated_results, key=isolated_results.get)
    worst_model = min(isolated_results, key=isolated_results.get)
    
    print(f"\nBest on isolated nodes: {best_model} ({isolated_results[best_model]:.4f})")
    print(f"Worst on isolated nodes: {worst_model} ({isolated_results[worst_model]:.4f})")
    print(f"Gap: {isolated_results[best_model] - isolated_results[worst_model]:.4f}")
    
    # Analyze label distribution of isolated nodes
    isolated_labels = y_mapped[isolated_test_nodes]
    unique_labels, label_counts = torch.unique(isolated_labels, return_counts=True)
    
    print(f"\nLabel distribution of isolated test nodes:")
    for label, count in zip(unique_labels, label_counts):
        percentage = (count.item() / len(isolated_test_nodes)) * 100
        print(f"  Label {label.item()}: {count.item()} nodes ({percentage:.1f}%)")
        
else:
    print("No isolated nodes in test set")
    isolated_results = None

## 8. Connectivity Threshold Analysis

In [None]:
print("\n" + "="*60)
print("CONNECTIVITY THRESHOLD ANALYSIS")
print("="*60)

# Test performance at different degree thresholds
degree_thresholds = [0, 1, 2, 3, 5, 10, 20, 50]
threshold_results = {
    'threshold': [],
    'num_test_nodes': []
}

for model_name in models.keys():
    threshold_results[f'{model_name}_accuracy'] = []

print("Finding connectivity thresholds where GNNs become worthwhile...")
print(f"\n{'Degree <=':<10} {'Test Nodes':<12} ", end='')
for model_name in models.keys():
    print(f"{model_name:<12}", end=' ')
print()
print("-" * (10 + 12 + 12 * len(models) + len(models)))

for threshold in degree_thresholds:
    # Create mask for nodes with degree <= threshold
    threshold_mask = node_degrees <= threshold
    test_in_threshold = test_idx[threshold_mask[test_idx]]
    num_test_nodes = len(test_in_threshold)
    
    threshold_results['threshold'].append(threshold)
    threshold_results['num_test_nodes'].append(num_test_nodes)
    
    print(f"{threshold:<10} {num_test_nodes:<12} ", end='')
    
    if num_test_nodes == 0:
        # No nodes at this threshold
        for model_name in models.keys():
            threshold_results[f'{model_name}_accuracy'].append(0.0)
        print("No nodes")
        continue
    
    # Evaluate each model
    for model_name, model in models.items():
        accuracy, _ = evaluate_on_stratum(
            model, X, remapped_edges, test_idx, y_mapped, threshold_mask, model_name
        )
        threshold_results[f'{model_name}_accuracy'].append(accuracy)
        print(f"{accuracy:.4f}    ", end='')
    
    print()

# Convert to DataFrame
threshold_df = pd.DataFrame(threshold_results)

# Find crossover points where GNNs beat MLP
if 'MLP' in models:
    mlp_accs = threshold_df['MLP_accuracy'].values
    print(f"\nCrossover analysis (where GNNs beat MLP):")
    print(f"{'Model':<15} {'Crossover Threshold':<20} {'Improvement':<12}")
    print("-" * 50)
    
    for model_name in models.keys():
        if model_name == 'MLP':
            continue
            
        model_accs = threshold_df[f'{model_name}_accuracy'].values
        crossover_threshold = None
        
        for i, (mlp_acc, model_acc) in enumerate(zip(mlp_accs, model_accs)):
            if model_acc > mlp_acc and threshold_df['num_test_nodes'].iloc[i] > 50:  # Require sufficient sample size
                crossover_threshold = threshold_df['threshold'].iloc[i]
                improvement = model_acc - mlp_acc
                print(f"{model_name:<15} degree <= {crossover_threshold:<12} {improvement:+.4f}")
                break
        
        if crossover_threshold is None:
            print(f"{model_name:<15} {'No crossover':<20} {'N/A':<12}")

print(f"\nThreshold Analysis Results:")
print(threshold_df.to_string(index=False))

## 9. Visualizations

In [None]:
print("\n" + "="*60)
print("CREATING VISUALIZATIONS")
print("="*60)

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
model_colors = plt.cm.tab10(np.linspace(0, 1, len(models)))
color_map = {model: color for model, color in zip(models.keys(), model_colors)}

In [None]:
# Plot 1: Performance by Degree Strata
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Left plot: Accuracy by stratum
strata_names = results_df['stratum'].values
x_pos = np.arange(len(strata_names))
bar_width = 0.8 / len(models)

for i, model_name in enumerate(models.keys()):
    accuracies = results_df[f'{model_name}_accuracy'].values
    ax1.bar(x_pos + i * bar_width, accuracies, bar_width, 
           label=model_name, color=color_map[model_name], alpha=0.8)

ax1.set_xlabel('Degree Stratum', fontweight='bold')
ax1.set_ylabel('Test Accuracy', fontweight='bold')
ax1.set_title('Performance by Node Connectivity Level', fontweight='bold')
ax1.set_xticks(x_pos + bar_width * (len(models) - 1) / 2)
ax1.set_xticklabels(strata_names, rotation=0)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, max(1, max([results_df[f'{m}_accuracy'].max() for m in models.keys()]) * 1.1))

# Right plot: Node count by stratum
node_counts = results_df['num_test_nodes'].values
ax2.bar(strata_names, node_counts, color='lightblue', alpha=0.7, edgecolor='navy')
ax2.set_xlabel('Degree Stratum', fontweight='bold')
ax2.set_ylabel('Number of Test Nodes', fontweight='bold')
ax2.set_title('Test Node Distribution by Connectivity', fontweight='bold')
ax2.grid(True, alpha=0.3)

# Add value labels on bars
for i, count in enumerate(node_counts):
    ax2.text(i, count + max(node_counts) * 0.01, f'{count:,}', 
            ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig('images/cold_start/performance_by_strata.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot 2: Connectivity Threshold Analysis
fig, ax = plt.subplots(figsize=(12, 8))

thresholds = threshold_df['threshold'].values
for model_name in models.keys():
    accuracies = threshold_df[f'{model_name}_accuracy'].values
    ax.plot(thresholds, accuracies, marker='o', linewidth=2, markersize=8, 
           label=model_name, color=color_map[model_name])

ax.set_xlabel('Maximum Node Degree (Cumulative)', fontweight='bold')
ax.set_ylabel('Test Accuracy', fontweight='bold')
ax.set_title('Performance vs Connectivity Threshold\n(Cumulative: nodes with degree <= threshold)', fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xlim(-1, max(thresholds) + 1)

plt.tight_layout()
plt.savefig('images/cold_start/threshold_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot 3: GNN vs MLP Advantage by Connectivity
if 'MLP' in models:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    mlp_accs = threshold_df['MLP_accuracy'].values
    
    for model_name in models.keys():
        if model_name == 'MLP':
            continue
        
        model_accs = threshold_df[f'{model_name}_accuracy'].values
        advantages = model_accs - mlp_accs
        
        # Color positive and negative advantages differently
        colors = ['green' if x >= 0 else 'red' for x in advantages]
        
        ax.bar(thresholds + np.arange(len(models.keys()) - 1).tolist().index(list(models.keys()).index(model_name) - 1) * 0.1, 
               advantages, width=0.8/len(models), alpha=0.7,
               label=f'{model_name} - MLP', color=color_map[model_name])
    
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
    ax.set_xlabel('Maximum Node Degree', fontweight='bold')
    ax.set_ylabel('Accuracy Advantage over MLP', fontweight='bold')
    ax.set_title('GNN Advantage over MLP by Connectivity Level', fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('images/cold_start/gnn_mlp_advantage.png', dpi=300, bbox_inches='tight')
    plt.show()

## 10. Analysis Summary and Insights

In [None]:
print("\n" + "="*70)
print("COMPREHENSIVE COLD-START ANALYSIS SUMMARY")
print("="*70)

print("\nKEY FINDINGS:")

# 1. Isolated nodes performance
if isolated_results:
    print("\n1. TRUE COLD-START PERFORMANCE (Degree = 0):")
    for model_name, acc in isolated_results.items():
        baseline_acc = baseline_results[model_name]['test_accuracy']
        degradation = baseline_acc - acc
        print(f"   {model_name:<15}: {acc:.4f} (degradation: {degradation:.4f})")
    
    best_isolated = max(isolated_results, key=isolated_results.get)
    worst_isolated = min(isolated_results, key=isolated_results.get)
    print(f"   Best isolated: {best_isolated} ({isolated_results[best_isolated]:.4f})")
    print(f"   Worst isolated: {worst_isolated} ({isolated_results[worst_isolated]:.4f})")

# 2. Connectivity advantages
print("\n2. CONNECTIVITY ADVANTAGES:")
for stratum_name in ['sparse', 'moderate', 'well_connected']:
    if stratum_name in results_df['stratum'].values:
        idx = results_df[results_df['stratum'] == stratum_name].index[0]
        print(f"   {stratum_name.replace('_', ' ').title()} connections:")
        for model_name in models.keys():
            acc = results_df.loc[idx, f'{model_name}_accuracy']
            print(f"     {model_name}: {acc:.4f}")

# 3. Best model by connectivity level
print("\n3. BEST MODEL BY CONNECTIVITY LEVEL:")
for i, row in results_df.iterrows():
    stratum = row['stratum']
    best_acc = 0
    best_model = None
    for model_name in models.keys():
        acc = row[f'{model_name}_accuracy']
        if acc > best_acc:
            best_acc = acc
            best_model = model_name
    print(f"   {stratum.replace('_', ' ').title():<15}: {best_model} ({best_acc:.4f})")

# 4. When do GNNs become worthwhile?
if 'MLP' in models:
    print("\n4. GNN WORTHINESS THRESHOLDS:")
    mlp_accs = threshold_df['MLP_accuracy'].values
    for model_name in models.keys():
        if model_name == 'MLP':
            continue
        
        model_accs = threshold_df[f'{model_name}_accuracy'].values
        worthwhile_threshold = None
        
        for i, (mlp_acc, model_acc) in enumerate(zip(mlp_accs, model_accs)):
            if model_acc > mlp_acc and threshold_df['num_test_nodes'].iloc[i] > 50:
                worthwhile_threshold = threshold_df['threshold'].iloc[i]
                break
        
        if worthwhile_threshold is not None:
            print(f"   {model_name}: degree > {worthwhile_threshold}")
        else:
            print(f"   {model_name}: never clearly better than MLP")

print(f"\n5. BASELINE COMPARISON (All test nodes):")
for model_name in models.keys():
    acc = baseline_results[model_name]['test_accuracy']
    print(f"   {model_name}: {acc:.4f}")

## 11. Save Results

In [None]:
# Save comprehensive results
comprehensive_results = {
    'experiment': 'Comprehensive Cold-Start Analysis',
    'models_analyzed': list(models.keys()),
    'dataset_info': {
        'total_nodes': int(X.shape[0]),
        'total_edges': int(remapped_edges.shape[1]),
        'test_nodes': int(len(test_idx)),
        'features': int(num_features),
        'classes': int(num_classes)
    },
    'degree_statistics': degree_stats,
    'degree_strata': {
        stratum: {
            'node_count': int(mask.sum().item()),
            'percentage': float((mask.sum().item() / X.shape[0]) * 100)
        }
        for stratum, mask in degree_strata.items()
    },
    'performance_by_strata': results_df.to_dict('records'),
    'connectivity_threshold_analysis': threshold_df.to_dict('records'),
    'isolated_nodes_analysis': isolated_results if isolated_results else {},
    'baseline_accuracies': {
        model: baseline_results[model]['test_accuracy']
        for model in models.keys()
    }
}

# Save results
with open('results/cold_start/comprehensive_cold_start_results.json', 'w') as f:
    json.dump(comprehensive_results, f, indent=2)

# Save detailed tables
results_df.to_csv('results/cold_start/performance_by_strata.csv', index=False)
threshold_df.to_csv('results/cold_start/threshold_analysis.csv', index=False)

print("\nResults saved:")
print("  - results/cold_start/comprehensive_cold_start_results.json")
print("  - results/cold_start/performance_by_strata.csv")
print("  - results/cold_start/threshold_analysis.csv")

print("\nGenerated visualizations:")
print("  - images/cold_start/performance_by_strata.png")
print("  - images/cold_start/threshold_analysis.png") 
if 'MLP' in models:
    print("  - images/cold_start/gnn_mlp_advantage.png")

print(f"\nCOMPREHENSIVE COLD-START ANALYSIS COMPLETE!")
print(f"Analyzed {len(models)} models across {len(degree_strata)} connectivity levels")