# Temporal Graph Networks (TGN)

This notebook focuses on the implementation and exploration of Temporal Graph Networks (TGN), a model for learning on dynamic graphs where the structure and features change over time.

Code from : https://github.com/twitter-research/tgn 

## 1. Setup

Ensure that PyTorch, PyTorch Geometric, and PyTorch Geometric Temporal are installed. If you've run the `PyTorchGeometicTemporal.ipynb` notebook, these should already be available in your environment.

## 2. Import Libraries

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.data import TemporalData
import torch_geometric_temporal # To ensure it's accessible, version check optional here

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

## 3. TGN Model Implementation

The TGN model consists of several key components:
1. **Memory**: Stores an up-to-date representation of each node in the graph.
2. **Message Function**: Computes messages from node interactions.
3. **Message Aggregator**: Aggregates messages for a node.
4. **Memory Updater**: Updates the node's memory based on aggregated messages.
5. **Embedding Module**: Generates temporal embeddings for nodes, used for downstream tasks.

In [None]:
class TGNModel(nn.Module):
    def __init__(self, num_nodes, raw_msg_dim, memory_dim, time_dim, embedding_dim):
        super().__init__()
        self.num_nodes = num_nodes
        self.raw_msg_dim = raw_msg_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim

        # TGN Memory module
        self.memory = TGNMemory(
            num_nodes=self.num_nodes,
            raw_msg_dim=self.raw_msg_dim, # Dimension of raw messages (e.g., edge features)
            memory_dim=self.memory_dim,   # Dimension of node memory
            time_dim=self.time_dim,       # Dimension of time encoding
            message_module=nn.Identity(), # Simple message function (can be more complex)
            aggregator_module=nn.LSTM(input_size=self.memory_dim + self.time_dim + self.raw_msg_dim, # Example input size
                                      hidden_size=self.memory_dim) # Example aggregator
        )

        # Graph attention layer for embeddings (example: TransformerConv)
        # The input to this layer will be the node memory (or a projection of it)
        self.gnn_conv = TransformerConv(in_channels=self.memory_dim, 
                                        out_channels=self.embedding_dim, 
                                        heads=2, 
                                        dropout=0.1)

        # Link predictor (example for link prediction task)
        self.link_pred = nn.Linear(self.embedding_dim * 2, 1)

    def forward(self, n_id, t, msg, src, dst, edge_index=None):
        # n_id: node ids involved in current batch/snapshot
        # t: timestamps of events
        # msg: raw messages (e.g., edge features)
        # src, dst: source and destination nodes of events
        # edge_index: if you have a static graph structure for the GNN part, otherwise derive from src/dst

        # 1. Update/Query Memory
        # Process recent events to update memory (for training)
        # For inference/embedding generation, you might just query the memory
        # The TGNMemory module handles this based on its internal state and input
        # This is a simplified view; actual TGNMemory usage might be more involved
        # For example, memory.update_state(src, dst, t, msg) might be called during training loops
        # And memory.get_memory(n_id) to retrieve memory for embedding generation.

        # Let's assume we get the latest memory for nodes in n_id
        # This is a placeholder for how you'd interact with memory. The TGNMemory API is more nuanced.
        # Typically, you'd call `memory.update_state` for new events and `memory.get_memory` for node states.
        # For simplicity in this forward, let's assume `memory.memory` gives current node memories.
        # This is NOT the direct API usage for TGNMemory, but for a conceptual model structure.
        node_memory = self.memory.memory[n_id] # Placeholder access

        # 2. Generate Embeddings using GNN
        # If edge_index is not provided, it might need to be constructed from src, dst for the current batch
        # This depends on whether the GNN operates on the full graph or a batch-specific subgraph
        if edge_index is None:
            # Create a simple edge_index for the batch if needed for the GNN layer
            # This is a simplification. TGN often uses temporal sampling for GNN input.
            # Map src, dst to 0...N-1 for the batch if they are global IDs
            unique_nodes, batch_n_id = torch.unique(torch.cat([src, dst]), return_inverse=True)
            batch_src, batch_dst = batch_n_id[:len(src)], batch_n_id[len(src):]
            edge_index = torch.stack([batch_src, batch_dst], dim=0)
            # And node_memory would need to correspond to these unique_nodes
            # node_memory = self.memory.memory[unique_nodes] # More accurate placeholder

        # The GNN conv expects node features and edge_index
        # Here, node_memory serves as input features to the GNN
        x = self.gnn_conv(node_memory, edge_index) # x will be node embeddings

        # 3. Example: Link Prediction (if this is the task)
        # This requires embeddings for source and destination nodes of potential links
        # For the given src, dst, we need to map them to the GNN output `x`
        # This part is highly dependent on how `x` (embeddings) aligns with `src` and `dst` (global IDs)
        # Assuming `x` corresponds to `n_id` if `edge_index` was for the batch using `n_id` directly.
        # Or if `x` corresponds to `unique_nodes` from the batch construction.

        # For simplicity, let's assume x contains embeddings for all nodes in n_id
        # and src/dst are indices relative to n_id or can be mapped.
        # This is a conceptual step for link prediction:
        # src_emb = x[src_indices_in_x]
        # dst_emb = x[dst_indices_in_x]
        # link_emb = torch.cat([src_emb, dst_emb], dim=1)
        # pred = self.link_pred(link_emb)
        # return pred, node_memory # Or just embeddings if that's the output

        return x, node_memory # Return embeddings and memory (or just embeddings)

    def reset_memory(self):
        self.memory.reset_state() # Reset memory state (e.g., at the start of an epoch)

    def detach_memory(self):
        self.memory.detach() # Detach memory from computation graph (e.g., for BPTT)

