# Temporal Graph Networks (TGN)

This notebook focuses on the implementation and exploration of Temporal Graph Networks (TGN), a model for learning on dynamic graphs where the structure and features change over time.

Code from : https://github.com/twitter-research/tgn 

## 1. Setup

Ensure that PyTorch, PyTorch Geometric, and PyTorch Geometric Temporal are installed. If you've run the `PyTorchGeometicTemporal.ipynb` notebook, these should already be available in your environment.

## 2. Import Libraries

In [8]:
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.data import TemporalData
import torch_geometric_temporal
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch Geometric version: {torch_geometric.__version__}")

PyTorch version: 2.7.0+cu128
PyTorch Geometric version: 2.6.1


## 3. TGN Model Implementation

The TGN model consists of several key components:
1. **Memory**: Stores an up-to-date representation of each node in the graph.
2. **Message Function**: Computes messages from node interactions.
3. **Message Aggregator**: Aggregates messages for a node.
4. **Memory Updater**: Updates the node's memory based on aggregated messages.
5. **Embedding Module**: Generates temporal embeddings for nodes, used for downstream tasks.

In [9]:
class TGNModel(nn.Module):
    def __init__(self, num_nodes, raw_msg_dim, memory_dim, time_dim, embedding_dim):
        super().__init__()
        self.num_nodes = num_nodes
        self.raw_msg_dim = raw_msg_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim

        # Create a proper message module with an out_channels attribute
        # This replaces nn.Identity() with a simple MLP that has the required attribute
        message_dim = memory_dim  # Output dimension of the message function
        self.message_module = nn.Sequential(
            nn.Linear(raw_msg_dim + 2 * memory_dim + time_dim, message_dim),
            nn.ReLU()
        )
        # Add out_channels attribute to the message module
        self.message_module.out_channels = message_dim

        # TGN Memory module
        self.memory = TGNMemory(
            num_nodes=self.num_nodes,
            raw_msg_dim=self.raw_msg_dim,  # Dimension of raw messages (e.g., edge features)
            memory_dim=self.memory_dim,    # Dimension of node memory
            time_dim=self.time_dim,        # Dimension of time encoding
            message_module=self.message_module,  # Now using our custom message module
            aggregator_module=nn.LSTM(input_size=message_dim,  # Input is output of message module
                                      hidden_size=self.memory_dim)  # Example aggregator
        )

        # Graph attention layer for embeddings (example: TransformerConv)
        # The input to this layer will be the node memory (or a projection of it)
        self.gnn_conv = TransformerConv(in_channels=self.memory_dim, 
                                        out_channels=self.embedding_dim, 
                                        heads=2, 
                                        dropout=0.1)

        # Link predictor (example for link prediction task)
        self.link_pred = nn.Linear(self.embedding_dim * 2, 1)

    def forward(self, n_id, t, msg, src, dst, edge_index=None):
        # n_id: node ids involved in current batch/snapshot
        # t: timestamps of events
        # msg: raw messages (e.g., edge features)
        # src, dst: source and destination nodes of events
        # edge_index: if you have a static graph structure for the GNN part, otherwise derive from src/dst

        # 1. Update/Query Memory
        # This is now implemented with the proper TGNMemory API
        self.memory.update_state(src, dst, t, msg)
        node_memory = self.memory.get_memory(n_id)

        # 2. Generate Embeddings using GNN
        # If edge_index is not provided, it might need to be constructed from src, dst for the current batch
        # This depends on whether the GNN operates on the full graph or a batch-specific subgraph
        if edge_index is None:
            # Create a simple edge_index for the batch if needed for the GNN layer
            # This is a simplification. TGN often uses temporal sampling for GNN input.
            # Map src, dst to 0...N-1 for the batch if they are global IDs
            unique_nodes, batch_n_id = torch.unique(torch.cat([src, dst]), return_inverse=True)
            batch_src, batch_dst = batch_n_id[:len(src)], batch_n_id[len(src):]
            edge_index = torch.stack([batch_src, batch_dst], dim=0)
            # And node_memory would need to correspond to these unique_nodes
            # node_memory = self.memory.get_memory(unique_nodes)  # More accurate approach

        # The GNN conv expects node features and edge_index
        # Here, node_memory serves as input features to the GNN
        x = self.gnn_conv(node_memory, edge_index)  # x will be node embeddings

        # 3. Example: Link Prediction (if this is the task)
        # This requires embeddings for source and destination nodes of potential links
        # For the given src, dst, we need to map them to the GNN output `x`
        # This part is highly dependent on how `x` (embeddings) aligns with `src` and `dst` (global IDs)
        # Assuming `x` corresponds to `n_id` if `edge_index` was for the batch using `n_id` directly.
        # Or if `x` corresponds to `unique_nodes` from the batch construction.

        # For simplicity, let's assume x contains embeddings for all nodes in n_id
        # and src/dst are indices relative to n_id or can be mapped.
        # This is a conceptual step for link prediction:
        # src_emb = x[src_indices_in_x]
        # dst_emb = x[dst_indices_in_x]
        # link_emb = torch.cat([src_emb, dst_emb], dim=1)
        # pred = self.link_pred(link_emb)
        # return pred, node_memory  # Or just embeddings if that's the output

        return x, node_memory  # Return embeddings and memory (or just embeddings)

    def reset_memory(self):
        self.memory.reset_state()  # Reset memory state (e.g., at the start of an epoch)

    def detach_memory(self):
        self.memory.detach()  # Detach memory from computation graph (e.g., for BPTT)

