In [None]:
!pip install torch-geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.utils import add_self_loops
from torch_geometric.datasets import Planetoid, Reddit
from sklearn.metrics import f1_score
import numpy as np
import os.path as osp

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m61.4/63.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [None]:
DATASET_CONFIG = {
    'Cora': {
        'hidden_channels': 32,
        'sample_size': 512,
        'lr': 0.005,
        'weight_decay': 5e-4,
        'epochs': 200,
        'heads': 4,
        'warmup_epochs': 10,
        'log_interval': 20,
        'root': '/tmp/Cora'
    },
    'Pubmed': {
        'hidden_channels': 64,
        'sample_size': 3000,
        'lr': 0.01,
        'weight_decay': 5e-4,
        'epochs': 200,
        'heads': 4,
        'warmup_epochs': 10,
        'log_interval': 20,
        'root': '/tmp/Pubmed'
    }
}


In [None]:
def get_dataset(name):
    cfg = DATASET_CONFIG[name]

    if name == 'Reddit':
        dataset = Reddit(root=cfg['root'])
        data = dataset[0]

        return dataset, data

    else: # Cora or Pubmed
        dataset = Planetoid(root=cfg['root'], name=name)
        data = dataset[0]
        data.train_mask = ~(data.val_mask | data.test_mask)

        return dataset, data

In [None]:
class MultiHeadFastGateLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads=4, dropout=0.2):
        super().__init__()
        self.heads = heads
        self.out_per_head = out_channels // heads

        #linear projections and attention vectors for each head
        self.W = nn.ModuleList([nn.Linear(in_channels, self.out_per_head, bias=False) for _ in range(heads)])
        self.a = nn.ModuleList([nn.Linear(2 * self.out_per_head, 1, bias=False) for _ in range(heads)])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index, sampled_nodes, target_nodes, probs, num_nodes, precomputed_alpha=None):
        out_heads = []

        for h in range(self.heads):
            h_x = self.W[h](x)  # projected features for this head

            # build masks for target nodes and sampled neighbors
            target_mask = torch.zeros(num_nodes, dtype=torch.bool, device=x.device)
            sample_mask = torch.zeros(num_nodes, dtype=torch.bool, device=x.device)
            target_mask[target_nodes] = True
            sample_mask[sampled_nodes] = True

            row, col = edge_index
            mask = target_mask[row] & sample_mask[col]
            row_m, col_m = row[mask], col[mask]
            h_v, h_u = h_x[row_m], h_x[col_m]

            # compute attention/gate
            if precomputed_alpha is None:
                e_vu = self.a[h](torch.cat([h_v, h_u], dim=1)).squeeze()
                gate = torch.sigmoid(e_vu)
                gate = gate / (probs[col_m] + 1e-12)
                gate = self.dropout(gate)
            else:
                gate = precomputed_alpha[mask]

            unique_targets = torch.unique(target_nodes, sorted=True)
            target_map = torch.full((num_nodes,), -1, device=x.device)
            target_map[unique_targets] = torch.arange(unique_targets.size(0), device=x.device)
            valid_mask = target_map[row_m] >= 0
            local_row = target_map[row_m[valid_mask]]
            h_u = h_u[valid_mask]
            gate = gate[valid_mask]

            # aggregate with degree-normalization
            out_local = torch.zeros(unique_targets.size(0), self.out_per_head, device=x.device)

            # Compute sum of attention per target node
            deg_v = torch.zeros(unique_targets.size(0), device=x.device)
            deg_v.index_add_(0, local_row, gate)

            # Weighted sum
            out_local.index_add_(0, local_row, gate.unsqueeze(1) * h_u)

            out_local = out_local / (deg_v.unsqueeze(1) + 1e-6)

            # Residual connection
            out_local = out_local + h_x[unique_targets]


            out_global = torch.zeros(num_nodes, self.out_per_head, device=x.device)
            out_global[unique_targets] = out_local
            out_heads.append(out_global)

        # Concatenate heads
        out = torch.cat(out_heads, dim=1)
        return out  # shape = num_nodes, out_channels


class MultiHeadFastGateModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, sample_size, heads=4):
        super().__init__()
        self.sample_size = sample_size
        self.heads = heads
        self.lin1 = nn.Linear(in_channels, hidden_channels, bias=False)
        self.conv1 = MultiHeadFastGateLayer(hidden_channels, hidden_channels, heads=heads)
        self.conv2 = MultiHeadFastGateLayer(hidden_channels, out_channels, heads=1)  # single head for output

    def forward(self, x_pre, edge_index, probs, target_nodes, num_nodes, precomputed_alpha=None):
        h = F.relu(self.lin1(x_pre))

        # Sample neighbors
        if self.training:
            sampled = torch.multinomial(probs, self.sample_size, replacement=True)
        else:
            sampled = torch.arange(num_nodes, device=x_pre.device)

        h = F.relu(self.conv1(h, edge_index, sampled, target_nodes, probs, num_nodes, precomputed_alpha))
        h = self.conv2(h, edge_index, sampled, target_nodes, probs, num_nodes, precomputed_alpha)

        return F.log_softmax(h[target_nodes], dim=1)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load data
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0].to(device)


edge_index, _ = add_self_loops(
    data.edge_index, num_nodes=data.num_nodes
)