print("TGNModel class defined.")

### Example Usage (Conceptual)

In [None]:
# Parameters (example values)
num_nodes = 100        # Total number of nodes in the graph
raw_msg_dim = 16       # Dimension of raw edge features (messages)
memory_dim = 32        # Dimension of the node memory
time_dim = 8           # Dimension of the time encoding fed to memory
embedding_dim = 64     # Dimension of the final node embeddings

# Instantiate the model
tgn_model = TGNModel(num_nodes, raw_msg_dim, memory_dim, time_dim, embedding_dim)
print("TGNModel instantiated.")

# --- Conceptual Data for one batch/step ---
# This data would typically come from a DataLoader handling TemporalData objects
batch_size = 32 # Number of events in the batch

# Node IDs involved in the current events (global IDs)
src_nodes = torch.randint(0, num_nodes, (batch_size,))
dst_nodes = torch.randint(0, num_nodes, (batch_size,))
n_ids_batch = torch.cat([src_nodes, dst_nodes]).unique() # Unique nodes in this batch

# Timestamps of events
event_times = torch.rand(batch_size).sort().values * 100 # Sorted timestamps

# Raw messages (edge features)
edge_features = torch.randn(batch_size, raw_msg_dim)

# --- Interacting with the TGNMemory (this happens typically during the training loop) ---
# 1. Update memory with new events (for training). This modifies memory internal state.
# `tgn_model.memory.update_state(src_nodes, dst_nodes, event_times, edge_features)`
# Note: The `update_state` method of TGNMemory expects specific arguments and structure.
# It typically involves raw messages, time encodings, etc.
# The `message_module` and `aggregator_module` defined in TGNMemory are used here.

# 2. Get current memory for nodes (e.g., for GNN input or final embeddings)
# current_node_memories = tgn_model.memory.get_memory(n_ids_batch)
# Or, if you want all memories: `tgn_model.memory.memory` (but use `get_memory` for specific nodes)

# --- Forward pass (simplified for this example) ---
# The forward pass in the TGNModel class is a bit conceptual.
# A more complete TGN pipeline would involve:
#   a. Processing a batch of events: update memory, compute messages.
#   b. For nodes involved in these events (or a sampled neighborhood), generate embeddings using the GNN.
#   c. Use these embeddings for a downstream task (e.g., link prediction, node classification).

print("Conceptual data prepared. Note: Actual TGN training involves careful handling of memory updates and batching.")

# To actually run the forward pass as defined (which is conceptual):
# output_embeddings, last_memory_state = tgn_model(n_ids_batch, event_times, edge_features, src_nodes, dst_nodes)
# print(f"Output embedding shape (conceptual): {output_embeddings.shape}")
# print(f"Last memory state shape (conceptual): {last_memory_state.shape}")

# Reset memory (e.g., at the start of an epoch)
# tgn_model.reset_memory()
# print("Memory reset.")