# Notebook 3: Knowledge Graph Embeddings

This notebook covers methods to **embed Knowledge Graphs (KGs)** — learning low-dimensional vector representations of entities and relations — and an introduction to **Heterogeneous Graphs**.

| Model | Paper | Key Idea |
|-------|-------|----------|
| **TransE** | Bordes et al. (2013) | $h + r \approx t$ in Euclidean space |
| **TransR** | Lin et al. (2015) | Entity/relation in separate spaces, projection matrix |
| **RotatE** | Sun et al. (2019) | Relation as rotation in complex space |
| **ComplEx** | Trouillon et al. (2016) | Complex-valued embeddings for asymmetric relations |
| **LiteralE** | Kristiansen et al. (2018) | Incorporate literal attributes into embeddings |
| **Heterogeneous Graphs** | — | Multiple node/edge types |

**Contents**
1. [Knowledge Graphs: Background](#1-knowledge-graphs-background)
2. [TransE](#2-transe)
3. [TransR](#3-transr)
4. [RotatE](#4-rotate)
5. [ComplEx](#5-complex)
6. [LiteralE](#6-literale)
7. [Heterogeneous Graphs](#7-heterogeneous-graphs)
8. [Exercises](#8-exercises)

In [None]:
# Uncomment to install
# !pip install torch torch_geometric

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
print(f'PyTorch version: {torch.__version__}')

try:
    import torch_geometric
    print(f'PyG version: {torch_geometric.__version__}')
    pyg_available = True
except ImportError:
    print('PyTorch Geometric not installed')
    pyg_available = False

---
## 1. Knowledge Graphs: Background

### 1.1 What is a Knowledge Graph?

A **Knowledge Graph (KG)** is a multi-relational directed graph $\mathcal{G} = (\mathcal{E}, \mathcal{R}, \mathcal{T})$ where:
- $\mathcal{E}$ — set of **entities** (nodes)
- $\mathcal{R}$ — set of **relation types** (edge labels)
- $\mathcal{T} \subseteq \mathcal{E} \times \mathcal{R} \times \mathcal{E}$ — set of **triples** $(h, r, t)$ meaning "head entity $h$ has relation $r$ to tail entity $t$"

**Example triples:**
```
(London, capitalOf, UK)
(UK, locatedIn, Europe)
(Shakespeare, bornIn, StratfordUponAvon)
```

### 1.2 Knowledge Graph Completion (Link Prediction)

Real-world KGs are **incomplete**. The task of **KG completion** is to predict missing triples, i.e.:
- Given $(h, r, ?)$ → predict the most likely tail entity
- Given $(?, r, t)$ → predict the most likely head entity

### 1.3 KG Embedding Approach

All KG embedding methods share a common framework:
1. Assign each entity $e \in \mathcal{E}$ a vector $\mathbf{e} \in \mathbb{R}^d$ (or $\mathbb{C}^d$)
2. Assign each relation $r \in \mathcal{R}$ a vector (or matrix) $\mathbf{r}$
3. Define a **scoring function** $f(h, r, t) \in \mathbb{R}$ that should be high for true triples and low for false ones
4. Train with a margin-based or binary cross-entropy loss

### 1.4 Evaluation Metrics

| Metric | Description |
|--------|-------------|
| **MRR** | Mean Reciprocal Rank: $\frac{1}{|\mathcal{T}|}\sum_{i} \frac{1}{\text{rank}_i}$ |
| **Hits@k** | Fraction of test triples where the correct entity is in top-$k$ |
| **MR** | Mean Rank |

### 1.5 Toy Knowledge Graph

In [None]:
# Toy KG with 6 entities and 3 relations
entity2id = {
    'Alice': 0, 'Bob': 1, 'Carol': 2,
    'London': 3, 'Paris': 4, 'UK': 5
}
relation2id = {
    'knows': 0, 'livesIn': 1, 'locatedIn': 2
}

# Triples as (head_id, relation_id, tail_id)
triples = [
    (0, 0, 1),  # Alice knows Bob
    (1, 0, 2),  # Bob knows Carol
    (0, 1, 3),  # Alice livesIn London
    (1, 1, 4),  # Bob livesIn Paris
    (3, 2, 5),  # London locatedIn UK
]

num_entities = len(entity2id)
num_relations = len(relation2id)

triples_tensor = torch.tensor(triples, dtype=torch.long)  # [T, 3]
print('Entities:', num_entities)
print('Relations:', num_relations)
print('Training triples:', triples_tensor.shape)

In [None]:
# Visualise the toy KG
try:
    import networkx as nx

    id2entity = {v: k for k, v in entity2id.items()}
    id2relation = {v: k for k, v in relation2id.items()}

    DG = nx.MultiDiGraph()
    for h, r, t in triples:
        DG.add_edge(id2entity[h], id2entity[t], label=id2relation[r])

    pos = nx.spring_layout(DG, seed=7)
    nx.draw_networkx_nodes(DG, pos, node_color='lightblue', node_size=1500)
    nx.draw_networkx_labels(DG, pos, font_size=9)
    edge_labels = {(id2entity[h], id2entity[t]): id2relation[r] for h, r, t in triples}
    nx.draw_networkx_edge_labels(DG, pos, edge_labels=edge_labels, font_size=8)
    nx.draw_networkx_edges(DG, pos, arrows=True)
    plt.title('Toy Knowledge Graph'); plt.axis('off'); plt.show()
except ImportError:
    print('networkx not installed')

---
## 2. TransE

### 2.1 Theory

**TransE** (Bordes et al., 2013) is the seminal translational distance model. It models a relation $r$ as a **translation** in embedding space:

$$\mathbf{h} + \mathbf{r} \approx \mathbf{t}$$

#### Scoring Function

$$f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|_{1/2}$$

(Negative $L_1$ or $L_2$ distance — higher score = more plausible triple)

#### Training Loss — Margin-Based

$$\mathcal{L} = \sum_{(h,r,t) \in \mathcal{T}} \sum_{(h',r,t') \in \mathcal{T}^-}
\left[\gamma + f(h',r,t') - f(h,r,t)\right]_+$$

where $\gamma > 0$ is the margin, $[x]_+ = \max(0, x)$, and $\mathcal{T}^-$ contains **negative** (corrupted) triples.

#### Negative Sampling

Negative triples are created by randomly replacing either $h$ or $t$ with a random entity:
$$\mathcal{T}^- = \{(h', r, t) : h' \in \mathcal{E}\} \cup \{(h, r, t') : t' \in \mathcal{E}\}$$

#### Limitations
- Cannot model **1-to-N**, **N-to-1**, and **N-to-N** relations (forces entities to a single point)
- Assumes every relation has a unique inverse (cannot model symmetric relations well)

### 2.2 TransE Implementation

In [None]:
class TransE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, margin=1.0, norm=1):
        super().__init__()
        self.margin = margin
        self.norm = norm

        self.entity_emb  = nn.Embedding(num_entities,  embedding_dim)
        self.relation_emb = nn.Embedding(num_relations, embedding_dim)

        # Initialise uniformly
        nn.init.uniform_(self.entity_emb.weight,  -6/embedding_dim**0.5, 6/embedding_dim**0.5)
        nn.init.uniform_(self.relation_emb.weight, -6/embedding_dim**0.5, 6/embedding_dim**0.5)

        # Normalise relation embeddings to unit norm
        with torch.no_grad():
            self.relation_emb.weight.data = F.normalize(self.relation_emb.weight.data, p=2, dim=1)

    def score(self, h_idx, r_idx, t_idx):
        """Lower score = more plausible triple."""
        h = F.normalize(self.entity_emb(h_idx), p=2, dim=1)
        r = self.relation_emb(r_idx)
        t = F.normalize(self.entity_emb(t_idx), p=2, dim=1)
        return torch.norm(h + r - t, p=self.norm, dim=1)

    def forward(self, pos_triples, neg_triples):
        """Margin-ranking loss."""
        ph, pr, pt = pos_triples[:, 0], pos_triples[:, 1], pos_triples[:, 2]
        nh, nr, nt = neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2]

        pos_score = self.score(ph, pr, pt)
        neg_score = self.score(nh, nr, nt)

        loss = F.relu(self.margin + pos_score - neg_score).mean()
        return loss


def corrupt_triples(triples, num_entities):
    """Generate negative triples by randomly replacing head or tail."""
    neg = triples.clone()
    mask = torch.rand(len(triples)) > 0.5   # 50% chance replace head, 50% tail
    random_entities = torch.randint(0, num_entities, (len(triples),))
    neg[mask, 0]  = random_entities[mask]    # replace head
    neg[~mask, 2] = random_entities[~mask]   # replace tail
    return neg


print('TransE defined')

In [None]:
# Train TransE on the toy KG
dim = 10
model_transe = TransE(num_entities, num_relations, dim, margin=1.0)
optimizer = optim.Adam(model_transe.parameters(), lr=0.01)

losses = []
for epoch in range(500):
    neg_triples = corrupt_triples(triples_tensor, num_entities)
    optimizer.zero_grad()
    loss = model_transe(triples_tensor, neg_triples)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

plt.plot(losses)
plt.xlabel('Epoch'); plt.ylabel('Margin Loss')
plt.title('TransE Training Loss'); plt.show()
print(f'Final loss: {losses[-1]:.4f}')

In [None]:
# Inspect learned embeddings
model_transe.eval()
with torch.no_grad():
    ent_embs = F.normalize(model_transe.entity_emb.weight, p=2, dim=1)
    rel_embs = model_transe.relation_emb.weight

id2entity = {v: k for k, v in entity2id.items()}

# Verify: h + r ≈ t for (Alice, livesIn, London)
alice_id = entity2id['Alice']
london_id = entity2id['London']
livesIn_id = relation2id['livesIn']

pred = ent_embs[alice_id] + rel_embs[livesIn_id]
dists = torch.norm(ent_embs - pred.unsqueeze(0), p=2, dim=1)
ranked = dists.argsort()

print('Query: (Alice, livesIn, ?)')
for rank, idx in enumerate(ranked):
    print(f'  Rank {rank+1}: {id2entity[idx.item()]} (dist={dists[idx].item():.4f})')

---
## 3. TransR

### 3.1 Theory

**TransR** (Lin et al., 2015) addresses TransE's inability to handle complex relation patterns by introducing **relation-specific projection spaces**.

**Key idea:** Entities and relations live in *different* spaces. Entities are projected into the relation space before the translation:

$$\mathbf{h}_r = \mathbf{h}\mathbf{M}_r, \quad \mathbf{t}_r = \mathbf{t}\mathbf{M}_r$$

where $\mathbf{M}_r \in \mathbb{R}^{d_e \times d_r}$ is a relation-specific projection matrix.

#### Scoring Function

$$f(h, r, t) = -\|\mathbf{h}_r + \mathbf{r} - \mathbf{t}_r\|_2^2$$

#### Advantages over TransE
- Different relations can model different geometric structures (e.g., a plane for 1-to-N)
- Entity dimension $d_e$ and relation dimension $d_r$ can differ

#### Disadvantage
- $|\mathcal{R}|$ projection matrices → large memory footprint

### 3.2 TransR Implementation

In [None]:
class TransR(nn.Module):
    def __init__(self, num_entities, num_relations, entity_dim, relation_dim, margin=1.0):
        super().__init__()
        self.margin = margin
        self.entity_dim = entity_dim
        self.relation_dim = relation_dim

        self.entity_emb  = nn.Embedding(num_entities,  entity_dim)
        self.relation_emb = nn.Embedding(num_relations, relation_dim)
        # Projection matrix for each relation: [num_relations, entity_dim * relation_dim]
        self.proj_matrix = nn.Embedding(num_relations, entity_dim * relation_dim)

        nn.init.xavier_uniform_(self.entity_emb.weight)
        nn.init.xavier_uniform_(self.relation_emb.weight)
        nn.init.xavier_uniform_(self.proj_matrix.weight)

    def project(self, entity_emb, proj):
        """Project entity embedding into relation space."""
        # entity_emb: [B, entity_dim]
        # proj: [B, entity_dim * relation_dim]
        B = entity_emb.size(0)
        M = proj.view(B, self.entity_dim, self.relation_dim)
        # [B, 1, entity_dim] @ [B, entity_dim, relation_dim] -> [B, 1, relation_dim]
        projected = torch.bmm(entity_emb.unsqueeze(1), M).squeeze(1)
        return F.normalize(projected, p=2, dim=1)

    def score(self, h_idx, r_idx, t_idx):
        h   = F.normalize(self.entity_emb(h_idx),  p=2, dim=1)
        t   = F.normalize(self.entity_emb(t_idx),  p=2, dim=1)
        r   = self.relation_emb(r_idx)
        M_r = self.proj_matrix(r_idx)

        h_r = self.project(h, M_r)
        t_r = self.project(t, M_r)

        return torch.norm(h_r + r - t_r, p=2, dim=1)

    def forward(self, pos_triples, neg_triples):
        ph, pr, pt = pos_triples[:, 0], pos_triples[:, 1], pos_triples[:, 2]
        nh, nr, nt = neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2]
        pos_score = self.score(ph, pr, pt)
        neg_score = self.score(nh, nr, nt)
        return F.relu(self.margin + pos_score - neg_score).mean()


model_transr = TransR(num_entities, num_relations, entity_dim=10, relation_dim=6)
optimizer_r = optim.Adam(model_transr.parameters(), lr=0.01)

for epoch in range(500):
    neg = corrupt_triples(triples_tensor, num_entities)
    optimizer_r.zero_grad()
    loss = model_transr(triples_tensor, neg)
    loss.backward()
    optimizer_r.step()

print(f'TransR final loss: {loss.item():.4f}')

---
## 4. RotatE

### 4.1 Theory

**RotatE** (Sun et al., 2019) models each relation as a **rotation in complex space**.

Entities are represented as complex vectors $\mathbf{h}, \mathbf{t} \in \mathbb{C}^d$ and each relation as a complex vector $\mathbf{r} \in \mathbb{C}^d$ with $|r_k| = 1$ for all $k$:

$$\mathbf{h} \circ \mathbf{r} \approx \mathbf{t}$$

where $\circ$ denotes element-wise (Hadamard) product in $\mathbb{C}^d$.

Because $|r_k| = 1$, each dimension performs a **rotation by angle $\theta_{r,k}$**:
$$r_k = e^{i\theta_{r,k}} = \cos\theta_{r,k} + i\sin\theta_{r,k}$$

#### Scoring Function

$$f(h, r, t) = -\|\mathbf{h} \circ \mathbf{r} - \mathbf{t}\|$$

#### Relation Patterns Captured

| Pattern | Example | How RotatE handles it |
|---------|---------|---------------------|
| **Symmetry** | *isSiblingOf* | $\theta_r = 0$ or $\pi$ |
| **Antisymmetry** | *isParentOf* | $\theta_r \neq 0, \pi$ |
| **Inversion** | *isChildOf* = inverse of *isParentOf* | $\theta_{r^{-1}} = -\theta_r$ |
| **Composition** | $r_1 \circ r_2 = r_3$ | $\theta_{r_3} = \theta_{r_1} + \theta_{r_2}$ |

#### Self-Adversarial Negative Sampling

RotatE uses a **self-adversarial** loss:
$$\mathcal{L} = -\log\sigma(\gamma - d_r(h,t)) - \sum_i p(h_i', r, t_i') \log\sigma(d_r(h_i', t_i') - \gamma)$$

where $p(h', r, t') \propto \exp(\alpha \cdot f(h', r, t'))$ is the probability of sampling negative triple $(h', r, t')$.

### 4.2 RotatE Implementation

In [None]:
class RotatE(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, margin=6.0):
        super().__init__()
        assert embedding_dim % 2 == 0, 'embedding_dim must be even (real + imag)'
        self.dim = embedding_dim // 2  # complex dimension
        self.margin = margin

        # Store real and imaginary parts separately
        self.entity_emb_re = nn.Embedding(num_entities, self.dim)
        self.entity_emb_im = nn.Embedding(num_entities, self.dim)

        # Only angle for relations (|r| = 1 enforced by using cos/sin)
        self.relation_phase = nn.Embedding(num_relations, self.dim)

        range_init = (6 / self.dim) ** 0.5
        nn.init.uniform_(self.entity_emb_re.weight, -range_init, range_init)
        nn.init.uniform_(self.entity_emb_im.weight, -range_init, range_init)
        nn.init.uniform_(self.relation_phase.weight, -torch.pi, torch.pi)

    def score(self, h_idx, r_idx, t_idx):
        h_re = self.entity_emb_re(h_idx)
        h_im = self.entity_emb_im(h_idx)
        t_re = self.entity_emb_re(t_idx)
        t_im = self.entity_emb_im(t_idx)

        # Relation as unit complex: r = cos(θ) + i sin(θ)
        theta = self.relation_phase(r_idx)
        r_re = torch.cos(theta)
        r_im = torch.sin(theta)

        # h ○ r = (h_re * r_re - h_im * r_im) + i(h_re * r_im + h_im * r_re)
        hr_re = h_re * r_re - h_im * r_im
        hr_im = h_re * r_im + h_im * r_re

        # Distance: ||h ○ r - t||
        diff_re = hr_re - t_re
        diff_im = hr_im - t_im
        return torch.sqrt((diff_re ** 2 + diff_im ** 2).sum(dim=1) + 1e-8)

    def forward(self, pos_triples, neg_triples):
        ph, pr, pt = pos_triples[:, 0], pos_triples[:, 1], pos_triples[:, 2]
        nh, nr, nt = neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2]
        pos_score = self.score(ph, pr, pt)
        neg_score = self.score(nh, nr, nt)
        return F.relu(self.margin + pos_score - neg_score).mean()


model_rotate = RotatE(num_entities, num_relations, embedding_dim=10, margin=1.0)
optimizer_rot = optim.Adam(model_rotate.parameters(), lr=0.01)

for epoch in range(500):
    neg = corrupt_triples(triples_tensor, num_entities)
    optimizer_rot.zero_grad()
    loss = model_rotate(triples_tensor, neg)
    loss.backward()
    optimizer_rot.step()

print(f'RotatE final loss: {loss.item():.4f}')

### 4.3 Visualising RotatE Relation Phases

In [None]:
model_rotate.eval()
with torch.no_grad():
    phases = model_rotate.relation_phase.weight.cpu()  # [num_relations, dim]

id2relation = {v: k for k, v in relation2id.items()}

fig, axes = plt.subplots(1, num_relations, figsize=(4 * num_relations, 4))
for r_id in range(num_relations):
    theta = phases[r_id].numpy()  # angles
    x = np.cos(theta)
    y = np.sin(theta)
    ax = axes[r_id]
    circle = plt.Circle((0, 0), 1, color='lightgray', fill=False)
    ax.add_patch(circle)
    ax.scatter(x, y, s=50)
    for i, (xi, yi) in enumerate(zip(x, y)):
        ax.annotate(str(i), (xi, yi), fontsize=7)
    ax.set_xlim(-1.3, 1.3); ax.set_ylim(-1.3, 1.3)
    ax.set_aspect('equal')
    ax.set_title(f'Relation: {id2relation[r_id]}')

plt.suptitle('RotatE: Rotation angles per relation dimension')
plt.tight_layout()
plt.show()

---
## 5. ComplEx

### 5.1 Theory

**ComplEx** (Trouillon et al., 2016) uses **complex-valued** embeddings but with a different scoring function based on the **Hermitian inner product**.

All entities and relations are embedded in $\mathbb{C}^d$:
$$\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{C}^d$$

#### Scoring Function

$$f(h, r, t) = \text{Re}(\langle \mathbf{h}, \mathbf{r}, \bar{\mathbf{t}} \rangle) = \text{Re}\!\left(\sum_k h_k \cdot r_k \cdot \overline{t_k}\right)$$

where $\bar{\mathbf{t}}$ is the **complex conjugate** of $\mathbf{t}$.

#### Why Complex Numbers?

The Hermitian inner product is **asymmetric**: $\langle h, r, \bar{t} \rangle \neq \langle t, r, \bar{h} \rangle$ in general, which allows modelling **asymmetric** and **antisymmetric** relations.

#### Connection to Other Models

ComplEx can be written in real coordinates:

$$f(h, r, t) = \mathbf{h}_{re}^\top \text{diag}(\mathbf{r}_{re}) \mathbf{t}_{re}
+ \mathbf{h}_{re}^\top \text{diag}(\mathbf{r}_{im}) \mathbf{t}_{im}
+ \mathbf{h}_{im}^\top \text{diag}(\mathbf{r}_{re}) \mathbf{t}_{im}
- \mathbf{h}_{im}^\top \text{diag}(\mathbf{r}_{im}) \mathbf{t}_{re}$$

### 5.2 ComplEx Implementation

In [None]:
class ComplEx(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim, reg_lambda=1e-3):
        super().__init__()
        self.reg_lambda = reg_lambda

        # Real and imaginary parts
        self.ent_re  = nn.Embedding(num_entities,  embedding_dim)
        self.ent_im  = nn.Embedding(num_entities,  embedding_dim)
        self.rel_re  = nn.Embedding(num_relations, embedding_dim)
        self.rel_im  = nn.Embedding(num_relations, embedding_dim)

        for emb in [self.ent_re, self.ent_im, self.rel_re, self.rel_im]:
            nn.init.xavier_uniform_(emb.weight)

    def score(self, h_idx, r_idx, t_idx):
        h_re = self.ent_re(h_idx);  h_im = self.ent_im(h_idx)
        r_re = self.rel_re(r_idx);  r_im = self.rel_im(r_idx)
        t_re = self.ent_re(t_idx);  t_im = self.ent_im(t_idx)

        # Re(<h, r, conj(t)>)
        return (
            (h_re * r_re * t_re).sum(dim=1)
          + (h_re * r_im * t_im).sum(dim=1)
          + (h_im * r_re * t_im).sum(dim=1)
          - (h_im * r_im * t_re).sum(dim=1)
        )

    def regularisation(self, h_idx, r_idx, t_idx):
        # L3 regularisation on entity and relation embeddings
        reg = (
            self.ent_re(h_idx).norm(p=2, dim=1).pow(2).mean() +
            self.ent_im(h_idx).norm(p=2, dim=1).pow(2).mean() +
            self.rel_re(r_idx).norm(p=2, dim=1).pow(2).mean() +
            self.rel_im(r_idx).norm(p=2, dim=1).pow(2).mean() +
            self.ent_re(t_idx).norm(p=2, dim=1).pow(2).mean() +
            self.ent_im(t_idx).norm(p=2, dim=1).pow(2).mean()
        )
        return self.reg_lambda * reg

    def forward(self, triples, labels):
        """Binary cross-entropy loss with labels +1 / -1."""
        h_idx, r_idx, t_idx = triples[:, 0], triples[:, 1], triples[:, 2]
        scores = self.score(h_idx, r_idx, t_idx)
        # Sigmoid BCE
        loss = F.softplus(-labels * scores).mean()
        loss += self.regularisation(h_idx, r_idx, t_idx)
        return loss


# Combine positive (+1) and negative (-1) triples
neg_triples_init = corrupt_triples(triples_tensor, num_entities)
all_triples  = torch.cat([triples_tensor, neg_triples_init], dim=0)
all_labels   = torch.cat([
    torch.ones(len(triples_tensor)),
    -torch.ones(len(neg_triples_init))
])

model_complex = ComplEx(num_entities, num_relations, embedding_dim=10)
optimizer_c   = optim.Adam(model_complex.parameters(), lr=0.01)

for epoch in range(500):
    neg = corrupt_triples(triples_tensor, num_entities)
    all_t = torch.cat([triples_tensor, neg], dim=0)
    all_l = torch.cat([torch.ones(len(triples_tensor)), -torch.ones(len(neg))])
    optimizer_c.zero_grad()
    loss = model_complex(all_t, all_l)
    loss.backward()
    optimizer_c.step()

print(f'ComplEx final loss: {loss.item():.4f}')

---
## 6. LiteralE

### 6.1 Theory

**LiteralE** (Kristiansen et al., 2018) extends KG embeddings to incorporate **literal attributes** (numerical or textual) associated with entities.

#### Motivation

Standard KG embeddings only use structural information (triples). However, KGs like Wikidata contain literals such as:
```
(Einstein, dateOfBirth, "1879-03-14")
(Einstein, height, 1.75)
```

These literals provide valuable signals for link prediction.

#### Core Idea

LiteralE modifies entity embeddings by incorporating literal information via a **gate function** $g$:

$$\tilde{\mathbf{e}} = g(\mathbf{e}, \mathbf{l}_e)$$

where $\mathbf{l}_e \in \mathbb{R}^L$ is a vector of literals for entity $e$ and $g$ is a learned transformation (e.g., linear gate):

$$g(\mathbf{e}, \mathbf{l}_e) = \sigma\!\left(\mathbf{W}_1 \mathbf{e} + \mathbf{W}_2 \mathbf{l}_e + \mathbf{b}\right)$$

The enriched embedding $\tilde{\mathbf{e}}$ is then used in any base scoring function (e.g., DistMult, ComplEx).

#### LiteralE variants
- **LiteralE (DistMult)**: base model = DistMult
- **LiteralE (ComplEx)**: base model = ComplEx
- **Numerical gate** vs **Text gate** (for different literal types)

### 6.2 LiteralE Implementation

In [None]:
class LiteralE(nn.Module):
    """LiteralE with DistMult as base model."""

    def __init__(self, num_entities, num_relations, embedding_dim, num_literals):
        super().__init__()
        self.entity_emb  = nn.Embedding(num_entities,  embedding_dim)
        self.relation_emb = nn.Embedding(num_relations, embedding_dim)

        # Gate: transforms [entity_emb || literals] -> entity_dim
        self.gate = nn.Sequential(
            nn.Linear(embedding_dim + num_literals, embedding_dim),
            nn.Tanh()
        )

        nn.init.xavier_uniform_(self.entity_emb.weight)
        nn.init.xavier_uniform_(self.relation_emb.weight)

    def enrich(self, entity_idx, literals):
        """Enrich entity embedding with literal information."""
        e = self.entity_emb(entity_idx)       # [B, d]
        l = literals[entity_idx]              # [B, L]
        combined = torch.cat([e, l], dim=-1)  # [B, d+L]
        return self.gate(combined)             # [B, d]

    def score(self, h_idx, r_idx, t_idx, literals):
        """DistMult scoring with enriched embeddings."""
        h_tilde = self.enrich(h_idx, literals)
        t_tilde = self.enrich(t_idx, literals)
        r = self.relation_emb(r_idx)
        # DistMult: <h, r, t>
        return (h_tilde * r * t_tilde).sum(dim=1)

    def forward(self, triples, labels, literals):
        h_idx, r_idx, t_idx = triples[:, 0], triples[:, 1], triples[:, 2]
        scores = self.score(h_idx, r_idx, t_idx, literals)
        return F.softplus(-labels * scores).mean()


# Toy literals: 3 numerical attributes per entity (e.g., age, height, income)
torch.manual_seed(0)
literals = torch.randn(num_entities, 3)  # [num_entities, 3]

model_literale = LiteralE(num_entities, num_relations,
                           embedding_dim=10, num_literals=3)
optimizer_le   = optim.Adam(model_literale.parameters(), lr=0.01)

for epoch in range(500):
    neg = corrupt_triples(triples_tensor, num_entities)
    all_t = torch.cat([triples_tensor, neg], dim=0)
    all_l = torch.cat([torch.ones(len(triples_tensor)), -torch.ones(len(neg))])
    optimizer_le.zero_grad()
    loss = model_literale(all_t, all_l, literals)
    loss.backward()
    optimizer_le.step()

print(f'LiteralE final loss: {loss.item():.4f}')

---
## 7. Heterogeneous Graphs

### 7.1 Theory

A **Heterogeneous Graph** is a graph where nodes and/or edges can be of **multiple types**:

$$G = (V, E, \phi, \psi)$$

- $\phi: V \to \mathcal{A}$ — node type mapping ($|\mathcal{A}| > 1$)
- $\psi: E \to \mathcal{R}$ — edge (relation) type mapping ($|\mathcal{R}| > 1$)

**Example:** An academic network with:
- Node types: `Author`, `Paper`, `Venue`
- Edge types: `writes`, `cites`, `publishedIn`

#### Meta-path

A **meta-path** is a composite relation defined over a sequence of node and edge types:
$$\mathcal{P} = A_1 \xrightarrow{R_1} A_2 \xrightarrow{R_2} \cdots \xrightarrow{R_l} A_{l+1}$$

Example: `Author → writes → Paper → cites → Paper` captures co-citation patterns.

#### Why Heterogeneous GNNs?

Standard GNNs treat all nodes/edges equally. Heterogeneous GNNs maintain **type-specific transformations**:
$$\mathbf{h}_v^{(l+1)} = \text{AGG}\Bigl(\{\mathbf{W}_{\phi(v), \psi(e)} \mathbf{h}_u^{(l)} : u \in \mathcal{N}_{\psi}(v)\}\Bigr)$$

### 7.2 Heterogeneous Graph with PyG

In [None]:
if pyg_available:
    from torch_geometric.data import HeteroData

    # Build a small academic heterogeneous graph
    hdata = HeteroData()

    # Node types
    hdata['author'].x  = torch.randn(4, 8)   # 4 authors, 8 features
    hdata['paper'].x   = torch.randn(6, 16)  # 6 papers, 16 features
    hdata['venue'].x   = torch.randn(2, 4)   # 2 venues, 4 features

    # Edge types: (src_type, edge_type, dst_type)
    # author -writes-> paper
    hdata['author', 'writes', 'paper'].edge_index = torch.tensor([
        [0, 1, 2, 3, 0],
        [0, 1, 2, 3, 4]
    ])
    # paper -cites-> paper
    hdata['paper', 'cites', 'paper'].edge_index = torch.tensor([
        [0, 1, 2, 3],
        [1, 2, 3, 4]
    ])
    # paper -publishedIn-> venue
    hdata['paper', 'publishedIn', 'venue'].edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 5],
        [0, 0, 1, 1, 0, 1]
    ])

    print(hdata)
    print('Node types:', hdata.node_types)
    print('Edge types:', hdata.edge_types)

### 7.3 Heterogeneous GNN with `to_homogeneous` and `HeteroConv`

In [None]:
if pyg_available:
    from torch_geometric.nn import HeteroConv, SAGEConv as _SAGEConv

    class HeteroGNN(nn.Module):
        def __init__(self, hidden_channels, out_channels):
            super().__init__()
            # Project each node type to the same hidden dimension
            self.proj = nn.ModuleDict({
                'author': nn.Linear(8,  hidden_channels),
                'paper' : nn.Linear(16, hidden_channels),
                'venue' : nn.Linear(4,  hidden_channels),
            })

            # One SAGEConv per edge type
            self.conv = HeteroConv({
                ('author', 'writes',      'paper'): _SAGEConv(hidden_channels, hidden_channels),
                ('paper',  'cites',       'paper'): _SAGEConv(hidden_channels, hidden_channels),
                ('paper',  'publishedIn', 'venue'): _SAGEConv(hidden_channels, hidden_channels),
            }, aggr='sum')

            self.lin = nn.Linear(hidden_channels, out_channels)

        def forward(self, x_dict, edge_index_dict):
            # Project all node types
            x_dict = {ntype: F.relu(self.proj[ntype](x))
                      for ntype, x in x_dict.items()}
            # Heterogeneous convolution
            x_dict = self.conv(x_dict, edge_index_dict)
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
            return x_dict


    hetero_model = HeteroGNN(hidden_channels=32, out_channels=4)
    out = hetero_model(hdata.x_dict, hdata.edge_index_dict)
    for ntype, emb in out.items():
        print(f'{ntype}: {emb.shape}')

### 7.4 OGB-MAG Dataset (Large-Scale Heterogeneous Graph)

In [None]:
# Note: OGB datasets are large; this cell shows the API but does not download
print("""To load the OGB-MAG (Microsoft Academic Graph) dataset:

    from torch_geometric.datasets import OGB_MAG
    dataset = OGB_MAG(root='/tmp/mag', preprocess='metapath2vec')
    data = dataset[0]
    print(data)  # HeteroData with 4 node types and 4 edge types

This heterogeneous graph has:
  - 4 node types: paper, author, institution, field_of_study
  - 4 edge types: cites, writes, affiliated_with, has_topic
  - Task: node classification on papers into 349 classes
""")

---
## 8. Exercises

### Exercise 1 — Link Prediction Evaluation

Implement a **ranking-based evaluation** function:
1. For each test triple $(h, r, t)$, compute scores for all possible tails: $\{(h, r, t') : t' \in \mathcal{E}\}$.
2. Rank the correct tail $t$ among all candidates.
3. Compute **MRR** and **Hits@1**, **Hits@3**, **Hits@10** for TransE, RotatE, and ComplEx on the toy KG.

In [None]:
def evaluate_kg(model, triples, num_entities, model_type='transe'):
    """Evaluate KG model via MRR and Hits@k."""
    model.eval()
    ranks = []
    with torch.no_grad():
        for triple in triples:
            h, r, t = triple
            # TODO: score all possible tails, rank the correct one
            ...
    # TODO: compute MRR and Hits@k
    ...

### Exercise 2 — DistMult

Implement **DistMult** (Yang et al., 2015):

$$f(h, r, t) = \mathbf{h}^\top \text{diag}(\mathbf{r}) \mathbf{t} = \sum_k h_k r_k t_k$$

Note: DistMult is a special case of ComplEx where all embeddings are real-valued. Compare its link prediction performance with TransE on the toy KG.

In [None]:
# Exercise 2 — your solution here
class DistMult(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super().__init__()
        # TODO
        ...

    def score(self, h_idx, r_idx, t_idx):
        # TODO: <h, r, t>
        ...

    def forward(self, triples, labels):
        ...

### Exercise 3 — Relation Pattern Analysis

Using the RotatE model trained above:
1. Find the angle $\theta$ that best approximates the `knows` relation (should be close to 0 if symmetric).
2. Check whether the relation `livesIn` has a different angle profile.
3. Modify the training loop to explicitly include a **symmetric** triple pair: $(Alice, knows, Bob)$ and $(Bob, knows, Alice)$. Retrain and compare the angles.

In [None]:
# Exercise 3 — your solution here
...

### Exercise 4 — LiteralE with Text Features

Extend the `LiteralE` implementation to support **text literals** by:
1. Adding a `TextEncoder` module (e.g., a simple bag-of-words encoder or pre-computed TF-IDF vectors).
2. Concatenating numerical and text literal vectors before the gate.
3. Testing on the toy KG with added entity descriptions.

In [None]:
# Exercise 4 — your solution here
# Entity descriptions (simplified bag-of-words vectors)
descriptions = {
    'Alice': 'software engineer based in london',
    'Bob':   'data scientist living in paris',
    'Carol': 'researcher at university',
    'London': 'capital city of the united kingdom',
    'Paris':  'capital city of france',
    'UK':     'country in northwestern europe',
}
# TODO: encode descriptions and extend LiteralE
...

### Exercise 5 — Heterogeneous Link Prediction

Using the academic heterogeneous graph:
1. Add a **reverse edge type** for each edge type (e.g., `paper -written_by-> author`).
2. Build a 2-layer `HeteroConv` model.
3. Add a link prediction head: given a pair of `author` nodes, predict whether they collaborated (both wrote the same paper).
4. Train with binary cross-entropy and report AUC.

In [None]:
# Exercise 5 — your solution here
...

---
## Summary

| Model | Space | Scoring Function | Key Strength |
|-------|-------|------------------|--------------|
| **TransE** | $\mathbb{R}^d$ | $-\|\mathbf{h}+\mathbf{r}-\mathbf{t}\|$ | Simple, effective for 1-to-1 relations |
| **TransR** | $\mathbb{R}^d, \mathbb{R}^k$ | $-\|\mathbf{h}_r+\mathbf{r}-\mathbf{t}_r\|$ | Handles 1-to-N via projection |
| **RotatE** | $\mathbb{C}^d$ | $-\|\mathbf{h}\circ\mathbf{r}-\mathbf{t}\|$ | Captures symmetry, antisymmetry, inversion, composition |
| **ComplEx** | $\mathbb{C}^d$ | $\text{Re}(\langle\mathbf{h},\mathbf{r},\bar{\mathbf{t}}\rangle)$ | Asymmetric relations, strong at large-scale |
| **LiteralE** | $\mathbb{R}^d$ | Any + gate | Incorporates entity attributes |
| **Hetero Graphs** | Multi-type | Type-specific transforms | Multiple node/edge types |

**Next notebook →** `04_advanced_gnns.ipynb` — Graph Transformer, Heterogeneous Graph Transformer, and R-GCN.