# PyTorch Geometric Temporal for Dynamic Network Adaptation

This notebook explores the implementation of PyTorch Geometric Temporal, focusing on its application in the research project: *Implementing Decay-Based Temporal Attention for Dynamic Network Adaptation*.

Implementation based on: https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs

## 1. Setup and Installation

In [None]:
# Installation of required libraries
# Ensure you have PyTorch installed. If not, uncomment and run the appropriate command from https://pytorch.org/get-started/locally/ and for CUDA enabled GPU support use cuda12.8
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

# Install PyTorch Geometric and PyTorch Geometric Temporal
!pip install torch-geometric
!pip install torch-geometric-temporal

## 2. Import Libraries

In [None]:
# import torch
# import torch_geometric
# import torch_geometric_temporal

# print(f"PyTorch version: {torch.__version__}")
# print(f"PyTorch Geometric version: {torch_geometric.__version__}")
# print(f"PyTorch Geometric Temporal version: {torch_geometric_temporal.__version__}")

PyTorch version: 2.7.0+cu128
PyTorch Geometric version: 2.6.1
PyTorch Geometric Temporal version: 0.54.0


## 3. Baselining the TGAT Model

This section will focus on building and utilizing models related to TGAT implementation, particularly those that can be adapted or extended for decay-based temporal attention mechanisms in dynamic networks.

### 3.1. Baseline Implementation: TGAT

We'll start by implementing the Temporal Graph Attention Network (TGAT) model from the [Inductive representation learning on temporal graphs paper](https://arxiv.org/abs/2002.07962). This implementation follows the original source code available on [GitHub](https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs). We will then extend this model to incorporate a decay mechanism.

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# Time Encoding Module
class TimeEncoder(nn.Module):
    def __init__(self, dimension):
        super(TimeEncoder, self).__init__()
        self.dimension = dimension
        self.w = nn.Linear(1, dimension)
        
        # Initialize with non-trainable fixed frequencies
        self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))).float().reshape(dimension, 1))
        self.w.bias = nn.Parameter(torch.zeros(dimension))
        
    def forward(self, t):
        # t has shape [batch_size, 1] or [batch_size]
        if t.dim() == 1:
            t = t.unsqueeze(1)  # [batch_size] -> [batch_size, 1]
        
        # Return shape [batch_size, dimension]
        return torch.cos(self.w(t))

# Multi-Head Attention Layer
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
        
        # Linear projections for Query, Key, Value
        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        
        # Final output projection
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        
    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
        
        residual = q  # For residual connection
        
        # Linear projections and reshape
        q = self.w_qs(q).view(batch_size, len_q, n_head, d_k)  # [b, lq, n, dk]
        k = self.w_ks(k).view(batch_size, len_k, n_head, d_k)  # [b, lk, n, dk]
        v = self.w_vs(v).view(batch_size, len_v, n_head, d_v)  # [b, lv, n, dv]
        
        # Transpose for attention calculation
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)  # [b, n, lq, dk], [b, n, lk, dk], [b, n, lv, dv]
        
        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(d_k)  # [b, n, lq, lk]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # Apply mask (if provided)
        
        attn = F.softmax(scores, dim=-1)  # [b, n, lq, lk]
        attn = self.dropout(attn)
        
        # Multiply attention with values
        output = torch.matmul(attn, v)  # [b, n, lq, dv]
        
        # Reshape back and apply output projection
        output = output.transpose(1, 2).contiguous().view(batch_size, len_q, -1)  # [b, lq, n*dv]
        output = self.fc(output)  # [b, lq, d_model]
        
        # Apply residual connection and layer normalization
        output = self.layer_norm(output + residual)
        
        return output, attn

# Feed-Forward Network
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        residual = x
        
        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x = self.layer_norm(residual + x)
        
        return x

