<a href="https://colab.research.google.com/github/ghommidhWassim/GNN-variants/blob/main/LADIES.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"
!pip install torchvision
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
2.6.0+cu124
12.4
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html


In [None]:
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html


Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import to_scipy_sparse_matrix
from sklearn.metrics import accuracy_score
import random

In [None]:


# Ensure reproducibility
def seed_everything(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

seed_everything()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------- Load Dataset -------------------
def dataset_load():
    print(f"Using device: {device}")
    dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())
    data = dataset[0].to(device)
    return dataset.num_features, data, dataset.num_classes

num_features, data, num_classes = dataset_load()

# ------------------- Prepare Adjacency -------------------
adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes)
lap_matrix = adj + sp.eye(adj.shape[0])

def row_normalize(mx):
    rowsum = np.array(mx.sum(1)).flatten()
    rowsum[rowsum == 0] = 1  # Avoid division by zero
    r_inv = np.power(rowsum, -1)
    r_mat_inv = sp.diags(r_inv)
    return r_mat_inv.dot(mx)


lap_matrix = row_normalize(lap_matrix)

# ------------------- Sampler -------------------
def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert scipy sparse matrix to torch sparse tensor"""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse_coo_tensor(indices, values, shape, device=device)

def ladies_sampler(seed, batch_nodes, samp_num_list, num_nodes, lap_matrix, depth):
    np.random.seed(seed)
    previous_nodes = batch_nodes.cpu().numpy()
    adjs = []
    for d in range(depth):
        U = lap_matrix[previous_nodes, :]
        pi = np.array(np.sum(U.multiply(U), axis=0))[0]
        p = pi / np.sum(pi)
        s_num = np.min([np.sum(p > 0), samp_num_list[d]])
        after_nodes = np.random.choice(num_nodes, s_num, p=p, replace=False)
        after_nodes = np.unique(np.concatenate((after_nodes, batch_nodes.cpu().numpy())))
        adj = U[:, after_nodes].multiply(1 / p[after_nodes])
        adj = row_normalize(adj)
        adjs.append(sparse_mx_to_torch_sparse_tensor(adj))
        previous_nodes = after_nodes
    adjs.reverse()
    return adjs, torch.tensor(previous_nodes, device=device), batch_nodes

# ------------------- Model -------------------
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(nn.Linear(hidden_channels, hidden_channels))
        self.convs.append(nn.Linear(hidden_channels, out_channels))

    def forward(self, x, adjs):
        for i, (conv, adj) in enumerate(zip(self.convs[:-1], adjs)):
            x = conv(x)
            x = torch.sparse.mm(adj, x)
            x = F.relu(x)
        x = self.convs[-1](x)
        return x

# ------------------- Training -------------------
model = GCN(num_features, 64, num_classes, num_layers=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

train_nodes = torch.where(data.train_mask)[0]
valid_nodes = torch.where(data.val_mask)[0]
labels = data.y
features = data.x
batch_size = 512
samp_num_list = [64, 64, 64]
depth = len(samp_num_list)

for epoch in range(1, 101):
    model.train()
    optimizer.zero_grad()

    idx = torch.randperm(train_nodes.size(0), device=device)[:batch_size]
    batch_nodes = train_nodes[idx]

    adjs, input_nodes, output_nodes = ladies_sampler(
        seed=np.random.randint(0, 100000),
        batch_nodes=batch_nodes,
        samp_num_list=samp_num_list,
        num_nodes=data.num_nodes,
        lap_matrix=lap_matrix,
        depth=depth
    )

    out = model(features[input_nodes], adjs)
    loss = criterion(out[output_nodes], labels[output_nodes])

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            full_adj = sparse_mx_to_torch_sparse_tensor(row_normalize(adj + sp.eye(adj.shape[0])))
            val_out = model(features, [full_adj]*depth)
            preds = val_out[valid_nodes].argmax(dim=1)
            acc = (preds == labels[valid_nodes]).float().mean().item()
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {acc:.4f}")

# ------------------- GPU Usage -------------------
print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"Max GPU memory used:  {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")


Using device: cuda
Epoch 010 | Loss: 0.9635 | Val Acc: 0.6600
Epoch 020 | Loss: 0.4166 | Val Acc: 0.7340
Epoch 030 | Loss: 0.0463 | Val Acc: 0.7760
Epoch 040 | Loss: 0.0046 | Val Acc: 0.7720
Epoch 050 | Loss: 0.0015 | Val Acc: 0.7380
Epoch 060 | Loss: 0.0017 | Val Acc: 0.7440
Epoch 070 | Loss: 0.0004 | Val Acc: 0.7100
Epoch 080 | Loss: 0.0096 | Val Acc: 0.6940
Epoch 090 | Loss: 0.0002 | Val Acc: 0.7660
Epoch 100 | Loss: 0.0006 | Val Acc: 0.7200
GPU memory allocated: 61.74 MB
Max GPU memory used:  83.33 MB