row, col = edge_index
deg = torch.zeros(data.num_nodes, device=device)
deg.index_add_(0, row, torch.ones(row.size(0), device=device))
probs = deg / deg.sum()
probs = probs.clamp(min=1e-4)
probs = probs / probs.sum()

#precompute
norm = 1.0 / torch.sqrt(deg[row] * deg[col])
x_precomputed = torch.zeros_like(data.x)
x_precomputed.index_add_(0, row, norm.unsqueeze(1) * data.x[col])

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2236, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [None]:
HIDDEN_CHANNELS = 32
SAMPLE_SIZE = 256
LR = 0.005
WEIGHT_DECAY = 5e-4
EPOCHS = 300
WARMUP_EPOCHS = 10
HEADS = 4

model = MultiHeadFastGateModel(
    in_channels=dataset.num_features,
    hidden_channels=HIDDEN_CHANNELS,
    out_channels=dataset.num_classes,
    sample_size=SAMPLE_SIZE,
    heads=HEADS
).to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY
)

In [None]:
def train_and_eval(epochs=EPOCHS):
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    val_idx   = data.val_mask.nonzero(as_tuple=False).view(-1)
    test_idx  = data.test_mask.nonzero(as_tuple=False).view(-1)

    train_idx = torch.sort(train_idx)[0]
    val_idx   = torch.sort(val_idx)[0]
    test_idx  = torch.sort(test_idx)[0]

    precomputed_alpha = None

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        if epoch <= WARMUP_EPOCHS:
            with torch.no_grad():
                h = F.relu(model.lin1(x_precomputed))
                attention_list = []
                for head in range(model.conv1.heads):
                    h_v = model.conv1.W[head](h)[edge_index[0]]
                    h_u = model.conv1.W[head](h)[edge_index[1]]
                    e_vu = model.conv1.a[head](torch.cat([h_v, h_u], dim=1)).squeeze()
                    attention_list.append(torch.sigmoid(e_vu))
                precomputed_alpha = torch.mean(torch.stack(attention_list), dim=0)

        out = model(x_precomputed, edge_index, probs, train_idx, data.num_nodes, precomputed_alpha)
        loss = F.nll_loss(out, data.y[train_idx])
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_out = model(x_precomputed, edge_index, probs, val_idx, data.num_nodes)
                val_acc = (val_out.argmax(dim=1) == data.y[val_idx]).float().mean()
            model.eval()
            with torch.no_grad():
                test_out = model(x_precomputed, edge_index, probs, test_idx, data.num_nodes)
                test_acc = (test_out.argmax(dim=1) == data.y[test_idx]).float().mean()
            print(f"Epoch {epoch:03d} | Loss {loss:.4f} | Val Acc {val_acc:.4f} | Test Acc {test_acc:.4f}")



    # evaluation
    model.eval()
    with torch.no_grad():
        test_out = model(x_precomputed, edge_index, probs, test_idx, data.num_nodes)
        test_acc = (test_out.argmax(dim=1) == data.y[test_idx]).float().mean()

    print("-" * 30)
    print(f"Final Test Accuracy: {test_acc:.4f}")
    print("-" * 30)

    return test_acc

In [None]:
train_and_eval()


Epoch 010 | Loss 1.2510 | Val Acc 0.6560 | Test Acc 0.6770
Epoch 020 | Loss 0.3624 | Val Acc 0.7440 | Test Acc 0.7580
Epoch 030 | Loss 0.1061 | Val Acc 0.7080 | Test Acc 0.7340
Epoch 040 | Loss 0.0301 | Val Acc 0.7040 | Test Acc 0.7410
Epoch 050 | Loss 0.0128 | Val Acc 0.6920 | Test Acc 0.7380
Epoch 060 | Loss 0.0097 | Val Acc 0.7140 | Test Acc 0.7430
Epoch 070 | Loss 0.1273 | Val Acc 0.6960 | Test Acc 0.7390
Epoch 080 | Loss 0.0084 | Val Acc 0.7040 | Test Acc 0.7300
Epoch 090 | Loss 0.0102 | Val Acc 0.7220 | Test Acc 0.7600
Epoch 100 | Loss 0.0167 | Val Acc 0.7000 | Test Acc 0.7400
Epoch 110 | Loss 0.0993 | Val Acc 0.7180 | Test Acc 0.7620
Epoch 120 | Loss 0.0397 | Val Acc 0.7100 | Test Acc 0.7400
Epoch 130 | Loss 0.0743 | Val Acc 0.7040 | Test Acc 0.7490
Epoch 140 | Loss 0.0523 | Val Acc 0.7000 | Test Acc 0.7390
Epoch 150 | Loss 0.1176 | Val Acc 0.7060 | Test Acc 0.7520
Epoch 160 | Loss 0.0283 | Val Acc 0.6800 | Test Acc 0.7240
Epoch 170 | Loss 0.0129 | Val Acc 0.6960 | Test Acc 0.75

tensor(0.7520)

In [None]:
#1  heads: Final Test Accuracy: 0.7510
#2  heads: Final Test Accuracy: 0.7580
#4  heads: Final Test Accuracy: 0.7340
#8  heads: Final Test Accuracy: 0.7460
#16 heads: Final Test Accuracy: 0.7500



