In [1]:
import os
import os.path as osp
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
import cv2
import mediapipe as mp
import numpy as np
from typing import List, Tuple
import json
import warnings

# Suppress MediaPipe warnings
warnings.filterwarnings('ignore', category=UserWarning)


def compute_edge_features(x: torch.Tensor, edge_index: torch.Tensor, 
                         availability_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute edge features (distance) and filter edges based on availability.
    
    Args:
        x: Node features of shape (T, N, 3) where 3 = (x, y, confidence)
        edge_index: Edge connectivity of shape (2, num_edges)
        availability_mask: Availability mask of shape (T, N)
        
    Returns:
        filtered_edge_index: Edge index with unavailable nodes removed, shape (2, num_valid_edges)
        edge_attr: Edge features (distance) of shape (num_valid_edges, T)
    """
    T, N, _ = x.shape
    num_edges = edge_index.size(1)
    
    # Extract source and target nodes
    source_nodes = edge_index[0]  # (num_edges,)
    target_nodes = edge_index[1]  # (num_edges,)
    
    # Get coordinates for all frames
    source_coords = x[:, source_nodes, :2]  # (T, num_edges, 2)
    target_coords = x[:, target_nodes, :2]  # (T, num_edges, 2)
    
    # Compute Euclidean distance for each edge at each frame
    distances = torch.norm(target_coords - source_coords, dim=2)  # (T, num_edges)
    
    # Check availability: edge is valid only if both nodes are available
    source_available = availability_mask[:, source_nodes]  # (T, num_edges)
    target_available = availability_mask[:, target_nodes]  # (T, num_edges)
    edge_available = source_available * target_available  # (T, num_edges)
    
    # An edge is valid if it's available in at least some frames
    edge_valid = edge_available.sum(dim=0) > 0  # (num_edges,)
    
    # Filter edges
    filtered_edge_index = edge_index[:, edge_valid]  # (2, num_valid_edges)
    filtered_distances = distances[:, edge_valid]  # (T, num_valid_edges)
    filtered_availability = edge_available[:, edge_valid]  # (T, num_valid_edges)
    
    # Set distance to 0 where edge is not available in that frame
    filtered_distances = filtered_distances * filtered_availability
    
    # Transpose to get (num_valid_edges, T)
    edge_attr = filtered_distances.t()  # (num_valid_edges, T)
    
    return filtered_edge_index, edge_attr


def create_edge_index(n_pose: int, n_face: int, n_hand: int) -> torch.Tensor:
    """
    Create edge index based on actual landmark counts.
    
    Args:
        n_pose: Number of pose landmarks (12 for upper body)
        n_face: Number of face landmarks (lips + eyebrows)
        n_hand: Number of hand landmarks (21 per hand)
        
    Returns:
        edge_index: Tensor of shape (2, num_edges)
    """
    edges = []
    
    # POSE CONNECTIONS (upper body: shoulders, elbows, wrists, hips)
    pose_connections = [
        (0, 1),   # Left shoulder to right shoulder
        (0, 2),   # Left shoulder to left elbow  
        (2, 4),   # Left elbow to left wrist
        (1, 3),   # Right shoulder to right elbow
        (3, 5),   # Right elbow to right wrist
        (0, 6),   # Left shoulder to left hip
        (1, 7),   # Right shoulder to right hip
        (6, 7),   # Left hip to right hip
        (6, 8),   # Left hip to left knee
        (7, 9),   # Right hip to right knee
        (8, 10),  # Left knee to left ankle
        (9, 11),  # Right knee to right ankle
    ]
    
    for i, j in pose_connections:
        if i < n_pose and j < n_pose:
            edges.append((i, j))
    
    # FACE CONNECTIONS (sequential for lips and eyebrows)
    face_start = n_pose
    for i in range(n_face - 1):
        edges.append((face_start + i, face_start + i + 1))
    # Close the loop
    if n_face > 0:
        edges.append((face_start, face_start + n_face - 1))
    
    # LEFT HAND CONNECTIONS
    left_hand_start = n_pose + n_face
    hand_connections = [
        (0, 1), (1, 2), (2, 3), (3, 4),          # Thumb
        (0, 5), (5, 6), (6, 7), (7, 8),          # Index
        (0, 9), (9, 10), (10, 11), (11, 12),     # Middle
        (0, 13), (13, 14), (14, 15), (15, 16),   # Ring
        (0, 17), (17, 18), (18, 19), (19, 20),   # Pinky
        (5, 9), (9, 13), (13, 17),               # Palm
    ]
    
    for i, j in hand_connections:
        edges.append((left_hand_start + i, left_hand_start + j))
    
    # RIGHT HAND CONNECTIONS
    right_hand_start = left_hand_start + n_hand
    for i, j in hand_connections:
        edges.append((right_hand_start + i, right_hand_start + j))
    
    # Convert to bidirectional edges
    bidirectional_edges = []
    for i, j in edges:
        bidirectional_edges.append([i, j])
        bidirectional_edges.append([j, i])
    
    edge_index = torch.tensor(bidirectional_edges, dtype=torch.long).t().contiguous()
    return edge_index


def extract_features_from_video(video_path: str) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int, int]]:
    """
    Extract pose features from a video file (optimized version).
    
    Returns:
        features: Array of shape (T, N, 3) - x, y, confidence
        availability_mask: Array of shape (T, N) - 1 if keypoint available, 0 if missing
        counts: Tuple of (n_pose, n_face, n_hand)
    """
    mp_holistic = mp.solutions.holistic
    mp_face_mesh = mp.solutions.face_mesh
    
    # Define landmark indices
    POSE_INDICES = list(range(11, 23))  # Upper body
    
    face_connections = (
        mp_face_mesh.FACEMESH_LIPS | 
        mp_face_mesh.FACEMESH_LEFT_EYEBROW | 
        mp_face_mesh.FACEMESH_RIGHT_EYEBROW
    )
    FACE_INDICES = sorted(set([idx for pair in face_connections for idx in pair]))
    HAND_INDICES = list(range(21))
    
    N_POSE = len(POSE_INDICES)
    N_FACE = len(FACE_INDICES)
    N_HAND = len(HAND_INDICES)
    TOTAL_N = N_POSE + N_FACE + (N_HAND * 2)
    
    # Load video - optimized to read frame count first
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Cannot open video: {video_path}")
    
    # Get frame count and pre-allocate
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Pre-allocate arrays for all frames
    features = np.zeros((frame_count, TOTAL_N, 3), dtype=np.float32)
    availability_mask = np.zeros((frame_count, TOTAL_N), dtype=np.float32)
    
    # Process frames with holistic model
    holistic = mp_holistic.Holistic(
        static_image_mode=False,
        model_complexity=1,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
        enable_segmentation=False,  # Disable segmentation for speed
        refine_face_landmarks=False  # Disable face refinement for speed
    )
    
    t = 0
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # Convert to RGB once
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = holistic.process(rgb_frame)
        
        idx = 0
        
        # Pose - vectorized approach
        if results.pose_landmarks:
            for pose_idx in POSE_INDICES:
                lm = results.pose_landmarks.landmark[pose_idx]
                features[t, idx] = [lm.x, lm.y, lm.visibility]
                availability_mask[t, idx] = 1.0 if lm.visibility > 0.5 else 0.0
                idx += 1
        else:
            idx += N_POSE
        
        # Face
        if results.face_landmarks:
            for face_idx in FACE_INDICES:
                lm = results.face_landmarks.landmark[face_idx]
                features[t, idx] = [lm.x, lm.y, 1.0]
                availability_mask[t, idx] = 1.0
                idx += 1
        else:
            idx += N_FACE
        
        # Left hand
        if results.left_hand_landmarks:
            for hand_idx in HAND_INDICES:
                lm = results.left_hand_landmarks.landmark[hand_idx]
                features[t, idx] = [lm.x, lm.y, 1.0]
                availability_mask[t, idx] = 1.0
                idx += 1
        else:
            idx += N_HAND
        
        # Right hand
        if results.right_hand_landmarks:
            for hand_idx in HAND_INDICES:
                lm = results.right_hand_landmarks.landmark[hand_idx]
                features[t, idx] = [lm.x, lm.y, 1.0]
                availability_mask[t, idx] = 1.0
                idx += 1
        else:
            idx += N_HAND
        
        t += 1
    
    cap.release()
    holistic.close()
    
    # Trim arrays if actual frame count differs
    if t < frame_count:
        features = features[:t]
        availability_mask = availability_mask[:t]
    
    return features, availability_mask, (N_POSE, N_FACE, N_HAND)


class SignLanguageDataset(Dataset):
    """
    PyTorch Geometric Dataset for Sign Language Recognition.
    
    Each sample is a temporal graph with:
    - x: Node features of shape (T, N, 3) where T=frames, N=keypoints, 3=(x,y,confidence)
    - edge_index: Graph connectivity of shape (2, num_edges)
    - y: Label (sign class)
    
    Directory structure:
        root/
            raw/
                video1.mp4
                video2.mp4
                ...
                labels.json  # {"video1.mp4": 0, "video2.mp4": 1, ...}
            processed/
                data_0.pt
                data_1.pt
                ...
    """
    
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        
    @property
    def raw_file_names(self) -> List[str]:
        """Return list of raw video files."""
        raw_dir = self.raw_dir
        if not osp.exists(raw_dir):
            return []
        
        video_files = [f for f in os.listdir(raw_dir) 
                      if f.endswith(('.mp4', '.avi', '.mov'))]
        return sorted(video_files)
    
    @property
    def processed_file_names(self) -> List[str]:
        """Return list of processed .pt files."""
        return [f'data_{i}.pt' for i in range(len(self.raw_file_names))]
    
    def download(self):
        """Download dataset (implement if needed)."""
        pass
    
    def process(self):
        """Process raw videos into graph data."""
        # Load labels
        labels_path = osp.join(self.raw_dir, 'labels.json')
        if osp.exists(labels_path):
            with open(labels_path, 'r') as f:
                labels = json.load(f)
        else:
            print("Warning: labels.json not found, using default label 0")
            labels = {}
        
        idx = 0
        for raw_path in self.raw_paths:
            if raw_path.endswith('labels.json'):
                continue
            
            print(f"Processing {osp.basename(raw_path)}...")
            
            try:
                # Extract features
                features, availability_mask, (n_pose, n_face, n_hand) = extract_features_from_video(raw_path)
                
                # Create initial edge index (all possible edges)
                edge_index = create_edge_index(n_pose, n_face, n_hand)
                
                # Convert to tensors
                x = torch.from_numpy(features).float()  # Shape: (T, N, 3)
                availability_mask_tensor = torch.from_numpy(availability_mask).float()  # (T, N)
                
                # Compute edge features and filter unavailable edges
                filtered_edge_index, edge_attr = compute_edge_features(
                    x, edge_index, availability_mask_tensor
                )
                
                # Get label
                video_name = osp.basename(raw_path)
                y = torch.tensor([labels.get(video_name, 0)], dtype=torch.long)
                
                # Create Data object
                data = Data(
                    x=x,  # (T, N, 3) - node features: x, y, confidence
                    edge_index=filtered_edge_index,  # (2, num_valid_edges)
                    edge_attr=edge_attr,  # (num_valid_edges, T) - distance features
                    availability_mask=availability_mask_tensor,  # (T, N) - 1 if available, 0 if missing
                    y=y,  # (1,)
                    num_frames=x.size(0),
                    num_nodes=x.size(1)
                )
                
                # Apply filters and transforms
                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                
                # Save with optimized settings
                torch.save(
                    data, 
                    osp.join(self.processed_dir, f'data_{idx}.pt'),
                    _use_new_zipfile_serialization=True  # Use optimized serialization
                )
                idx += 1
                
            except Exception as e:
                print(f"Error processing {raw_path}: {e}")
                continue
    
    def len(self) -> int:
        """Return number of samples."""
        return len(self.processed_file_names)
    
    def get(self, idx: int) -> Data:
        """Load and return a single graph."""
        data = torch.load(
            osp.join(self.processed_dir, f'data_{idx}.pt'),
            weights_only=False  # Required for PyG Data objects
        )
        return data

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool
from torch_geometric.data import Data, Batch
import math


class MotionTopologyEnhancement(nn.Module):
    """
    Motion Topology Enhancement (MTE) module to capture rich motion representations.
    Learns dynamic graph topology based on motion features.
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        # Learnable matrices for topology modeling
        self.theta = nn.Linear(in_channels, out_channels)
        self.phi = nn.Linear(in_channels, out_channels)
        
        # Graph convolution
        self.gcn = GCNConv(in_channels, out_channels)
        
    def forward(self, x, edge_index):
        """
        Args:
            x: Node features (num_nodes, in_channels)
            edge_index: Graph connectivity (2, num_edges)
        """
        # Compute attention-based topology
        theta_x = self.theta(x)  # (num_nodes, out_channels)
        phi_x = self.phi(x)      # (num_nodes, out_channels)
        
        # Apply GCN with original topology
        out = self.gcn(x, edge_index)
        
        return out


class PrototypeReconstructionNetwork(nn.Module):
    """
    Prototype Reconstruction Network (PRN) that decomposes features
    into learnable prototypes representing motion patterns.
    """
    def __init__(self, feature_dim, num_prototypes=128, prototype_dim=256):
        super().__init__()
        self.num_prototypes = num_prototypes
        self.prototype_dim = prototype_dim
        
        # Learnable prototype memory
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, prototype_dim))
        nn.init.xavier_uniform_(self.prototypes)
        
        # Feature projection
        self.feature_proj = nn.Sequential(
            nn.Linear(feature_dim, prototype_dim),
            nn.ReLU(inplace=True)
        )
        
        # Reconstruction projection
        self.recon_proj = nn.Linear(prototype_dim, feature_dim)
        
    def forward(self, x):
        """
        Args:
            x: Input features (batch_size, feature_dim)
        Returns:
            reconstructed: Reconstructed features
            proto_scores: Prototype assignment scores
        """
        # Handle single sample case
        if x.dim() == 1:
            x = x.unsqueeze(0)
        
        batch_size = x.size(0)
        
        # Project features to prototype space
        x_proj = self.feature_proj(x)  # (batch_size, prototype_dim)
        
        # Compute similarity to prototypes (cosine similarity)
        x_norm = F.normalize(x_proj, p=2, dim=1)
        proto_norm = F.normalize(self.prototypes, p=2, dim=1)
        
        # Similarity matrix: (batch_size, num_prototypes)
        proto_scores = torch.matmul(x_norm, proto_norm.t())
        
        # Soft assignment using softmax
        proto_weights = F.softmax(proto_scores / 0.1, dim=1)  # Temperature = 0.1
        
        # Reconstruct features as weighted combination of prototypes
        reconstructed_proj = torch.matmul(proto_weights, self.prototypes)
        
        # Project back to original feature space
        reconstructed = self.recon_proj(reconstructed_proj)
        
        return reconstructed, proto_scores