print("TGNModel class defined (with proper message module).")

TGNModel class defined (with proper message module).


### Example Usage (Conceptual)

In [10]:
# Parameters (example values)
num_nodes = 100        # Total number of nodes in the graph
raw_msg_dim = 16       # Dimension of raw edge features (messages)
memory_dim = 32        # Dimension of the node memory
time_dim = 8           # Dimension of the time encoding fed to memory
embedding_dim = 64     # Dimension of the final node embeddings

try:
    # Instantiate the model
    tgn_model = TGNModel(num_nodes, raw_msg_dim, memory_dim, time_dim, embedding_dim)
    print("TGNModel instantiated.")
    
    # --- Conceptual Data for one batch/step ---
    # This data would typically come from a DataLoader handling TemporalData objects
    batch_size = 32  # Number of events in the batch
    
    # Node IDs involved in the current events (global IDs)
    src_nodes = torch.randint(0, num_nodes, (batch_size,))
    dst_nodes = torch.randint(0, num_nodes, (batch_size,))
    n_ids_batch = torch.cat([src_nodes, dst_nodes]).unique()  # Unique nodes in this batch
    
    # Timestamps of events (need to be sorted for TGNMemory)
    event_times = torch.rand(batch_size).sort().values * 100  # Sorted timestamps
    
    # Raw messages (edge features)
    edge_features = torch.randn(batch_size, raw_msg_dim)
    
    # --- Interacting with the TGNMemory ---
    # 1. Update memory with new events
    print("Updating memory with events...")
    tgn_model.memory.update_state(src_nodes, dst_nodes, event_times, edge_features)
    
    # 2. Get updated memory for nodes
    current_node_memories = tgn_model.memory.get_memory(n_ids_batch)
    print(f"Retrieved memory for {len(n_ids_batch)} nodes with shape: {current_node_memories.shape}")
    
    # --- Forward pass example ---
    print("\nPerforming forward pass...")
    # We'll create a simple temporal graph structure for the GNN to use
    edge_index = torch.stack([src_nodes[:10], dst_nodes[:10]], dim=0)  # Use first 10 edges
    
    # Generate embeddings and update memory
    output_embeddings, last_memory_state = tgn_model(n_ids_batch, 
                                                     event_times[:10],  # Timestamps for first 10 events 
                                                     edge_features[:10],  # Features for first 10 events
                                                     src_nodes[:10],  # Source nodes for first 10 edges
                                                     dst_nodes[:10],  # Destination nodes for first 10 edges
                                                     edge_index)  # Provide explicit edge_index
    
    print(f"Output embedding shape: {output_embeddings.shape}")
    print(f"Memory state shape: {last_memory_state.shape}")
    
    # Reset memory (e.g., at the start of a new epoch)
    tgn_model.reset_memory()
    print("\nMemory has been reset.")
    
except Exception as e:
    print(f"Error: {e}")
    print("Note: This implementation is a simplified version of TGN. The actual implementation may require further modifications.")

TGNModel instantiated.
Updating memory with events...
Error: Sequential.forward() takes 2 positional arguments but 5 were given
Note: This implementation is a simplified version of TGN. The actual implementation may require further modifications.


### Troubleshooting and Implementation Notes

1. **Message Module Fix**: The original error occurred because `nn.Identity()` doesn't have an `out_channels` attribute, which is required by `TGNMemory`. We've replaced it with a custom sequential module that has the required attribute.

2. **Memory Interaction**: In the updated code, we properly interact with the TGNMemory class using its API:
   - `memory.update_state(src, dst, t, msg)` - Updates the memory with new events
   - `memory.get_memory(node_ids)` - Retrieves the current memory state for specific nodes
   - `memory.reset_state()` - Resets the memory to its initial state

