## 4. Training and Evaluation

Train the model and evaluate performance on link prediction and community detection tasks.

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001, patience=10):
    """Train the TGN model with early stopping"""
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    criterion = nn.BCELoss()
    
    best_val_auc = 0.0
    patience_counter = 0
    train_losses = []
    val_aucs = []
    
    print(f"Starting training for {num_epochs} epochs...")
    print(f"Device: {device}")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        all_preds = []
        all_labels = []
        
        start_time = time.time()
        
        for batch_idx, batch in enumerate(train_loader):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(
                batch['src_idx'], batch['dst_idx'], batch['timestamp'],
                batch['src_features'], batch['dst_features'], batch['edge_features']
            )
            
            # Compute loss
            loss = criterion(outputs['link_prob'], batch['label'])
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            epoch_loss += loss.item()
            all_preds.extend(outputs['link_prob'].cpu().detach().numpy())
            all_labels.extend(batch['label'].cpu().numpy())
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        # Calculate training metrics
        train_auc = roc_auc_score(all_labels, all_preds)
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        # Validation phase
        model.eval()
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                
                outputs = model(
                    batch['src_idx'], batch['dst_idx'], batch['timestamp'],
                    batch['src_features'], batch['dst_features'], batch['edge_features']
                )
                
                val_preds.extend(outputs['link_prob'].cpu().numpy())
                val_labels.extend(batch['label'].cpu().numpy())
        
        val_auc = roc_auc_score(val_labels, val_preds)
        val_ap = average_precision_score(val_labels, val_preds)
        val_aucs.append(val_auc)
        
        # Learning rate scheduling
        scheduler.step(val_auc)
        
        epoch_time = time.time() - start_time
        
        print(f"Epoch {epoch+1}/{num_epochs} ({epoch_time:.1f}s)")
        print(f"  Train Loss: {avg_loss:.4f}, Train AUC: {train_auc:.4f}")
        print(f"  Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}")
        print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Early stopping
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"  New best model saved (AUC: {best_val_auc:.4f})")
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
        
        print("-" * 60)
    
    # Load best model
    model.load_state_dict(torch.load('best_model.pth'))
    
    return {
        'train_losses': train_losses,
        'val_aucs': val_aucs,
        'best_val_auc': best_val_auc
    }

def evaluate_model(model, test_loader, verbose=True):
    """Evaluate model on test set"""
    model.eval()
    
    all_preds = []
    all_labels = []
    all_src_embeddings = []
    all_dst_embeddings = []
    all_src_communities = []
    all_dst_communities = []
    
    with torch.no_grad():
        for batch in test_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            outputs = model(
                batch['src_idx'], batch['dst_idx'], batch['timestamp'],
                batch['src_features'], batch['dst_features'], batch['edge_features']
            )
            
            all_preds.extend(outputs['link_prob'].cpu().numpy())
            all_labels.extend(batch['label'].cpu().numpy())
            all_src_embeddings.append(outputs['src_embeddings'].cpu().numpy())
            all_dst_embeddings.append(outputs['dst_embeddings'].cpu().numpy())
            all_src_communities.append(outputs['src_community'].cpu().numpy())
            all_dst_communities.append(outputs['dst_community'].cpu().numpy())
    
    # Calculate metrics
    test_auc = roc_auc_score(all_labels, all_preds)
    test_ap = average_precision_score(all_labels, all_preds)
    
    # Threshold predictions
    binary_preds = (np.array(all_preds) > 0.5).astype(int)
    test_acc = accuracy_score(all_labels, binary_preds)
    
    if verbose:
        print("\nTest Results:")
        print(f"AUC: {test_auc:.4f}")
        print(f"AP: {test_ap:.4f}")
        print(f"Accuracy: {test_acc:.4f}")
    
    return {
        'auc': test_auc,
        'ap': test_ap,
        'accuracy': test_acc,
        'predictions': all_preds,
        'labels': all_labels,
        'src_embeddings': np.vstack(all_src_embeddings),
        'dst_embeddings': np.vstack(all_dst_embeddings),
        'src_communities': np.vstack(all_src_communities),
        'dst_communities': np.vstack(all_dst_communities)
    }

# Train the model if everything is set up
if dataset is not None and 'model' in locals():
    print("Starting model training...")
    
    training_results = train_model(
        model, train_loader, val_loader, 
        num_epochs=30, lr=0.001, patience=8
    )
    
    print("\nTraining completed!")
    print(f"Best validation AUC: {training_results['best_val_auc']:.4f}")
    
    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_results = evaluate_model(model, test_loader)
    
    print("\nTraining and evaluation completed successfully!")