# Temporal Graph Attention Layer
class TGATLayer(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, d_time, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout)
        self.time_enc = TimeEncoder(d_time)
        self.position_ffn = PositionwiseFeedForward(d_model, d_model * 4, dropout)
        
    def forward(self, node_features, neighbor_features, edge_times, masks=None):
        """
        Args:
            node_features: [batch_size, d_model]
            neighbor_features: [batch_size, num_neighbors, d_model]
            edge_times: [batch_size, num_neighbors]
            masks: [batch_size, num_neighbors] (optional)
        Returns:
            output: [batch_size, d_model]
        """
        batch_size, num_neighbors = neighbor_features.size(0), neighbor_features.size(1)
        
        # Encode the time information
        time_features = self.time_enc(edge_times.view(-1, 1)).view(batch_size, num_neighbors, -1)  # [b, num_n, d_time]
        
        # Add time encoding to neighbor features
        time_enhanced_features = torch.cat([neighbor_features, time_features], dim=-1)
        
        # Expand node features for attention
        node_features_expanded = node_features.unsqueeze(1)  # [b, 1, d_model]
        
        # Apply attention between the node and its neighbors
        attn_output, attn_weights = self.attention(node_features_expanded, time_enhanced_features, time_enhanced_features, masks)
        
        # Apply feed-forward network
        output = self.position_ffn(attn_output.squeeze(1))  # [b, d_model]
        
        return output

# Complete TGAT Model
class TGAT(nn.Module):
    def __init__(self, node_features_dim, edge_features_dim, time_dim, hidden_dim, num_layers, num_heads, dropout=0.1):
        super(TGAT, self).__init__()
        
        self.node_features_dim = node_features_dim
        self.edge_features_dim = edge_features_dim
        self.time_dim = time_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        
        # Initial node embedding layer
        self.node_embedding = nn.Linear(node_features_dim, hidden_dim)
        
        # Edge embedding layer
        self.edge_embedding = nn.Linear(edge_features_dim, hidden_dim)
        
        # TGAT layers
        self.layers = nn.ModuleList([
            TGATLayer(num_heads, hidden_dim + edge_features_dim, 
                      hidden_dim // num_heads, hidden_dim // num_heads, 
                      time_dim, dropout) 
            for _ in range(num_layers)
        ])
        
        # For link prediction
        self.link_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def compute_temporal_embeddings(self, node_features, neighbor_indices, neighbor_times, neighbor_features=None, edge_features=None):
        """
        Compute node embeddings using the temporal graph attention mechanism.
        
        Args:
            node_features: Node features [batch_size, node_features_dim]
            neighbor_indices: Indices of neighbors [batch_size, num_neighbors]
            neighbor_times: Timestamps of neighbor connections [batch_size, num_neighbors]
            neighbor_features: Features of neighbors [batch_size, num_neighbors, node_features_dim]
            edge_features: Edge features [batch_size, num_neighbors, edge_features_dim]
            
        Returns:
            node_embeddings: Computed temporal node embeddings [batch_size, hidden_dim]
        """
        # Initial embedding
        node_embeddings = self.node_embedding(node_features)  # [b, hidden_dim]
        
        # If neighbor features are not provided, use a lookup table or zeros
        if neighbor_features is None:
            neighbor_embeddings = torch.zeros(node_features.size(0), neighbor_indices.size(1), self.hidden_dim, device=node_features.device)
        else:
            neighbor_embeddings = self.node_embedding(neighbor_features)  # [b, num_n, hidden_dim]
        
        # Process edge features if provided
        if edge_features is not None:
            edge_embeddings = self.edge_embedding(edge_features)  # [b, num_n, hidden_dim]
            # Incorporate edge features into neighbor embeddings
            neighbor_embeddings = torch.cat([neighbor_embeddings, edge_embeddings], dim=-1)  # [b, num_n, 2*hidden_dim]
        
        # Create masks if needed (e.g., for padding)
        masks = None  # Set to appropriate mask if needed
        
        # Apply TGAT layers
        x = node_embeddings
        for layer in self.layers:
            x = layer(x, neighbor_embeddings, neighbor_times, masks)
        
        return x
    
    def forward(self, src_features, dst_features, src_neighbor_indices, dst_neighbor_indices, 
                src_neighbor_times, dst_neighbor_times, src_neighbor_features=None, 
                dst_neighbor_features=None, src_edge_features=None, dst_edge_features=None):
        """
        Forward pass for link prediction task.
        
        Returns:
            link_prob: Probability of link existence [batch_size, 1]
        """
        # Compute node embeddings for both source and destination nodes
        src_embeddings = self.compute_temporal_embeddings(
            src_features, src_neighbor_indices, src_neighbor_times, 
            src_neighbor_features, src_edge_features
        )
        
        dst_embeddings = self.compute_temporal_embeddings(
            dst_features, dst_neighbor_indices, dst_neighbor_times, 
            dst_neighbor_features, dst_edge_features
        )
        
        # Concatenate embeddings for link prediction
        link_features = torch.cat([src_embeddings, dst_embeddings], dim=1)  # [b, 2*hidden_dim]
        
        # Predict link
        link_prob = self.link_predictor(link_features)
        
        return link_prob