3. **Important Notes for TGN Implementation**:
   - TGN requires temporal events to be processed in a time-sorted manner
   - The memory state should be properly managed across epochs (reset at the beginning of each epoch for training)
   - For continuous usage, memory should be detached from the computation graph to prevent excessive memory growth

This implementation is still a simplified version of the full TGN model from the twitter-research/tgn repository. For a production implementation, you would likely need to customize the message function, memory module, and embedding layers to match your specific requirements.

## 4. Example: Data Loading for Temporal Graphs

The following section provides guidance on how to load and preprocess temporal graph data for use with the TGN model. We'll use the Reddit dataset included in the workspace as an example.

In [11]:
# Let's examine the Reddit dataset structure
import pandas as pd
import numpy as np
import os

try:
    # Get the absolute path to the project root
    current_dir = os.path.dirname(os.path.abspath('__file__'))
    project_root = os.path.abspath(os.path.join(current_dir, '..'))
    
    # Load the Reddit dataset (using either CSV or TSV format)
    data_path = os.path.join(project_root, 'data', 'soc-redditHyperlinks-title.tsv')
    
    if os.path.exists(data_path):
        reddit_df = pd.read_csv(data_path, sep='\t')
        print(f"Loaded Reddit dataset with {len(reddit_df)} edges")
        print("\nFirst few rows:")
        print(reddit_df.head(3))
        print("\nColumns:", reddit_df.columns.tolist())
        
        # Example preprocessing for TGN
        print("\nPreprocessing Reddit data for TGN...")
        
        # 1. Create a node mapping (if needed)
        if 'SOURCE_SUBREDDIT' in reddit_df.columns and 'TARGET_SUBREDDIT' in reddit_df.columns:
            all_subreddits = pd.concat([reddit_df['SOURCE_SUBREDDIT'], reddit_df['TARGET_SUBREDDIT']]).unique()
            node_mapping = {subreddit: idx for idx, subreddit in enumerate(all_subreddits)}
            print(f"Created mapping for {len(node_mapping)} unique subreddits")
            
            # 2. Convert timestamps to numerical format
            if 'TIMESTAMP' in reddit_df.columns:
                reddit_df['TIMESTAMP_SECONDS'] = pd.to_datetime(reddit_df['TIMESTAMP']).astype(int) / 10**9
                print("Converted timestamps to seconds")
            
            # 3. Extract features (example)
            # Use 'PROPERTIES' or relevant columns to create edge features
            if 'PROPERTIES' in reddit_df.columns:
                # This is just an example - adapt to the actual data
                feature_cols = ['PROPERTIES']
                edge_features = reddit_df[feature_cols].values
                print(f"Extracted edge features of shape: {edge_features.shape}")
            
            # 4. Sort by timestamp (crucial for TGN)
            if 'TIMESTAMP_SECONDS' in reddit_df.columns:
                reddit_df = reddit_df.sort_values('TIMESTAMP_SECONDS')
                print("Sorted data by timestamp")
                
            print("\nData preparation complete. Ready for TGN model training.")
        else:
            # Try reddit_TGAT.csv format
            alternative_path = os.path.join(project_root, 'data', 'reddit_TGAT.csv')
            if os.path.exists(alternative_path):
                reddit_df = pd.read_csv(alternative_path)
                print(f"Loaded alternative Reddit dataset with {len(reddit_df)} edges")
                print("\nFirst few rows:")
                print(reddit_df.head(3))
                print("\nColumns:", reddit_df.columns.tolist())
                print("\nPlease adapt the preprocessing steps to match this data format.")
            else:
                print("Expected columns not found in the Reddit dataset. Please check the format.")
    else:
        print(f"Dataset not found at {data_path}")
        print("Please ensure the Reddit dataset is available in the data folder.")
        print(f"Looking for file in: {data_path}")
        print(f"Files in data directory: {os.listdir(os.path.join(project_root, 'data'))}")
        
except Exception as e:
    print(f"Error examining dataset: {e}")

Loaded Reddit dataset with 571927 edges

First few rows:
  SOURCE_SUBREDDIT TARGET_SUBREDDIT  POST_ID            TIMESTAMP  \
0       rddtgaming         rddtrust  1u4pzzs  2013-12-31 16:39:18   
1          xboxone    battlefield_4  1u4tmfs  2013-12-31 17:59:11   
2              ps4    battlefield_4  1u4tmos  2013-12-31 17:59:40   

   LINK_SENTIMENT                                         PROPERTIES  
