<a href="https://colab.research.google.com/github/ghommidhWassim/GNN-variants/blob/main/spanGNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"
!pip install torchvision
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
2.6.0+cu124
12.4
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
from torch_geometric.utils import add_self_loops
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import subgraph

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def dataset_load():
  print(f"Using device: {device}")
  dataset = Planetoid(root='data/Planetoid', name='PubMed', transform=NormalizeFeatures())
  num_features = dataset.num_features
  num_classes = dataset.num_classes
  data = dataset[0].to(device)  # Get the first graph object.
  return num_features, data, num_classes, device,dataset

In [14]:
num_features, data, num_classes, device, dataset = dataset_load()

print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"Max GPU memory used:  {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

Using device: cuda
GPU memory allocated: 95.69 MB
Max GPU memory used:  591.23 MB


In [5]:

class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.conv1 = GCNConv(in_size, hid_size)
        self.conv2 = GCNConv(hid_size, out_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x


def drop_edge(edge_index, drop_prob):
    num_edges = edge_index.size(1)
    mask = torch.rand(num_edges, device=edge_index.device) > drop_prob
    return edge_index[:, mask]


def compute_gradient_prob(model, x, edge_index, labels, train_mask):
    model.eval()
    x = x.clone().detach().requires_grad_(True)

    logits = model(x, edge_index)
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    model.zero_grad()
    loss.backward()

    grad = x.grad  # [num_nodes, num_features]

    src, dst = edge_index
    grad_diff = (grad[src] - grad[dst]).norm(p=2, dim=1)  # edge importance score

    prob = grad_diff / grad_diff.sum()
    return prob.detach()


def span_edge_sampling(edge_index, edge_sample_ratio, prob=None):
    num_edges = edge_index.size(1)
    sample_size = max(int(num_edges * edge_sample_ratio), 1)
    if prob is not None:
        sampled_indices = torch.multinomial(prob, sample_size, replacement=False)
    else:
        sampled_indices = torch.randperm(num_edges)[:sample_size]
    return edge_index[:, sampled_indices]


def train_one_epoch(data, model, optimizer, drop_prob, sample_ratio):
    model.train()
    x, edge_index = data.x, data.edge_index
    labels = data.y
    train_mask = data.train_mask

    # 1) DropEdge
    edge_index_dropped = drop_edge(edge_index, drop_prob)

    # 2) Compute gradient-based sampling probabilities on dropped graph
    prob = compute_gradient_prob(model, x, edge_index_dropped, labels, train_mask)

    # 3) Sample edges according to probabilities
    edge_index_sampled = span_edge_sampling(edge_index_dropped, sample_ratio, prob)

    # 4) Train on sampled subgraph
    optimizer.zero_grad()
    out = model(x, edge_index_sampled)
    loss = F.cross_entropy(out[train_mask], labels[train_mask])
    loss.backward()
    optimizer.step()

    return loss.item()


@torch.no_grad()
def evaluate(data, model):
    model.eval()
    x, edge_index = data.x, data.edge_index
    labels = data.y
    val_mask = data.val_mask

    out = model(x, edge_index)
    pred = out[val_mask].argmax(dim=1)
    acc = (pred == labels[val_mask]).float().mean().item()
    return acc


# Initialize model and optimizer
model = GCN(num_features, 16, num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

num_epochs = 100
drop_prob = 0.1
sample_ratio = 0.5

for epoch in range(1, num_epochs + 1):
    loss = train_one_epoch(data, model, optimizer, drop_prob, sample_ratio)
    val_acc = evaluate(data, model)
    print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

Epoch 001 | Loss: 1.0991 | Val Acc: 0.5520
Epoch 002 | Loss: 1.0926 | Val Acc: 0.5140
Epoch 003 | Loss: 1.0860 | Val Acc: 0.5660
Epoch 004 | Loss: 1.0794 | Val Acc: 0.6300
Epoch 005 | Loss: 1.0709 | Val Acc: 0.6940
Epoch 006 | Loss: 1.0616 | Val Acc: 0.7160
Epoch 007 | Loss: 1.0524 | Val Acc: 0.6940
Epoch 008 | Loss: 1.0440 | Val Acc: 0.7100
Epoch 009 | Loss: 1.0324 | Val Acc: 0.7160
Epoch 010 | Loss: 1.0227 | Val Acc: 0.7260
Epoch 011 | Loss: 1.0101 | Val Acc: 0.7400
Epoch 012 | Loss: 0.9931 | Val Acc: 0.7280
Epoch 013 | Loss: 0.9839 | Val Acc: 0.7220
Epoch 014 | Loss: 0.9705 | Val Acc: 0.7280
Epoch 015 | Loss: 0.9502 | Val Acc: 0.7300
Epoch 016 | Loss: 0.9355 | Val Acc: 0.7340
Epoch 017 | Loss: 0.9219 | Val Acc: 0.7260
Epoch 018 | Loss: 0.9038 | Val Acc: 0.7380
Epoch 019 | Loss: 0.8808 | Val Acc: 0.7400
Epoch 020 | Loss: 0.8686 | Val Acc: 0.7400
Epoch 021 | Loss: 0.8582 | Val Acc: 0.7460
Epoch 022 | Loss: 0.8197 | Val Acc: 0.7440
Epoch 023 | Loss: 0.8100 | Val Acc: 0.7480
Epoch 024 |

In [6]:
print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"Max GPU memory used:  {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

GPU memory allocated: 55.94 MB
Max GPU memory used:  591.23 MB


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.conv1 = GCNConv(in_size, hid_size)
        self.conv2 = GCNConv(hid_size, out_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

class SpanGNN:
    def __init__(self, model, edge_ratio=0.3, strategy='vm'):
        self.model = model
        self.edge_ratio = edge_ratio  # α_up
        self.strategy = strategy  # 'vm' or 'gnr'
        self.current_edges = None

    def _compute_degrees(self, edge_index, num_nodes):
        """Compute node degrees with dtype matching edge_index."""
        return torch.zeros(num_nodes, dtype=edge_index.dtype, device=edge_index.device).scatter_add_(
            0, edge_index[0], torch.ones_like(edge_index[0])
        )

    def _variance_minimized_sampling(self, edge_index, degrees):
        """Variance-minimized sampling (p_e ∝ 1/deg(u) + 1/deg(v))."""
        src, dst = edge_index
        prob = 1.0 / (degrees[src] + degrees[dst] + 1e-10)  # Avoid division by zero
        return prob / prob.sum()

    def _gradient_noise_reduced_sampling(self, model, x, edge_index, labels, train_mask):
        """Gradient-noise-reduced sampling (p_e ∝ ‖∇W L‖)."""
        x = x.clone().detach().requires_grad_(True)
        model.eval()
        logits = model(x, edge_index)
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        loss.backward()
        grad_norm = x.grad.norm(p=2, dim=1)[edge_index[0]]  # Use source node gradients
        prob = grad_norm / (grad_norm.sum() + 1e-10)
        return prob.detach()

    def update_subgraph(self, data, edge_index_full):
        """Update subgraph with quality-aware edges."""
        num_edges_to_add = int(self.edge_ratio * edge_index_full.size(1))

        # Step 1: Random candidate pool
        candidate_idx = torch.randperm(edge_index_full.size(1))[:min(10_000, edge_index_full.size(1))]
        edge_candidates = edge_index_full[:, candidate_idx]

        # Step 2: Importance sampling
        if self.strategy == 'vm':
            degrees = self._compute_degrees(edge_index_full, data.num_nodes)
            prob = self._variance_minimized_sampling(edge_candidates, degrees)
        else:
            prob = self._gradient_noise_reduced_sampling(
                self.model, data.x, edge_candidates, data.y, data.train_mask
            )

        # Sample edges
        selected = torch.multinomial(prob, min(num_edges_to_add, prob.size(0)), replacement=False)
        new_edges = edge_candidates[:, selected]

        # Merge with existing edges
        if self.current_edges is None:
            self.current_edges = new_edges
        else:
            self.current_edges = torch.unique(
                torch.cat([self.current_edges, new_edges], dim=1),
                dim=1
            )

        # Enforce edge budget
        if self.current_edges.size(1) > num_edges_to_add:
            keep = torch.randperm(self.current_edges.size(1))[:num_edges_to_add]
            self.current_edges = self.current_edges[:, keep]

        return self.current_edges

    def train(self, data, epochs=100):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)

        for epoch in range(epochs):
            # Update subgraph
            edge_index = self.update_subgraph(data, data.edge_index)

            # Train step
            self.model.train()
            optimizer.zero_grad()
            out = self.model(data.x, edge_index)
            loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()

            # Validation
            val_acc = self.evaluate(data)
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

    @torch.no_grad()
    def evaluate(self, data):
        self.model.eval()
        out = self.model(data.x, self.current_edges)
        pred = out[data.val_mask].argmax(dim=1)
        return (pred == data.y[data.val_mask]).float().mean().item()

In [18]:
model = GCN(data.num_features, 16, dataset.num_classes).to(device)
span_gnn = SpanGNN(model, edge_ratio=0.3, strategy='vm')  # or 'gnr'

# Train
span_gnn.train(data, epochs=100)

Epoch 000 | Loss: 1.0974 | Val Acc: 0.4020
Epoch 001 | Loss: 1.0884 | Val Acc: 0.5580
Epoch 002 | Loss: 1.0820 | Val Acc: 0.5780
Epoch 003 | Loss: 1.0664 | Val Acc: 0.6020
Epoch 004 | Loss: 1.0653 | Val Acc: 0.6080
Epoch 005 | Loss: 1.0468 | Val Acc: 0.6240
Epoch 006 | Loss: 1.0381 | Val Acc: 0.6420
Epoch 007 | Loss: 1.0203 | Val Acc: 0.6720
Epoch 008 | Loss: 1.0215 | Val Acc: 0.6840
Epoch 009 | Loss: 1.0087 | Val Acc: 0.7000
Epoch 010 | Loss: 0.9788 | Val Acc: 0.7240
Epoch 011 | Loss: 0.9671 | Val Acc: 0.7120
Epoch 012 | Loss: 0.9544 | Val Acc: 0.7160
Epoch 013 | Loss: 0.9169 | Val Acc: 0.7020
Epoch 014 | Loss: 0.9084 | Val Acc: 0.7120
Epoch 015 | Loss: 0.9141 | Val Acc: 0.7120
Epoch 016 | Loss: 0.8653 | Val Acc: 0.7280
Epoch 017 | Loss: 0.8542 | Val Acc: 0.7420
Epoch 018 | Loss: 0.8698 | Val Acc: 0.7380
Epoch 019 | Loss: 0.8282 | Val Acc: 0.7360
Epoch 020 | Loss: 0.8045 | Val Acc: 0.7380
Epoch 021 | Loss: 0.7740 | Val Acc: 0.7240
Epoch 022 | Loss: 0.7577 | Val Acc: 0.7260
Epoch 023 |

In [19]:
print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"Max GPU memory used:  {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")

GPU memory allocated: 96.34 MB
Max GPU memory used:  591.23 MB
