# Lecture 1: Heterogeneous Graphs & Typed Message Passing (PyTorch Geometric)

This notebook is designed for a graduate-level Graph ML course.

**Learning goals**

By the end of this notebook, you should be able to:

- Model real systems as heterogeneous graphs from a network science perspective.
- Understand the formal definition of a heterogeneous graph.
- Implement a toy heterogeneous academic graph using `torch_geometric.data.HeteroData`.
- Train a simple heterogeneous GNN for paper node classification using `to_hetero`.


## 1. Heterogeneous Graphs: Intuition & Definition

In many network science examples, we model a system as a *homogeneous* graph:

- One node type (e.g., "person")
- One edge type (e.g., "friendship")

But real systems are often richer:

- Academic network: **authors – papers – institutions – fields**
- E-commerce: **users – items – categories**
- Biology: **genes – proteins – diseases – drugs**

A **heterogeneous graph** (a.k.a. *heterogeneous information network*) is:

\begin{equation}
G = (V, E, \phi, \psi)
\end{equation}

- \(V\): set of nodes
- \(E\): set of edges
- \(\phi: V \rightarrow \mathcal{T}_V\): node type mapping
- \(\psi: E \rightarrow \mathcal{T}_E\): edge (relation) type mapping

Here, \(\mathcal{T}_V\) is the set of node types and \(\mathcal{T}_E\) is the set of edge types.

You can think of each relation type as a different **layer** in a multilayer network.

### Typed Message Passing

In a standard (homogeneous) GNN layer we have:

\begin{equation}
\mathbf{h}_i^{(l+1)} = \sigma\left( \sum_{j \in \mathcal{N}(i)} \frac{1}{c_{ij}} W^{(l)} \mathbf{h}_j^{(l)} \right)
\end{equation}

For a **heterogeneous** graph, neighbors come via *different* relation types. For each relation \(r \in \mathcal{T}_E\) we use a relation-specific weight matrix \(W_r^{(l)}\):

\begin{equation}
\mathbf{h}_i^{(l+1)} = \sigma\left(
    \sum_{r \in \mathcal{T}_E}
    \sum_{j \in \mathcal{N}_r(i)}
    \frac{1}{c_{ijr}} W_r^{(l)} \mathbf{h}_j^{(l)}
\right).
\end{equation}

- \(\mathcal{N}_r(i)\): neighbors of node \(i\) through relation \(r\)
- Each relation is a "channel" of information.


## 2. Setup

Install and import dependencies. In a fresh environment you may need to install `torch` and `torch-geometric` following the official instructions.

> **Note:** The installation commands are commented out; uncomment and adapt them as needed in your own environment.


In [6]:
# !pip install torch torch_geometric -q

import torch
from torch_geometric.data import HeteroData
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, to_hetero
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

## 3. Build a Toy Academic Heterogeneous Graph

We will build a small academic-style network with three node types:

- `author`
- `paper`
- `institution`

Edges will include:

- `author` —**writes**→ `paper`
- `paper` —**written_by**→ `author` (reverse)
- `author` —**affiliated_with**→ `institution`
- `institution` —**has_member**→ `author` (reverse)
- `paper` —**cites**→ `paper`

In [7]:
# Create a HeteroData object
data = HeteroData()

# Define sizes
num_authors = 100
num_papers = 200
num_institutions = 10

# Node features (16-dim for each type)
data['author'].x = torch.randn(num_authors, 16)
data['paper'].x = torch.randn(num_papers, 16)
data['institution'].x = torch.randn(num_institutions, 16)

# Edges: author -> paper (writes)
num_auth_paper_edges = 400
author_ids = torch.randint(0, num_authors, (num_auth_paper_edges,))
paper_ids  = torch.randint(0, num_papers, (num_auth_paper_edges,))
data[('author', 'writes', 'paper')].edge_index = torch.stack([author_ids, paper_ids], dim=0)

# Reverse edges: paper -> author (written_by)
data[('paper', 'written_by', 'author')].edge_index = torch.stack([paper_ids, author_ids], dim=0)

# Edges: author -> institution (affiliated_with)
num_auth_inst_edges = num_authors  # one institution per author for simplicity
author_ids2 = torch.arange(num_authors)
inst_ids = torch.randint(0, num_institutions, (num_authors,))
data[('author', 'affiliated_with', 'institution')].edge_index = torch.stack([author_ids2, inst_ids], dim=0)

# Reverse edges: institution -> author (has_member)
data[('institution', 'has_member', 'author')].edge_index = torch.stack([inst_ids, author_ids2], dim=0)

# Edges: paper -> paper (cites)
num_citations = 500
src_papers = torch.randint(0, num_papers, (num_citations,))
dst_papers = torch.randint(0, num_papers, (num_citations,))
data[('paper', 'cites', 'paper')].edge_index = torch.stack([src_papers, dst_papers], dim=0)

data

