# LogGraph-SSL Complete Training Notebook 🚀

## High-Performance HDFS Anomaly Detection with Self-Supervised Learning

This notebook provides a complete implementation of LogGraph-SSL for anomaly detection in HDFS logs using Graph Neural Networks with Self-Supervised Learning tasks.

### Features:
- 🔥 **RTX 4090 Optimized**: Full GPU acceleration with CUDA 11.8
- 🧠 **Multiple GNN Architectures**: GCN, GAT, GraphSAGE encoders
- 🎯 **SSL Tasks**: Node masking, edge prediction, contrastive learning
- 📊 **Real-time Monitoring**: Training metrics and GPU utilization
- 🛡️ **Robust Architecture**: Memory efficient with gradient checkpointing

### Training Pipeline:
1. Environment setup and validation
2. Data preprocessing and graph construction
3. Model initialization and configuration
4. Self-supervised pre-training
5. Supervised fine-tuning for anomaly detection
6. Comprehensive evaluation and analysis

In [None]:
# Complete Environment Setup and Imports
import os
import sys
import time
import warnings
import logging
from datetime import datetime
from pathlib import Path

# Core data science libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
from sklearn.preprocessing import StandardScaler
import pickle
import json

# PyTorch and GPU libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

# PyTorch Geometric for GNNs
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_mean_pool, global_max_pool
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_networkx, add_self_loops

# Additional utilities
import networkx as nx
from collections import defaultdict, Counter
import re
from tqdm.auto import tqdm
import psutil

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔥 Using device: {device}")

if torch.cuda.is_available():
    print(f"🎯 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"🔧 CUDA Version: {torch.version.cuda}")
    print(f"⚡ PyTorch Version: {torch.__version__}")
    print(f"📊 PyTorch Geometric Version: {torch_geometric.__version__}")
else:
    print("⚠️ CUDA not available. Using CPU.")

print("✅ Environment setup complete!")

In [None]:
# Data Processing and Graph Building Utilities
class HDFSDataProcessor:
    """Comprehensive HDFS log data processor for graph construction"""
    
    def __init__(self, vocab_size=5000):
        self.vocab_size = vocab_size
        self.token_to_id = {}
        self.id_to_token = {}
        self.template_to_id = {}
        self.sequences = []
        self.labels = []
        self.graphs = []
        
    def build_vocabulary(self, log_data):
        """Build vocabulary from log templates"""
        print("🔧 Building vocabulary...")
        all_tokens = []
        
        for line in log_data:
            tokens = line.strip().split()
            all_tokens.extend(tokens)
        
        # Count token frequencies
        token_counts = Counter(all_tokens)
        
        # Create vocabulary with most frequent tokens
        vocab_tokens = [token for token, count in token_counts.most_common(self.vocab_size - 2)]
        
        # Add special tokens
        self.token_to_id = {'<PAD>': 0, '<UNK>': 1}
        self.id_to_token = {0: '<PAD>', 1: '<UNK>'}
        
        for i, token in enumerate(vocab_tokens):
            self.token_to_id[token] = i + 2
            self.id_to_token[i + 2] = token
            
        print(f"✅ Vocabulary built with {len(self.token_to_id)} tokens")
        return self.token_to_id
    
    def tokenize_sequence(self, sequence):
        """Convert log sequence to token IDs"""
        tokens = sequence.strip().split()
        token_ids = []
        
        for token in tokens:
            if token in self.token_to_id:
                token_ids.append(self.token_to_id[token])
            else:
                token_ids.append(self.token_to_id['<UNK>'])
                
        return token_ids
    
    def build_sequence_graph(self, token_ids, max_distance=3):
        """Build graph from token sequence with positional and semantic edges"""
        num_nodes = len(token_ids)
        if num_nodes == 0:
            return None
            
        # Node features (token embeddings will be learned)
        x = torch.tensor(token_ids, dtype=torch.long).unsqueeze(1)
        
        # Edge construction
        edge_indices = []
        edge_attrs = []
        
        # Sequential edges (next token relationships)
        for i in range(num_nodes - 1):
            edge_indices.extend([[i, i+1], [i+1, i]])  # bidirectional
            edge_attrs.extend([1, 1])  # edge type 1: sequential
        
        # Positional edges (skip connections)
        for i in range(num_nodes):
            for j in range(i + 2, min(i + max_distance + 1, num_nodes)):
                edge_indices.extend([[i, j], [j, i]])
                edge_attrs.extend([2, 2])  # edge type 2: positional
        
        # Token similarity edges (same tokens)
        for i in range(num_nodes):
            for j in range(i + 1, num_nodes):
                if token_ids[i] == token_ids[j] and token_ids[i] != 0:  # not padding
                    edge_indices.extend([[i, j], [j, i]])
                    edge_attrs.extend([3, 3])  # edge type 3: token similarity
        
        if not edge_indices:
            # Create self-loops if no edges
            edge_indices = [[i, i] for i in range(num_nodes)]
            edge_attrs = [0] * num_nodes  # edge type 0: self-loop
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.long)
        
        # Create PyTorch Geometric data object
        graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        graph.num_nodes = num_nodes
        
        return graph
    
    def load_hdfs_data(self, log_file, label_file=None):
        """Load HDFS log data and labels"""
        print(f"📂 Loading data from {log_file}")
        
        # Load log sequences
        with open(log_file, 'r') as f:
            log_data = f.readlines()
        
        # Load labels if provided
        labels = None
        if label_file and os.path.exists(label_file):
            with open(label_file, 'r') as f:
                labels = [int(line.strip()) for line in f.readlines()]
            print(f"📊 Loaded {len(labels)} labels")
        
        # Build vocabulary
        self.build_vocabulary(log_data)
        
        # Process sequences and build graphs
        print("🔄 Processing sequences and building graphs...")
        self.sequences = []
        self.graphs = []
        
        for i, line in enumerate(tqdm(log_data, desc="Building graphs")):
            token_ids = self.tokenize_sequence(line)
            if len(token_ids) > 0:  # Skip empty sequences
                graph = self.build_sequence_graph(token_ids)
                if graph is not None:
                    self.sequences.append(token_ids)
                    self.graphs.append(graph)
        
        # Set labels
        if labels:
            self.labels = labels[:len(self.graphs)]  # Match graph count
        else:
            self.labels = [0] * len(self.graphs)  # Default to normal
        
        print(f"✅ Processed {len(self.graphs)} sequences into graphs")
        print(f"📈 Average nodes per graph: {np.mean([g.num_nodes for g in self.graphs]):.1f}")
        
        return self.graphs, self.labels

