<a href="https://colab.research.google.com/github/dp457/Graph-Neural-Network/blob/main/Temporal_Graph_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Temporal Graph Network (TGN) link-prediction example on JODIE–Wikipedia dataset. It implements,
1. Learned node memory
2. Message/aggregation pipeline.
3. Temporal graph attention style embedding
4. MLP decoder for link probability.

1. **High level** --> model keeps a per-node memory $s_{i} (t)$ that gets updated only when node $i$ participates in an event.
2. **Temporal GNN embedding module**, aggregates the most-recent neighbors, using attention with **time-encodings** and **edge-messages**.
3. **Decoder** predicts when a source-destination interaction occurs.

Encoder–decoder view of dynamic graphs described by TGN, with specific choices -> identity message function, last message aggregation and GRU-style memory and **one-layer temporal attention embedding**.



In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


# 1. Data and Task

1. **Temporal events:** Chronologically ordered list of interactions $e_{ij} (t)$ (user $i$ edits page $j$ at time $t$) possibly with edge features/messages $x_{ij} (t)$. Dynamic graphs are formalized as a sequence of stamped events $G = \{ x(t_1), x(t_2), \cdots \}.$

2. **Tasks:** Given past events up to some time, predict whether the next batch’s edges occur. TGN uses this task as the canonical self-supervised training objective.

The loaders *TemporalDataLoader* and *LastNeighborLoader* maintain temporal order and sample the most recent-neighbors (size 10) and create negatives at 1:1 ratio. Most-recent sampling is a recommended TGN choice: it outperforms uniform neighbor sampling on dynamic data.

In [2]:
import os.path as osp

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join('data', 'JODIE')
dataset = JODIEDataset(path, name='wikipedia')
data = dataset[0]

# For small datasets, we can put the whole dataset on GPU and thus avoid
# expensive memory transfer costs for mini-batches:
data = data.to(device)

train_data, val_data, test_data = data.train_val_test_split(
    val_ratio=0.15, test_ratio=0.15)

train_loader = TemporalDataLoader(
    train_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)
val_loader = TemporalDataLoader(
    val_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)
test_loader = TemporalDataLoader(
    test_data,
    batch_size=200,
    neg_sampling_ratio=1.0,
)
neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)

Downloading http://snap.stanford.edu/jodie/wikipedia.csv
Processing...
Done!


#2. TGN Core Modules

## 2.1 Per-node memory $s_i (t)$

TGN maintains a vector state $s_i (t)$ for each node $i$, updated after processing an event,

\begin{equation}
 s_{i} (t) = \text{mem} (\tilde{m}_i (t), s_i (t^-))
 \end{equation}

Here $t^-$ s the previous time $i$ was involved in an event, and $\tilde{m}_i (t)$ is the aggregate of batch's messages for $i$.


**Message functions** - For an interaction $e_{ij} (t)$ TGN forms two messages,
$m_{i} (t) = \text{msg}_s (s_i (t^-), s_j (t^-), \Delta t, e_{ij}(t))$,
$m_{j} (t) = \text{msg}_d (s_j (t^-), s_i (t^-), \Delta t, e_{ij}(t))$

Here, the $\text{msg}$ is the identity.

## 2.2 Embedding module: temporal attention with time encoding,

TGN computes a temporal embedding $z_i (t)$ by aggregating the temporal $k-$hop neighborhood using a graph attention layer with time encoder $\phi(\cdot)$:

\begin{align}
h_i^{(\ell)}(t) &= \text{MLP}^{(\ell)}\!\left(h_i^{(\ell-1)}(t) \; \| \; \tilde{h}_i^{(\ell)}(t)\right), \\
\tilde{h}_i^{(\ell)} &= \text{MultiHeadAttn}^{(\ell)} \!\left(q^{(\ell)}(t), K^{(\ell)}(t), V^{(\ell)}(t)\right), \\
q^{(\ell)}(t) &= h_i^{(\ell-1)}(t) \; \| \; \phi(0), \\
K^{(\ell)}(t) = V^{(\ell)}(t) &= \left[h_j^{(\ell-1)}(t) \; \| \; e_{ij} \; \| \; \phi(t - t_j)\right]_{j \in \mathcal{N}_i(\{0,t\})}, \\
z_i(t) &= h_i^{(L)}(t)
\end{align}