class TemporalConvNet(nn.Module):
    """
    Temporal convolution to aggregate features across time.
    """
    def __init__(self, in_channels, out_channels, kernel_size=9):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        """
        Args:
            x: (batch_size, channels, time_steps)
        """
        return self.conv(x)


class ProtoGCN(nn.Module):
    """
    ProtoGCN: Prototype-based Graph Convolutional Network for Sign Language Recognition.
    
    The model consists of:
    1. Motion Topology Enhancement (MTE) - learns dynamic graph topology
    2. Spatial-Temporal GCN layers - extract spatio-temporal features
    3. Prototype Reconstruction Network (PRN) - decomposes into motion prototypes
    4. Classification head with contrastive learning support
    """
    def __init__(
        self,
        num_nodes,
        in_channels=3,  # x, y, confidence
        hidden_channels=64,
        num_classes=100,
        num_gcn_layers=4,
        num_prototypes=128,
        dropout=0.5,
        temporal_kernel_size=9
    ):
        super().__init__()
        
        self.num_nodes = num_nodes
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes
        self.num_gcn_layers = num_gcn_layers
        
        # Input projection (process each node independently)
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        # Spatial GCN layers with MTE
        self.gcn_layers = nn.ModuleList()
        self.mte_layers = nn.ModuleList()
        self.temporal_convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        for i in range(num_gcn_layers):
            self.gcn_layers.append(GCNConv(hidden_channels, hidden_channels))
            self.mte_layers.append(MotionTopologyEnhancement(hidden_channels, hidden_channels))
            self.temporal_convs.append(TemporalConvNet(hidden_channels, hidden_channels, temporal_kernel_size))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        self.dropout = nn.Dropout(dropout)
        
        # Prototype Reconstruction Network
        self.prn = PrototypeReconstructionNetwork(
            feature_dim=hidden_channels,
            num_prototypes=num_prototypes,
            prototype_dim=256
        )
        
        # Classification head
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels // 2, num_classes)
        )
        
        # Projection head for contrastive learning
        self.projector = nn.Sequential(
            nn.Linear(hidden_channels, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256)
        )
        
    def forward(self, data, return_embedding=False):
        """
        Args:
            data: PyG Data object with:
                - x: (T, N, 3) node features for single sample
                - edge_index: (2, num_edges) graph connectivity
                - num_frames: T
                - num_nodes: N
        Returns:
            logits: Class predictions
            embedding: Feature embedding (if return_embedding=True)
            proto_scores: Prototype scores (for contrastive learning)
        """
        x = data.x  # (T, N, 3)
        edge_index = data.edge_index  # (2, num_edges)
        
        # Get dimensions
        if x.dim() == 3:
            T, N, C = x.shape
        else:
            # Batched case would be handled differently
            raise NotImplementedError("Batched processing not yet implemented")
        
        # Input projection: (T, N, 3) -> (T, N, hidden_channels)
        x = self.input_proj(x)  # (T, N, hidden_channels)
        
        # Process through GCN layers with temporal convolution
        for i in range(self.num_gcn_layers):
            # Spatial processing: apply GCN at each timestep
            # Reshape: (T, N, C) -> (T*N, C)
            x_flat = x.reshape(T * N, self.hidden_channels)
            
            # Expand edge_index for all timesteps
            # Create edge_index for each timestep: add offset for each frame
            edge_index_expanded = []
            for t in range(T):
                edge_index_t = edge_index + (t * N)
                edge_index_expanded.append(edge_index_t)
            edge_index_temporal = torch.cat(edge_index_expanded, dim=1)  # (2, T*num_edges)
            
            # Apply GCN and MTE
            x_gcn = self.gcn_layers[i](x_flat, edge_index_temporal)
            x_mte = self.mte_layers[i](x_flat, edge_index_temporal)
            
            # Combine and normalize
            x_spatial = x_gcn + x_mte  # (T*N, hidden_channels)
            x_spatial = self.batch_norms[i](x_spatial)
            x_spatial = F.relu(x_spatial)
            x_spatial = self.dropout(x_spatial)
            
            # Reshape back: (T*N, C) -> (T, N, C)
            x = x_spatial.reshape(T, N, self.hidden_channels)
            
            # Temporal convolution: (T, N, C) -> (N, C, T)
            x = x.permute(1, 2, 0).contiguous()  # (N, hidden_channels, T)
            x = self.temporal_convs[i](x)  # (N, hidden_channels, T)
            
            # Reshape back: (N, C, T) -> (T, N, C)
            x = x.permute(2, 0, 1).contiguous()  # (T, N, hidden_channels)
        
        # Global pooling: average over time and nodes
        x = x.mean(dim=[0, 1])  # (hidden_channels,)
        
        # Prototype reconstruction
        x_recon, proto_scores = self.prn(x)
        
        # Handle batch dimension
        if x_recon.dim() == 2:
            x_recon = x_recon.squeeze(0)
        
        # Classification
        logits = self.fc(x_recon.unsqueeze(0)).squeeze(0)
        
        # Projection for contrastive learning
        embedding = self.projector(x_recon.unsqueeze(0)).squeeze(0)
        
        if return_embedding:
            return logits, embedding, proto_scores
        return logits