# Initialize data processor
data_processor = HDFSDataProcessor(vocab_size=5000)
print("✅ Data processor initialized")

In [None]:
# Complete GNN Model Architectures

class GCNEncoder(nn.Module):
    """Graph Convolutional Network encoder"""
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=3, dropout=0.1):
        super(GCNEncoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Token embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # GCN layers
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(embedding_dim, hidden_dim))
        
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
        
        self.dropout = nn.Dropout(dropout)
        self.norm_layers = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        
    def forward(self, x, edge_index, batch=None):
        # Get embeddings for tokens
        if x.dim() == 2 and x.size(1) == 1:
            x = x.squeeze(1)  # Remove extra dimension
        
        h = self.embedding(x)  # Shape: [num_nodes, embedding_dim]
        
        # Apply GCN layers
        for i, conv in enumerate(self.convs):
            h = conv(h, edge_index)
            h = self.norm_layers[i](h)
            h = F.relu(h)
            h = self.dropout(h)
        
        return h

class GATEncoder(nn.Module):
    """Graph Attention Network encoder"""
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=3, heads=4, dropout=0.1):
        super(GATEncoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.heads = heads
        
        # Token embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # GAT layers
        self.convs = nn.ModuleList()
        self.convs.append(GATConv(embedding_dim, hidden_dim // heads, heads=heads, dropout=dropout))
        
        for _ in range(num_layers - 1):
            self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads, dropout=dropout))
        
        self.dropout = nn.Dropout(dropout)
        self.norm_layers = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        
    def forward(self, x, edge_index, batch=None):
        # Get embeddings for tokens
        if x.dim() == 2 and x.size(1) == 1:
            x = x.squeeze(1)
        
        h = self.embedding(x)
        
        # Apply GAT layers
        for i, conv in enumerate(self.convs):
            h = conv(h, edge_index)
            h = self.norm_layers[i](h)
            h = F.relu(h)
            h = self.dropout(h)
        
        return h

class GraphSAGEEncoder(nn.Module):
    """GraphSAGE encoder"""
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=3, dropout=0.1):
        super(GraphSAGEEncoder, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Token embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        # SAGE layers
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(embedding_dim, hidden_dim))
        
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_dim, hidden_dim))
        
        self.dropout = nn.Dropout(dropout)
        self.norm_layers = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        
    def forward(self, x, edge_index, batch=None):
        # Get embeddings for tokens
        if x.dim() == 2 and x.size(1) == 1:
            x = x.squeeze(1)
        
        h = self.embedding(x)
        
        # Apply SAGE layers
        for i, conv in enumerate(self.convs):
            h = conv(h, edge_index)
            h = self.norm_layers[i](h)
            h = F.relu(h)
            h = self.dropout(h)
        
        return h

class AnomalyDetectionHead(nn.Module):
    """Anomaly detection classification head"""
    def __init__(self, hidden_dim, num_classes=2, dropout=0.3):
        super(AnomalyDetectionHead, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes
        
        # Graph-level representation layers
        self.global_pool = global_mean_pool
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 4, num_classes)
        )
        
    def forward(self, node_embeddings, batch=None):
        # Global pooling to get graph-level representation
        if batch is None:
            # Single graph case
            graph_embedding = torch.mean(node_embeddings, dim=0, keepdim=True)
        else:
            # Batch case
            graph_embedding = self.global_pool(node_embeddings, batch)
        
        # Classification
        logits = self.classifier(graph_embedding)
        return logits

