<a href="https://colab.research.google.com/github/ghommidhWassim/GNN-variants/blob/main/graphSAGE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"
!pip install torchvision
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone
2.6.0+cu124
12.4
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (fr

In [55]:
# Standard libraries
import numpy as np
from scipy import sparse
import seaborn as sns
import pandas as pd
import time
# Plotting libraries
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm
from IPython.display import Javascript  # Restrict height of output cell.

# PyTorch
import torch
import torch.nn.functional as F
from torch.nn import Linear
import torch.nn as nn
from torch_sparse import spmm
# import pyg_lib
import torch_sparse

# PyTorch geometric
from torch_geometric.nn import GCNConv,SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric import seed_everything
import torch
import os.path as osp
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader


In [56]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def dataset_load():
  print(f"Using device: {device}")
  dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())
  num_features = dataset.num_features
  num_classes = dataset.num_classes
  data = dataset[0].to(device)  # Get the first graph object.
  return num_features, data, num_classes, device,dataset

def clean_gpu_memory():
    """Cleans GPU memory without fully resetting the CUDA context"""
    import gc
    gc.collect()  # Python garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # PyTorch cache
        torch.cuda.reset_peak_memory_stats()  # Reset tracking
        print(f"Memory after cleanup: {torch.cuda.memory_allocated()/1024**2:.2f} MB")

num_features, data, num_classes, device, dataset = dataset_load()
print(f'Number of nodes:          {data.num_nodes}')
print(f'Number of edges:          {data.num_edges}')
print(f'Average node degree:      {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.3f}')
print(f'Has isolated nodes:       {data.has_isolated_nodes()}')
print(f'Has self-loops:           {data.has_self_loops()}')
print(f'Is undirected:            {data.is_undirected()}')
num_features

Using device: cuda
Number of nodes:          19717
Number of edges:          88648
Average node degree:      4.50
Number of training nodes: 60
Training node label rate: 0.003
Has isolated nodes:       False
Has self-loops:           False
Is undirected:            True


500

In [57]:
train_loader = NeighborLoader(
    data,
    input_nodes=data.train_mask,
    num_neighbors=[10, 10],  # s = 10 per layer (2 layers)
    batch_size=128,
    shuffle=True,
)


In [58]:
class testGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = torch.nn.ModuleList()
        # First layer: in_channels -> hidden_channels
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        # Intermediate layers: hidden_channels -> hidden_channels
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        # Last layer: hidden_channels -> out_channels (optional, if no linear layers)
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        # Optional MLP head (for further transformation)
        self.lin1 = Linear(out_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = F.relu(conv(x, edge_index))
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Last layer (no ReLU/Dropout for classification)
        x = self.convs[-1](x, edge_index)

        # Optional MLP head
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin2(x)

        return F.log_softmax(x, dim=1)

In [59]:
hidden_channels = 64

model = testGraphSAGE(
    in_channels=dataset.num_features,  # Input feature dimension
    hidden_channels=64,               # Hidden layer size
    num_layers=2,                     # Number of SAGEConv layers
    out_channels=dataset.num_classes,  # Output dimension (number of classes)
    dropout=0.5,                      # Dropout rate                         # Jumping Knowledge (optional: "cat", "max", "lstm")
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()  # Negative Log Likelihood (used with log_softmax)


In [60]:
# Minibatch training function
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Full-batch evaluation (for simplicity)
def evaluate(mask):
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index)
        pred = out[mask].argmax(dim=1)
        acc = (pred == data.y[mask]).float().mean().item()
    return acc

In [61]:
# Training loop
start_time = time.time()
for epoch in range(1, 101):
    loss = train()
    val_acc = evaluate(data.val_mask)
    print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}")

end_time = time.time()
print(f"Training time: {end_time - start_time:.2f} seconds")

# Test accuracy
test_acc = evaluate(data.test_mask)
print(f"Test Accuracy: {test_acc:.4f}")