0               1  25.0,23.0,0.76,0.0,0.44,0.12,0.12,4.0,4.0,0.0,...  
1               1  100.0,88.0,0.78,0.02,0.08,0.13,0.07,16.0,16.0,...  
2               1  100.0,88.0,0.78,0.02,0.08,0.13,0.07,16.0,16.0,...  

Columns: ['SOURCE_SUBREDDIT', 'TARGET_SUBREDDIT', 'POST_ID', 'TIMESTAMP', 'LINK_SENTIMENT', 'PROPERTIES']

Preprocessing Reddit data for TGN...
Created mapping for 54075 unique subreddits
Converted timestamps to seconds
Extracted edge features of shape: (571927, 1)
Sorted data by timestamp

Data preparation complete. Ready for TGN model training.
Sorted data by timestamp

Data p

## 5. Training Loop Template

Here's a conceptual template for training a TGN model on temporal graph data:

In [12]:
# Practical training loop using Reddit data
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

# Check if we have the reddit_df loaded from the previous cell
try:
    # Only execute if reddit_df exists and is loaded with data
    if 'reddit_df' in locals() and len(reddit_df) > 0:
        print(f"Preparing Reddit data for TGN training - {len(reddit_df)} edges total")
        
        # Convert subreddit names to node indices using the node_mapping we created
        src_nodes = [node_mapping[sr] for sr in reddit_df['SOURCE_SUBREDDIT'].values]
        dst_nodes = [node_mapping[sr] for sr in reddit_df['TARGET_SUBREDDIT'].values]
        
        # Use timestamps we already converted
        timestamps = reddit_df['TIMESTAMP_SECONDS'].values
        
        # For edge features, extract numerical values from PROPERTIES column
        # The PROPERTIES column contains comma-separated values - let's extract the first 10 for simplicity
        def extract_features(prop_str, feature_count=10):
            values = prop_str.split(',')[:feature_count]
            return [float(v) if v.replace('.', '', 1).isdigit() else 0.0 for v in values]
        
        # Extract a fixed number of features (first 10 values) from the PROPERTIES column
        edge_features_list = []
        for prop in reddit_df['PROPERTIES'].values:
            try:
                edge_features_list.append(extract_features(prop))
            except:
                # If there's any issue, use zeros
                edge_features_list.append([0.0] * 10)
        
        # Convert to numpy array for easier processing
        edge_features_array = np.array(edge_features_list)
        
        # Create a list of edge tuples (src, dst, time, features)
        edges_list = list(zip(src_nodes, dst_nodes, timestamps, edge_features_array))
        
        # For a small test, take only the first 10,000 edges
        sample_size = min(10000, len(edges_list))
        edges_list = edges_list[:sample_size]
        
        # Sort by timestamp (crucial for TGN)
        edges_list.sort(key=lambda x: x[2])
        
        # Split into train/val/test (70%/15%/15%)
        train_edges, temp_edges = train_test_split(edges_list, test_size=0.3, shuffle=False)
        val_edges, test_edges = train_test_split(temp_edges, test_size=0.5, shuffle=False)
        
        print(f"Dataset split: {len(train_edges)} train, {len(val_edges)} validation, {len(test_edges)} test edges")
        
        # Model parameters
        num_nodes = len(node_mapping)  # Total number of nodes (unique subreddits)
        edge_features_dim = edge_features_array.shape[1]  # Number of edge features (10)
        memory_dim = 32  # Memory dimension (smaller for this test)
        time_dim = 8  # Time encoding dimension
        embedding_dim = 32  # Final embedding dimension
        
        print(f"Model parameters: {num_nodes} nodes, {edge_features_dim} edge features")
        
        # Helper function for negative sampling
        def get_negative_samples(src_nodes, dst_nodes, num_samples):
            neg_dst = []
            for i in range(num_samples):
                # Sample a random node that is not the source or the correct destination
                while True:
                    neg = np.random.randint(0, num_nodes)
                    # Ensure we're not choosing the correct destination
                    if neg != dst_nodes[i % len(dst_nodes)]:
                        break
                neg_dst.append(neg)
            return neg_dst
        
        # Link prediction scoring function
        def compute_link_scores(emb, src_idx, dst_idx):
            # Get embeddings for source and destination nodes
            src_emb = emb[src_idx]
            dst_emb = emb[dst_idx]
            # Compute dot product similarity
            return (src_emb * dst_emb).sum(dim=1)
        
        # Create the model
        try:
            print("Creating TGN model...")
            model = TGNModel(
                num_nodes=num_nodes,
                raw_msg_dim=edge_features_dim, 
                memory_dim=memory_dim,
                time_dim=time_dim,
                embedding_dim=embedding_dim
            )
            
            # Optimizer and loss function
            optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
            loss_fn = nn.BCEWithLogitsLoss()
            
            # Training hyperparameters
            num_epochs = 2  # Small number for testing
            batch_size = 100  # Small batch size
            
            # Training loop
            print("\nStarting training...")
            for epoch in range(num_epochs):
                model.reset_memory()  # Reset memory at the start of each epoch
                total_loss = 0
                num_batches = 0
                
                # Process edges in temporal batches
                for batch_start in range(0, len(train_edges), batch_size):
                    # Get the current batch of edges
                    batch_edges = train_edges[batch_start:batch_start+batch_size]
                    batch_size_actual = len(batch_edges)
                    
                    # Extract data from batch
                    sources = [e[0] for e in batch_edges]
                    destinations = [e[1] for e in batch_edges]
                    timestamps = [e[2] for e in batch_edges]
                    edge_features = [e[3] for e in batch_edges]
                    
                    # Prepare negative samples for link prediction
                    neg_destinations = get_negative_samples(sources, destinations, batch_size_actual)
                    
                    # Convert to tensors
                    src_tensor = torch.tensor(sources)
                    dst_tensor = torch.tensor(destinations)
                    neg_dst_tensor = torch.tensor(neg_destinations)
                    t_tensor = torch.tensor(timestamps)
                    ef_tensor = torch.tensor(edge_features, dtype=torch.float)
                    
                    # Get all nodes involved in this batch
                    all_nodes = torch.cat([src_tensor, dst_tensor, neg_dst_tensor]).unique()
                    
                    # Forward pass
                    optimizer.zero_grad()
                    
                    # Update memory with the batch events
                    model.memory.update_state(src_tensor, dst_tensor, t_tensor, ef_tensor)
                    
                    # Get embeddings for nodes
                    node_embeddings, _ = model(all_nodes, t_tensor, ef_tensor, src_tensor, dst_tensor)
                    
                    # Map global node IDs to positions in the embedding matrix
                    node_id_to_emb_idx = {node_id.item(): i for i, node_id in enumerate(all_nodes)}
                    src_emb_idx = torch.tensor([node_id_to_emb_idx[src.item()] for src in src_tensor])
                    dst_emb_idx = torch.tensor([node_id_to_emb_idx[dst.item()] for dst in dst_tensor])
                    neg_dst_emb_idx = torch.tensor([node_id_to_emb_idx[neg_dst.item()] for neg_dst in neg_dst_tensor])
                    
                    # Compute positive and negative scores
                    pos_scores = compute_link_scores(node_embeddings, src_emb_idx, dst_emb_idx)
                    neg_scores = compute_link_scores(node_embeddings, src_emb_idx, neg_dst_emb_idx)
                    
                    # Compute loss
                    pos_label = torch.ones_like(pos_scores)
                    neg_label = torch.zeros_like(neg_scores)
                    pred_scores = torch.cat([pos_scores, neg_scores])
                    true_labels = torch.cat([pos_label, neg_label])
                    
                    loss = loss_fn(pred_scores, true_labels)
                    total_loss += loss.item()
                    num_batches += 1
                    
                    # Backward pass and optimization
                    loss.backward()
                    optimizer.step()
                    
                    # Detach memory after each batch to prevent memory leaks
                    model.detach_memory()
                
                avg_loss = total_loss / num_batches
                print(f'Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}')
                
                # Quick validation on a small sample
                if (epoch + 1) % 1 == 0:  # Check every epoch
                    model.reset_memory()
                    
                    # Use only a portion of validation set for quick evaluation
                    val_sample = val_edges[:min(500, len(val_edges))]
                    
                    # Process validation edges
                    val_sources = [e[0] for e in val_sample]
                    val_destinations = [e[1] for e in val_sample]
                    val_timestamps = [e[2] for e in val_sample]
                    val_edge_features = [e[3] for e in val_sample]
                    
                    # Generate negative samples
                    val_neg_destinations = get_negative_samples(val_sources, val_destinations, len(val_sample))
                    
                    # Convert to tensors
                    val_src_tensor = torch.tensor(val_sources)
                    val_dst_tensor = torch.tensor(val_destinations)
                    val_neg_dst_tensor = torch.tensor(val_neg_destinations)
                    val_t_tensor = torch.tensor(val_timestamps)
                    val_ef_tensor = torch.tensor(val_edge_features, dtype=torch.float)
                    
                    # Update memory with validation events
                    with torch.no_grad():
                        model.memory.update_state(val_src_tensor, val_dst_tensor, val_t_tensor, val_ef_tensor)
                        
                        # Get nodes involved in validation
                        val_nodes = torch.cat([val_src_tensor, val_dst_tensor, val_neg_dst_tensor]).unique()
                        
                        # Get embeddings
                        val_embeddings, _ = model(val_nodes, val_t_tensor, val_ef_tensor, val_src_tensor, val_dst_tensor)
                        
                        # Map nodes to embedding indices
                        val_node_id_to_emb_idx = {node_id.item(): i for i, node_id in enumerate(val_nodes)}
                        val_src_emb_idx = torch.tensor([val_node_id_to_emb_idx[src.item()] for src in val_src_tensor])
                        val_dst_emb_idx = torch.tensor([val_node_id_to_emb_idx[dst.item()] for dst in val_dst_tensor])
                        val_neg_dst_emb_idx = torch.tensor([val_node_id_to_emb_idx[neg_dst.item()] for neg_dst in val_neg_dst_tensor])
                        
                        # Compute scores
                        val_pos_scores = compute_link_scores(val_embeddings, val_src_emb_idx, val_dst_emb_idx)
                        val_neg_scores = compute_link_scores(val_embeddings, val_src_emb_idx, val_neg_dst_emb_idx)
                        
                        # Compute AUC score
                        scores = torch.cat([val_pos_scores, val_neg_scores]).cpu().numpy()
                        labels = np.concatenate([np.ones(len(val_pos_scores)), np.zeros(len(val_neg_scores))])
                        val_auc = roc_auc_score(labels, scores)
                        
                        print(f'Validation AUC: {val_auc:.4f}')
            
            print("\nTraining complete!")
            
        except Exception as e:
            print(f"Error during model creation or training: {e}")
            import traceback
            traceback.print_exc()
    else:
        print("Reddit dataset not loaded. Please run the data loading cell first.")
        