class SSLTaskManager(nn.Module):
    """Self-supervised learning task manager"""
    def __init__(self, hidden_dim, vocab_size):
        super(SSLTaskManager, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        
        # Node masking task (token prediction)
        self.token_predictor = nn.Linear(hidden_dim, vocab_size)
        
        # Edge prediction task
        self.edge_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # Contrastive learning projection head
        self.projection_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2)
        )
        
    def node_masking_loss(self, node_embeddings, original_tokens, masked_indices):
        """Compute node masking loss"""
        if len(masked_indices) == 0:
            return torch.tensor(0.0, device=node_embeddings.device)
        
        # Predict masked tokens
        masked_embeddings = node_embeddings[masked_indices]
        token_logits = self.token_predictor(masked_embeddings)
        
        # Get true tokens
        true_tokens = original_tokens[masked_indices]
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(token_logits, true_tokens)
        return loss
    
    def edge_prediction_loss(self, node_embeddings, edge_index, neg_edge_index):
        """Compute edge prediction loss"""
        # Positive edges
        pos_edge_embeddings = torch.cat([
            node_embeddings[edge_index[0]], 
            node_embeddings[edge_index[1]]
        ], dim=1)
        pos_scores = self.edge_predictor(pos_edge_embeddings).squeeze()
        
        # Negative edges
        neg_edge_embeddings = torch.cat([
            node_embeddings[neg_edge_index[0]], 
            node_embeddings[neg_edge_index[1]]
        ], dim=1)
        neg_scores = self.edge_predictor(neg_edge_embeddings).squeeze()
        
        # Binary classification loss
        pos_loss = F.binary_cross_entropy(pos_scores, torch.ones_like(pos_scores))
        neg_loss = F.binary_cross_entropy(neg_scores, torch.zeros_like(neg_scores))
        
        return (pos_loss + neg_loss) / 2
    
    def contrastive_loss(self, embeddings1, embeddings2, temperature=0.1):
        """Compute contrastive loss"""
        # Project embeddings
        z1 = self.projection_head(embeddings1)
        z2 = self.projection_head(embeddings2)
        
        # Normalize
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(z1, z2.T) / temperature
        
        # Labels (diagonal should be positive pairs)
        labels = torch.arange(z1.size(0), device=z1.device)
        
        # Contrastive loss
        loss = F.cross_entropy(similarity_matrix, labels)
        return loss