# Example usage:
"""
# Initialize model
model = TGAT(
    node_features_dim=100,  # Dimension of node features
    edge_features_dim=50,    # Dimension of edge features 
    time_dim=10,            # Dimension of time encoding
    hidden_dim=128,         # Hidden dimension size
    num_layers=2,           # Number of TGAT layers
    num_heads=2,            # Number of attention heads
    dropout=0.1            # Dropout rate
)

# Example data
batch_size = 32
num_neighbors = 20

# Node features
src_features = torch.randn(batch_size, 100)
dst_features = torch.randn(batch_size, 100)

# Neighbor indices
src_neighbor_indices = torch.randint(0, 1000, (batch_size, num_neighbors))
dst_neighbor_indices = torch.randint(0, 1000, (batch_size, num_neighbors))

# Neighbor timestamps
src_neighbor_times = torch.rand(batch_size, num_neighbors)
dst_neighbor_times = torch.rand(batch_size, num_neighbors)

# Edge features
src_edge_features = torch.randn(batch_size, num_neighbors, 50)
dst_edge_features = torch.randn(batch_size, num_neighbors, 50)

# Forward pass
link_prob = model(src_features, dst_features, src_neighbor_indices, dst_neighbor_indices,
                 src_neighbor_times, dst_neighbor_times, edge_features_src=src_edge_features,
                 edge_features_dst=dst_edge_features)

print(link_prob.shape)  # Should be [batch_size, 1]
"""

'\n# Initialize model\nmodel = TGAT(\n    node_features_dim=100,  # Dimension of node features\n    edge_features_dim=50,    # Dimension of edge features \n    time_dim=10,            # Dimension of time encoding\n    hidden_dim=128,         # Hidden dimension size\n    num_layers=2,           # Number of TGAT layers\n    num_heads=2,            # Number of attention heads\n    dropout=0.1            # Dropout rate\n)\n\n# Example data\nbatch_size = 32\nnum_neighbors = 20\n\n# Node features\nsrc_features = torch.randn(batch_size, 100)\ndst_features = torch.randn(batch_size, 100)\n\n# Neighbor indices\nsrc_neighbor_indices = torch.randint(0, 1000, (batch_size, num_neighbors))\ndst_neighbor_indices = torch.randint(0, 1000, (batch_size, num_neighbors))\n\n# Neighbor timestamps\nsrc_neighbor_times = torch.rand(batch_size, num_neighbors)\ndst_neighbor_times = torch.rand(batch_size, num_neighbors)\n\n# Edge features\nsrc_edge_features = torch.randn(batch_size, num_neighbors, 50)\ndst_edge_fe

### Implementation Details

This TGAT implementation consists of several components:

1. **TimeEncoder**: Encodes time differences into feature vectors using a fixed set of frequencies. This allows the model to capture temporal patterns at different scales.

2. **MultiHeadAttention**: Implements the standard multi-head attention mechanism from the Transformer architecture, enabling the model to attend to different parts of the neighborhood.

3. **PositionwiseFeedForward**: A point-wise feed-forward network applied after the attention mechanism.

4. **TGATLayer**: Combines the attention mechanism with time encoding to create a temporal graph attention layer that can process both structural and temporal information.

5. **TGAT**: The full model that stacks multiple TGAT layers and includes components for link prediction.

The key innovation in TGAT is the incorporation of temporal information directly into the attention mechanism, allowing the model to learn how node relationships evolve over time.

In [10]:
# Example instantiation of the TGAT model

# Node/edge feature dimensions - these should match your data
node_features_dim = 64
edge_features_dim = 32

# Model hyperparameters
time_dim = 10
hidden_dim = 128
num_layers = 2
num_heads = 4
dropout = 0.1

# Create the model
tgat_model = TGAT(
    node_features_dim=node_features_dim,
    edge_features_dim=edge_features_dim,
    time_dim=time_dim,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    dropout=dropout
)

print(f"TGAT model initialized with {num_layers} layers and {num_heads} attention heads.")

TGAT model initialized with 2 layers and 4 attention heads.


### 3.2. Preprocessing the data