# Enhanced TGN Model with Visualizations

This notebook demonstrates the Enhanced Temporal Graph Network (TGN) model with:
1. Disabled decay factor in DecayTemporalAttention for evaluation without decay metrics
2. Visualizations of the Reddit dataset in graph format

We'll use the implementation from the main Enhanced_TGN.ipynb notebook but with these modifications.

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import networkx as nx
from torch.utils.data import Dataset, DataLoader
from src.enhanced_tgn import TemporalGraphNetwork
from src.graph_visualization import visualize_temporal_graph, visualize_community_structure, visualize_temporal_communities
from src.dataset_visualization import visualize_reddit_dataset

## 1. Decay Factor in TGN

We've modified the DecayTemporalAttention class to disable the decay factor by setting it to zero. Let's implement it here to show the change:

In [None]:
class DecayTemporalAttention(nn.Module):
    """Decay-based temporal attention mechanism for dynamic graphs.
    This implements time-aware attention where attention weights decay over time.
    We've disabled the decay effect by setting decay_factor to 0 for evaluation.
    """
    def __init__(self, node_feat_dim, edge_feat_dim, time_feat_dim, memory_dim, output_dim, 
                 n_heads=2, dropout=0.1, decay_factor=0.1):
        super(DecayTemporalAttention, self).__init__()
        self.n_heads = n_heads
        # self.decay_factor = decay_factor
        self.decay_factor = 0.0  # Setting to zero to disable decay effect
        
        # Dimension calculations
        self.query_dim = node_feat_dim + time_feat_dim
        self.key_dim = node_feat_dim + edge_feat_dim + time_feat_dim
        self.value_dim = self.key_dim
        
        # Query, key, value projections
        self.w_query = nn.Linear(self.query_dim, n_heads * memory_dim)
        self.w_key = nn.Linear(self.key_dim, n_heads * memory_dim)
        self.w_value = nn.Linear(self.value_dim, n_heads * memory_dim)
        
        # Output projection
        self.output_layer = nn.Linear(n_heads * memory_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(output_dim)
    
    def forward(self, node_features, node_time_features, neighbor_features, 
                neighbor_time_features, edge_features, time_diffs, attention_mask=None):
        """Forward pass with decay-based temporal attention"""
        batch_size, n_neighbors = neighbor_features.size(0), neighbor_features.size(1)
        memory_dim = self.w_query.out_features // self.n_heads
        
        # Create query from node features and time features
        query = torch.cat([node_features, node_time_features], dim=1)
        query = self.w_query(query).view(batch_size, self.n_heads, memory_dim)
        
        # Create key/value from neighbor features
        key_input = torch.cat([
            neighbor_features.reshape(batch_size * n_neighbors, -1),
            neighbor_time_features.reshape(batch_size * n_neighbors, -1),
            edge_features.reshape(batch_size * n_neighbors, -1)
        ], dim=1).view(batch_size, n_neighbors, -1)
        
        key = self.w_key(key_input).view(batch_size, n_neighbors, self.n_heads, memory_dim)
        key = key.permute(0, 2, 1, 3)  # [batch_size, n_heads, n_neighbors, memory_dim]
        
        value = self.w_value(key_input).view(batch_size, n_neighbors, self.n_heads, memory_dim)
        value = value.permute(0, 2, 1, 3)  # [batch_size, n_heads, n_neighbors, memory_dim]
        
        # Calculate attention scores
        scores = torch.matmul(query.unsqueeze(2), key.transpose(-2, -1)) / math.sqrt(memory_dim)
        
        # Apply decay-based attenuation (with decay_factor=0, this has no effect)
        time_decay = torch.exp(-self.decay_factor * time_diffs).unsqueeze(1).unsqueeze(1)
        scores = scores * time_decay
        
        # Apply attention mask if provided
        if attention_mask is not None:
            mask = attention_mask.unsqueeze(1).unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights
        context = torch.matmul(attn_weights, value)
        context = context.transpose(1, 2).reshape(batch_size, -1)
        
        # Apply output projection
        output = self.output_layer(context)
        
        # Apply residual connection if dimensions match
        if output.size(1) == node_features.size(1):
            output = self.layer_norm(output + node_features)
        
        return output

print("DecayTemporalAttention class with disabled decay factor defined.")

## 2. Load the Dataset

We'll load the Reddit dataset using the same function as in the main notebook.

In [None]:
# Helper functions for data preparation
def load_reddit_dataset():
    """Load and preprocess the Reddit dataset for TGN"""
    try:
        # Try to load the dataset
        data_path = os.path.join('data', 'reddit_TGAT.csv')
        if not os.path.exists(data_path):
            data_path = os.path.join('data', 'soc-redditHyperlinks-title.tsv')
            if os.path.exists(data_path):
                df = pd.read_csv(data_path, sep='\t')
            else:
                raise FileNotFoundError("Reddit dataset not found")
        else:
            df = pd.read_csv(data_path)
            
        print(f"Loaded Reddit dataset with {len(df)} edges")
        
        # Extract nodes and create a mapping
        if 'SOURCE_SUBREDDIT' in df.columns and 'TARGET_SUBREDDIT' in df.columns:
            src_col, dst_col = 'SOURCE_SUBREDDIT', 'TARGET_SUBREDDIT'
        else:
            src_col, dst_col = df.columns[0], df.columns[1]
            
        all_nodes = pd.concat([df[src_col], df[dst_col]]).unique()
        node_mapping = {node: idx for idx, node in enumerate(all_nodes)}
        num_nodes = len(node_mapping)
        
        # Extract timestamps
        if 'TIMESTAMP' in df.columns:
            time_col = 'TIMESTAMP'
            # Convert to numerical format if needed
            df['timestamp'] = pd.to_datetime(df[time_col]).astype(int) / 10**9
        elif 'timestamp' in df.columns:
            time_col = 'timestamp'
            df['timestamp'] = df[time_col].astype(float)
        else:
            # If no timestamp, create artificial ones
            df['timestamp'] = range(len(df))
            
        # Sort by timestamp
        df = df.sort_values('timestamp')
        
        # Map nodes to indices
        df['src_idx'] = df[src_col].map(node_mapping)
        df['dst_idx'] = df[dst_col].map(node_mapping)
        
        # Create node features (random for demonstration)
        node_features = np.random.randn(num_nodes, 100)  # 100-dimensional features
        
        # Create edge features
        if 'PROPERTIES' in df.columns:
            # Use properties as edge features
            # This is a placeholder - would need actual feature extraction
            edge_features = np.random.randn(len(df), 50)  # 50-dimensional features
        else:
            # Create random edge features
            edge_features = np.random.randn(len(df), 50)  # 50-dimensional features
            
        # Split into train/val/test
        train_ratio, val_ratio = 0.7, 0.15
        train_end = int(len(df) * train_ratio)
        val_end = int(len(df) * (train_ratio + val_ratio))
        
        train_df = df.iloc[:train_end]
        val_df = df.iloc[train_end:val_end]
        test_df = df.iloc[val_end:]
        
        print(f"Split dataset into {len(train_df)} train, {len(val_df)} validation, and {len(test_df)} test samples")
        print(f"Number of unique nodes: {num_nodes}")
        
        return {
            'train_df': train_df,
            'val_df': val_df,
            'test_df': test_df,
            'node_features': node_features,
            'edge_features': edge_features,
            'num_nodes': num_nodes
        }
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return None

# Load the dataset
dataset_info = load_reddit_dataset()

## 3. Visualize the Dataset

Now let's visualize the Reddit dataset as a graph using our new visualization functions.

In [None]:
# Visualize the dataset if available
if dataset_info:
    G = visualize_reddit_dataset(dataset_info)
else:
    print("Dataset not available for visualization.")

## 4. Initialize and Use the Modified TGN Model

Now we'll initialize the TGN model with the modified DecayTemporalAttention class (decay factor set to 0) for evaluation.

In [None]:
# Model configuration
if dataset_info:
    num_nodes = dataset_info['num_nodes']
else:
    num_nodes = 1000  # Default if dataset not available

# Model hyperparameters
node_feat_dim = 100  # Dimension of node features
edge_feat_dim = 50   # Dimension of edge features
memory_dim = 100     # Dimension of node memory
time_dim = 10        # Dimension of time encoding
embedding_dim = 100  # Dimension of final node embeddings
message_dim = 100    # Dimension of messages
n_layers = 2         # Number of graph attention layers
n_heads = 2          # Number of attention heads
dropout = 0.1        # Dropout probability

# Initialize the model
model = TemporalGraphNetwork(
    num_nodes=num_nodes,
    node_feat_dim=node_feat_dim,
    edge_feat_dim=edge_feat_dim,
    memory_dim=memory_dim,
    time_dim=time_dim,
    embedding_dim=embedding_dim,
    message_dim=message_dim,
    n_layers=n_layers,
    n_heads=n_heads,
    dropout=dropout,
    use_memory=True,
    message_function='mlp',     # Options: 'mlp', 'identity'
    memory_updater='gru',       # Options: 'gru', 'rnn'
    aggregator='lstm'           # Options: 'lstm', 'mean'
)

print(f"Enhanced TGN model initialized with {n_layers} layers and {n_heads} attention heads.")
print("Note: The decay factor in DecayTemporalAttention is set to 0 for evaluation without decay metrics.")

## 5. Comparison with and without Decay Factor

For illustration, let's show how the model behaves differently with and without the decay factor.

In [None]:
import math

# Function to demonstrate how the decay factor affects attention weights
def compare_decay_factors():
    # Simulate time differences between current time and past interactions
    time_diffs = torch.tensor([0.1, 1.0, 5.0, 10.0, 20.0])
    
    # Calculate decay with different factors
    decay_factors = [0.0, 0.1, 0.5, 1.0]
    
    plt.figure(figsize=(10, 6))
    
    for factor in decay_factors:
        decay = torch.exp(-factor * time_diffs)
        plt.plot(time_diffs.numpy(), decay.numpy(), marker='o', label=f"Decay factor = {factor}")
    
    plt.title("Effect of Decay Factor on Attention Weights")
    plt.xlabel("Time Difference")
    plt.ylabel("Attention Weight Multiplier")
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Explain the impact
    print("Impact of decay factor on the model:")
    print("- With decay_factor = 0.0 (our modified version): All interactions receive equal weight regardless of time.")
    print("- With decay_factor > 0: Recent interactions get higher weight than older ones.")
    print("- Higher decay factors cause weights to decay more rapidly with time.")
    print("\nBy setting decay_factor = 0, we can evaluate the model's performance without the temporal decay effect.")

# Run the comparison
compare_decay_factors()

## 6. Conclusion

In this notebook, we've:

1. Modified the DecayTemporalAttention class by setting the decay factor to zero to evaluate the model without temporal decay.
2. Added visualization capabilities to view the Reddit dataset as a graph, including community structure and temporal evolution.
3. Demonstrated how the decay factor affects the attention weights in the TGN model.

These modifications allow for a more comprehensive analysis of the TGN model's behavior and the structure of temporal graph datasets.