# Lecture 2: Knowledge Graphs & R-GCN for Link Prediction (PyTorch Geometric)

This notebook builds on Lecture 1 and focuses on **knowledge graphs (KGs)** and **relational graph convolutional networks (R-GCN)** for link prediction.

**Learning goals**

- Understand the knowledge graph formalism: entities, relations, triples.
- See the connection between KGs and heterogeneous / multilayer networks.
- Implement a toy biomedical KG in PyG format.
- Train a simple R-GCN-based link prediction model with DistMult-style scoring.


## 1. Knowledge Graphs: Definition

A **knowledge graph** is a directed, labeled multigraph:

- One shared entity set \(\mathcal{E}\)
- A set of relations \(\mathcal{R}\)
- Facts represented as **triples** \((h, r, t)\):

\begin{equation}
(h, r, t) \in \mathcal{E} \times \mathcal{R} \times \mathcal{E}
\end{equation}

- \(h\): head entity
- \(r\): relation type
- \(t\): tail entity

Examples:

- (Aspirin, *treats*, Headache)
- (Alan Turing, *born_in*, London)
- (DrugM, *targets*, GeneA)

From a network science point of view:

- We have **one node type "entity"**.
- Each relation \(r\) defines a **layer** with its own adjacency matrix \(A_r\).

## 2. R-GCN: Relational Message Passing

Relational GCN (R-GCN) extends GCN to handle many relation types. For entity \(i\), the layer update is:

\begin{equation}
\mathbf{h}_i^{(l+1)} = \sigma\left(
    \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{c_{i,r}} W_r^{(l)} \mathbf{h}_j^{(l)}
    + W_0^{(l)} \mathbf{h}_i^{(l)}
\right),
\end{equation}

where:
- \(\mathcal{N}_r(i)\) are neighbors of \(i\) under relation \(r\),
- \(W_r^{(l)}\) is the weight matrix for relation \(r\),
- \(W_0^{(l)}\) handles self-loops, and
- \(c_{i,r}\) is a normalization constant.

Again, each relation layer contributes its own "signal" to the update of \(i\).

## 3. Setup

As before, we import `torch` and `torch_geometric`. 

> Installation commands are commented out; adapt to your environment if needed.


In [None]:
# !pip install torch torch_geometric -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import RGCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 4. Build a Toy Biomedical Knowledge Graph

We create a small KG with entities of implicit types:

- Genes: `GeneA`, `GeneB`
- Diseases: `DiseaseX`, `DiseaseY`
- Drugs: `DrugM`, `DrugN`

Relations:

- `associates_with` (Gene → Disease)
- `treats` (Drug → Disease)
- `targets` (Drug → Gene)

We will encode these as integer IDs and then as PyG tensors.


In [None]:
# Define entity and relation vocabularies
entity2id = {
    'GeneA': 0,
    'GeneB': 1,
    'DiseaseX': 2,
    'DiseaseY': 3,
    'DrugM': 4,
    'DrugN': 5,
}
num_entities = len(entity2id)

rel2id = {
    'associates_with': 0,  # Gene - Disease
    'treats': 1,           # Drug - Disease
    'targets': 2,          # Drug - Gene
}
num_relations = len(rel2id)

# Define triples (h, r, t)
triples = [
    ('GeneA', 'associates_with', 'DiseaseX'),
    ('GeneB', 'associates_with', 'DiseaseY'),
    ('DrugM', 'treats', 'DiseaseX'),
    ('DrugN', 'treats', 'DiseaseY'),
    ('DrugM', 'targets', 'GeneA'),
    ('DrugN', 'targets', 'GeneB'),
]

heads = torch.tensor([entity2id[h] for (h, r, t) in triples])
rels  = torch.tensor([rel2id[r]    for (h, r, t) in triples])
tails = torch.tensor([entity2id[t] for (h, r, t) in triples])

edge_index = torch.stack([heads, tails], dim=0)   # shape [2, num_edges]
edge_type  = rels                                  # shape [num_edges]

num_edges = edge_index.size(1)
num_entities, num_edges, num_relations

## 5. Define an R-GCN Encoder

We define a small R-GCN with two layers. Node features are learnable embeddings initialized randomly.

In [None]:
class RGCN(nn.Module):
    def __init__(self, num_entities, num_relations, emb_dim=32, hidden_dim=64):
        super().__init__()
        # Initial entity embeddings
        self.entity_emb = nn.Embedding(num_entities, emb_dim)

        self.conv1 = RGCNConv(
            in_channels=emb_dim,
            out_channels=hidden_dim,
            num_relations=num_relations
        )
        self.conv2 = RGCNConv(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            num_relations=num_relations
        )

    def forward(self, edge_index, edge_type):
        x = self.entity_emb.weight  # [num_entities, emb_dim]
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_type)
        return x  # final entity embeddings

## 6. Link Prediction with DistMult-Style Scoring

We use a simple DistMult-like score function:

\begin{equation}
\phi(h, r, t) = \langle \mathbf{e}_h, \mathbf{r}_r, \mathbf{e}_t \rangle
= \sum_k e_{h,k} \, r_{r,k} \, e_{t,k}
\end{equation}