HeteroData(
  author={ x=[100, 16] },
  paper={ x=[200, 16] },
  institution={ x=[10, 16] },
  (author, writes, paper)={ edge_index=[2, 400] },
  (paper, written_by, author)={ edge_index=[2, 400] },
  (author, affiliated_with, institution)={ edge_index=[2, 100] },
  (institution, has_member, author)={ edge_index=[2, 100] },
  (paper, cites, paper)={ edge_index=[2, 500] }
)

### Create a Paper Classification Task

We assign each paper a label from 0 to 3 (simulating 4 research fields) and create train/val/test masks. 

In [8]:
num_classes = 4
num_papers = data['paper'].x.size(0)

# Random labels for papers
data['paper'].y = torch.randint(0, num_classes, (num_papers,))

# Train/val/test masks for papers
train_ratio, val_ratio = 0.6, 0.2
perm = torch.randperm(num_papers)
train_end = int(train_ratio * num_papers)
val_end = int((train_ratio + val_ratio) * num_papers)

train_mask = torch.zeros(num_papers, dtype=torch.bool)
val_mask = torch.zeros(num_papers, dtype=torch.bool)
test_mask = torch.zeros(num_papers, dtype=torch.bool)

train_mask[perm[:train_end]] = True
val_mask[perm[train_end:val_end]] = True
test_mask[perm[val_end:]] = True

data['paper'].train_mask = train_mask
data['paper'].val_mask = val_mask
data['paper'].test_mask = test_mask

data

HeteroData(
  author={ x=[100, 16] },
  paper={
    x=[200, 16],
    y=[200],
    train_mask=[200],
    val_mask=[200],
    test_mask=[200],
  },
  institution={ x=[10, 16] },
  (author, writes, paper)={ edge_index=[2, 400] },
  (paper, written_by, author)={ edge_index=[2, 400] },
  (author, affiliated_with, institution)={ edge_index=[2, 100] },
  (institution, has_member, author)={ edge_index=[2, 100] },
  (paper, cites, paper)={ edge_index=[2, 500] }
)

## 4. Define a Base GNN and Convert with `to_hetero`

We first define a standard GraphSAGE-style GNN that works on a homogeneous graph. Then we use `to_hetero` so PyG automatically creates relation-specific copies for each edge type.

Mathematically, this corresponds to:

\begin{equation}
\mathbf{h}_i^{(l+1)} = \sigma\left(
    \sum_{r \in \mathcal{T}_E}
    \sum_{j \in \mathcal{N}_r(i)}
    \frac{1}{c_{ijr}} W_r^{(l)} \mathbf{h}_j^{(l)}
\right).
\end{equation}

In [9]:
class BaseGNN(nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        # (-1, -1) lets PyG infer input dims for each node type
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), hidden_channels)
        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        # x: node features (for a single meta-type)
        # edge_index: adjacency
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        return self.lin(x)


metadata = data.metadata()  # (node_types, edge_types)
hidden_channels = 32
model = BaseGNN(hidden_channels, num_classes)

# Convert to heterogeneous model
hetero_model = to_hetero(model, metadata, aggr='sum')
hetero_model = hetero_model.to(device)
data = data.to(device)

hetero_model

AttributeError: module 'torch.fx._symbolic_trace' has no attribute 'List'

## 5. Train the Heterogeneous GNN for Paper Classification

We now train the model using cross-entropy loss on the `paper` node type.

In [None]:
optimizer = optim.Adam(hetero_model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

def train():
    hetero_model.train()
    optimizer.zero_grad()
    out_dict = hetero_model(data.x_dict, data.edge_index_dict)
    out = out_dict['paper']  # logits for paper nodes

    loss = criterion(out[data['paper'].train_mask],
                     data['paper'].y[data['paper'].train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(split='val'):
    hetero_model.eval()
    out_dict = hetero_model(data.x_dict, data.edge_index_dict)
    out = out_dict['paper'].softmax(dim=-1)
    y_true = data['paper'].y

    if split == 'train':
        mask = data['paper'].train_mask
    elif split == 'val':
        mask = data['paper'].val_mask
    else:
        mask = data['paper'].test_mask

    pred = out[mask].argmax(dim=-1)
    correct = (pred == y_true[mask]).sum().item()
    total = mask.sum().item()
    return correct / total if total > 0 else 0.0

for epoch in range(1, 101):
    loss = train()
    if epoch % 10 == 0:
        train_acc = evaluate('train')
        val_acc = evaluate('val')
        print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")

## 6. Discussion & Exercises

**Questions / prompts:**

1. Network science view: Interpret each edge type as a *layer* of a multiplex network. How is the GNN aggregating across layers?
2. What happens if you drop the `cites` relation? Does validation accuracy change?
3. Try modifying the model to use `aggr='mean'` instead of `sum`. Does it matter on this toy data?

**Extensions:**

- Add a new node type `field` and connect papers to fields (topics).
- Visualize the degree distributions for each node type and compare with classical measures like degree centrality.
