# Notebook 4: Advanced GNNs — Graph Transformer, Heterogeneous Graph Transformer, R-GCN

This notebook covers three advanced GNN architectures that push beyond standard message passing:

| Model | Paper | Key Idea |
|-------|-------|----------|
| **Graph Transformer (GT / Graphormer)** | Dwivedi & Bresson (2021); Ying et al. (2021) | Transformer self-attention on graph nodes |
| **Heterogeneous Graph Transformer (HGT)** | Hu et al. (2020) | Relation-aware attention for heterogeneous graphs |
| **R-GCN** | Schlichtkrull et al. (2018) | Relation-specific weights for relational graphs |

**Contents**
1. [Graph Transformer (GT)](#1-graph-transformer)
2. [Graphormer](#2-graphormer)
3. [Heterogeneous Graph Transformer (HGT)](#3-heterogeneous-graph-transformer-hgt)
4. [Relational GCN (R-GCN)](#4-relational-gcn-r-gcn)
5. [Comparison Experiment](#5-comparison-experiment)
6. [Exercises](#6-exercises)

In [None]:
# Uncomment to install
# !pip install torch torch_geometric

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import matplotlib.pyplot as plt

torch.manual_seed(42)
print(f'PyTorch version: {torch.__version__}')

try:
    import torch_geometric
    from torch_geometric.datasets import Planetoid, DBLP
    import torch_geometric.transforms as T
    print(f'PyG version: {torch_geometric.__version__}')
    pyg_available = True
except ImportError:
    print('PyTorch Geometric not installed')
    pyg_available = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

---
## 1. Graph Transformer

### 1.1 Motivation: Limitations of Standard GNNs

Standard message-passing GNNs have a few known limitations:

| Limitation | Description |
|------------|-------------|
| **Over-smoothing** | Deep GNNs converge to indistinguishable representations |
| **Over-squashing** | Information from distant nodes is compressed into fixed-size vectors |
| **Limited expressiveness** | Bounded by the 1-Weisfeiler-Leman (1-WL) test |
| **Local neighbourhood only** | Long-range dependencies require many layers |

**Transformers** (Vaswani et al., 2017) address many of these with **global self-attention**, which allows every node to attend to every other node in a single step.

### 1.2 Vanilla Self-Attention on Graphs

The simplest approach is to apply **standard multi-head attention (MHA)** treating nodes as tokens:

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$$

where $Q = X W_Q$, $K = X W_K$, $V = X W_V$, and $X$ is the node feature matrix.

Without any graph bias this is equivalent to a fully connected graph — the graph structure is ignored.

### 1.3 Graph Transformer (Dwivedi & Bresson, 2021)

The **Graph Transformer** layer incorporates graph structure by:
1. **Masking attention** to only attend within the neighbourhood (local attention)
2. Using **Laplacian positional encodings (LPE)** to inject structural information

#### Laplacian Positional Encoding

The top-$k$ eigenvectors of the normalised graph Laplacian are appended to node features:
$$\mathbf{x}_v \leftarrow \mathbf{x}_v \| [\lambda_1^v, \lambda_2^v, \ldots, \lambda_k^v]$$

This gives nodes a position-like signal that reflects the graph topology.

#### GT Layer Formula

$$\hat{h}_i = \text{MHA}(h_i, \{h_j : j \in \mathcal{N}(i)\})$$

$$h_i^{\prime} = \text{LayerNorm}(h_i + \hat{h}_i)$$

$$h_i^{\prime\prime} = \text{LayerNorm}(h_i^{\prime} + \text{FFN}(h_i^{\prime}))$$

### 1.4 Graph Transformer Implementation

In [None]:
class GraphTransformerLayer(nn.Module):
    """
    Graph Transformer layer (Dwivedi & Bresson, 2021).
    Applies multi-head attention with a sparse mask (edge mask)
    so only neighbours can attend to each other.
    """

    def __init__(self, d_model, num_heads, dropout=0.1, use_edge_bias=True):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        """
        x:          [N, d_model]
        edge_index: [2, E]  (directed COO format)
        Returns: [N, d_model]
        """
        N = x.size(0)
        residual = x

        Q = self.W_Q(x).view(N, self.num_heads, self.d_k)   # [N, H, d_k]
        K = self.W_K(x).view(N, self.num_heads, self.d_k)
        V = self.W_V(x).view(N, self.num_heads, self.d_k)

        src, dst = edge_index   # src -> dst

        # Attention score per edge: Q[dst] · K[src] / sqrt(d_k)
        q_dst = Q[dst]           # [E, H, d_k]
        k_src = K[src]           # [E, H, d_k]
        attn  = (q_dst * k_src).sum(-1) / math.sqrt(self.d_k)  # [E, H]

        # Softmax over in-edges for each dst node (manual scatter-softmax)
        attn_max = torch.full((N, self.num_heads), float('-inf'), device=x.device)
        attn_max.scatter_reduce_(0,
                                 dst.unsqueeze(-1).expand_as(attn),
                                 attn, reduce='amax', include_self=True)
        attn_exp = (attn - attn_max[dst]).exp()
        attn_sum = torch.zeros(N, self.num_heads, device=x.device)
        attn_sum.scatter_add_(0, dst.unsqueeze(-1).expand_as(attn_exp), attn_exp)
        alpha = attn_exp / (attn_sum[dst] + 1e-16)    # [E, H]
        alpha = self.dropout(alpha)

        # Weighted sum of values
        v_src = V[src]    # [E, H, d_k]
        agg = torch.zeros(N, self.num_heads, self.d_k, device=x.device)
        agg.scatter_add_(0,
                         dst.unsqueeze(-1).unsqueeze(-1).expand_as(v_src),
                         alpha.unsqueeze(-1) * v_src)

        out = agg.view(N, self.d_model)    # [N, d_model]
        out = self.W_O(out)

        # Pre-LN residual
        x = self.norm1(residual + self.dropout(out))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x


print('GraphTransformerLayer defined')

In [None]:
class GraphTransformer(nn.Module):
    """Graph Transformer for node classification."""

    def __init__(self, in_ch, d_model, num_heads, num_layers, out_ch, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(in_ch, d_model)
        self.layers = nn.ModuleList([
            GraphTransformerLayer(d_model, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(d_model, out_ch)

    def forward(self, x, edge_index):
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x, edge_index)
        return F.log_softmax(self.classifier(x), dim=1)


# Quick sanity check
N, F_in, d, H, L, C = 10, 16, 32, 4, 2, 5
x_test = torch.randn(N, F_in)
ei_test = torch.randint(0, N, (2, 20))
gt_model = GraphTransformer(F_in, d, H, L, C)
out_test = gt_model(x_test, ei_test)
print('GraphTransformer output shape:', out_test.shape)  # [N, C]

### 1.5 Graph Transformer with PyG's `TransformerConv`

In [None]:
if pyg_available:
    from torch_geometric.nn import TransformerConv

    class PyGGraphTransformer(nn.Module):
        def __init__(self, in_ch, hidden_ch, out_ch, heads=4, num_layers=2, dropout=0.1):
            super().__init__()
            self.convs = nn.ModuleList()
            self.norms = nn.ModuleList()

            in_c = in_ch
            for i in range(num_layers):
                out_c = out_ch if i == num_layers - 1 else hidden_ch
                concat = (i < num_layers - 1)  # concat on all but last layer
                self.convs.append(
                    TransformerConv(in_c, out_c, heads=heads,
                                   dropout=dropout, concat=concat)
                )
                self.norms.append(nn.LayerNorm(out_c * heads if concat else out_c))
                in_c = out_c * heads if concat else out_c

            self.dropout = dropout

        def forward(self, x, edge_index):
            for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
                x = conv(x, edge_index)
                x = norm(x)
                if i < len(self.convs) - 1:
                    x = F.elu(x)
                    x = F.dropout(x, p=self.dropout, training=self.training)
            return F.log_softmax(x, dim=1)

    print('PyGGraphTransformer defined')

---
## 2. Graphormer

### 2.1 Theory

**Graphormer** (Ying et al., 2021) achieves strong performance on molecular property prediction (OGB-LSC, PCQM4M). It augments standard Transformer with **graph-specific structural encodings**:

#### 2.1.1 Centrality Encoding

Degree information is injected as an additive bias to node features:
$$\mathbf{h}_i^{(0)} = \mathbf{x}_i + z_{\text{deg}^-(i)} + z_{\text{deg}^+(i)}$$

where $z_{\text{deg}}$ are learnable degree embeddings.

#### 2.1.2 Spatial Encoding

A learnable scalar bias $b_{\phi(v_i, v_j)}$ is added to the attention score based on the **shortest path distance** $\phi(v_i, v_j)$:

$$A_{ij} = \frac{(Q_i)(K_j)^\top}{\sqrt{d}} + b_{\phi(v_i, v_j)}$$

This allows nodes at different graph distances to have different attention biases.

#### 2.1.3 Edge Encoding

Edge features on the shortest path from $v_i$ to $v_j$ are averaged and used as an additional attention bias:
$$c_{ij} = \frac{1}{N_{ij}} \sum_{n=1}^{N_{ij}} \mathbf{x}_{e_n}^\top \mathbf{w}_n$$

#### 2.1.4 Virtual Node (VNode)

A **virtual node** (VNODE) is added and connected to all real nodes. It aggregates global information and can be used as a graph-level representation.

#### Summary of Graphormer Contributions

| Component | Role |
|-----------|------|
| Centrality encoding | Node importance (degree) |
| Spatial encoding | Graph distance between nodes |
| Edge encoding | Edge feature integration |
| Virtual node | Global graph representation |

In [None]:
class GraphormerLayer(nn.Module):
    """
    Simplified Graphormer layer demonstrating spatial encoding bias.
    (Full Graphormer also includes edge encoding; omitted here for clarity.)
    """

    def __init__(self, d_model, num_heads, max_dist=5, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

        # Spatial bias: one learnable scalar per distance bucket and per head
        self.spatial_bias = nn.Embedding(max_dist + 2, num_heads)  # +2: 0 (self), -1 (unreachable)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, dist_matrix):
        """
        x:           [N, d_model]
        dist_matrix: [N, N] integer shortest-path distances (-1 = unreachable)
        """
        N = x.size(0)

        Q = self.W_Q(x).view(N, self.num_heads, self.d_k).transpose(0, 1)  # [H, N, d_k]
        K = self.W_K(x).view(N, self.num_heads, self.d_k).transpose(0, 1)
        V = self.W_V(x).view(N, self.num_heads, self.d_k).transpose(0, 1)

        # Standard attention scores
        attn = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(self.d_k)  # [H, N, N]

        # Add spatial bias
        # Clamp dist: unreachable (-1) -> max_dist+1, > max_dist -> max_dist
        d = dist_matrix.clone()
        d[d < 0] = self.spatial_bias.num_embeddings - 1
        d = d.clamp(max=self.spatial_bias.num_embeddings - 1)
        bias = self.spatial_bias(d)           # [N, N, H]
        bias = bias.permute(2, 0, 1)          # [H, N, N]
        attn = attn + bias

        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        out = torch.bmm(attn, V)              # [H, N, d_k]
        out = out.transpose(0, 1).contiguous().view(N, self.d_model)
        out = self.W_O(out)

        x = self.norm1(x + self.dropout(out))
        x = self.norm2(x + self.dropout(self.ffn(x)))
        return x


# Test with a random distance matrix
N, d, H = 8, 16, 2
x_g = torch.randn(N, d)
dist_m = torch.randint(0, 5, (N, N))
dist_m.fill_diagonal_(0)
gl = GraphormerLayer(d, H)
out_g = gl(x_g, dist_m)
print('Graphormer layer output shape:', out_g.shape)

---
## 3. Heterogeneous Graph Transformer (HGT)

### 3.1 Theory

**HGT** (Hu et al., 2020) generalises the Transformer to **heterogeneous graphs** by defining **type-specific attention** for every combination of source node type, edge type, and target node type.

#### 3.1.1 HGT Attention

For a target node $t$ of type $\tau(t)$ and a source node $s$ of type $\tau(s)$ connected by edge type $\phi(e)$:

**Multi-head attention score:**
$$\text{Attn}^{(i)}(s, e, t) = \frac{\mathbf{K}^{(i)}(s) \cdot \mathbf{Q}^{(i)}(t)^\top}{\sqrt{d/h}} \cdot \mathbf{W}^{ATT}_{\phi(e)}$$

where:
- $\mathbf{K}^{(i)}(s) = \mathbf{h}_s \mathbf{W}_{\tau(s)}^{K,i}$ — type-specific key projection
- $\mathbf{Q}^{(i)}(t) = \mathbf{h}_t \mathbf{W}_{\tau(t)}^{Q,i}$ — type-specific query projection
- $\mathbf{W}^{ATT}_{\phi(e)} \in \mathbb{R}^{d/h \times d/h}$ — edge-type-specific attention matrix

**Message:**
$$\text{Msg}^{(i)}(s, e, t) = \mathbf{h}_s \mathbf{W}_{\tau(s)}^{V,i} \cdot \mathbf{W}^{MSG}_{\phi(e)}$$

**Aggregation:**
$$\tilde{\mathbf{h}}_t = \bigoplus_{i=1}^{h} \sum_{s \in \mathcal{N}(t)} \text{softmax}(\text{Attn}^{(i)}) \cdot \text{Msg}^{(i)}$$

**Update:**
$$\mathbf{h}_t^{(l+1)} = \sigma\!\left(\text{LayerNorm}(\tilde{\mathbf{h}}_t \mathbf{W}_{\tau(t)}^{A})\right)$$

#### 3.1.2 Key Properties

| Property | Description |
|----------|-------------|
| **Type-specific projections** | Each node type has its own Q, K, V matrices |
| **Relation-specific attention** | Edge types modulate attention scores and messages |
| **Scalable** | Attention is sparse (only along edges) |
| **Inductive** | Can handle new nodes at test time |

### 3.2 HGT with PyG

In [None]:
if pyg_available:
    from torch_geometric.nn import HGTConv, Linear as PyGLinear
    from torch_geometric.data import HeteroData

    # Re-create the academic heterogeneous graph from Notebook 3
    hdata = HeteroData()
    hdata['author'].x  = torch.randn(4, 8)
    hdata['paper'].x   = torch.randn(6, 16)
    hdata['venue'].x   = torch.randn(2, 4)

    hdata['author', 'writes', 'paper'].edge_index = torch.tensor([
        [0, 1, 2, 3, 0], [0, 1, 2, 3, 4]
    ])
    hdata['paper', 'cites', 'paper'].edge_index = torch.tensor([
        [0, 1, 2, 3], [1, 2, 3, 4]
    ])
    hdata['paper', 'publishedIn', 'venue'].edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 5], [0, 0, 1, 1, 0, 1]
    ])

    print(hdata)

In [None]:
if pyg_available:
    class HGT(nn.Module):
        def __init__(self, hidden_channels, out_channels, num_heads, num_layers, data):
            super().__init__()

            # Input linear projections per node type (to unified hidden_channels)
            self.lin_dict = nn.ModuleDict()
            for ntype in data.node_types:
                in_ch = data[ntype].x.size(-1)
                self.lin_dict[ntype] = PyGLinear(in_ch, hidden_channels)

            # HGT layers
            self.convs = nn.ModuleList([
                HGTConv(hidden_channels, hidden_channels,
                        metadata=data.metadata(), heads=num_heads)
                for _ in range(num_layers)
            ])

            # Output classifier (for 'paper' nodes)
            self.lin_out = PyGLinear(hidden_channels, out_channels)

        def forward(self, x_dict, edge_index_dict):
            # Project all node types to hidden_channels
            x_dict = {
                ntype: self.lin_dict[ntype](x).relu_()
                for ntype, x in x_dict.items()
            }
            # Heterogeneous transformer layers
            for conv in self.convs:
                x_dict = conv(x_dict, edge_index_dict)
            # Classify paper nodes
            return self.lin_out(x_dict['paper'])


    hgt_model = HGT(hidden_channels=32, out_channels=4,
                    num_heads=2, num_layers=2, data=hdata)
    out_hgt = hgt_model(hdata.x_dict, hdata.edge_index_dict)
    print('HGT output (paper nodes):', out_hgt.shape)  # [6, 4]

### 3.3 HGT on DBLP Dataset

In [None]:
if pyg_available:
    try:
        # DBLP: author classification (4 areas: DB, DM, IR, ML)
        # Node types: author, paper, term, conference
        dataset_dblp = DBLP(root='/tmp/DBLP', transform=T.Constant(node_types='conference'))
        dblp_data = dataset_dblp[0]
        print(dblp_data)
        print('Author labels:', dblp_data['author'].y.unique())
    except Exception as e:
        print(f'DBLP dataset not available: {e}')

In [None]:
if pyg_available:
    try:
        class HGT_DBLP(nn.Module):
            def __init__(self, data, hidden_ch=64, out_ch=4, heads=4, num_layers=2):
                super().__init__()
                self.lin_dict = nn.ModuleDict()
                for ntype in data.node_types:
                    in_c = data[ntype].x.size(-1)
                    self.lin_dict[ntype] = PyGLinear(in_c, hidden_ch)

                self.convs = nn.ModuleList([
                    HGTConv(hidden_ch, hidden_ch, metadata=data.metadata(), heads=heads)
                    for _ in range(num_layers)
                ])
                self.classifier = PyGLinear(hidden_ch, out_ch)

            def forward(self, x_dict, edge_index_dict):
                x_dict = {nt: self.lin_dict[nt](x).relu_() for nt, x in x_dict.items()}
                for conv in self.convs:
                    x_dict = conv(x_dict, edge_index_dict)
                return self.classifier(x_dict['author'])


        dblp_data = dblp_data.to(device)
        hgt_dblp  = HGT_DBLP(dblp_data).to(device)
        opt_hgt   = optim.Adam(hgt_dblp.parameters(), lr=5e-4, weight_decay=1e-4)

        def train_hgt():
            hgt_dblp.train()
            opt_hgt.zero_grad()
            out = hgt_dblp(dblp_data.x_dict, dblp_data.edge_index_dict)
            loss = F.cross_entropy(out[dblp_data['author'].train_mask],
                                   dblp_data['author'].y[dblp_data['author'].train_mask])
            loss.backward()
            opt_hgt.step()
            return loss.item()

        @torch.no_grad()
        def test_hgt():
            hgt_dblp.eval()
            out = hgt_dblp(dblp_data.x_dict, dblp_data.edge_index_dict)
            pred = out.argmax(dim=1)
            accs = {}
            for split in ['train_mask', 'val_mask', 'test_mask']:
                mask = dblp_data['author'][split]
                acc  = (pred[mask] == dblp_data['author'].y[mask]).float().mean()
                accs[split.replace('_mask', '')] = acc.item()
            return accs

        hgt_losses = []
        for epoch in range(1, 101):
            l = train_hgt()
            hgt_losses.append(l)

        accs = test_hgt()
        print(f"Train: {accs['train']:.4f} | Val: {accs['val']:.4f} | Test: {accs['test']:.4f}")
    except NameError:
        print('DBLP dataset not loaded — skipping HGT training')

---
## 4. Relational GCN (R-GCN)

### 4.1 Theory

**R-GCN** (Schlichtkrull et al., 2018) extends GCN to **multi-relational (knowledge) graphs** by using **relation-type-specific weight matrices**.

#### 4.1.1 Propagation Rule

$$\mathbf{h}_i^{(l+1)} = \sigma\!\left(
    \mathbf{W}_0^{(l)} \mathbf{h}_i^{(l)}
    + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
      \frac{1}{c_{i,r}} \mathbf{W}_r^{(l)} \mathbf{h}_j^{(l)}
\right)$$

| Symbol | Meaning |
|--------|----------|
| $\mathbf{W}_0^{(l)}$ | Self-loop weight matrix |
| $\mathbf{W}_r^{(l)}$ | Relation-$r$-specific weight matrix |
| $\mathcal{N}_r(i)$ | Neighbours of $i$ via relation $r$ |
| $c_{i,r}$ | Normalisation constant (e.g., $|\mathcal{N}_r(i)|$) |

#### 4.1.2 Parameter Reduction

With many relation types, having a separate $\mathbf{W}_r$ for each $r$ leads to a huge parameter count and overfitting. R-GCN proposes two regularisation techniques:

**Basis decomposition:**
$$\mathbf{W}_r = \sum_{b=1}^B a_{rb} \mathbf{V}_b$$

The relation-specific matrices $\mathbf{W}_r$ are linear combinations of $B$ shared basis matrices $\mathbf{V}_b$, with relation-specific coefficients $a_{rb}$.

**Block-diagonal decomposition:**
$$\mathbf{W}_r = \bigoplus_{b=1}^B \mathbf{Q}_{br}$$

Each $\mathbf{W}_r$ is block-diagonal with blocks $\mathbf{Q}_{br}$.

#### 4.1.3 Applications

1. **Entity classification** — classify nodes in a KG
2. **Link prediction** — combined with a decoder (e.g., DistMult) for KG completion

### 4.2 R-GCN from Scratch

In [None]:
class RGCNConvFromScratch(nn.Module):
    """
    R-GCN convolution layer with basis decomposition.
    """

    def __init__(self, in_channels, out_channels, num_relations, num_bases=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases or num_relations  # full if None

        # Basis matrices
        self.basis = nn.Parameter(torch.empty(self.num_bases, in_channels, out_channels))
        # Relation-specific coefficients
        if num_bases and num_bases < num_relations:
            self.att = nn.Parameter(torch.empty(num_relations, self.num_bases))
        else:
            self.att = None

        # Self-loop weight
        self.root_weight = nn.Linear(in_channels, out_channels, bias=False)
        self.bias = nn.Parameter(torch.zeros(out_channels))

        nn.init.xavier_uniform_(self.basis)
        if self.att is not None:
            nn.init.xavier_uniform_(self.att)

    def weight_for_relation(self, r_idx):
        """Compute the weight matrix W_r using basis decomposition."""
        if self.att is not None:
            # W_r = sum_b a_{r,b} * V_b
            w = (self.att[r_idx].unsqueeze(-1).unsqueeze(-1) * self.basis).sum(0)
        else:
            w = self.basis[r_idx]
        return w   # [in_ch, out_ch]

    def forward(self, x, edge_index, edge_type):
        """
        x:          [N, in_channels]
        edge_index: [2, E]
        edge_type:  [E]  integer edge type for each edge
        """
        N = x.size(0)
        out = self.root_weight(x)   # self-loop

        for r in range(self.num_relations):
            mask = edge_type == r
            if mask.sum() == 0:
                continue
            ei_r = edge_index[:, mask]     # edges of relation r
            src, dst = ei_r

            W_r = self.weight_for_relation(r)   # [in_ch, out_ch]
            msg = x[src] @ W_r                  # [E_r, out_ch]

            # Normalise by number of relation-r neighbours
            deg_r = torch.zeros(N).scatter_add_(0, dst, torch.ones(dst.size(0)))
            norm  = deg_r[dst].clamp(min=1).unsqueeze(-1)
            msg   = msg / norm

            out.scatter_add_(0, dst.unsqueeze(-1).expand_as(msg), msg)

        return F.relu(out + self.bias)


print('RGCNConvFromScratch defined')

In [None]:
# Test on a small multi-relational graph
N_r = 6
num_rel = 3
x_r   = torch.randn(N_r, 8)
ei_r  = torch.tensor([[0,1,2,3,4,0,2], [1,2,3,4,5,3,5]], dtype=torch.long)
et_r  = torch.tensor([0,0,1,1,2,2,0], dtype=torch.long)   # edge types

rgcn_scratch = RGCNConvFromScratch(8, 16, num_relations=num_rel, num_bases=2)
out_r = rgcn_scratch(x_r, ei_r, et_r)
print('RGCN scratch output shape:', out_r.shape)  # [6, 16]

### 4.3 R-GCN with PyG's `RGCNConv`

In [None]:
if pyg_available:
    from torch_geometric.nn import RGCNConv

    class RGCN(nn.Module):
        def __init__(self, in_ch, hidden_ch, out_ch, num_relations, num_bases=None):
            super().__init__()
            self.conv1 = RGCNConv(in_ch, hidden_ch, num_relations=num_relations,
                                  num_bases=num_bases)
            self.conv2 = RGCNConv(hidden_ch, out_ch, num_relations=num_relations,
                                  num_bases=num_bases)

        def forward(self, x, edge_index, edge_type):
            x = F.relu(self.conv1(x, edge_index, edge_type))
            x = F.dropout(x, p=0.3, training=self.training)
            x = self.conv2(x, edge_index, edge_type)
            return F.log_softmax(x, dim=1)

    rgcn_pyg = RGCN(in_ch=8, hidden_ch=16, out_ch=4,
                    num_relations=num_rel, num_bases=2)
    out_rgcn = rgcn_pyg(x_r, ei_r, et_r)
    print('RGCN PyG output shape:', out_rgcn.shape)  # [6, 4]

### 4.4 R-GCN for Link Prediction on a KG

In [None]:
# R-GCN encoder + DistMult decoder for KG link prediction
if pyg_available:
    class RGCNLinkPredictor(nn.Module):
        """
        Encoder: R-GCN to produce node embeddings.
        Decoder: DistMult scoring function.
        """
        def __init__(self, num_entities, num_relations, hidden_ch, emb_dim, num_bases=None):
            super().__init__()
            # Learnable entity embeddings as input features
            self.entity_emb = nn.Embedding(num_entities, hidden_ch)

            # R-GCN encoder
            self.encoder = RGCNConv(hidden_ch, emb_dim,
                                    num_relations=num_relations,
                                    num_bases=num_bases)

            # DistMult relation embedding
            self.relation_emb = nn.Embedding(num_relations, emb_dim)

        def encode(self, edge_index, edge_type):
            x = self.entity_emb.weight   # [N, hidden_ch]
            z = F.relu(self.encoder(x, edge_index, edge_type))  # [N, emb_dim]
            return z

        def decode(self, z, h_idx, r_idx, t_idx):
            h = z[h_idx]
            r = self.relation_emb(r_idx)
            t = z[t_idx]
            return (h * r * t).sum(dim=-1)

        def forward(self, edge_index, edge_type, pos_triples, neg_triples):
            z = self.encode(edge_index, edge_type)

            ph, pr, pt = pos_triples[:, 0], pos_triples[:, 1], pos_triples[:, 2]
            nh, nr, nt = neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2]

            pos_score = self.decode(z, ph, pr, pt)
            neg_score = self.decode(z, nh, nr, nt)

            all_scores = torch.cat([pos_score, neg_score])
            all_labels = torch.cat([
                torch.ones(len(pos_score)),
                torch.zeros(len(neg_score))
            ])
            return F.binary_cross_entropy_with_logits(all_scores, all_labels)

    print('RGCNLinkPredictor defined')

In [None]:
if pyg_available:
    # Use toy KG from Notebook 3
    entity2id = {'Alice': 0, 'Bob': 1, 'Carol': 2, 'London': 3, 'Paris': 4, 'UK': 5}
    relation2id = {'knows': 0, 'livesIn': 1, 'locatedIn': 2}
    triples = [(0,0,1),(1,0,2),(0,1,3),(1,1,4),(3,2,5)]
    triples_tensor = torch.tensor(triples, dtype=torch.long)

    num_entities_toy  = len(entity2id)
    num_relations_toy = len(relation2id)

    # Build edge_index and edge_type for R-GCN (use all triples as graph structure)
    edge_idx_toy  = triples_tensor[:, [0, 2]].T.contiguous()  # [2, E]
    edge_type_toy = triples_tensor[:, 1]                       # [E]

    def corrupt_triples_kg(triples, num_ent):
        neg = triples.clone()
        mask = torch.rand(len(triples)) > 0.5
        rand = torch.randint(0, num_ent, (len(triples),))
        neg[mask, 0] = rand[mask]
        neg[~mask, 2] = rand[~mask]
        return neg

    rgcn_lp = RGCNLinkPredictor(
        num_entities=num_entities_toy,
        num_relations=num_relations_toy,
        hidden_ch=16, emb_dim=8, num_bases=None
    )
    opt_rgcn_lp = optim.Adam(rgcn_lp.parameters(), lr=0.01)

    lp_losses = []
    for epoch in range(300):
        neg = corrupt_triples_kg(triples_tensor, num_entities_toy)
        opt_rgcn_lp.zero_grad()
        loss = rgcn_lp(edge_idx_toy, edge_type_toy, triples_tensor, neg)
        loss.backward()
        opt_rgcn_lp.step()
        lp_losses.append(loss.item())

    plt.plot(lp_losses)
    plt.xlabel('Epoch'); plt.ylabel('BCE Loss')
    plt.title('R-GCN Link Prediction Training Loss'); plt.show()
    print(f'Final loss: {lp_losses[-1]:.4f}')

---
## 5. Comparison Experiment

### 5.1 Node Classification on Cora: GCN vs Graph Transformer

In [None]:
if pyg_available:
    from torch_geometric.nn import GCNConv

    dataset_cora = Planetoid(root='/tmp/Cora', name='Cora',
                             transform=T.NormalizeFeatures())
    data_cora = dataset_cora[0].to(device)

    class GCN_Baseline(nn.Module):
        def __init__(self, in_ch, hidden, out_ch):
            super().__init__()
            self.c1 = GCNConv(in_ch, hidden)
            self.c2 = GCNConv(hidden, out_ch)
        def forward(self, x, edge_index):
            x = F.relu(self.c1(x, edge_index))
            x = F.dropout(x, 0.5, training=self.training)
            return F.log_softmax(self.c2(x, edge_index), dim=1)

    def run_experiment(model_cls, name, **kwargs):
        torch.manual_seed(42)
        model = model_cls(**kwargs).to(device)
        opt   = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        for _ in range(200):
            model.train()
            opt.zero_grad()
            out = model(data_cora.x, data_cora.edge_index)
            F.nll_loss(out[data_cora.train_mask],
                       data_cora.y[data_cora.train_mask]).backward()
            opt.step()
        model.eval()
        with torch.no_grad():
            out  = model(data_cora.x, data_cora.edge_index)
            pred = out.argmax(1)
            test_acc = (pred[data_cora.test_mask] ==
                        data_cora.y[data_cora.test_mask]).float().mean().item()
        print(f'{name}: Test Acc = {test_acc:.4f}')
        return test_acc

    gcn_acc = run_experiment(
        GCN_Baseline, 'GCN',
        in_ch=dataset_cora.num_features, hidden=64, out_ch=dataset_cora.num_classes
    )

    gt_acc = run_experiment(
        PyGGraphTransformer, 'Graph Transformer (TransformerConv)',
        in_ch=dataset_cora.num_features,
        hidden_ch=16, out_ch=dataset_cora.num_classes,
        heads=4, num_layers=2, dropout=0.1
    )

### 5.2 Model Comparison Summary

| Model | Node Types | Edge Types | Attention | Long-Range | Scalability |
|-------|-----------|-----------|-----------|-----------|-------------|
| **GCN** | 1 | 1 | No (degree norm) | No | O(\|E\|) |
| **Graph Transformer** | 1 | 1 | Yes (MHA) | Yes (global) | O(\|E\| or N²) |
| **Graphormer** | 1 | 1 | Yes + structural bias | Yes | O(N²) |
| **HGT** | Multiple | Multiple | Yes (type-specific) | No (sparse) | O(\|E\|) |
| **R-GCN** | 1 | Multiple | No | No | O(\|R\|·\|E\|) |

---
## 6. Exercises

### Exercise 1 — Positional Encodings for Graph Transformers

Laplacian positional encoding (LPE) injects graph-structural information into node features:
1. Compute the normalised Laplacian $\tilde{L}$ for the Cora graph.
2. Extract the top-$k$ ($k=8$) non-trivial eigenvectors.
3. Concatenate them to node features before passing to the Graph Transformer.
4. Compare test accuracy with and without LPE.

*Hint: use `scipy.sparse.linalg.eigsh` or `torch.linalg.eigh`.*

In [None]:
# Exercise 1 — your solution here
def laplacian_pe(edge_index, num_nodes, k=8):
    """Compute top-k Laplacian eigenvectors as positional encodings."""
    # TODO: build normalised Laplacian and compute eigenvectors
    ...

### Exercise 2 — Graphormer Centrality Encoding

Implement the **centrality encoding** from Graphormer:
1. Compute in-degree and out-degree for each node in a directed graph.
2. Create learnable degree embeddings.
3. Add them to the initial node features.
4. Train a simplified Graphormer on a classification task and compare with/without centrality encoding.

In [None]:
# Exercise 2 — your solution here
class CentralityEncoding(nn.Module):
    def __init__(self, max_degree, d_model):
        super().__init__()
        self.in_deg_emb  = nn.Embedding(max_degree + 1, d_model)
        self.out_deg_emb = nn.Embedding(max_degree + 1, d_model)

    def forward(self, x, edge_index, num_nodes):
        # TODO: compute in/out degrees and add embeddings to x
        ...

### Exercise 3 — R-GCN for Entity Classification

The **AIFB** and **MUTAG** datasets are popular KG entity classification benchmarks.
Using `torch_geometric.datasets.Entities`:
1. Load the AIFB dataset.
2. Build a 2-layer R-GCN with basis decomposition ($B = 30$).
3. Train and report test accuracy.
4. Compare the number of parameters with and without basis decomposition.

In [None]:
# Exercise 3 — your solution here
if pyg_available:
    try:
        from torch_geometric.datasets import Entities
        aifb = Entities(root='/tmp/AIFB', name='AIFB')
        print(aifb[0])
        # TODO: build and train R-GCN
    except Exception as e:
        print(f'Dataset not available: {e}')

### Exercise 4 — HGT vs R-GCN on DBLP

1. Load the DBLP dataset.
2. Convert it to a **homogeneous** graph using `data.to_homogeneous()` and train an R-GCN.
3. Train an HGT on the original heterogeneous DBLP graph.
4. Compare test accuracy and discuss why HGT should perform better on heterogeneous data.

In [None]:
# Exercise 4 — your solution here
...

### Exercise 5 — Global Attention (Fully Connected) vs Sparse Attention

Implement a **fully-connected attention** variant of the Graph Transformer (where every node can attend to every other node, $O(N^2)$):
1. Replace the sparse neighbourhood-based attention in `GraphTransformerLayer` with full $N \times N$ attention.
2. Add a binary edge mask to optionally restrict attention to neighbourhoods.
3. Run both variants on a small graph, compare attention patterns and performance.
4. Discuss the trade-off between expressiveness and scalability.

In [None]:
# Exercise 5 — your solution here
class FullAttentionGraphTransformer(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        # TODO: full N x N attention with optional edge mask
        ...

---
## Summary

| Model | Core Mechanism | Graph Types | When to Use |
|-------|----------------|-------------|-------------|
| **Graph Transformer** | MHA + LPE | Homogeneous | Need long-range, interpretable attention |
| **Graphormer** | MHA + spatial/centrality/edge biases | Homogeneous | Molecular graphs, high accuracy |
| **HGT** | Type-specific MHA | Heterogeneous | Multiple node/edge types with attention |
| **R-GCN** | Relation-specific GCN | Multi-relational / KG | Knowledge graphs, entity classification, link prediction |

### What We've Covered Across All Notebooks

| # | Notebook | Topics |
|---|----------|--------|
| 1 | PyTorch & PyG | Tensors, autograd, `nn.Module`, `Data`, `MessagePassing` |
| 2 | GCN, GraphSAGE, GAT | Spectral GNNs, inductive learning, attention |
| 3 | KG Embeddings | TransE, TransR, RotatE, ComplEx, LiteralE, Hetero graphs |
| 4 | Advanced GNNs | Graph Transformer, Graphormer, HGT, R-GCN |