where:
- \(\mathbf{e}_h\) and \(\mathbf{e}_t\) are entity embeddings,
- \(\mathbf{r}_r\) is a relation embedding.

In [None]:
class RGCNLinkPredictor(nn.Module):
    def __init__(self, num_entities, num_relations, emb_dim=32, hidden_dim=64):
        super().__init__()
        self.rgcn = RGCN(num_entities, num_relations, emb_dim, hidden_dim)
        self.rel_emb = nn.Embedding(num_relations, hidden_dim)

    def forward(self, edge_index, edge_type):
        # Return entity embeddings
        return self.rgcn(edge_index, edge_type)

    def score_triples(self, entity_emb, heads, rels, tails):
        # entity_emb: [num_entities, hidden_dim]
        h = entity_emb[heads]           # [B, d]
        r = self.rel_emb(rels)         # [B, d]
        t = entity_emb[tails]         # [B, d]
        return (h * r * t).sum(dim=-1)  # [B]

## 7. Negative Sampling and Training Loop

For each positive triple \((h, r, t)\), we generate a negative triple by corrupting the tail with a random entity. We then optimize a binary cross-entropy loss over positive and negative scores.

In [None]:
import random

def negative_sampling(num_entities, heads, rels, tails):
    """Very simple negative sampler: replace tail with a random entity."""
    neg_tails = tails.clone()
    for i in range(len(tails)):
        neg_tails[i] = random.randrange(num_entities)
    return heads, rels, neg_tails

# Move tensors and model to device
model = RGCNLinkPredictor(num_entities, num_relations).to(device)
edge_index = edge_index.to(device)
edge_type = edge_type.to(device)
heads = heads.to(device)
rels = rels.to(device)
tails = tails.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01)
bce_loss = nn.BCEWithLogitsLoss()

def train_link_pred(epochs=200):
    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        # 1) Get entity embeddings from R-GCN
        entity_emb = model(edge_index, edge_type)

        # 2) Positive scores
        pos_scores = model.score_triples(entity_emb, heads, rels, tails)

        # 3) Negative samples & scores
        neg_heads, neg_rels, neg_tails = negative_sampling(num_entities, heads, rels, tails)
        neg_heads = neg_heads.to(device)
        neg_rels  = neg_rels.to(device)
        neg_tails = neg_tails.to(device)
        neg_scores = model.score_triples(entity_emb, neg_heads, neg_rels, neg_tails)

        # 4) Labels: 1 for pos, 0 for neg
        scores = torch.cat([pos_scores, neg_scores], dim=0)
        labels = torch.cat([
            torch.ones_like(pos_scores),
            torch.zeros_like(neg_scores)
        ], dim=0)

        loss = bce_loss(scores, labels)
        loss.backward()
        optimizer.step()

        if epoch % 20 == 0:
            with torch.no_grad():
                pos_prob = torch.sigmoid(pos_scores)
                neg_prob = torch.sigmoid(neg_scores)
                print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, "
                      f"Pos prob mean: {pos_prob.mean().item():.3f}, "
                      f"Neg prob mean: {neg_prob.mean().item():.3f}")


train_link_pred(epochs=200)

## 8. Inspect Learned Embeddings and Scores

We can inspect the scores for all true triples and compare them with some random negatives to see if the model has learned reasonable distinctions.

In [None]:
@torch.no_grad()
def inspect_scores():
    model.eval()
    entity_emb = model(edge_index, edge_type)

    pos_scores = torch.sigmoid(model.score_triples(entity_emb, heads, rels, tails))

    print("\nPositive triples and their scores:")
    for i, (h, r, t) in enumerate(triples):
        print(f"{(h, r, t)} -> score={pos_scores[i].item():.3f}")

    # Sample some random negative triples
    neg_heads, neg_rels, neg_tails = negative_sampling(num_entities, heads, rels, tails)
    neg_scores = torch.sigmoid(model.score_triples(entity_emb, neg_heads.to(device),
                                                   neg_rels.to(device), neg_tails.to(device)))

    print("\nExample negative triples and their scores:")
    for i in range(len(triples)):
        h_id, r_id, t_id = int(neg_heads[i]), int(neg_rels[i]), int(neg_tails[i])
        inv_entity = {v: k for k, v in entity2id.items()}
        inv_rel = {v: k for k, v in rel2id.items()}
        h_name, r_name, t_name = inv_entity[h_id], inv_rel[r_id], inv_entity[t_id]
        print(f"{(h_name, r_name, t_name)} -> score={neg_scores[i].item():.3f}")


inspect_scores()

## 9. Discussion & Exercises

**Conceptual questions:**

1. View each relation as a distinct **layer** in a multilayer network. How does R-GCN combine layer-specific signals?
2. What happens if we remove one relation type (e.g., `targets`)? Can we still infer which drug treats which disease?

**Coding exercises:**

1. Extend the KG with an additional entity type, e.g., `VariantX`, and relation `has_variant` (Gene → Variant).
2. Implement a small evaluation routine that computes Hits@k or Mean Reciprocal Rank (MRR) on a held-out set of triples.
3. Compare R-GCN-based link prediction with a simple translational KGE model like TransE on this toy data.

These exercises connect classical network science ideas (layers, adjacency matrices, diffusion) with modern graph ML models on heterogeneous and knowledge graphs.