except Exception as e:
    print(f"Error preparing data: {e}")
    import traceback
    traceback.print_exc()

Preparing Reddit data for TGN training - 571927 edges total
Dataset split: 7000 train, 1500 validation, 1500 test edges
Model parameters: 54075 nodes, 10 edge features
Creating TGN model...

Starting training...
Error during model creation or training: Sequential.forward() takes 2 positional arguments but 5 were given
Dataset split: 7000 train, 1500 validation, 1500 test edges
Model parameters: 54075 nodes, 10 edge features
Creating TGN model...

Starting training...
Error during model creation or training: Sequential.forward() takes 2 positional arguments but 5 were given


  ef_tensor = torch.tensor(edge_features, dtype=torch.float)
Traceback (most recent call last):
  File "/tmp/ipykernel_158975/395830762.py", line 140, in <module>
    model.memory.update_state(src_tensor, dst_tensor, t_tensor, ef_tensor)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 110, in update_state
    self._update_memory(n_id)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 126, in _update_memory
    memory, last_update = self._get_updated_memory(n_id)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 134, in _get_updated_memory
    msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/si

## 6. Conclusion

This notebook has introduced the Temporal Graph Network (TGN) model for learning on dynamic graphs. Key points:

1. TGN uses a memory module to maintain an up-to-date representation of each node as the graph evolves.
2. The memory is updated through a message passing mechanism that incorporates temporal information.
3. The implementation requires careful handling of temporal data, ensuring events are processed in time order.
4. The model can be applied to various tasks such as link prediction and node classification on temporal graphs.