This is the TGN-attention embedding, (a TGAT-style temporal graph attention driven by node memories and time encodings).

Embedding via temporal attention

*TransformerConv:* Multi-head graph attention operator
*   For each node $i$, it attends over neighbours $j$ with queries from $s_i (t)$ keys and values from $[s_j (t) || e_{ij} (t) ]$
*   The attention weight is computed as,
\begin{equation}
\alpha_{ij} = \frac{\exp\left( \langle Q_i, K_{ij} \rangle \right)}{\sum_{k \in \mathcal{N}(i)} \exp\left( \langle Q_i, K_{ik} \rangle \right)},
\end{equation}
where, $Q_i = W_q s_i (t)$, $K_{ij} = W_k [s_j (t) || e_{ij} (t)]$
*  Then time-aware temporal embeddings $z_{i} (t)$ is given as,
\begin{equation}
z_i(t) = \sum_{j \in \mathcal{N}(i)} \alpha_{ij} W_v \big[ s_j(t) \,\|\, e_{ij}(t) \big].
\end{equation}

gnn object transforms node memories into temporal embeddings that capture:
*   memory of past interactions
*   recent-neighbour influence via attention
*   edge-level context (message + time difference)




In [3]:
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim) # Realizes multi-head attention over sampled neighbourhood

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t   # relative time
        rel_t_enc = self.time_enc(rel_t.to(x.dtype)) # relative time encoding
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr) #temporal emeddings


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)


memory_dim = time_dim = embedding_dim = 100


# For each node i, TGN keeps a memory vector, Memory updated by aggregating the most
# message m_i (t) (via  GRU like update)
# s_i, s_j , time encoding of delay since last update, x_ij (t)

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)



gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,  # input - memory vector size
    out_channels=embedding_dim, # output embedding dimension
    msg_dim=data.msg.size(-1), # dimension of edge message
    time_enc=memory.time_enc, # same Time2Vec module used by memory
).to(device)


The Linkpredictor computes the score
\begin{equation}
\text{score} (i,j) = w^T \sigma (W_s z_i (t) + W_d z_j (t))
\end{equation}

The training done with balanced BCE on positives/negatives i.e

$\mathcal{L} = \text{BCE} ( \text{score(i,j)}, 1) + \text{BCE} (\text{score} (i, \tilde{j}),0)$

This is exactly the TGN link-prediction objective (compute embeddings, decode edge probability, BCE).

The key subtelty in TGN -> when the interaction update the memory. If the memory is updated with current batch's interaction before predicting the same interaction , the model would peek into the future.

**Problem: Temporal Leakage**
Imagine predicting whether the edge $(i,j,t)$ exists at time $t$. Each node has the memory vector $s_i (t)$.

*   If $s_i (t)$ is updated with information of $(i,j,t)$ before trying to predict the edge, the model has "peeked" the answer which is known as **data leakage**.
*   The prediction would be unrealistic, because in a real-world setting you only have access to past events, not the current one you are trying to predict.

**Solution: Raw Message Store**

1. Store Messages but do not update immediately.

When a new interaction $(i,j,t)$ arrives, create a raw message for node $i$ and $j$ (containing memories, edge features, and time differences). But do not update their memories, just keep it aside.

2. Predict using the old memory.
3. After prediction update the memory.



In [4]:
link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


In [5]:
def train():
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        # Get updated memory of all nodes involved in the computation.
        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
                data.msg[e_id].to(device))
        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # Update memory and neighbor loader with ground-truth state.
        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        loss.backward()
        optimizer.step()
        memory.detach()
        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    link_pred.eval()

    torch.manual_seed(12345)  # Ensure deterministic sampling across epochs.

    aps, aucs = [], []
    for batch in loader:
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id].to(device),
                data.msg[e_id].to(device))
        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()
        y_true = torch.cat(
            [torch.ones(pos_out.size(0)),
             torch.zeros(neg_out.size(0))], dim=0)

        aps.append(average_precision_score(y_true, y_pred))
        aucs.append(roc_auc_score(y_true, y_pred))

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)
    return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())