Epoch: 001, Loss: 1.1242, Val Acc: 0.1960
Epoch: 002, Loss: 1.1311, Val Acc: 0.4160
Epoch: 003, Loss: 1.1162, Val Acc: 0.3880
Epoch: 004, Loss: 1.1374, Val Acc: 0.3880
Epoch: 005, Loss: 1.0949, Val Acc: 0.3880
Epoch: 006, Loss: 1.0953, Val Acc: 0.3880
Epoch: 007, Loss: 1.0689, Val Acc: 0.3880
Epoch: 008, Loss: 1.0586, Val Acc: 0.5080
Epoch: 009, Loss: 1.0287, Val Acc: 0.5900
Epoch: 010, Loss: 1.0439, Val Acc: 0.5500
Epoch: 011, Loss: 0.9708, Val Acc: 0.5720
Epoch: 012, Loss: 0.9433, Val Acc: 0.5980
Epoch: 013, Loss: 0.8779, Val Acc: 0.6120
Epoch: 014, Loss: 0.7957, Val Acc: 0.6160
Epoch: 015, Loss: 0.7473, Val Acc: 0.6140
Epoch: 016, Loss: 0.6580, Val Acc: 0.6340
Epoch: 017, Loss: 0.5874, Val Acc: 0.6560
Epoch: 018, Loss: 0.5376, Val Acc: 0.6900
Epoch: 019, Loss: 0.4907, Val Acc: 0.7020
Epoch: 020, Loss: 0.4448, Val Acc: 0.7100
Epoch: 021, Loss: 0.3371, Val Acc: 0.7040
Epoch: 022, Loss: 0.2543, Val Acc: 0.7140
Epoch: 023, Loss: 0.2385, Val Acc: 0.7140
Epoch: 024, Loss: 0.1736, Val Acc:

In [62]:
test_acc = evaluate(data.test_mask)
print(f"Test Accuracy: {test_acc:.4f}")


Test Accuracy: 0.7500


In [53]:
print(f"Allocated memory after tensor creation: {torch.cuda.memory_allocated() / (1024**2):.2f} MB")
print(f"Reserved memory after tensor creation: {torch.cuda.memory_reserved() / (1024**2):.2f} MB")
print(f"Peak allocated memory: {torch.cuda.max_memory_allocated() / (1024**2):.2f} MB")


Allocated memory after tensor creation: 146.08 MB
Reserved memory after tensor creation: 446.00 MB
Peak allocated memory: 392.16 MB


In [None]:
peak_memory_mb=f"{torch.cuda.max_memory_allocated()/1024**2:.2f}"
total_train_time=f"{end_time - start_time:.2f}"
import json

metrics = {
    "model": "graphSAGE",
    "accuracy": "0.7200",
    "memory_MB": peak_memory_mb,
    "train_time_sec": total_train_time
}

with open("graphSAGE_results.json", "w") as f:
    json.dump(metrics, f)

In [None]:
def calc_graphsage_memory_requirements(batch_size, fanout, num_layers, hidden_dim):
    bytes_per_float = 4

    # Total number of unique nodes sampled for 1 batch (approx)
    total_sampled_nodes = batch_size * (fanout ** (num_layers - 1))

    # Embedding memory for all sampled nodes
    intermediate_embeddings = total_sampled_nodes * hidden_dim * bytes_per_float

    # Weight matrices: L layers of KxK
    weight_params = num_layers * hidden_dim * hidden_dim * bytes_per_float

    total = intermediate_embeddings + weight_params

    return {
        "intermediate_embeddings_MB": intermediate_embeddings / 1024**2,
        "weight_matrices_MB": weight_params / 1024**2,
        "total_MB": total / 1024**2
    }

# Example values:
batch_size = 128
fanout = 10
num_layers = 2
hidden_dim = 64

mem_usage = calc_graphsage_memory_requirements(batch_size, fanout, num_layers, hidden_dim)

print(f"Intermediate Embeddings: {mem_usage['intermediate_embeddings_MB']:.2f} MB")
print(f"Weight Matrices        : {mem_usage['weight_matrices_MB']:.2f} MB")
print(f"Total Estimated        : {mem_usage['total_MB']:.2f} MB")