class ClassSpecificContrastiveLoss(nn.Module):
    """
    Class-Specific Contrastive (CSC) Loss to enhance inter-class distinction.
    """
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: (batch_size, embedding_dim)
            labels: (batch_size,)
        """
        # Handle single sample
        if embeddings.dim() == 1:
            embeddings = embeddings.unsqueeze(0)
        
        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(embeddings, embeddings.t()) / self.temperature
        
        # Create positive and negative masks
        labels = labels.view(-1, 1)
        pos_mask = (labels == labels.t()).float()
        neg_mask = (labels != labels.t()).float()
        
        # Remove diagonal
        pos_mask = pos_mask - torch.eye(pos_mask.size(0), device=pos_mask.device)
        
        # Compute loss
        exp_sim = torch.exp(sim_matrix)
        pos_sim = (exp_sim * pos_mask).sum(dim=1)
        neg_sim = (exp_sim * neg_mask).sum(dim=1)
        
        loss = -torch.log(pos_sim / (pos_sim + neg_sim + 1e-8))
        return loss.mean()

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
from torch_geometric.loader import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm
import time
from typing import Dict, Tuple, Optional, List
import random
from sklearn.metrics import (
    precision_score, recall_score, f1_score, 
    confusion_matrix, classification_report,
    top_k_accuracy_score
)
import seaborn as sns


class DataAugmentation:
    """
    On-the-fly data augmentation for temporal graph data.
    """
    def __init__(
        self,
        drop_frame_prob=0.2,
        drop_node_prob=0.1,
        noise_std=0.01,
        temporal_crop_ratio=0.8
    ):
        self.drop_frame_prob = drop_frame_prob
        self.drop_node_prob = drop_node_prob
        self.noise_std = noise_std
        self.temporal_crop_ratio = temporal_crop_ratio
    
    def __call__(self, data, training=True):
        """Apply augmentation to data."""
        if not training:
            return data
        
        x = data.x.clone()  # (T, N, 3)
        T, N, C = x.shape
        
        # 1. Random frame dropping (temporal augmentation)
        if random.random() < self.drop_frame_prob and T > 10:
            num_frames_to_keep = max(10, int(T * self.temporal_crop_ratio))
            frame_indices = sorted(random.sample(range(T), num_frames_to_keep))
            x = x[frame_indices]
            data.num_frames = len(frame_indices)
        
        # 2. Random node dropping (spatial augmentation)
        if random.random() < self.drop_node_prob:
            num_nodes_to_drop = max(1, int(N * 0.1))
            nodes_to_drop = random.sample(range(N), num_nodes_to_drop)
            
            # Set dropped nodes to zero and update availability mask
            x[:, nodes_to_drop, :] = 0
            if hasattr(data, 'availability_mask'):
                availability_mask = data.availability_mask.clone()
                availability_mask[:, nodes_to_drop] = 0
                data.availability_mask = availability_mask
        
        # 3. Add Gaussian noise to node coordinates
        if self.noise_std > 0:
            noise = torch.randn_like(x[:, :, :2]) * self.noise_std
            x[:, :, :2] = x[:, :, :2] + noise
            # Clip to valid range [0, 1]
            x[:, :, :2] = torch.clamp(x[:, :, :2], 0.0, 1.0)
        
        # 4. Random horizontal flip
        if random.random() < 0.5:
            x[:, :, 0] = 1.0 - x[:, :, 0]  # Flip x-coordinates
        
        data.x = x
        return data


class EarlyStopping:
    """Early stopping to stop training when validation loss doesn't improve."""
    def __init__(self, patience=10, min_delta=0.001, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_epoch = 0
        
        if mode == 'min':
            self.monitor_op = np.less
            self.min_delta *= -1
        else:
            self.monitor_op = np.greater
            self.min_delta *= 1
    
    def __call__(self, score, epoch):
        if self.best_score is None:
            self.best_score = score
            self.best_epoch = epoch
        elif self.monitor_op(score, self.best_score + self.min_delta):
            self.best_score = score
            self.best_epoch = epoch
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        
        return self.early_stop


class MetricsTracker:
    """Track and store training metrics including detailed classification metrics."""
    def __init__(self):
        self.metrics = {
            'train_loss': [],
            'train_acc': [],
            'train_f1': [],
            'train_precision': [],
            'train_recall': [],
            'val_loss': [],
            'val_acc': [],
            'val_f1': [],
            'val_precision': [],
            'val_recall': [],
            'val_top5_acc': [],
            'learning_rates': []
        }
        self.best_confusion_matrix = None
        self.best_classification_report = None
    
    def update(self, **kwargs):
        for key, value in kwargs.items():
            if key in self.metrics:
                self.metrics[key].append(value)
    
    def save(self, path):
        # Save numerical metrics
        metrics_to_save = {k: v for k, v in self.metrics.items() 
                          if k not in ['best_confusion_matrix', 'best_classification_report']}
        with open(path, 'w') as f:
            json.dump(metrics_to_save, f, indent=2)
    
    def plot(self, save_dir):
        """Plot comprehensive training curves."""
        save_dir = Path(save_dir)
        
        # Create main training curves plot
        fig, axes = plt.subplots(3, 2, figsize=(15, 15))
        
        # Loss curves
        axes[0, 0].plot(self.metrics['train_loss'], label='Train Loss', linewidth=2, color='#3498db')
        axes[0, 0].plot(self.metrics['val_loss'], label='Val Loss', linewidth=2, color='#e74c3c')
        axes[0, 0].set_xlabel('Epoch', fontsize=12)
        axes[0, 0].set_ylabel('Loss', fontsize=12)
        axes[0, 0].set_title('Loss Curves', fontsize=14, fontweight='bold')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Accuracy curves
        axes[0, 1].plot(self.metrics['train_acc'], label='Train Acc', linewidth=2, color='#3498db')
        axes[0, 1].plot(self.metrics['val_acc'], label='Val Acc', linewidth=2, color='#e74c3c')
        if len(self.metrics['val_top5_acc']) > 0:
            axes[0, 1].plot(self.metrics['val_top5_acc'], label='Val Top-5 Acc', 
                           linewidth=2, color='#2ecc71', linestyle='--')
        axes[0, 1].set_xlabel('Epoch', fontsize=12)
        axes[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
        axes[0, 1].set_title('Accuracy Curves', fontsize=14, fontweight='bold')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # F1 Score curves
        axes[1, 0].plot(self.metrics['train_f1'], label='Train F1', linewidth=2, color='#3498db')
        axes[1, 0].plot(self.metrics['val_f1'], label='Val F1', linewidth=2, color='#e74c3c')
        axes[1, 0].set_xlabel('Epoch', fontsize=12)
        axes[1, 0].set_ylabel('F1 Score', fontsize=12)
        axes[1, 0].set_title('F1 Score Curves', fontsize=14, fontweight='bold')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Precision curves
        axes[1, 1].plot(self.metrics['train_precision'], label='Train Precision', linewidth=2, color='#3498db')
        axes[1, 1].plot(self.metrics['val_precision'], label='Val Precision', linewidth=2, color='#e74c3c')
        axes[1, 1].set_xlabel('Epoch', fontsize=12)
        axes[1, 1].set_ylabel('Precision', fontsize=12)
        axes[1, 1].set_title('Precision Curves', fontsize=14, fontweight='bold')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        # Recall curves
        axes[2, 0].plot(self.metrics['train_recall'], label='Train Recall', linewidth=2, color='#3498db')
        axes[2, 0].plot(self.metrics['val_recall'], label='Val Recall', linewidth=2, color='#e74c3c')
        axes[2, 0].set_xlabel('Epoch', fontsize=12)
        axes[2, 0].set_ylabel('Recall', fontsize=12)
        axes[2, 0].set_title('Recall Curves', fontsize=14, fontweight='bold')
        axes[2, 0].legend()
        axes[2, 0].grid(True, alpha=0.3)
        
        # Learning rate
        axes[2, 1].plot(self.metrics['learning_rates'], linewidth=2, color='#9b59b6')
        axes[2, 1].set_xlabel('Epoch', fontsize=12)
        axes[2, 1].set_ylabel('Learning Rate', fontsize=12)
        axes[2, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        axes[2, 1].set_yscale('log')
        axes[2, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(save_dir / 'training_curves.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # Create summary statistics plot
        self._plot_summary(save_dir)
        
        # Plot confusion matrix if available
        if self.best_confusion_matrix is not None:
            self._plot_confusion_matrix(save_dir)
        
        print(f"Training curves saved to {save_dir / 'training_curves.png'}")
    
    def _plot_summary(self, save_dir):
        """Plot summary statistics."""
        fig, ax = plt.subplots(figsize=(10, 8))
        
        best_val_acc = max(self.metrics['val_acc'])
        best_epoch = self.metrics['val_acc'].index(best_val_acc)
        best_val_f1 = self.metrics['val_f1'][best_epoch] if len(self.metrics['val_f1']) > best_epoch else 0
        best_val_precision = self.metrics['val_precision'][best_epoch] if len(self.metrics['val_precision']) > best_epoch else 0
        best_val_recall = self.metrics['val_recall'][best_epoch] if len(self.metrics['val_recall']) > best_epoch else 0
        
        summary_text = f"""
        TRAINING SUMMARY
        {'='*50}
        
        Best Validation Metrics (Epoch {best_epoch + 1}):
          • Accuracy:  {best_val_acc:.2f}%
          • F1 Score:  {best_val_f1:.4f}
          • Precision: {best_val_precision:.4f}
          • Recall:    {best_val_recall:.4f}
        
        {'='*50}
        
        Final Metrics:
          • Train Loss: {self.metrics['train_loss'][-1]:.4f}
          • Val Loss:   {self.metrics['val_loss'][-1]:.4f}
          • Train Acc:  {self.metrics['train_acc'][-1]:.2f}%
          • Val Acc:    {self.metrics['val_acc'][-1]:.2f}%
          • Val F1:     {self.metrics['val_f1'][-1]:.4f}
        
        {'='*50}
        
        Total Epochs: {len(self.metrics['train_loss'])}
        """
        
        ax.text(0.1, 0.5, summary_text, fontsize=11, verticalalignment='center',
               fontfamily='monospace',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
        ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(save_dir / 'training_summary.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_confusion_matrix(self, save_dir):
        """Plot confusion matrix."""
        if self.best_confusion_matrix is None:
            return
        
        cm = self.best_confusion_matrix
        
        # If too many classes, show a smaller subset or simplified view
        if cm.shape[0] > 20:
            # Show only top 20 most confused classes
            fig, ax = plt.subplots(figsize=(12, 10))
            sns.heatmap(cm[:20, :20], annot=False, fmt='d', cmap='Blues', ax=ax)
            ax.set_title('Confusion Matrix (Top 20 Classes)', fontsize=14, fontweight='bold')
            ax.set_xlabel('Predicted Label', fontsize=12)
            ax.set_ylabel('True Label', fontsize=12)
        else:
            fig, ax = plt.subplots(figsize=(12, 10))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
            ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
            ax.set_xlabel('Predicted Label', fontsize=12)
            ax.set_ylabel('True Label', fontsize=12)
        
        plt.tight_layout()
        plt.savefig(save_dir / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Confusion matrix saved to {save_dir / 'confusion_matrix.png'}")
    
    def save_classification_report(self, save_dir):
        """Save detailed classification report."""
        if self.best_classification_report is not None:
            report_path = save_dir / 'classification_report.txt'
            with open(report_path, 'w') as f:
                f.write("BEST MODEL CLASSIFICATION REPORT\n")
                f.write("="*70 + "\n\n")
                f.write(self.best_classification_report)
            print(f"Classification report saved to {report_path}")


def compute_metrics(predictions: List[int], targets: List[int], 
                   probabilities: Optional[np.ndarray] = None) -> Dict[str, float]:
    """
    Compute comprehensive classification metrics.
    
    Args:
        predictions: List of predicted class indices
        targets: List of true class indices
        probabilities: Optional array of class probabilities for top-k accuracy
        
    Returns:
        Dictionary of metrics
    """
    metrics = {}
    
    # Basic metrics
    metrics['accuracy'] = 100 * np.mean(np.array(predictions) == np.array(targets))
    
    # Precision, Recall, F1 (weighted average for multi-class)
    metrics['precision'] = precision_score(targets, predictions, average='weighted', zero_division=0)
    metrics['recall'] = recall_score(targets, predictions, average='weighted', zero_division=0)
    metrics['f1'] = f1_score(targets, predictions, average='weighted', zero_division=0)
    
    # Macro-averaged metrics (treats all classes equally)
    metrics['precision_macro'] = precision_score(targets, predictions, average='macro', zero_division=0)
    metrics['recall_macro'] = recall_score(targets, predictions, average='macro', zero_division=0)
    metrics['f1_macro'] = f1_score(targets, predictions, average='macro', zero_division=0)
    
    # Top-5 accuracy if probabilities provided
    if probabilities is not None and probabilities.shape[1] >= 5:
        metrics['top5_accuracy'] = 100 * top_k_accuracy_score(
            targets, probabilities, k=5, labels=list(range(probabilities.shape[1]))
        )
    
    return metrics


def train_epoch(model, loader, optimizer, criterion, contrastive_criterion, 
                augmentation, device, lambda_contrast=0.1, lambda_recon=0.1):
    """Train for one epoch with detailed metrics."""
    model.train()
    total_loss = 0
    all_predictions = []
    all_targets = []
    all_probs = []
    
    pbar = tqdm(loader, desc='Training')
    for data in pbar:
        # Apply augmentation
        data = augmentation(data, training=True)
        data = data.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits, embeddings, proto_scores = model(data, return_embedding=True)
        
        # Classification loss
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        loss_cls = criterion(logits, data.y)
        
        # Contrastive loss (if batch size > 1)
        loss_contrast = torch.tensor(0.0, device=device)
        if embeddings.dim() == 1:
            embeddings = embeddings.unsqueeze(0)
        if embeddings.size(0) > 1 and data.y.size(0) > 1:
            loss_contrast = contrastive_criterion(embeddings, data.y)
        
        # Reconstruction loss (prototype diversity)
        loss_recon = torch.tensor(0.0, device=device)
        if proto_scores.size(0) > 0:
            proto_entropy = -(proto_scores.softmax(dim=1) * proto_scores.log_softmax(dim=1)).sum(dim=1).mean()
            loss_recon = -proto_entropy
        
        # Total loss
        loss = loss_cls + lambda_contrast * loss_contrast + lambda_recon * loss_recon
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Collect predictions and targets
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        probs = F.softmax(logits, dim=1)
        
        all_predictions.extend(pred.cpu().numpy())
        all_targets.extend(data.y.cpu().numpy())
        all_probs.append(probs.detach().cpu().numpy())
        
        # Update progress bar
        current_acc = 100 * np.mean(np.array(all_predictions) == np.array(all_targets))
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{current_acc:.2f}%'})
    
    # Compute comprehensive metrics
    avg_loss = total_loss / len(loader)
    all_probs = np.vstack(all_probs)
    metrics = compute_metrics(all_predictions, all_targets, all_probs)
    
    return avg_loss, metrics


@torch.no_grad()
def evaluate(model, loader, criterion, device, return_detailed=False):
    """
    Evaluate model on validation/test set with comprehensive metrics.
    
    Args:
        model: Model to evaluate
        loader: Data loader
        criterion: Loss function
        device: Device to use
        return_detailed: If True, return confusion matrix and classification report
        
    Returns:
        avg_loss: Average loss
        metrics: Dictionary of metrics
        confusion_mat: Confusion matrix (if return_detailed=True)
        class_report: Classification report (if return_detailed=True)
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_targets = []
    all_probs = []
    
    pbar = tqdm(loader, desc='Evaluating')
    for data in pbar:
        data = data.to(device)
        
        # Forward pass
        logits = model(data, return_embedding=False)
        
        # Loss
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        loss = criterion(logits, data.y)
        
        # Collect predictions and targets
        total_loss += loss.item()
        pred = logits.argmax(dim=1)
        probs = F.softmax(logits, dim=1)
        
        all_predictions.extend(pred.cpu().numpy())
        all_targets.extend(data.y.cpu().numpy())
        all_probs.append(probs.cpu().numpy())
        
        # Update progress bar
        current_acc = 100 * np.mean(np.array(all_predictions) == np.array(all_targets))
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{current_acc:.2f}%'})
    
    # Compute comprehensive metrics
    avg_loss = total_loss / len(loader)
    all_probs = np.vstack(all_probs)
    metrics = compute_metrics(all_predictions, all_targets, all_probs)
    
    if return_detailed:
        # Confusion matrix
        confusion_mat = confusion_matrix(all_targets, all_predictions)
        
        # Classification report
        class_report = classification_report(
            all_targets, all_predictions,
            target_names=[f'Class_{i}' for i in range(max(all_targets) + 1)],
            zero_division=0
        )
        
        return avg_loss, metrics, confusion_mat, class_report
    
    return avg_loss, metrics


def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=100,
    learning_rate=0.001,
    weight_decay=1e-4,
    patience=15,
    save_dir='checkpoints',
    device='cuda',
    lambda_contrast=0.1,
    lambda_recon=0.1
):
    """
    Complete training pipeline with early stopping and learning rate scheduling.
    
    Args:
        model: ProtoGCN model
        train_loader: Training data loader
        val_loader: Validation data loader
        num_epochs: Maximum number of epochs
        learning_rate: Initial learning rate
        weight_decay: L2 regularization
        patience: Early stopping patience
        save_dir: Directory to save checkpoints
        device: Device to train on
        lambda_contrast: Weight for contrastive loss
        lambda_recon: Weight for reconstruction loss
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    # Move model to device
    model = model.to(device)
    
    # Optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Loss functions
    criterion = nn.CrossEntropyLoss()
    contrastive_criterion = ClassSpecificContrastiveLoss(temperature=0.1)
    
    # Data augmentation
    augmentation = DataAugmentation(
        drop_frame_prob=0.2,
        drop_node_prob=0.1,
        noise_std=0.01,
        temporal_crop_ratio=0.8
    )
    
    # Early stopping and metrics
    early_stopping = EarlyStopping(patience=patience, mode='max')  # Monitor validation accuracy
    metrics_tracker = MetricsTracker()
    
    best_val_acc = 0.0
    
    print(f"\n{'='*60}")
    print(f"Starting training on {device}")
    print(f"{'='*60}\n")
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print(f"{'-'*60}")
        
        # Train
        train_loss, train_metrics = train_epoch(
            model, train_loader, optimizer, criterion, contrastive_criterion,
            augmentation, device, lambda_contrast, lambda_recon
        )
        
        # Validate
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)
        
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # Track metrics
        metrics_tracker.update(
            train_loss=train_loss,
            train_acc=train_metrics['accuracy'],
            train_f1=train_metrics['f1'],
            train_precision=train_metrics['precision'],
            train_recall=train_metrics['recall'],
            val_loss=val_loss,
            val_acc=val_metrics['accuracy'],
            val_f1=val_metrics['f1'],
            val_precision=val_metrics['precision'],
            val_recall=val_metrics['recall'],
            val_top5_acc=val_metrics.get('top5_accuracy', 0),
            learning_rates=current_lr
        )
        
        # Print summary
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Train - Loss: {train_loss:.4f} | Acc: {train_metrics['accuracy']:.2f}% | "
              f"F1: {train_metrics['f1']:.4f} | Precision: {train_metrics['precision']:.4f} | "
              f"Recall: {train_metrics['recall']:.4f}")
        print(f"  Val   - Loss: {val_loss:.4f} | Acc: {val_metrics['accuracy']:.2f}% | "
              f"F1: {val_metrics['f1']:.4f} | Precision: {val_metrics['precision']:.4f} | "
              f"Recall: {val_metrics['recall']:.4f}")
        if 'top5_accuracy' in val_metrics:
            print(f"  Val Top-5 Accuracy: {val_metrics['top5_accuracy']:.2f}%")
        print(f"  Learning Rate: {current_lr:.6f}")
        
        # Save best model and get detailed metrics
        if val_metrics['accuracy'] > best_val_acc:
            best_val_acc = val_metrics['accuracy']
            
            # Get detailed metrics for best model
            _, _, confusion_mat, class_report = evaluate(
                model, val_loader, criterion, device, return_detailed=True
            )
            metrics_tracker.best_confusion_matrix = confusion_mat
            metrics_tracker.best_classification_report = class_report
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_acc': val_metrics['accuracy'],
                'val_loss': val_loss,
                'val_f1': val_metrics['f1'],
                'val_precision': val_metrics['precision'],
                'val_recall': val_metrics['recall'],
            }, save_dir / 'best_model.pth')
            print(f"  ✓ New best model saved! (Val Acc: {val_metrics['accuracy']:.2f}%, "
                  f"F1: {val_metrics['f1']:.4f})")
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
            }, save_dir / f'checkpoint_epoch_{epoch+1}.pth')
        
        # Early stopping
        if early_stopping(val_metrics['accuracy'], epoch):
            print(f"\n{'='*60}")
            print(f"Early stopping triggered at epoch {epoch + 1}")
            print(f"Best validation accuracy: {best_val_acc:.2f}% at epoch {early_stopping.best_epoch + 1}")
            print(f"{'='*60}\n")
            break
    
    # Training complete
    training_time = time.time() - start_time
    print(f"\n{'='*60}")
    print(f"Training completed in {training_time / 60:.2f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"{'='*60}\n")
    
    # Save final model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, save_dir / 'final_model.pth')
    
    # Save metrics and plots
    metrics_tracker.save(save_dir / 'metrics.json')
    metrics_tracker.plot(save_dir)
    metrics_tracker.save_classification_report(save_dir)
    
    return model, metrics_tracker


@torch.no_grad()
def inference(model, data, device='cuda', return_probs=False):
    """
    Perform inference on a single sample or batch.
    
    Args:
        model: Trained ProtoGCN model
        data: PyG Data object
        device: Device to run inference on
        return_probs: Whether to return class probabilities
        
    Returns:
        prediction: Predicted class index
        probabilities: Class probabilities (if return_probs=True)
        confidence: Prediction confidence
    """
    model.eval()
    data = data.to(device)
    
    # Forward pass
    logits = model(data, return_embedding=False)
    
    # Get probabilities
    if len(logits.shape) == 1:
        logits = logits.unsqueeze(0)
    probs = F.softmax(logits, dim=1)
    
    # Get prediction and confidence
    confidence, prediction = probs.max(dim=1)
    
    if return_probs:
        return prediction.item(), probs.squeeze().cpu().numpy(), confidence.item()
    return prediction.item(), confidence.item()


def load_model(model, checkpoint_path, device='cuda'):
    """Load model from checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    print(f"Model loaded from {checkpoint_path}")
    if 'val_acc' in checkpoint:
        print(f"Checkpoint metrics:")
        print(f"  • Accuracy:  {checkpoint['val_acc']:.2f}%")
        if 'val_f1' in checkpoint:
            print(f"  • F1 Score:  {checkpoint['val_f1']:.4f}")
            print(f"  • Precision: {checkpoint['val_precision']:.4f}")
            print(f"  • Recall:    {checkpoint['val_recall']:.4f}")
    return model


@torch.no_grad()
def detailed_evaluation(model, loader, device='cuda', save_dir='results'):
    """
    Perform detailed evaluation with comprehensive metrics and visualizations.
    
    Args:
        model: Trained model
        loader: Data loader
        device: Device to use
        save_dir: Directory to save results
    """
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    
    model.eval()
    all_predictions = []
    all_targets = []
    all_probs = []
    
    print("\nPerforming detailed evaluation...")
    for data in tqdm(loader, desc='Evaluating'):
        data = data.to(device)
        logits = model(data, return_embedding=False)
        
        if len(logits.shape) == 1:
            logits = logits.unsqueeze(0)
        
        pred = logits.argmax(dim=1)
        probs = F.softmax(logits, dim=1)
        
        all_predictions.extend(pred.cpu().numpy())
        all_targets.extend(data.y.cpu().numpy())
        all_probs.append(probs.cpu().numpy())
    
    all_probs = np.vstack(all_probs)
    
    # Compute all metrics
    metrics = compute_metrics(all_predictions, all_targets, all_probs)
    
    # Print metrics
    print("\n" + "="*70)
    print("DETAILED EVALUATION RESULTS")
    print("="*70)
    print(f"\nAccuracy Metrics:")
    print(f"  • Top-1 Accuracy:  {metrics['accuracy']:.2f}%")
    if 'top5_accuracy' in metrics:
        print(f"  • Top-5 Accuracy:  {metrics['top5_accuracy']:.2f}%")
    
    print(f"\nWeighted Average Metrics:")
    print(f"  • Precision:       {metrics['precision']:.4f}")
    print(f"  • Recall:          {metrics['recall']:.4f}")
    print(f"  • F1 Score:        {metrics['f1']:.4f}")
    
    print(f"\nMacro Average Metrics:")
    print(f"  • Precision:       {metrics['precision_macro']:.4f}")
    print(f"  • Recall:          {metrics['recall_macro']:.4f}")
    print(f"  • F1 Score:        {metrics['f1_macro']:.4f}")
    print("="*70 + "\n")
    
    # Confusion matrix
    confusion_mat = confusion_matrix(all_targets, all_predictions)
    
    # Classification report
    num_classes = len(np.unique(all_targets))
    class_report = classification_report(
        all_targets, all_predictions,
        target_names=[f'Class_{i}' for i in range(num_classes)],
        zero_division=0
    )
    
    # Save results
    with open(save_dir / 'evaluation_metrics.json', 'w') as f:
        json.dump(metrics, f, indent=2)
    
    with open(save_dir / 'classification_report.txt', 'w') as f:
        f.write("EVALUATION CLASSIFICATION REPORT\n")
        f.write("="*70 + "\n\n")
        f.write(class_report)
    
    # Plot confusion matrix
    if num_classes <= 20:
        fig, ax = plt.subplots(figsize=(12, 10))
        sns.heatmap(confusion_mat, annot=True, fmt='d', cmap='Blues', ax=ax)
        ax.set_title('Confusion Matrix', fontsize=14, fontweight='bold')
        ax.set_xlabel('Predicted Label', fontsize=12)
        ax.set_ylabel('True Label', fontsize=12)
    else:
        # Show subset for many classes
        fig, ax = plt.subplots(figsize=(14, 12))
        sns.heatmap(confusion_mat[:20, :20], annot=False, fmt='d', cmap='Blues', ax=ax)
        ax.set_title('Confusion Matrix (Top 20 Classes)', fontsize=14, fontweight='bold')
        ax.set_xlabel('Predicted Label', fontsize=12)
        ax.set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(save_dir / 'confusion_matrix.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Results saved to {save_dir}/")
    print(f"  • evaluation_metrics.json")
    print(f"  • classification_report.txt")
    print(f"  • confusion_matrix.png")
    
    return metrics, confusion_mat, class_report


# Import the contrastive loss class


# Example usage
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load dataset
dataset = SignLanguageDataset(root='data/sign_language')

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

# Initialize model
model = ProtoGCN(
    num_nodes=114,
    in_channels=3,
    hidden_channels=64,
    num_classes=2,
    num_gcn_layers=4,
    num_prototypes=128,
    dropout=0.5
)

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model
trained_model, metrics = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=100,
    learning_rate=0.001,
    weight_decay=1e-4,
    patience=5,
    save_dir='checkpoints',
    device=device,
    lambda_contrast=0.1,
    lambda_recon=0.1
)

# Load best model for inference
best_model = ProtoGCN(
    num_nodes=114,
    in_channels=3,
    hidden_channels=64,
    num_classes=2,
    num_gcn_layers=4,
    num_prototypes=128,
    dropout=0.5
)
best_model = load_model(best_model, 'checkpoints/best_model.pth', device)

# Detailed evaluation
print("\n" + "="*70)
print("PERFORMING DETAILED EVALUATION ON VALIDATION SET")
print("="*70)
detailed_metrics, conf_matrix, class_rep = detailed_evaluation(
    best_model, val_loader, device, save_dir='results'
)

# Test inference
print("\n" + "="*70)
print("TESTING INFERENCE")
print("="*70)
test_data = dataset[0]
prediction, confidence = inference(best_model, test_data, device)
print(f"\nSingle Sample Inference:")
print(f"  • Prediction: Class {prediction}")
print(f"  • Confidence: {confidence:.4f}")

# Test with probabilities
prediction, probs, confidence = inference(best_model, test_data, device, return_probs=True)
print(f"\nTop 5 Predictions:")
top5_indices = np.argsort(probs)[-5:][::-1]
for i, idx in enumerate(top5_indices):
    print(f"  {i+1}. Class {idx}: {probs[idx]:.4f}")
print("="*70)

Using device: cuda

Model parameters: 365,986

Starting training on cuda


Epoch 1/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10.95it/s, loss=0.2340, acc=25.00%]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 56.70it/s, loss=0.7379, acc=0.00%]



Epoch 1 Summary:
  Train - Loss: 0.2191 | Acc: 25.00% | F1: 0.1000 | Precision: 0.0625 | Recall: 0.2500
  Val   - Loss: 0.7379 | Acc: 0.00% | F1: 0.0000 | Precision: 0.0000 | Recall: 0.0000
  Learning Rate: 0.001000

Epoch 2/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 55.66it/s, loss=0.1846, acc=50.00%]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 85.81it/s, loss=0.7232, acc=0.00%]



Epoch 2 Summary:
  Train - Loss: 0.2311 | Acc: 50.00% | F1: 0.5000 | Precision: 0.8333 | Recall: 0.5000
  Val   - Loss: 0.7232 | Acc: 0.00% | F1: 0.0000 | Precision: 0.0000 | Recall: 0.0000
  Learning Rate: 0.000999

Epoch 3/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 59.29it/s, loss=0.2259, acc=25.00%]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 90.25it/s, loss=0.7135, acc=0.00%]



Epoch 3 Summary:
  Train - Loss: 0.2263 | Acc: 25.00% | F1: 0.1000 | Precision: 0.0625 | Recall: 0.2500
  Val   - Loss: 0.7135 | Acc: 0.00% | F1: 0.0000 | Precision: 0.0000 | Recall: 0.0000
  Learning Rate: 0.000998

Epoch 4/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 60.69it/s, loss=0.1836, acc=75.00%]
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 86.85it/s, loss=0.6948, acc=0.00%]



Epoch 4 Summary:
  Train - Loss: 0.1990 | Acc: 75.00% | F1: 0.7667 | Precision: 0.8750 | Recall: 0.7500
  Val   - Loss: 0.6948 | Acc: 0.00% | F1: 0.0000 | Precision: 0.0000 | Recall: 0.0000
  Learning Rate: 0.000996

Epoch 5/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 54.75it/s, loss=0.1686, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 77.53it/s, loss=0.6729, acc=100.00%]



Epoch 5 Summary:
  Train - Loss: 0.1911 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.6729 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000994


Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 70.75it/s, loss=0.6729, acc=100.00%]


  ✓ New best model saved! (Val Acc: 100.00%, F1: 1.0000)

Epoch 6/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 57.85it/s, loss=0.2100, acc=50.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 80.01it/s, loss=0.6508, acc=100.00%]



Epoch 6 Summary:
  Train - Loss: 0.2114 | Acc: 50.00% | F1: 0.5000 | Precision: 0.5000 | Recall: 0.5000
  Val   - Loss: 0.6508 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000991

Epoch 7/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 55.82it/s, loss=0.1240, acc=50.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 88.30it/s, loss=0.6040, acc=100.00%]



Epoch 7 Summary:
  Train - Loss: 0.1559 | Acc: 50.00% | F1: 0.5000 | Precision: 0.5000 | Recall: 0.5000
  Val   - Loss: 0.6040 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000988

Epoch 8/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 65.85it/s, loss=0.2741, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 81.73it/s, loss=0.5516, acc=100.00%]



Epoch 8 Summary:
  Train - Loss: 0.1383 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.5516 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000984

Epoch 9/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 59.10it/s, loss=0.5709, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 88.05it/s, loss=0.4980, acc=100.00%]



Epoch 9 Summary:
  Train - Loss: 0.2032 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.4980 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000980

Epoch 10/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 65.35it/s, loss=-0.1128, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 89.76it/s, loss=0.4463, acc=100.00%]



Epoch 10 Summary:
  Train - Loss: 0.0456 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.4463 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000976

Epoch 11/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 61.96it/s, loss=-0.2486, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 84.49it/s, loss=0.3894, acc=100.00%]



Epoch 11 Summary:
  Train - Loss: 0.0605 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.3894 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000970

Epoch 12/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 60.94it/s, loss=-0.2733, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 89.41it/s, loss=0.3227, acc=100.00%]



Epoch 12 Summary:
  Train - Loss: -0.0887 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.3227 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000965

Epoch 13/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 60.03it/s, loss=-0.2852, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 92.40it/s, loss=0.2502, acc=100.00%]



Epoch 13 Summary:
  Train - Loss: 0.0053 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.2502 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000959

Epoch 14/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 64.45it/s, loss=-0.3332, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 93.45it/s, loss=0.1959, acc=100.00%]



Epoch 14 Summary:
  Train - Loss: 0.0790 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.1959 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000952

Epoch 15/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 60.12it/s, loss=-0.3991, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 89.81it/s, loss=0.1526, acc=100.00%]



Epoch 15 Summary:
  Train - Loss: -0.1196 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.1526 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000946

Epoch 16/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 63.18it/s, loss=-0.2370, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 86.27it/s, loss=0.1178, acc=100.00%]



Epoch 16 Summary:
  Train - Loss: 0.2665 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.1178 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000938

Epoch 17/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 61.65it/s, loss=-0.4420, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 92.41it/s, loss=0.0864, acc=100.00%]



Epoch 17 Summary:
  Train - Loss: 0.6400 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.0864 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000930

Epoch 18/100
------------------------------------------------------------


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 59.88it/s, loss=1.5496, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 86.46it/s, loss=0.0654, acc=100.00%]



Epoch 18 Summary:
  Train - Loss: 0.0687 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.0654 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000922

Epoch 19/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 63.25it/s, loss=-0.0526, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 90.21it/s, loss=0.0559, acc=100.00%]



Epoch 19 Summary:
  Train - Loss: 0.5193 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.0559 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000914

Epoch 20/100
------------------------------------------------------------


Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 61.64it/s, loss=-0.2974, acc=75.00%]
Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 91.67it/s, loss=0.0483, acc=100.00%]



Epoch 20 Summary:
  Train - Loss: -0.0187 | Acc: 75.00% | F1: 0.6429 | Precision: 0.5625 | Recall: 0.7500
  Val   - Loss: 0.0483 | Acc: 100.00% | F1: 1.0000 | Precision: 1.0000 | Recall: 1.0000
  Learning Rate: 0.000905

Early stopping triggered at epoch 20
Best validation accuracy: 100.00% at epoch 5


Training completed in 0.04 minutes
Best validation accuracy: 100.00%

Confusion matrix saved to checkpoints/confusion_matrix.png
Training curves saved to checkpoints/training_curves.png
Classification report saved to checkpoints/classification_report.txt
Model loaded from checkpoints/best_model.pth
Checkpoint metrics:
  • Accuracy:  100.00%
  • F1 Score:  1.0000
  • Precision: 1.0000
  • Recall:    1.0000

PERFORMING DETAILED EVALUATION ON VALIDATION SET

Performing detailed evaluation...


Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 91.74it/s]


DETAILED EVALUATION RESULTS

Accuracy Metrics:
  • Top-1 Accuracy:  100.00%

Weighted Average Metrics:
  • Precision:       1.0000
  • Recall:          1.0000
  • F1 Score:        1.0000

Macro Average Metrics:
  • Precision:       1.0000
  • Recall:          1.0000
  • F1 Score:        1.0000






Results saved to results/
  • evaluation_metrics.json
  • classification_report.txt
  • confusion_matrix.png

TESTING INFERENCE

Single Sample Inference:
  • Prediction: Class 0
  • Confidence: 0.5103

Top 5 Predictions:
  1. Class 0: 0.5103
  2. Class 1: 0.4897