for epoch in range(1, 51):
    loss = train()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
    val_ap, val_auc = test(val_loader)
    test_ap, test_auc = test(test_loader)
    print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')
    print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  total_loss += float(loss) * batch.num_events


Epoch: 01, Loss: 1.1204
Val AP: 0.8507, Val AUC: 0.8676
Test AP: 0.8165, Test AUC: 0.8394
Epoch: 02, Loss: 0.8843
Val AP: 0.9203, Val AUC: 0.9193
Test AP: 0.9158, Test AUC: 0.9131
Epoch: 03, Loss: 0.7189
Val AP: 0.9464, Val AUC: 0.9407
Test AP: 0.9406, Test AUC: 0.9352
Epoch: 04, Loss: 0.6698
Val AP: 0.9499, Val AUC: 0.9440
Test AP: 0.9443, Test AUC: 0.9384
Epoch: 05, Loss: 0.6475
Val AP: 0.9531, Val AUC: 0.9483
Test AP: 0.9504, Test AUC: 0.9453
Epoch: 06, Loss: 0.6359
Val AP: 0.9524, Val AUC: 0.9485
Test AP: 0.9432, Test AUC: 0.9386
Epoch: 07, Loss: 0.6218
Val AP: 0.9553, Val AUC: 0.9508
Test AP: 0.9480, Test AUC: 0.9428
Epoch: 08, Loss: 0.6077
Val AP: 0.9590, Val AUC: 0.9547
Test AP: 0.9520, Test AUC: 0.9478
Epoch: 09, Loss: 0.5997
Val AP: 0.9564, Val AUC: 0.9522
Test AP: 0.9508, Test AUC: 0.9466
Epoch: 10, Loss: 0.5820
Val AP: 0.9587, Val AUC: 0.9546
Test AP: 0.9516, Test AUC: 0.9472
Epoch: 11, Loss: 0.5740
Val AP: 0.9602, Val AUC: 0.9564
Test AP: 0.9537, Test AUC: 0.9497
Epoch: 12,

Above operation is summarized as follows:

1. Neighbourhood extraction

Let $(\mathbf{n}, E, \mathbf{e})$ be the neighbourhood loader. It sampes the 10 most-recent edges per node. For each $u \rightarrow v$ and the id event $e$ we have the timestamp $t_e$ and the message $x_e$.

2. Memory read and temporal attention

Lets read the memories as $S = \{ s_{i} (t) \}_{i \in \mathbf{n}}$ and last updates $\tilde{t}_i$.The edge features are built as

\begin{equation}
a_{uv} = [ \phi(\tilde{t}_u - t_e) || x_e]
\end{equation}

The one-layer transformer style attention to this subgraph is computed as,

\begin{equation}
z_{i} (t) = \text{TransformerConv} (s_{i} (t), (u \rightarrow v), a_{uv})
\end{equation}
This is an instance of TGN’s temporal attention with time encodings and edge features in the keys/values.

3. Decoding and Loss

For each positive $(i,j)$ and a sampled negative $(i, \tilde{j})$

\begin{equation}
\hat{y}_{ij} = \sigma (w^T \sigma (W_s z_i + W_d z_j)), \hat{y}_{i\tilde{j}} = \sigma (w^T \sigma (W_s z_i + W_d z_j))
\end{equation}

and $\mathcal{L} = \text{BCE}(\hat{y}_{ij}, 1)+ \text{BCE}(\hat{y}_{i\tilde{j}}, 0)$

4. Memory update

Store the raw messages for the batch interactions

$m^{\text{raw}}_i = (s_i (t), s_j (t), t, x_{ij}(t))$

Design Choices

1. One attention layer + memory both accurate and fast.
2. Last message aggregation is efficient and competitive; mean aggregation can be slightly better but is much slower.
3. Most-recent neighbor sampling outperforms uniform sampling on dynamic graphs.
