# PyTorch Geometric Temporal for Dynamic Network Adaptation

This notebook explores the implementation of PyTorch Geometric Temporal, focusing on its application in the research project: *Implementing Decay-Based Temporal Attention for Dynamic Network Adaptation*.

Code from : https://github.com/benedekrozemberczki/pytorch_geometric_temporal 

## 1. Setup and Installation

In [None]:
# Installation of required libraries
# Ensure you have PyTorch installed. If not, uncomment and run the appropriate command from https://pytorch.org/get-started/locally/ and for CUDA enabled GPU support use cuda12.8
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128

# Install PyTorch Geometric and PyTorch Geometric Temporal
!pip install torch-geometric
!pip install torch-geometric-temporal

## 2. Import Libraries

In [None]:
import torch
import torch_geometric
import torch_geometric_temporal

print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
print(f"PyTorch Geometric Temporal version: {torch_geometric_temporal.__version__}")

## 3. Implementing Decay-Based Temporal Attention

This section will focus on building and utilizing models from PyTorch Geometric Temporal, particularly those that can be adapted or extended for decay-based temporal attention mechanisms in dynamic networks.

### 3.1. Baseline Implementation: TGAT

We'll start by implementing a baseline model using PyTorch Geometric Temporal's Temporal Graph Attention Network (TGAT). We will then extend this model to incorporate a decay mechanism.

In [None]:
from torch_geometric_temporal.nn.conv import TGATConv
import torch

class TGATWithDecay(TGATConv):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0.0, bias=True, add_self_loops=True, improved=False, **kwargs):
        super().__init__(in_channels=in_channels, out_channels=out_channels, heads=heads, concat=concat, negative_slope=negative_slope, dropout=dropout, bias=bias, add_self_loops=add_self_loops, improved=improved, **kwargs)
        self.decay_rate = kwargs.get('decay_rate', 0.1) # Default decay_rate if not provided

    def forward(self, x, edge_index, edge_time, size=None):
        # Assuming x is a tuple (node_features, node_timestamps)
        # Or that node_timestamps can be accessed/derived. 
        # For simplicity, let's assume node_timestamps are passed or accessible.
        # This part needs careful handling based on how TGATConv expects inputs 
        # and how timestamps are managed in your specific dataset.

        # Placeholder for time difference calculation - this needs to be adapted to your data
        # For example, if x is (features, timestamps_tensor)
        # node_timestamps = x[1] 
        # time_diff = edge_time - node_timestamps[edge_index[0]] # Time diff relative to source node of an edge
        
        # A more generic approach might involve passing timestamps separately or as part of x
        # This is a conceptual placeholder for decay calculation:
        # time_diff = ... # Calculate time differences based on your data structure
        # decay = torch.exp(-self.decay_rate * time_diff)
        
        # The original TGATConv forward signature is (x, edge_index, edge_attr=None, size=None, return_attention_weights=None)
        # We need to ensure our 'decay' can be used as 'edge_attr'.
        # The structure of 'edge_attr' (decay) must match what TGATConv expects.
        
        # For now, let's call super().forward without edge_attr modification
        # as the exact mechanism for integrating decay as edge_attr in TGATConv
        # requires more details on data structure and TGATConv's internal workings.
        # The user's original code `return super().forward(x, edge_index, edge_attr=decay)` 
        # implies that edge_attr can be directly used for this. We will follow that, 
        # but it's important to ensure `decay` has the correct shape and meaning.

        # A simplified assumption: edge_time is a tensor of timestamps for each edge.
        # And we need a base time for each node to calculate the difference.
        # This is highly dependent on the dataset and how TGATConv is meant to be used.
        # The original TGATConv does not explicitly use edge_time in its forward pass in this manner.
        # The user's snippet `time_diff = edge_time - x.time` implies x has a .time attribute.
        # Let's assume x is a structure or object that has a .time attribute (e.g., node timestamps)
        # and edge_time is a tensor of timestamps for edges.

        # This part is speculative based on the user's snippet and needs to be verified with actual data.
        # If x is just a feature tensor, x.time will not work.
        # For demonstration, let's assume x is a feature tensor and we'll need another way to get node times.
        # Or, if edge_time itself represents the time differences, then decay can be computed directly.

        # Given the user's snippet: `time_diff = edge_time - x.time`
        # This implies `x` is not just the feature matrix but might be a custom object or tuple
        # where `x.time` refers to the timestamp of the nodes involved in the operation.
        # Let's assume for now that `edge_time` are absolute timestamps and we need a reference.
        # A common scenario in dynamic graphs is that `x` represents node features at a certain snapshot,
        # and `edge_time` are timestamps of interactions (edges).

        # The following is a direct interpretation of the user's snippet, 
        # assuming x has a .time attribute representing current/node timestamps.
        # This will likely need adjustment based on the actual data pipeline.
        
        if hasattr(x, 'time') and x.time is not None:
            time_diff = edge_time - x.time # This subtraction might need broadcasting or specific indexing
            decay = torch.exp(-self.decay_rate * time_diff.float()) # Ensure float for exp
            # Ensure decay has the correct shape for edge_attr, e.g., [num_edges, 1] or [num_edges, num_heads] if per-head decay
            if decay.ndim == 1:
                decay = decay.unsqueeze(-1)
            return super().forward(x, edge_index, edge_attr=decay, size=size)
        else:
            # Fallback if x.time is not available, call original TGAT
            # Or raise an error, or handle differently
            print("Warning: x.time not available for decay calculation. Using standard TGAT forward.")
            return super().forward(x, edge_index, size=size)

# Example Usage (conceptual):
# model = TGATWithDecay(in_channels=node_features.size(1), out_channels=32, heads=2, decay_rate=0.05)
# node_features_at_t = ... # Node features at current time t (tensor)
# node_timestamps_at_t = ... # Node timestamps if they vary per node and are part of input `x`
# x_input = (node_features_at_t, node_timestamps_at_t) # Or some object x_input.time = ...
# edge_index_at_t = ... # Edge connectivity
# edge_timestamps_at_t = ... # Timestamps for each edge

# output = model(x_input, edge_index_at_t, edge_timestamps_at_t)