class LogGraphSSL(nn.Module):
    """Complete LogGraph-SSL model with self-supervised learning"""
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, 
                 encoder_type='gcn', num_layers=3, dropout=0.1, heads=4):
        super(LogGraphSSL, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.encoder_type = encoder_type
        
        # Choose encoder architecture
        if encoder_type == 'gcn':
            self.encoder = GCNEncoder(vocab_size, embedding_dim, hidden_dim, num_layers, dropout)
        elif encoder_type == 'gat':
            self.encoder = GATEncoder(vocab_size, embedding_dim, hidden_dim, num_layers, heads, dropout)
        elif encoder_type == 'sage':
            self.encoder = GraphSAGEEncoder(vocab_size, embedding_dim, hidden_dim, num_layers, dropout)
        else:
            raise ValueError(f"Unknown encoder type: {encoder_type}")
        
        # Task heads
        self.anomaly_head = AnomalyDetectionHead(hidden_dim)
        self.ssl_manager = SSLTaskManager(hidden_dim, vocab_size)
        
        # Training mode flags
        self.ssl_training = True
        
    def forward(self, x, edge_index, batch=None):
        """Forward pass"""
        # Get node embeddings from encoder
        node_embeddings = self.encoder(x, edge_index, batch)
        
        # Anomaly detection logits
        anomaly_logits = self.anomaly_head(node_embeddings, batch)
        
        return {
            'node_embeddings': node_embeddings,
            'anomaly_logits': anomaly_logits
        }
    
    def ssl_forward(self, x, edge_index, batch=None):
        """Forward pass for SSL training"""
        outputs = self.forward(x, edge_index, batch)
        outputs['ssl_manager'] = self.ssl_manager
        return outputs

print("✅ Model architectures defined successfully!")

In [None]:
# Data Loading and Model Configuration

def load_hdfs_data():
    """Load and process HDFS data"""
    print("📂 Loading HDFS data...")
    
    # Define data files
    data_files = {
        'hdfs_full.txt': 'hdfs_full_labels.txt',
        'hdfs_train.txt': 'hdfs_train_labels.txt',  
        'hdfs_test.txt': 'hdfs_test_labels.txt'
    }
    
    # Try to find available data files
    available_file = None
    for log_file, label_file in data_files.items():
        if os.path.exists(log_file):
            if os.path.exists(label_file):
                print(f"✅ Found {log_file} with labels {label_file}")
                available_file = (log_file, label_file)
                break
            else:
                print(f"⚠️  Found {log_file} but no labels {label_file}")
                available_file = (log_file, None)
    
    if available_file is None:
        print("❌ No HDFS data files found!")
        return None, None
    
    # Load data using processor
    log_file, label_file = available_file
    graphs, labels = data_processor.load_hdfs_data(log_file, label_file)
    
    # Split data
    if len(graphs) > 1000:  # Only split if we have enough data
        train_graphs, test_graphs, train_labels, test_labels = train_test_split(
            graphs, labels, test_size=0.2, random_state=42, stratify=labels
        )
        print(f"📊 Train: {len(train_graphs)}, Test: {len(test_graphs)}")
    else:
        train_graphs, train_labels = graphs, labels
        test_graphs, test_labels = graphs[-100:], labels[-100:]  # Use last 100 for testing
        print(f"📊 Small dataset - Train: {len(train_graphs)}, Test: {len(test_graphs)}")
    
    return (train_graphs, train_labels), (test_graphs, test_labels)

def create_model_config():
    """Create model configuration"""
    config = {
        'vocab_size': len(data_processor.token_to_id),
        'embedding_dim': 128,
        'hidden_dim': 256,
        'encoder_type': 'gcn',  # Options: 'gcn', 'gat', 'sage'
        'num_layers': 3,
        'dropout': 0.1,
        'heads': 4,  # For GAT
        'learning_rate': 0.001,
        'weight_decay': 1e-5,
        'batch_size': 32,
        'ssl_epochs': 50,
        'supervised_epochs': 30,
        'ssl_weight': 1.0,
        'supervised_weight': 1.0
    }
    return config

def create_model(config):
    """Create and initialize model"""
    print(f"🏗️ Creating {config['encoder_type'].upper()} model...")
    
    model = LogGraphSSL(
        vocab_size=config['vocab_size'],
        embedding_dim=config['embedding_dim'],
        hidden_dim=config['hidden_dim'],
        encoder_type=config['encoder_type'],
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        heads=config['heads']
    ).to(device)
    
    print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"🔧 Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    return model

class GraphDataLoader:
    """Custom data loader for graph data"""
    def __init__(self, graphs, labels, batch_size=32, shuffle=True):
        self.graphs = graphs
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = list(range(len(graphs)))
        
    def __iter__(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
        
        for i in range(0, len(self.indices), self.batch_size):
            batch_indices = self.indices[i:i + self.batch_size]
            batch_graphs = [self.graphs[idx] for idx in batch_indices]
            batch_labels = [self.labels[idx] for idx in batch_indices]
            
            # Create batch
            try:
                batch = Batch.from_data_list(batch_graphs)
                batch_labels = torch.tensor(batch_labels, dtype=torch.long)
                yield batch.to(device), batch_labels.to(device)
            except Exception as e:
                print(f"⚠️ Skipping batch due to error: {e}")
                continue
    
    def __len__(self):
        return (len(self.indices) + self.batch_size - 1) // self.batch_size

print("✅ Data loading and configuration functions ready!")

In [None]:
# Training Functions

class TrainingManager:
    """Comprehensive training manager for SSL and supervised learning"""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = device
        
        # Optimizers
        self.ssl_optimizer = optim.Adam(model.parameters(), 
                                       lr=config['learning_rate'], 
                                       weight_decay=config['weight_decay'])
        self.supervised_optimizer = optim.Adam(model.parameters(), 
                                             lr=config['learning_rate'] * 0.1,  # Lower LR for fine-tuning
                                             weight_decay=config['weight_decay'])
        
        # Schedulers
        self.ssl_scheduler = ReduceLROnPlateau(self.ssl_optimizer, patience=10, factor=0.5)
        self.supervised_scheduler = ReduceLROnPlateau(self.supervised_optimizer, patience=5, factor=0.5)
        
        # Metrics tracking
        self.ssl_metrics = {'loss': [], 'node_loss': [], 'edge_loss': [], 'contrastive_loss': []}
        self.supervised_metrics = {'loss': [], 'accuracy': [], 'f1': []}
        
    def generate_negative_edges(self, edge_index, num_nodes, num_neg_samples=None):
        """Generate negative edges for edge prediction task"""
        if num_neg_samples is None:
            num_neg_samples = edge_index.size(1)
        
        # Get all possible edges
        all_edges = set()
        for i in range(edge_index.size(1)):
            all_edges.add((edge_index[0, i].item(), edge_index[1, i].item()))
        
        # Generate negative edges
        neg_edges = []
        attempts = 0
        max_attempts = num_neg_samples * 10
        
        while len(neg_edges) < num_neg_samples and attempts < max_attempts:
            src = np.random.randint(0, num_nodes)
            dst = np.random.randint(0, num_nodes)
            
            if src != dst and (src, dst) not in all_edges:
                neg_edges.append([src, dst])
            attempts += 1
        
        if len(neg_edges) == 0:
            # Fallback: create some random negative edges
            neg_edges = [[0, 1]] if num_nodes > 1 else [[0, 0]]
        
        return torch.tensor(neg_edges, dtype=torch.long, device=self.device).t()
    
    def mask_nodes(self, x, mask_ratio=0.15):
        """Mask random nodes for self-supervised learning"""
        if x.dim() == 2 and x.size(1) == 1:
            x = x.squeeze(1)
        
        num_nodes = x.size(0)
        num_mask = max(1, int(num_nodes * mask_ratio))
        
        # Choose random nodes to mask
        mask_indices = torch.randperm(num_nodes)[:num_mask]
        
        # Store original tokens
        original_tokens = x.clone()
        
        # Mask tokens (replace with <UNK> token which is ID 1)
        x_masked = x.clone()
        x_masked[mask_indices] = 1  # <UNK> token
        
        return x_masked.unsqueeze(1), original_tokens, mask_indices
    
    def ssl_train_step(self, batch, batch_labels):
        """Single SSL training step"""
        self.model.train()
        self.ssl_optimizer.zero_grad()
        
        # Prepare data
        x, edge_index = batch.x, batch.edge_index
        batch_info = getattr(batch, 'batch', None)
        
        # Mask nodes for node prediction task
        x_masked, original_tokens, mask_indices = self.mask_nodes(x.squeeze(1) if x.dim() == 2 else x)
        
        # Forward pass with masked input
        outputs = self.model.ssl_forward(x_masked, edge_index, batch_info)
        node_embeddings = outputs['node_embeddings']
        ssl_manager = outputs['ssl_manager']
        
        # SSL losses
        losses = {}
        
        # Node masking loss
        if len(mask_indices) > 0:
            losses['node'] = ssl_manager.node_masking_loss(node_embeddings, original_tokens, mask_indices)
        else:
            losses['node'] = torch.tensor(0.0, device=self.device)
        
        # Edge prediction loss
        try:
            neg_edge_index = self.generate_negative_edges(edge_index, batch.num_nodes)
            losses['edge'] = ssl_manager.edge_prediction_loss(node_embeddings, edge_index, neg_edge_index)
        except:
            losses['edge'] = torch.tensor(0.0, device=self.device)
        
        # Contrastive loss (using graph-level embeddings)
        try:
            if batch_info is not None and len(torch.unique(batch_info)) > 1:
                # Multiple graphs in batch
                graph_embeddings = global_mean_pool(node_embeddings, batch_info)
                if graph_embeddings.size(0) > 1:
                    # Create augmented views (simple dropout)
                    aug_embeddings = F.dropout(graph_embeddings, p=0.1, training=True)
                    losses['contrastive'] = ssl_manager.contrastive_loss(graph_embeddings, aug_embeddings)
                else:
                    losses['contrastive'] = torch.tensor(0.0, device=self.device)
            else:
                losses['contrastive'] = torch.tensor(0.0, device=self.device)
        except:
            losses['contrastive'] = torch.tensor(0.0, device=self.device)
        
        # Total SSL loss
        total_loss = losses['node'] + losses['edge'] + 0.1 * losses['contrastive']
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.ssl_optimizer.step()
        
        return {
            'total_loss': total_loss.item(),
            'node_loss': losses['node'].item(),
            'edge_loss': losses['edge'].item(),
            'contrastive_loss': losses['contrastive'].item()
        }
    
    def supervised_train_step(self, batch, batch_labels):
        """Single supervised training step"""
        self.model.train()
        self.supervised_optimizer.zero_grad()
        
        # Forward pass
        outputs = self.model(batch.x, batch.edge_index, getattr(batch, 'batch', None))
        logits = outputs['anomaly_logits']
        
        # Compute loss
        loss = F.cross_entropy(logits, batch_labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.supervised_optimizer.step()
        
        # Compute accuracy
        pred = torch.argmax(logits, dim=1)
        accuracy = (pred == batch_labels).float().mean()
        
        return {
            'loss': loss.item(),
            'accuracy': accuracy.item(),
            'predictions': pred.cpu().numpy(),
            'labels': batch_labels.cpu().numpy()
        }
    
    def evaluate(self, data_loader):
        """Evaluate model on data"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for batch, batch_labels in data_loader:
                outputs = self.model(batch.x, batch.edge_index, getattr(batch, 'batch', None))
                logits = outputs['anomaly_logits']
                
                loss = F.cross_entropy(logits, batch_labels)
                total_loss += loss.item()
                
                pred = torch.argmax(logits, dim=1)
                all_preds.extend(pred.cpu().numpy())
                all_labels.extend(batch_labels.cpu().numpy())
        
        # Compute metrics
        accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
        
        return {
            'loss': total_loss / len(data_loader),
            'accuracy': accuracy,
            'predictions': all_preds,
            'labels': all_labels
        }

def train_ssl_phase(model, train_loader, config):
    """Self-supervised pre-training phase"""
    print("🎯 Starting Self-Supervised Learning Phase...")
    
    trainer = TrainingManager(model, config)
    
    for epoch in range(config['ssl_epochs']):
        epoch_losses = {'total': [], 'node': [], 'edge': [], 'contrastive': []}
        
        # Training loop
        pbar = tqdm(train_loader, desc=f"SSL Epoch {epoch+1}/{config['ssl_epochs']}")
        for batch, batch_labels in pbar:
            try:
                losses = trainer.ssl_train_step(batch, batch_labels)
                
                epoch_losses['total'].append(losses['total_loss'])
                epoch_losses['node'].append(losses['node_loss'])
                epoch_losses['edge'].append(losses['edge_loss'])
                epoch_losses['contrastive'].append(losses['contrastive_loss'])
                
                # Update progress bar
                pbar.set_postfix({
                    'Loss': f"{losses['total_loss']:.4f}",
                    'Node': f"{losses['node_loss']:.4f}",
                    'Edge': f"{losses['edge_loss']:.4f}"
                })
                
            except Exception as e:
                print(f"⚠️ Error in SSL training step: {e}")
                continue
        
        # Log epoch metrics
        avg_loss = np.mean(epoch_losses['total'])
        trainer.ssl_scheduler.step(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"📊 SSL Epoch {epoch+1}: "
                  f"Loss: {avg_loss:.4f}, "
                  f"Node: {np.mean(epoch_losses['node']):.4f}, "
                  f"Edge: {np.mean(epoch_losses['edge']):.4f}")
    
    print("✅ SSL pre-training completed!")
    return trainer

def train_supervised_phase(trainer, train_loader, test_loader, config):
    """Supervised fine-tuning phase"""
    print("🎯 Starting Supervised Fine-tuning Phase...")
    
    best_accuracy = 0
    best_model_state = None
    
    for epoch in range(config['supervised_epochs']):
        # Training
        epoch_losses = []
        epoch_accuracies = []
        
        pbar = tqdm(train_loader, desc=f"Supervised Epoch {epoch+1}/{config['supervised_epochs']}")
        for batch, batch_labels in pbar:
            try:
                metrics = trainer.supervised_train_step(batch, batch_labels)
                
                epoch_losses.append(metrics['loss'])
                epoch_accuracies.append(metrics['accuracy'])
                
                pbar.set_postfix({
                    'Loss': f"{metrics['loss']:.4f}",
                    'Acc': f"{metrics['accuracy']:.4f}"
                })
                
            except Exception as e:
                print(f"⚠️ Error in supervised training step: {e}")
                continue
        
        # Evaluation
        train_metrics = {
            'loss': np.mean(epoch_losses),
            'accuracy': np.mean(epoch_accuracies)
        }
        
        test_metrics = trainer.evaluate(test_loader)
        
        # Learning rate scheduling
        trainer.supervised_scheduler.step(test_metrics['loss'])
        
        # Save best model
        if test_metrics['accuracy'] > best_accuracy:
            best_accuracy = test_metrics['accuracy']
            best_model_state = trainer.model.state_dict().copy()
        
        # Log metrics
        if (epoch + 1) % 5 == 0:
            print(f"📊 Supervised Epoch {epoch+1}: "
                  f"Train Loss: {train_metrics['loss']:.4f}, "
                  f"Train Acc: {train_metrics['accuracy']:.4f}, "
                  f"Test Loss: {test_metrics['loss']:.4f}, "
                  f"Test Acc: {test_metrics['accuracy']:.4f}")
    
    # Load best model
    if best_model_state is not None:
        trainer.model.load_state_dict(best_model_state)
        print(f"✅ Loaded best model with accuracy: {best_accuracy:.4f}")
    
    print("✅ Supervised fine-tuning completed!")
    return trainer, best_accuracy

print("✅ Training functions ready!")

In [None]:
# Main Training Execution

def main_training_pipeline():
    """Complete training pipeline execution"""
    print("🚀 Starting LogGraph-SSL Training Pipeline...")
    print("=" * 60)
    
    # Check GPU memory
    if torch.cuda.is_available():
        print(f"🔥 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
        torch.cuda.empty_cache()
    
    try:
        # Step 1: Load and process data
        print("\n📂 Step 1: Loading and processing data...")
        train_data, test_data = load_hdfs_data()
        
        if train_data is None:
            print("❌ Failed to load data. Please check data files.")
            return None
        
        train_graphs, train_labels = train_data
        test_graphs, test_labels = test_data
        
        print(f"✅ Data loaded successfully!")
        print(f"📊 Training samples: {len(train_graphs)}")
        print(f"📊 Test samples: {len(test_graphs)}")
        print(f"📊 Vocabulary size: {len(data_processor.token_to_id)}")
        
        # Check label distribution
        unique_labels = np.unique(train_labels)
        label_counts = Counter(train_labels)
        print(f"📈 Label distribution: {dict(label_counts)}")
        
        # Step 2: Create model configuration
        print("\n⚙️ Step 2: Creating model configuration...")
        config = create_model_config()
        print("✅ Configuration created:")
        for key, value in config.items():
            print(f"   {key}: {value}")
        
        # Step 3: Initialize model
        print(f"\n🏗️ Step 3: Initializing {config['encoder_type'].upper()} model...")
        model = create_model(config)
        print("✅ Model initialized successfully!")
        
        # Step 4: Create data loaders
        print("\n📦 Step 4: Creating data loaders...")
        train_loader = GraphDataLoader(train_graphs, train_labels, 
                                     config['batch_size'], shuffle=True)
        test_loader = GraphDataLoader(test_graphs, test_labels, 
                                    config['batch_size'], shuffle=False)
        
        print(f"✅ Data loaders created:")
        print(f"   Training batches: {len(train_loader)}")
        print(f"   Test batches: {len(test_loader)}")
        
        # Step 5: Self-supervised pre-training
        print(f"\n🎯 Step 5: Self-supervised pre-training ({config['ssl_epochs']} epochs)...")
        start_time = time.time()
        
        trainer = train_ssl_phase(model, train_loader, config)
        
        ssl_time = time.time() - start_time
        print(f"⏱️ SSL training completed in {ssl_time/60:.1f} minutes")
        
        # Step 6: Supervised fine-tuning
        print(f"\n🎯 Step 6: Supervised fine-tuning ({config['supervised_epochs']} epochs)...")
        start_time = time.time()
        
        trainer, best_accuracy = train_supervised_phase(trainer, train_loader, test_loader, config)
        
        supervised_time = time.time() - start_time
        print(f"⏱️ Supervised training completed in {supervised_time/60:.1f} minutes")
        
        # Step 7: Final evaluation
        print("\n📊 Step 7: Final comprehensive evaluation...")
        final_test_metrics = trainer.evaluate(test_loader)
        
        print("🎉 Training Pipeline Completed Successfully!")
        print("=" * 60)
        print("📈 Final Results:")
        print(f"   Test Accuracy: {final_test_metrics['accuracy']:.4f}")
        print(f"   Test Loss: {final_test_metrics['loss']:.4f}")
        print(f"   Best Accuracy: {best_accuracy:.4f}")
        print(f"   Total Training Time: {(ssl_time + supervised_time)/60:.1f} minutes")
        
        # Classification report
        if len(np.unique(final_test_metrics['labels'])) > 1:
            from sklearn.metrics import classification_report
            print("\n📋 Classification Report:")
            print(classification_report(final_test_metrics['labels'], 
                                       final_test_metrics['predictions'],
                                       target_names=['Normal', 'Anomaly']))
        
        # Save model
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_path = f"loggraph_ssl_model_{timestamp}.pth"
        torch.save({
            'model_state_dict': model.state_dict(),
            'config': config,
            'vocab': data_processor.token_to_id,
            'test_accuracy': final_test_metrics['accuracy']
        }, model_path)
        print(f"💾 Model saved to: {model_path}")
        
        return {
            'model': model,
            'trainer': trainer,
            'config': config,
            'final_metrics': final_test_metrics,
            'best_accuracy': best_accuracy
        }
        
    except Exception as e:
        print(f"❌ Error in training pipeline: {e}")
        import traceback
        traceback.print_exc()
        return None

# Execute the training pipeline
if __name__ == "__main__":
    print("🚀 Executing LogGraph-SSL Training Pipeline...")
    results = main_training_pipeline()
    
    if results is not None:
        print("✅ Training completed successfully!")
        print(f"🎯 Final test accuracy: {results['final_metrics']['accuracy']:.4f}")
    else:
        print("❌ Training failed. Please check the logs above.")

In [None]:
# Monitoring and Visualization

def plot_training_metrics(trainer, save_path=None):
    """Plot training metrics and GPU utilization"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # SSL losses
    if trainer.ssl_metrics['loss']:
        axes[0, 0].plot(trainer.ssl_metrics['loss'], label='Total Loss')
        axes[0, 0].plot(trainer.ssl_metrics['node_loss'], label='Node Loss')
        axes[0, 0].plot(trainer.ssl_metrics['edge_loss'], label='Edge Loss')
        axes[0, 0].set_title('Self-Supervised Learning Losses')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
    
    # Supervised metrics
    if trainer.supervised_metrics['loss']:
        axes[0, 1].plot(trainer.supervised_metrics['loss'], label='Loss', color='red')
        axes[0, 1].set_title('Supervised Learning Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].grid(True)
        
        # Accuracy on secondary y-axis
        ax2 = axes[0, 1].twinx()
        ax2.plot(trainer.supervised_metrics['accuracy'], label='Accuracy', color='blue')
        ax2.set_ylabel('Accuracy')
        ax2.legend(loc='upper right')
    
    # GPU memory usage (if available)
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024**3
        memory_reserved = torch.cuda.memory_reserved() / 1024**3
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
        
        axes[1, 0].bar(['Allocated', 'Reserved', 'Total'], 
                      [memory_allocated, memory_reserved, total_memory],
                      color=['orange', 'red', 'blue'])
        axes[1, 0].set_title('GPU Memory Usage (GB)')
        axes[1, 0].set_ylabel('Memory (GB)')
    
    # Model architecture summary
    if hasattr(trainer, 'model'):
        total_params = sum(p.numel() for p in trainer.model.parameters())
        trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
        
        axes[1, 1].bar(['Total', 'Trainable'], 
                      [total_params, trainable_params],
                      color=['lightblue', 'darkblue'])
        axes[1, 1].set_title('Model Parameters')
        axes[1, 1].set_ylabel('Number of Parameters')
        axes[1, 1].ticklabel_format(style='scientific', axis='y', scilimits=(0,0))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"📊 Training plots saved to: {save_path}")
    
    plt.show()

def monitor_gpu_usage():
    """Monitor and display current GPU usage"""
    if not torch.cuda.is_available():
        print("⚠️ CUDA not available")
        return
    
    print("🔥 GPU Monitoring:")
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"   Memory Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"   Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    # Memory utilization percentage
    memory_percent = (torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory) * 100
    print(f"   Utilization: {memory_percent:.1f}%")
    
    # Temperature and power (if available)
    try:
        import subprocess
        result = subprocess.run(['nvidia-smi', '--query-gpu=temperature.gpu,power.draw', 
                               '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            temp, power = result.stdout.strip().split(', ')
            print(f"   Temperature: {temp}°C")
            print(f"   Power Draw: {power}W")
    except:
        pass

def create_training_dashboard(trainer=None):
    """Create a simple training dashboard"""
    print("📊 LogGraph-SSL Training Dashboard")
    print("=" * 50)
    
    # System information
    print("🖥️ System Information:")
    print(f"   Python Version: {sys.version.split()[0]}")
    print(f"   PyTorch Version: {torch.__version__}")
    print(f"   Device: {device}")
    
    if torch.cuda.is_available():
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   CUDA Version: {torch.version.cuda}")
    
    # Memory usage
    print(f"\n💾 Memory Usage:")
    print(f"   RAM: {psutil.virtual_memory().percent:.1f}% used")
    
    if torch.cuda.is_available():
        memory_percent = (torch.cuda.memory_allocated() / torch.cuda.get_device_properties(0).total_memory) * 100
        print(f"   GPU: {memory_percent:.1f}% used")
    
    # Training status
    if trainer is not None:
        print(f"\n🎯 Training Status:")
        print(f"   SSL Epochs Completed: {len(trainer.ssl_metrics['loss'])}")
        print(f"   Supervised Epochs Completed: {len(trainer.supervised_metrics['loss'])}")
        
        if trainer.ssl_metrics['loss']:
            print(f"   Latest SSL Loss: {trainer.ssl_metrics['loss'][-1]:.4f}")
        
        if trainer.supervised_metrics['accuracy']:
            print(f"   Latest Accuracy: {trainer.supervised_metrics['accuracy'][-1]:.4f}")
    
    print("=" * 50)

# Test GPU and create initial dashboard
print("🔍 Initial System Check:")
monitor_gpu_usage()
print("\n📊 Creating training dashboard...")
create_training_dashboard()
print("✅ Monitoring setup complete!")

In [None]:
# Quick Test and Validation

def quick_test_setup():
    """Quick test to validate everything works before full training"""
    print("🧪 Running Quick Setup Validation...")
    print("-" * 40)
    
    try:
        # Test 1: Check data files
        print("📂 Test 1: Checking data files...")
        data_files = ['hdfs_full.txt', 'hdfs_train.txt', 'hdfs_test.txt']
        found_files = [f for f in data_files if os.path.exists(f)]
        
        if found_files:
            print(f"✅ Found data files: {found_files}")
        else:
            print("⚠️ No data files found. Creating sample data...")
            # Create minimal sample data for testing
            sample_logs = [
                "E1 E2 E3 E4",
                "E1 E5 E6",
                "E2 E3 E7 E8",
                "E1 E2 E9",
                "E5 E6 E10 E11"
            ]
            sample_labels = [0, 1, 0, 1, 0]
            
            with open('sample_logs.txt', 'w') as f:
                for log in sample_logs:
                    f.write(log + '\n')
            
            with open('sample_labels.txt', 'w') as f:
                for label in sample_labels:
                    f.write(str(label) + '\n')
            
            print("✅ Created sample data files")
        
        # Test 2: Data processing
        print("\n🔧 Test 2: Testing data processing...")
        test_processor = HDFSDataProcessor(vocab_size=100)
        
        # Use sample data or existing data
        if os.path.exists('sample_logs.txt'):
            test_graphs, test_labels = test_processor.load_hdfs_data('sample_logs.txt', 'sample_labels.txt')
        elif found_files:
            test_graphs, test_labels = test_processor.load_hdfs_data(found_files[0])
        else:
            raise FileNotFoundError("No data files available")
        
        print(f"✅ Processed {len(test_graphs)} graphs")
        print(f"   Vocabulary size: {len(test_processor.token_to_id)}")
        
        # Test 3: Model creation
        print("\n🏗️ Test 3: Testing model creation...")
        test_config = {
            'vocab_size': len(test_processor.token_to_id),
            'embedding_dim': 64,
            'hidden_dim': 128,
            'encoder_type': 'gcn',
            'num_layers': 2,
            'dropout': 0.1,
            'heads': 2
        }
        
        test_model = LogGraphSSL(**{k: v for k, v in test_config.items() if k != 'heads' or test_config['encoder_type'] == 'gat'}).to(device)
        print(f"✅ Created {test_config['encoder_type'].upper()} model")
        print(f"   Parameters: {sum(p.numel() for p in test_model.parameters()):,}")
        
        # Test 4: Single forward pass
        print("\n⚡ Test 4: Testing forward pass...")
        if test_graphs:
            sample_graph = test_graphs[0].to(device)
            
            with torch.no_grad():
                outputs = test_model(sample_graph.x, sample_graph.edge_index)
                print(f"✅ Forward pass successful")
                print(f"   Node embeddings shape: {outputs['node_embeddings'].shape}")
                print(f"   Anomaly logits shape: {outputs['anomaly_logits'].shape}")
        
        # Test 5: Data loader
        print("\n📦 Test 5: Testing data loader...")
        test_loader = GraphDataLoader(test_graphs[:3], test_labels[:3], batch_size=2)
        
        for batch, labels in test_loader:
            print(f"✅ Data loader working")
            print(f"   Batch size: {batch.batch.max().item() + 1 if hasattr(batch, 'batch') else 1}")
            print(f"   Labels shape: {labels.shape}")
            break
        
        print("\n🎉 All tests passed! Ready for full training.")
        print("-" * 40)
        
        return True
        
    except Exception as e:
        print(f"\n❌ Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False

def run_mini_training():
    """Run a mini training session to test the pipeline"""
    print("🏃‍♂️ Running Mini Training Session...")
    print("-" * 40)
    
    try:
        # Load minimal data
        test_processor = HDFSDataProcessor(vocab_size=100)
        
        if os.path.exists('sample_logs.txt'):
            graphs, labels = test_processor.load_hdfs_data('sample_logs.txt', 'sample_labels.txt')
        else:
            print("⚠️ No sample data available. Run quick_test_setup() first.")
            return False
        
        # Create mini config
        mini_config = {
            'vocab_size': len(test_processor.token_to_id),
            'embedding_dim': 32,
            'hidden_dim': 64,
            'encoder_type': 'gcn',
            'num_layers': 2,
            'dropout': 0.1,
            'heads': 2,
            'learning_rate': 0.01,
            'weight_decay': 1e-4,
            'batch_size': 2,
            'ssl_epochs': 3,
            'supervised_epochs': 2
        }
        
        # Create model and trainer
        model = LogGraphSSL(
            vocab_size=mini_config['vocab_size'],
            embedding_dim=mini_config['embedding_dim'],
            hidden_dim=mini_config['hidden_dim'],
            encoder_type=mini_config['encoder_type'],
            num_layers=mini_config['num_layers'],
            dropout=mini_config['dropout']
        ).to(device)
        
        # Create data loader
        train_loader = GraphDataLoader(graphs, labels, mini_config['batch_size'])
        
        # Quick SSL training
        print("🎯 Mini SSL training...")
        trainer = TrainingManager(model, mini_config)
        
        for epoch in range(mini_config['ssl_epochs']):
            for batch, batch_labels in train_loader:
                losses = trainer.ssl_train_step(batch, batch_labels)
                print(f"   SSL Epoch {epoch+1}, Loss: {losses['total_loss']:.4f}")
                break  # One batch per epoch for speed
        
        # Quick supervised training
        print("🎯 Mini supervised training...")
        for epoch in range(mini_config['supervised_epochs']):
            for batch, batch_labels in train_loader:
                metrics = trainer.supervised_train_step(batch, batch_labels)
                print(f"   Supervised Epoch {epoch+1}, Loss: {metrics['loss']:.4f}, Acc: {metrics['accuracy']:.4f}")
                break  # One batch per epoch for speed
        
        print("✅ Mini training completed successfully!")
        print("-" * 40)
        return True
        
    except Exception as e:
        print(f"❌ Mini training failed: {e}")
        import traceback
        traceback.print_exc()
        return False

# Run validation tests
print("🚀 Starting Local Validation...")
if quick_test_setup():
    print("\n🏃‍♂️ Running mini training test...")
    run_mini_training()
    print("\n✅ Local validation complete! Ready for full training on Jupyter server.")
else:
    print("\n❌ Validation failed. Please fix issues before proceeding.")