## 3. Enhanced Decay Factor TGN Model

Implement the Temporal Graph Network with configurable decay-based temporal attention.

In [None]:
class EnhancedDecayTGN(nn.Module):
    """
    Enhanced Temporal Graph Network with decay-based temporal attention
    Optimized for CUDA and includes community detection capabilities
    """
    
    def __init__(self, num_nodes, node_feat_dim, edge_feat_dim, memory_dim=128, 
                 time_dim=32, embedding_dim=128, decay_factor=0.1, 
                 n_heads=4, dropout=0.1, device='cuda'):
        super(EnhancedDecayTGN, self).__init__()
        
        self.num_nodes = num_nodes
        self.node_feat_dim = node_feat_dim
        self.edge_feat_dim = edge_feat_dim
        self.memory_dim = memory_dim
        self.time_dim = time_dim
        self.embedding_dim = embedding_dim
        self.decay_factor = decay_factor
        self.n_heads = n_heads
        self.device = device
        
        # Time encoder
        self.time_encoder = nn.Sequential(
            nn.Linear(1, time_dim),
            nn.ReLU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # Node and edge embedding layers
        self.node_embedding = nn.Linear(node_feat_dim, memory_dim)
        self.edge_embedding = nn.Linear(edge_feat_dim, memory_dim)
        
        # Memory module
        self.memory = nn.Parameter(torch.zeros(num_nodes, memory_dim), requires_grad=False)
        self.last_update = nn.Parameter(torch.zeros(num_nodes), requires_grad=False)
        self.memory_updater = nn.GRUCell(memory_dim, memory_dim)
        
        # Decay-based temporal attention
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=memory_dim,
            num_heads=n_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Message function
        self.message_function = nn.Sequential(
            nn.Linear(memory_dim * 2 + time_dim, memory_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(memory_dim, memory_dim)
        )
        
        # Embedding projection
        self.embedding_projection = nn.Sequential(
            nn.Linear(memory_dim, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim)
        )
        
        # Link predictor
        self.link_predictor = nn.Sequential(
            nn.Linear(embedding_dim * 2, embedding_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(embedding_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Community detection head
        self.community_head = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embedding_dim // 2, 32)  # 32 potential communities
        )
        
        self.to(device)
    
    def apply_decay(self, memory, current_time, last_update_time):
        """Apply exponential decay to memory based on time difference"""
        time_diff = current_time - last_update_time
        time_diff = torch.clamp(time_diff, min=0.0)
        decay_weight = torch.exp(-self.decay_factor * time_diff).unsqueeze(-1)
        return memory * decay_weight
    
    def update_memory(self, node_ids, timestamps, src_features, dst_features, edge_features):
        """Update node memory with temporal decay"""
        # Get current memory
        src_memory = self.memory[node_ids[:, 0]]
        dst_memory = self.memory[node_ids[:, 1]]
        
        # Apply decay
        src_memory = self.apply_decay(src_memory, timestamps, self.last_update[node_ids[:, 0]])
        dst_memory = self.apply_decay(dst_memory, timestamps, self.last_update[node_ids[:, 1]])
        
        # Encode time
        time_features = self.time_encoder(timestamps.unsqueeze(-1))
        
        # Create messages
        src_messages = self.message_function(
            torch.cat([src_features, dst_features, time_features], dim=-1)
        )
        dst_messages = self.message_function(
            torch.cat([dst_features, src_features, time_features], dim=-1)
        )
        
        # Update memory using GRU
        new_src_memory = self.memory_updater(src_messages, src_memory)
        new_dst_memory = self.memory_updater(dst_messages, dst_memory)
        
        # Update memory in-place
        self.memory[node_ids[:, 0]] = new_src_memory.detach()
        self.memory[node_ids[:, 1]] = new_dst_memory.detach()
        self.last_update[node_ids[:, 0]] = timestamps.detach()
        self.last_update[node_ids[:, 1]] = timestamps.detach()
        
        return new_src_memory, new_dst_memory
    
    def get_node_embeddings(self, node_ids, timestamps, node_features):
        """Get node embeddings with temporal decay"""
        # Get decayed memory
        memory = self.memory[node_ids]
        memory = self.apply_decay(memory, timestamps, self.last_update[node_ids])
        
        # Apply attention (self-attention for simplification)
        memory_attended, _ = self.temporal_attention(memory.unsqueeze(1), memory.unsqueeze(1), memory.unsqueeze(1))
        memory_attended = memory_attended.squeeze(1)
        
        # Project to final embedding
        embeddings = self.embedding_projection(memory_attended)
        
        return embeddings
    
    def forward(self, src_idx, dst_idx, timestamps, src_features, dst_features, edge_features):
        """Forward pass for link prediction"""
        batch_size = src_idx.size(0)
        node_ids = torch.stack([src_idx, dst_idx], dim=1)
        
        # Update memory
        src_memory, dst_memory = self.update_memory(
            node_ids, timestamps, src_features, dst_features, edge_features
        )
        
        # Get node embeddings
        src_embeddings = self.get_node_embeddings(src_idx, timestamps, src_features)
        dst_embeddings = self.get_node_embeddings(dst_idx, timestamps, dst_features)
        
        # Link prediction
        link_features = torch.cat([src_embeddings, dst_embeddings], dim=-1)
        link_prob = self.link_predictor(link_features)
        
        # Community detection (optional)
        src_community = self.community_head(src_embeddings)
        dst_community = self.community_head(dst_embeddings)
        
        return {
            'link_prob': link_prob.squeeze(-1),
            'src_embeddings': src_embeddings,
            'dst_embeddings': dst_embeddings,
            'src_community': src_community,
            'dst_community': dst_community
        }
    
    def reset_memory(self):
        """Reset memory state"""
        self.memory.data.zero_()
        self.last_update.data.zero_()
    
    def get_all_embeddings(self, timestamps=None):
        """Get embeddings for all nodes"""
        if timestamps is None:
            timestamps = torch.zeros(self.num_nodes, device=self.device)
        
        node_ids = torch.arange(self.num_nodes, device=self.device)
        dummy_features = torch.zeros(self.num_nodes, self.node_feat_dim, device=self.device)
        
        return self.get_node_embeddings(node_ids, timestamps, dummy_features)

# Initialize the model if dataset is available
if dataset is not None:
    model = EnhancedDecayTGN(
        num_nodes=dataset['num_nodes'],
        node_feat_dim=dataset['node_feat_dim'],
        edge_feat_dim=dataset['edge_feat_dim'],
        memory_dim=128,
        time_dim=32,
        embedding_dim=128,
        decay_factor=0.1,  # Configurable decay factor
        n_heads=4,
        dropout=0.1,
        device=device
    )
    
    print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters")
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Count trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Comprehensive Link Prediction and Community Detection with Decay Factor TGN

This notebook implements a complete system for link prediction and community detection using Temporal Graph Networks (TGN) with decay-based temporal attention. The implementation includes:

1. **Decay Factor TGN Model**: Enhanced temporal attention with configurable decay factors
2. **Link Prediction**: Predicting future connections in temporal graphs
3. **Community Detection**: Identifying and tracking community evolution over time
4. **Comprehensive Visualizations**: Graph structures, temporal evolution, and community dynamics
5. **CUDA Optimization**: GPU acceleration for all computations

Dataset: Reddit Hyperlinks - representing connections between subreddit communities

In [None]:
# Import all necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.cluster import SpectralClustering, KMeans
from sklearn.manifold import TSNE
import community.community_louvain as community_louvain

import os
import sys
import time
import warnings
from collections import defaultdict, Counter
import pickle
from datetime import datetime

# Add src directory to Python path
sys.path.append('../src')
sys.path.append('src')

# Import custom modules
from decay_tgn import DecayTemporalGraphNetwork, create_decay_tgn_model, analyze_decay_effects
from enhanced_tgn import TemporalGraphNetwork

# Set up device (prioritize CUDA)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Configure matplotlib and seaborn for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
warnings.filterwarnings('ignore')

print("All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
## 1. Data Loading and Preprocessing

Load and preprocess the Reddit Hyperlinks dataset for temporal graph analysis.

In [None]:
def load_reddit_dataset(data_path='data/', sample_size=None, verbose=True):
    """
    Load and preprocess the Reddit dataset for TGN training
    
    Args:
        data_path: Path to data directory
        sample_size: If specified, sample this many edges for faster processing
        verbose: Print detailed information
    
    Returns:
        Dictionary containing processed data splits and metadata
    """
    try:
        # Try multiple file formats
        reddit_files = [
            'reddit_TGAT.csv',
            'soc-redditHyperlinks-title.tsv',
            'soc-redditHyperlinks-body.tsv'
        ]
        
        df = None
        for filename in reddit_files:
            filepath = os.path.join(data_path, filename)
            if os.path.exists(filepath):
                if verbose:
                    print(f"Loading {filename}...")
                
                if filename.endswith('.csv'):
                    df = pd.read_csv(filepath)
                else:
                    df = pd.read_csv(filepath, sep='\t')
                break
        
        if df is None:
            raise FileNotFoundError("No Reddit dataset found in data directory")
        
        if verbose:
            print(f"Loaded dataset with {len(df)} edges")
            print(f"Columns: {list(df.columns)}")
        
        # Standardize column names
        column_mapping = {
            'SOURCE_SUBREDDIT': 'src',
            'TARGET_SUBREDDIT': 'dst',
            'TIMESTAMP': 'timestamp',
            'LINK_SENTIMENT': 'sentiment',
            'PROPERTIES': 'properties'
        }
        
        # Rename columns if they exist
        for old_name, new_name in column_mapping.items():
            if old_name in df.columns:
                df = df.rename(columns={old_name: new_name})
        
        # If standard names not found, use first two columns as src/dst
        if 'src' not in df.columns:
            df = df.rename(columns={df.columns[0]: 'src', df.columns[1]: 'dst'})
        
        # Sample data if requested for faster processing
        if sample_size and len(df) > sample_size:
            df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
            if verbose:
                print(f"Sampled {sample_size} edges for processing")
        
        # Process timestamps
        if 'timestamp' in df.columns:
            # Convert to datetime and then to numerical
            df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
            df['timestamp_numeric'] = df['timestamp'].astype(int) / 10**9
        else:
            # Create artificial timestamps
            df['timestamp_numeric'] = np.arange(len(df), dtype=float)
            if verbose:
                print("Created artificial timestamps")
        
        # Remove rows with invalid timestamps
        df = df.dropna(subset=['timestamp_numeric'])
        
        # Sort by timestamp
        df = df.sort_values('timestamp_numeric').reset_index(drop=True)
        
        # Create node mapping
        all_nodes = pd.concat([df['src'], df['dst']]).unique()
        node_to_idx = {node: idx for idx, node in enumerate(all_nodes)}
        idx_to_node = {idx: node for node, idx in node_to_idx.items()}
        
        # Map to indices
        df['src_idx'] = df['src'].map(node_to_idx)
        df['dst_idx'] = df['dst'].map(node_to_idx)
        
        num_nodes = len(all_nodes)
        
        # Extract edge features
        edge_feat_dim = 50
        if 'sentiment' in df.columns:
            # Use sentiment as one feature and pad with random features
            sentiment_values = pd.to_numeric(df['sentiment'], errors='coerce').fillna(0)
            edge_features = np.random.randn(len(df), edge_feat_dim)
            edge_features[:, 0] = sentiment_values  # First feature is sentiment
        else:
            # Random edge features
            edge_features = np.random.randn(len(df), edge_feat_dim)
        
        # Create node features (random initialization)
        node_feat_dim = 100
        node_features = np.random.randn(num_nodes, node_feat_dim)
        
        # Split data temporally (70% train, 15% val, 15% test)
        train_ratio, val_ratio = 0.7, 0.15
        train_end = int(len(df) * train_ratio)
        val_end = int(len(df) * (train_ratio + val_ratio))
        
        train_df = df.iloc[:train_end].copy()
        val_df = df.iloc[train_end:val_end].copy()
        test_df = df.iloc[val_end:].copy()
        
        # Create edge features for each split
        train_edge_features = edge_features[:train_end]
        val_edge_features = edge_features[train_end:val_end]
        test_edge_features = edge_features[val_end:]
        
        if verbose:
            print(f"\nDataset Statistics:")
            print(f"Total nodes: {num_nodes:,}")
            print(f"Total edges: {len(df):,}")
            print(f"Train edges: {len(train_df):,}")
            print(f"Validation edges: {len(val_df):,}")
            print(f"Test edges: {len(test_df):,}")
            print(f"Node feature dim: {node_feat_dim}")
            print(f"Edge feature dim: {edge_feat_dim}")
            print(f"Time range: {df['timestamp_numeric'].min():.0f} - {df['timestamp_numeric'].max():.0f}")
        
        return {
            'train_df': train_df,
            'val_df': val_df,
            'test_df': test_df,
            'train_edge_features': train_edge_features,
            'val_edge_features': val_edge_features,
            'test_edge_features': test_edge_features,
            'node_features': node_features,
            'num_nodes': num_nodes,
            'node_feat_dim': node_feat_dim,
            'edge_feat_dim': edge_feat_dim,
            'node_to_idx': node_to_idx,
            'idx_to_node': idx_to_node,
            'full_df': df
        }
        
    except Exception as e:
        print(f"Error loading dataset: {e}")
        import traceback
        traceback.print_exc()
        return None

# Load the dataset
print("Loading Reddit dataset...")
dataset = load_reddit_dataset(sample_size=50000)  # Sample for faster processing

if dataset is None:
    print("Failed to load dataset. Please check data files.")
else:
    print("Dataset loaded successfully!")

In [None]:
## 2. Temporal Graph Dataset Class

Create a PyTorch Dataset class for efficient batch processing of temporal graph data.

In [None]:
class TemporalGraphDataset(Dataset):
    """Dataset class for temporal graph data"""
    
    def __init__(self, df, edge_features, node_features, negative_sampling_ratio=1.0):
        """
        Args:
            df: DataFrame with columns ['src_idx', 'dst_idx', 'timestamp_numeric']
            edge_features: Edge features array
            node_features: Node features array
            negative_sampling_ratio: Ratio of negative to positive samples
        """
        self.df = df.reset_index(drop=True)
        self.edge_features = edge_features
        self.node_features = node_features
        self.negative_sampling_ratio = negative_sampling_ratio
        self.num_nodes = len(node_features)
        
        # Create positive samples
        self.positive_samples = []
        for idx, row in self.df.iterrows():
            self.positive_samples.append({
                'src': row['src_idx'],
                'dst': row['dst_idx'],
                'timestamp': row['timestamp_numeric'],
                'edge_features': self.edge_features[idx],
                'label': 1
            })
        
        # Generate negative samples
        self.negative_samples = self._generate_negative_samples()
        
        # Combine positive and negative samples
        self.samples = self.positive_samples + self.negative_samples
        
        print(f"Dataset created with {len(self.positive_samples)} positive and {len(self.negative_samples)} negative samples")
    
    def _generate_negative_samples(self):
        """Generate negative samples by random node sampling"""
        negative_samples = []
        num_negatives = int(len(self.positive_samples) * self.negative_sampling_ratio)
        
        # Create set of existing edges for each timestamp
        existing_edges = set()
        for _, row in self.df.iterrows():
            existing_edges.add((row['src_idx'], row['dst_idx']))
        
        np.random.seed(42)  # For reproducibility
        for i in range(num_negatives):
            # Sample a random positive sample to get timestamp
            pos_idx = np.random.randint(0, len(self.positive_samples))
            timestamp = self.positive_samples[pos_idx]['timestamp']
            
            # Sample random nodes that don't form an existing edge
            max_attempts = 100
            for _ in range(max_attempts):
                src = np.random.randint(0, self.num_nodes)
                dst = np.random.randint(0, self.num_nodes)
                
                if src != dst and (src, dst) not in existing_edges:
                    negative_samples.append({
                        'src': src,
                        'dst': dst,
                        'timestamp': timestamp,
                        'edge_features': np.random.randn(self.edge_features.shape[1]),
                        'label': 0
                    })
                    break
        
        return negative_samples
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        return {
            'src_idx': torch.tensor(sample['src'], dtype=torch.long),
            'dst_idx': torch.tensor(sample['dst'], dtype=torch.long),
            'timestamp': torch.tensor(sample['timestamp'], dtype=torch.float),
            'edge_features': torch.tensor(sample['edge_features'], dtype=torch.float),
            'src_features': torch.tensor(self.node_features[sample['src']], dtype=torch.float),
            'dst_features': torch.tensor(self.node_features[sample['dst']], dtype=torch.float),
            'label': torch.tensor(sample['label'], dtype=torch.float)
        }

# Create datasets if data loading was successful
if dataset is not None:
    train_dataset = TemporalGraphDataset(
        dataset['train_df'], 
        dataset['train_edge_features'], 
        dataset['node_features'],
        negative_sampling_ratio=1.0
    )
    
    val_dataset = TemporalGraphDataset(
        dataset['val_df'], 
        dataset['val_edge_features'], 
        dataset['node_features'],
        negative_sampling_ratio=1.0
    )
    
    test_dataset = TemporalGraphDataset(
        dataset['test_df'], 
        dataset['test_edge_features'], 
        dataset['node_features'],
        negative_sampling_ratio=1.0
    )
    
    # Create data loaders
    batch_size = 512 if torch.cuda.is_available() else 128
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    print(f"\nData loaders created with batch size: {batch_size}")
    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")