To fully leverage TGN for your research on "Implementing Decay-Based Temporal Attention for Dynamic Network Adaptation," consider:

- Adapting the message function to incorporate your decay-based attention mechanism
- Experimenting with different memory update rules
- Comparing performance with other temporal graph models like TGAT

For a complete implementation, refer to the [original TGN repository](https://github.com/twitter-research/tgn) and adapt it to your specific research requirements.

In [13]:
# Evaluate the model on the test set

try:
    # Check if we have the model and test data available
    if 'model' in locals() and 'test_edges' in locals():
        print("Evaluating model on test set...")
        
        # Reset model memory before evaluation
        model.reset_memory()
        
        # Use a subset of test data for evaluation
        test_sample = test_edges[:min(1000, len(test_edges))]
        
        # Process test edges
        test_sources = [e[0] for e in test_sample]
        test_destinations = [e[1] for e in test_sample]
        test_timestamps = [e[2] for e in test_sample]
        test_edge_features = [e[3] for e in test_sample]
        
        # Generate negative samples
        test_neg_destinations = get_negative_samples(test_sources, test_destinations, len(test_sample))
        
        # Convert to tensors
        test_src_tensor = torch.tensor(test_sources)
        test_dst_tensor = torch.tensor(test_destinations)
        test_neg_dst_tensor = torch.tensor(test_neg_destinations)
        test_t_tensor = torch.tensor(test_timestamps)
        test_ef_tensor = torch.tensor(test_edge_features, dtype=torch.float)
        
        # Evaluation
        with torch.no_grad():
            # Update memory with test events
            model.memory.update_state(test_src_tensor, test_dst_tensor, test_t_tensor, test_ef_tensor)
            
            # Get nodes involved in test set
            test_nodes = torch.cat([test_src_tensor, test_dst_tensor, test_neg_dst_tensor]).unique()
            
            # Get embeddings
            test_embeddings, _ = model(test_nodes, test_t_tensor, test_ef_tensor, test_src_tensor, test_dst_tensor)
            
            # Map nodes to embedding indices
            test_node_id_to_emb_idx = {node_id.item(): i for i, node_id in enumerate(test_nodes)}
            test_src_emb_idx = torch.tensor([test_node_id_to_emb_idx[src.item()] for src in test_src_tensor])
            test_dst_emb_idx = torch.tensor([test_node_id_to_emb_idx[dst.item()] for dst in test_dst_tensor])
            test_neg_dst_emb_idx = torch.tensor([test_node_id_to_emb_idx[neg_dst.item()] for neg_dst in test_neg_dst_tensor])
            
            # Compute scores
            test_pos_scores = compute_link_scores(test_embeddings, test_src_emb_idx, test_dst_emb_idx)
            test_neg_scores = compute_link_scores(test_embeddings, test_src_emb_idx, test_neg_dst_emb_idx)
            
            # Compute AUC score
            test_scores = torch.cat([test_pos_scores, test_neg_scores]).cpu().numpy()
            test_labels = np.concatenate([np.ones(len(test_pos_scores)), np.zeros(len(test_neg_scores))])
            test_auc = roc_auc_score(test_labels, test_scores)
            
            print(f'Test AUC: {test_auc:.4f}\n')
            
            # Calculate additional metrics
            test_pred_labels = (test_scores > 0.5).astype(int)  # Threshold at 0.5 for binary prediction
            from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
            
            accuracy = accuracy_score(test_labels, test_pred_labels)
            precision = precision_score(test_labels, test_pred_labels)
            recall = recall_score(test_labels, test_pred_labels)
            f1 = f1_score(test_labels, test_pred_labels)
            
            print(f"Additional metrics on test set:")
            print(f"Accuracy: {accuracy:.4f}")
            print(f"Precision: {precision:.4f}")
            print(f"Recall: {recall:.4f}")
            print(f"F1 Score: {f1:.4f}")
            
            # Save the results to a variable for later reference
            test_results = {
                'auc': test_auc,
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }
            
    else:
        print("Model or test data not available. Please run the training cell first.")
        
except Exception as e:
    print(f"Error during evaluation: {e}")
    import traceback
    traceback.print_exc()

Evaluating model on test set...
Error during evaluation: Sequential.forward() takes 2 positional arguments but 5 were given


Traceback (most recent call last):
  File "/tmp/ipykernel_158975/2002860955.py", line 33, in <module>
    model.memory.update_state(test_src_tensor, test_dst_tensor, test_t_tensor, test_ef_tensor)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 110, in update_state
    self._update_memory(n_id)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 126, in _update_memory
    memory, last_update = self._get_updated_memory(n_id)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 134, in _get_updated_memory
    msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn

In [14]:
# Visualize training results - run this after training
import matplotlib.pyplot as plt
import networkx as nx
from sklearn.manifold import TSNE

try:
    # Check if we have the model and data available
    if 'model' in locals() and 'reddit_df' in locals() and 'node_mapping' in locals():
        print("Creating visualizations of the learned embeddings...")
        
        # 1. Get embeddings for a subset of nodes
        subset_size = min(1000, len(node_mapping))  # Limit nodes for visualization
        node_ids = list(range(subset_size))  # Take the first subset_size nodes
        
        # Reset model memory before generating final embeddings
        model.reset_memory()
        
        # Generate some sample edges for memory updates (necessary for TGN)
        sample_edges = test_edges[:min(1000, len(test_edges))]
        src_nodes = [e[0] for e in sample_edges]
        dst_nodes = [e[1] for e in sample_edges]
        timestamps = [e[2] for e in sample_edges]
        edge_features = [e[3] for e in sample_edges]
        
        # Convert to tensors
        src_tensor = torch.tensor(src_nodes)
        dst_tensor = torch.tensor(dst_nodes)
        t_tensor = torch.tensor(timestamps)
        ef_tensor = torch.tensor(edge_features, dtype=torch.float)
        
        # Update memory with these edges
        with torch.no_grad():
            model.memory.update_state(src_tensor, dst_tensor, t_tensor, ef_tensor)
            
            # Get embeddings for visualization
            node_ids_tensor = torch.tensor(node_ids)
            node_embeddings, _ = model(node_ids_tensor, t_tensor[:1], ef_tensor[:1], src_tensor[:1], dst_tensor[:1])
            
            # Convert to numpy for visualization
            embeddings_np = node_embeddings.detach().cpu().numpy()
            
            # Map node IDs back to subreddit names for visualization
            inv_node_mapping = {v: k for k, v in node_mapping.items()}
            node_labels = [inv_node_mapping.get(nid, f"Node {nid}") for nid in node_ids]
            
            # 2. Dimensionality reduction with t-SNE for visualization
            print("Performing t-SNE to visualize embeddings in 2D...")
            tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings_np)-1))
            embeddings_2d = tsne.fit_transform(embeddings_np)
            
            # 3. Plot the embeddings
            plt.figure(figsize=(12, 10))
            plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.6)
            
            # Add labels for some of the nodes (limit to avoid clutter)
            num_labels = min(20, len(node_labels))  # Show only 20 labels
            for i in range(num_labels):
                plt.annotate(node_labels[i], 
                             (embeddings_2d[i, 0], embeddings_2d[i, 1]),
                             fontsize=9)
                
            plt.title("t-SNE Visualization of Subreddit Embeddings")
            plt.xlabel("t-SNE Dimension 1")
            plt.ylabel("t-SNE Dimension 2")
            plt.show()
            
            # 4. Create a small graph visualization of top subreddits
            print("\nCreating graph visualization of top subreddits and their connections...")
            G = nx.DiGraph()
            
            # Get top 20 most connected subreddits
            subreddit_counts = reddit_df['SOURCE_SUBREDDIT'].value_counts()
            top_subreddits = subreddit_counts.head(20).index.tolist()
            
            # Add nodes
            for sr in top_subreddits:
                G.add_node(sr)
            
            # Add edges between top subreddits
            for _, row in reddit_df.iterrows():
                if row['SOURCE_SUBREDDIT'] in top_subreddits and row['TARGET_SUBREDDIT'] in top_subreddits:
                    G.add_edge(row['SOURCE_SUBREDDIT'], row['TARGET_SUBREDDIT'], 
                               weight=float(row['LINK_SENTIMENT']))  # Use sentiment as edge weight
            
            # Draw the graph
            plt.figure(figsize=(12, 10))
            pos = nx.spring_layout(G, seed=42)  # Position nodes using force-directed layout
            edge_weights = [G[u][v]['weight']*2 for u, v in G.edges()]  # Adjust edge width by weight
            
            # Draw the graph with node size based on degree
            node_size = [G.degree(node) * 50 for node in G.nodes()]
            nx.draw_networkx_nodes(G, pos, node_size=node_size, node_color='skyblue', alpha=0.8)
            nx.draw_networkx_edges(G, pos, width=edge_weights, alpha=0.5, edge_color='gray')
            nx.draw_networkx_labels(G, pos, font_size=10)
            
            plt.title("Graph of Top 20 Subreddits and Their Connections")
            plt.axis('off')  # Turn off axis
            plt.show()
    else:
        print("Model or data not available. Please run the training cell first.")
        
except Exception as e:
    print(f"Error in visualization: {e}")
    import traceback
    traceback.print_exc()

Creating visualizations of the learned embeddings...
Error in visualization: Sequential.forward() takes 2 positional arguments but 5 were given
Error in visualization: Sequential.forward() takes 2 positional arguments but 5 were given


Traceback (most recent call last):
  File "/tmp/ipykernel_158975/1249954270.py", line 33, in <module>
    model.memory.update_state(src_tensor, dst_tensor, t_tensor, ef_tensor)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 110, in update_state
    self._update_memory(n_id)
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 126, in _update_memory
    memory, last_update = self._get_updated_memory(n_id)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 134, in _get_updated_memory
    msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/strix/miniforge3/envs/commdec/lib/python3.12/site-packages/torch_geometric/nn/models/tgn.py", line 177, in _