# Model Evaluation Notebook

This notebook provides comprehensive evaluation of trained FSLR models (Transformer and IV3-GRU) on test data.

## What this notebook does:
1. **Loads trained models** from checkpoints (both Transformer and IV3-GRU)
2. **Evaluates on test data** with comprehensive metrics
3. **Visualizes results** with interactive plots and confusion matrices
4. **Analyzes errors** and misclassification patterns
5. **Compares models** side-by-side (optional)

## Prerequisites:
- Trained model checkpoints in `data/processed/`
- Test data prepared in the same format as training data
- Test labels CSV with columns: `file,gloss,cat`

## Usage:
1. Update the configuration cell with your model and data paths
2. Run cells sequentially
3. Explore results interactively


In [None]:
# Imports and setup
import sys
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# Add project root to path
CWD = Path.cwd()
ROOT = CWD.parent if CWD.name == 'notebooks' else CWD
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

# Core imports
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from sklearn.metrics import accuracy_score, top_k_accuracy_score
from sklearn.metrics import cohen_kappa_score, matthews_corrcoef
from scipy import stats
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.offline as pyo

# Project imports
from models.transformer import SignTransformer
from models.iv3_gru import InceptionV3GRU
from training.utils import FSLDataset, evaluate
from training.train import FSLKeypointFileDataset, FSLFeatureFileDataset, collate_keypoints_with_padding, collate_features_with_padding
from torch.utils.data import DataLoader

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

print("All imports successful!")
print(f"Project root: {ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"Device available: {'CUDA' if torch.cuda.is_available() else 'CPU'}")


## Configuration

Update these paths to point to your trained models and test data.


In [None]:
# Configuration - Update these paths for your setup
CONFIG = {
    # Model selection
    'model_type': 'transformer',  # 'transformer' or 'iv3_gru'
    'model_path': 'data/processed/SignTransformer_best.pt',  # Path to checkpoint
    
    # Test data paths
    'test_data_path': 'data/processed/test_keypoints',  # Directory with test .npz files
    'test_labels_path': 'data/processed/test_labels.csv',  # CSV with test labels
    
    # Model parameters (should match training)
    'num_gloss': 105,
    'num_cat': 10,
    'hidden1': 16,  # For IV3-GRU
    'hidden2': 12,  # For IV3-GRU
    'dropout': 0.3,  # For IV3-GRU
    
    # Evaluation parameters
    'batch_size': 32,
    'num_workers': 0,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Data format (for IV3-GRU)
    'feature_key': 'X2048',  # Key in .npz files for features
    'kp_key': 'X',  # Key in .npz files for keypoints
    
    # Visualization
    'show_plots': True,
    'save_plots': False,
    'plot_dpi': 100,
    
    # Statistical analysis
    'confidence_level': 0.95,
    'bootstrap_samples': 1000,
}

def validate_configuration(config):
    """Validate configuration and check file paths"""
    print("Configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # Validate paths
    model_path = ROOT / config['model_path']
    test_data_path = ROOT / config['test_data_path']
    test_labels_path = ROOT / config['test_labels_path']
    
    print(f"\nPath validation:")
    print(f"  Model exists: {model_path.exists()}")
    print(f"  Test data exists: {test_data_path.exists()}")
    print(f"  Test labels exist: {test_labels_path.exists()}")
    
    missing_paths = []
    if not model_path.exists():
        missing_paths.append("Model checkpoint")
    if not test_data_path.exists():
        missing_paths.append("Test data directory")
    if not test_labels_path.exists():
        missing_paths.append("Test labels file")
    
    if missing_paths:
        print(f"WARNING: Missing required files: {', '.join(missing_paths)}")
        print("Please update the configuration above with correct paths.")
        return False
    
    return True

# Display and validate configuration
config_valid = validate_configuration(CONFIG)


## Model Loading

Load the trained model from checkpoint and prepare for evaluation.


In [None]:
def load_model(model_type, model_path, device, **model_kwargs):
    """Load trained model from checkpoint with comprehensive error handling"""
    print(f"Loading {model_type} model from {model_path}")
    
    try:
        # Check if file exists first
        if not model_path.exists():
            raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location=device)
        print(f"Checkpoint loaded (epoch {checkpoint.get('epoch', 'unknown')})")
        
        # Validate checkpoint structure
        if 'model' not in checkpoint:
            raise KeyError("Checkpoint missing 'model' key - invalid checkpoint format")
        
        # Create model architecture
        if model_type == 'transformer':
            model = SignTransformer(
                num_gloss=model_kwargs['num_gloss'],
                num_cat=model_kwargs['num_cat']
            )
        elif model_type == 'iv3_gru':
            model = InceptionV3GRU(
                num_gloss=model_kwargs['num_gloss'],
                num_cat=model_kwargs['num_cat'],
                hidden1=model_kwargs['hidden1'],
                hidden2=model_kwargs['hidden2'],
                dropout=model_kwargs['dropout'],
                pretrained_backbone=True,
                freeze_backbone=True
            )
        else:
            raise ValueError(f"Unknown model type: {model_type}. Expected 'transformer' or 'iv3_gru'")
        
        # Load state dict with error checking
        try:
            model.load_state_dict(checkpoint['model'])
        except RuntimeError as e:
            print(f"WARNING: Model state dict mismatch: {e}")
            print("Attempting to load with strict=False...")
            model.load_state_dict(checkpoint['model'], strict=False)
        
        model.to(device)
        model.eval()
        
        # Get model info
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"Model loaded successfully!")
        print(f"  - Model type: {model.__class__.__name__}")
        print(f"  - Total parameters: {total_params:,}")
        print(f"  - Trainable parameters: {trainable_params:,}")
        print(f"  - Model size: {total_params * 4 / 1024 / 1024:.1f} MB")
        print(f"  - Best validation metric: {checkpoint.get('best_metric', 'unknown')}")
        print(f"  - Device: {device}")
        
        # Validate model is in eval mode
        if model.training:
            print("WARNING: Model is in training mode, switching to eval mode")
            model.eval()
        
        return model, checkpoint
        
    except FileNotFoundError:
        print(f"ERROR: Model file not found at {model_path}")
        print("Please check the model_path in the configuration section")
        raise
    except KeyError as e:
        print(f"ERROR: Invalid checkpoint format - missing key: {e}")
        print("The checkpoint file may be corrupted or from an incompatible version")
        raise
    except RuntimeError as e:
        print(f"ERROR: Model architecture mismatch: {e}")
        print("The checkpoint may have been saved with different model parameters")
        raise
    except Exception as e:
        print(f"ERROR: Unexpected error loading model: {type(e).__name__}: {e}")
        raise

# Load the model with better error handling
if config_valid:
    try:
        model, checkpoint = load_model(
            model_type=CONFIG['model_type'],
            model_path=ROOT / CONFIG['model_path'],
            device=CONFIG['device'],
            **{k: v for k, v in CONFIG.items() if k in ['num_gloss', 'num_cat', 'hidden1', 'hidden2', 'dropout']}
        )
        print("Model loading successful - ready for evaluation")
    except Exception as e:
        print(f"Failed to load model: {e}")
        print("Please check your configuration and model file")
        model = None
        checkpoint = None
else:
    print("Skipping model loading due to configuration validation failure.")
    model = None
    checkpoint = None

## Test Data Loading

Load and prepare the test dataset for evaluation.


In [None]:
def load_test_data(config):
    """Load test dataset and create dataloader with comprehensive validation"""
    print(f"Loading test data from {config['test_data_path']}")
    
    try:
        # Validate paths exist
        test_data_path = ROOT / config['test_data_path']
        test_labels_path = ROOT / config['test_labels_path']
        
        if not test_data_path.exists():
            raise FileNotFoundError(f"Test data directory not found: {test_data_path}")
        if not test_labels_path.exists():
            raise FileNotFoundError(f"Test labels file not found: {test_labels_path}")
        
        # Load labels
        labels_df = pd.read_csv(test_labels_path)
        print(f"Loaded {len(labels_df)} test samples")
        
        # Validate labels format
        required_columns = ['file', 'gloss', 'cat']
        missing_columns = [col for col in required_columns if col not in labels_df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns in labels: {missing_columns}")
        
        # Check for missing values
        if labels_df.isnull().any().any():
            print("WARNING: Found missing values in labels. Filling with defaults.")
            labels_df = labels_df.fillna({'gloss': 0, 'cat': 0})
        
        # Validate class ranges
        max_gloss = labels_df['gloss'].max()
        max_cat = labels_df['cat'].max()
        if max_gloss >= config['num_gloss']:
            raise ValueError(f"Gloss label {max_gloss} exceeds expected range [0, {config['num_gloss']-1}]")
        if max_cat >= config['num_cat']:
            raise ValueError(f"Category label {max_cat} exceeds expected range [0, {config['num_cat']-1}]")
        
        print(f"Label validation passed:")
        print(f"  - Gloss range: [0, {max_gloss}] (expected max: {config['num_gloss']-1})")
        print(f"  - Category range: [0, {max_cat}] (expected max: {config['num_cat']-1})")
        
        # Create dataset based on model type
        if config['model_type'] == 'transformer':
            # Use keypoint dataset for transformer
            dataset = FSLKeypointFileDataset(
                keypoints_dir=test_data_path,
                labels_csv=test_labels_path,
                kp_key=config['kp_key']
            )
            collate_fn = collate_keypoints_with_padding
            print("Using keypoint dataset for Transformer model")
        elif config['model_type'] == 'iv3_gru':
            # Use feature dataset for IV3-GRU
            dataset = FSLFeatureFileDataset(
                features_dir=test_data_path,
                labels_csv=test_labels_path,
                feature_key=config['feature_key']
            )
            collate_fn = collate_features_with_padding
            print("Using feature dataset for IV3-GRU model")
        else:
            raise ValueError(f"Unknown model type: {config['model_type']}")
        
        # Validate dataset has data
        if len(dataset) == 0:
            raise ValueError("Dataset is empty - check data paths and file formats")
        
        # Test loading one sample to validate format
        try:
            sample = dataset[0]
            print(f"Dataset sample validation:")
            if len(sample) == 4:  # With lengths
                X, gloss, cat, length = sample
                print(f"  - Input shape: {X.shape}")
                print(f"  - Length: {length}")
            else:  # Without lengths
                X, gloss, cat = sample
                print(f"  - Input shape: {X.shape}")
                print(f"  - No length info")
            print(f"  - Gloss: {gloss}, Category: {cat}")
        except Exception as e:
            print(f"WARNING: Failed to validate sample format: {e}")
        
        # Create dataloader
        dataloader = DataLoader(
            dataset,
            batch_size=config['batch_size'],
            shuffle=False,  # Don't shuffle for consistent evaluation
            num_workers=config['num_workers'],
            collate_fn=collate_fn,
            pin_memory=True if config['device'] == 'cuda' else False,
            drop_last=False  # Keep all samples for evaluation
        )
        
        print(f"Test dataloader created with {len(dataloader)} batches")
        
        # Display dataset statistics
        print(f"\nTest dataset statistics:")
        print(f"  - Total samples: {len(dataset)}")
        print(f"  - Batch size: {config['batch_size']}")
        print(f"  - Number of batches: {len(dataloader)}")
        print(f"  - Gloss classes: {labels_df['gloss'].nunique()}")
        print(f"  - Category classes: {labels_df['cat'].nunique()}")
        
        # Show class distribution
        print(f"\nClass distribution:")
        gloss_counts = labels_df['gloss'].value_counts().sort_index()
        print(f"  - Gloss distribution:")
        print(f"    Range: {gloss_counts.min()} to {gloss_counts.max()} samples per class")
        print(f"    Mean: {gloss_counts.mean():.1f} samples per class")
        print(f"    Std: {gloss_counts.std():.1f} samples per class")
        
        cat_counts = labels_df['cat'].value_counts().sort_index()
        print(f"  - Category distribution:")
        print(f"    Range: {cat_counts.min()} to {cat_counts.max()} samples per class")
        print(f"    Mean: {cat_counts.mean():.1f} samples per class")
        print(f"    Std: {cat_counts.std():.1f} samples per class")
        
        # Check for class imbalance
        gloss_imbalance = gloss_counts.std() / gloss_counts.mean() if gloss_counts.mean() > 0 else 0
        cat_imbalance = cat_counts.std() / cat_counts.mean() if cat_counts.mean() > 0 else 0
        
        if gloss_imbalance > 0.5:
            print(f"  WARNING: High gloss class imbalance (CV: {gloss_imbalance:.2f})")
        if cat_imbalance > 0.5:
            print(f"  WARNING: High category class imbalance (CV: {cat_imbalance:.2f})")
        
        # Check for missing classes
        expected_gloss_classes = set(range(config['num_gloss']))
        actual_gloss_classes = set(labels_df['gloss'].unique())
        missing_gloss = expected_gloss_classes - actual_gloss_classes
        if missing_gloss:
            print(f"  WARNING: Missing gloss classes in test data: {sorted(missing_gloss)}")
        
        expected_cat_classes = set(range(config['num_cat']))
        actual_cat_classes = set(labels_df['cat'].unique())
        missing_cat = expected_cat_classes - actual_cat_classes
        if missing_cat:
            print(f"  WARNING: Missing category classes in test data: {sorted(missing_cat)}")
        
        return dataset, dataloader, labels_df
        
    except FileNotFoundError as e:
        print(f"ERROR: {e}")
        print("Please check the test data paths in the configuration")
        raise
    except ValueError as e:
        print(f"ERROR: Data validation failed: {e}")
        raise
    except Exception as e:
        print(f"ERROR: Unexpected error loading test data: {type(e).__name__}: {e}")
        raise

# Load test data with improved validation
if config_valid:
    try:
        test_dataset, test_dataloader, test_labels_df = load_test_data(CONFIG)
        print("Test data loading successful - ready for evaluation")
    except Exception as e:
        print(f"Failed to load test data: {e}")
        print("Please check your configuration and data files")
        test_dataset = None
        test_dataloader = None
        test_labels_df = None
else:
    print("Skipping test data loading due to configuration validation failure.")
    test_dataset = None
    test_dataloader = None
    test_labels_df = None

## Model Evaluation

Run comprehensive evaluation on the test dataset and calculate detailed metrics.


In [None]:
def calculate_confidence_intervals(accuracies, confidence_level=0.95):
    """Calculate confidence intervals for accuracy metrics using bootstrap"""
    if len(accuracies) == 0:
        return 0.0, 0.0
    
    n = len(accuracies)
    alpha = 1 - confidence_level
    lower_percentile = (alpha / 2) * 100
    upper_percentile = (1 - alpha / 2) * 100
    
    lower_bound = np.percentile(accuracies, lower_percentile)
    upper_bound = np.percentile(accuracies, upper_percentile)
    
    return lower_bound, upper_bound

def evaluate_model_comprehensive(model, dataloader, device, model_type, config):
    """Comprehensive model evaluation with detailed metrics and statistical analysis"""
    print(f"Running comprehensive evaluation...")
    
    if model is None:
        raise ValueError("Model is None - cannot evaluate")
    if dataloader is None:
        raise ValueError("Dataloader is None - cannot evaluate")
    
    model.eval()
    device = torch.device(device)
    
    # Storage for predictions and targets
    all_gloss_preds = []
    all_cat_preds = []
    all_gloss_targets = []
    all_cat_targets = []
    all_gloss_probs = []
    all_cat_probs = []
    
    # Loss calculation
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    num_batches = 0
    
    # Create forward function based on model type with better error handling
    if model_type == 'transformer':
        def forward_fn(m, X, lengths=None):
            try:
                if lengths is not None:
                    B, T, _ = X.shape
                    device = X.device
                    time_indices = torch.arange(T, device=device).unsqueeze(0)
                    mask = (time_indices < lengths.unsqueeze(1))
                else:
                    mask = None
                return m(X, mask=mask)
            except Exception as e:
                print(f"Error in transformer forward pass: {e}")
                raise
    else:  # iv3_gru
        def forward_fn(m, X, lengths=None):
            try:
                return m(X, lengths=lengths, features_already=True)
            except Exception as e:
                print(f"Error in IV3-GRU forward pass: {e}")
                raise
    
    # Evaluation loop with progress tracking
    total_samples = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            try:
                if len(batch) == 4:
                    X, gloss_targets, cat_targets, lengths = batch
                    lengths = lengths.to(device)
                else:
                    X, gloss_targets, cat_targets = batch
                    lengths = None
                
                X = X.to(device)
                gloss_targets = gloss_targets.to(device)
                cat_targets = cat_targets.to(device)
                
                # Validate input shapes
                if X.dim() != 3:
                    raise ValueError(f"Expected 3D input tensor [B, T, D], got {X.shape}")
                
                # Forward pass
                gloss_logits, cat_logits = forward_fn(model, X, lengths)
                
                # Validate output shapes
                if gloss_logits.dim() != 2 or cat_logits.dim() != 2:
                    raise ValueError(f"Invalid output dimensions: gloss {gloss_logits.shape}, cat {cat_logits.shape}")
                
                # Calculate loss
                loss_gloss = criterion(gloss_logits, gloss_targets)
                loss_cat = criterion(cat_logits, cat_targets)
                batch_loss = loss_gloss + loss_cat
                total_loss += batch_loss.item()
                num_batches += 1
                total_samples += X.shape[0]
                
                # Get predictions
                gloss_preds = gloss_logits.argmax(dim=1)
                cat_preds = cat_logits.argmax(dim=1)
                
                # Get probabilities
                gloss_probs = torch.softmax(gloss_logits, dim=1)
                cat_probs = torch.softmax(cat_logits, dim=1)
                
                # Store results
                all_gloss_preds.extend(gloss_preds.cpu().numpy())
                all_cat_preds.extend(cat_preds.cpu().numpy())
                all_gloss_targets.extend(gloss_targets.cpu().numpy())
                all_cat_targets.extend(cat_targets.cpu().numpy())
                all_gloss_probs.extend(gloss_probs.cpu().numpy())
                all_cat_probs.extend(cat_probs.cpu().numpy())
                
                if (batch_idx + 1) % 10 == 0:
                    print(f"  Processed {batch_idx + 1}/{len(dataloader)} batches ({total_samples} samples)")
                    
            except Exception as e:
                print(f"Error processing batch {batch_idx}: {e}")
                print(f"Batch info: X shape: {X.shape if 'X' in locals() else 'Unknown'}")
                raise
    
    print(f"Evaluation completed! Processed {total_samples} samples in {num_batches} batches")
    
    # Convert to numpy arrays
    all_gloss_preds = np.array(all_gloss_preds)
    all_cat_preds = np.array(all_cat_preds)
    all_gloss_targets = np.array(all_gloss_targets)
    all_cat_targets = np.array(all_cat_targets)
    all_gloss_probs = np.array(all_gloss_probs)
    all_cat_probs = np.array(all_cat_probs)
    
    # Validate arrays
    if len(all_gloss_preds) == 0:
        raise ValueError("No predictions collected - check data loading and model forward pass")
    
    # Calculate basic metrics
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    gloss_accuracy = accuracy_score(all_gloss_targets, all_gloss_preds)
    cat_accuracy = accuracy_score(all_cat_targets, all_cat_preds)
    
    # Top-k accuracy (with error handling)
    try:
        gloss_top3 = top_k_accuracy_score(all_gloss_targets, all_gloss_probs, k=min(3, all_gloss_probs.shape[1]))
        gloss_top5 = top_k_accuracy_score(all_gloss_targets, all_gloss_probs, k=min(5, all_gloss_probs.shape[1]))
    except Exception as e:
        print(f"Warning: Could not calculate top-k accuracy: {e}")
        gloss_top3 = gloss_accuracy
        gloss_top5 = gloss_accuracy
    
    # Additional metrics
    try:
        gloss_kappa = cohen_kappa_score(all_gloss_targets, all_gloss_preds)
        cat_kappa = cohen_kappa_score(all_cat_targets, all_cat_preds)
        gloss_mcc = matthews_corrcoef(all_gloss_targets, all_gloss_preds)
        cat_mcc = matthews_corrcoef(all_cat_targets, all_cat_preds)
    except Exception as e:
        print(f"Warning: Could not calculate advanced metrics: {e}")
        gloss_kappa = cat_kappa = gloss_mcc = cat_mcc = 0.0
    
    # Bootstrap confidence intervals
    n_samples = config.get('bootstrap_samples', 1000)
    confidence_level = config.get('confidence_level', 0.95)
    
    print(f"Calculating confidence intervals with {n_samples} bootstrap samples...")
    
    # Bootstrap gloss accuracy
    gloss_accuracies = []
    for _ in range(n_samples):
        indices = np.random.choice(len(all_gloss_targets), len(all_gloss_targets), replace=True)
        sample_acc = accuracy_score(all_gloss_targets[indices], all_gloss_preds[indices])
        gloss_accuracies.append(sample_acc)
    
    # Bootstrap category accuracy
    cat_accuracies = []
    for _ in range(n_samples):
        indices = np.random.choice(len(all_cat_targets), len(all_cat_targets), replace=True)
        sample_acc = accuracy_score(all_cat_targets[indices], all_cat_preds[indices])
        cat_accuracies.append(sample_acc)
    
    # Calculate confidence intervals
    gloss_ci_lower, gloss_ci_upper = calculate_confidence_intervals(gloss_accuracies, confidence_level)
    cat_ci_lower, cat_ci_upper = calculate_confidence_intervals(cat_accuracies, confidence_level)
    
    print(f"Evaluation completed!")
    print(f"  - Average loss: {avg_loss:.4f}")
    print(f"  - Gloss accuracy: {gloss_accuracy:.4f} [{gloss_ci_lower:.4f}, {gloss_ci_upper:.4f}]")
    print(f"  - Category accuracy: {cat_accuracy:.4f} [{cat_ci_lower:.4f}, {cat_ci_upper:.4f}]")
    print(f"  - Gloss top-3 accuracy: {gloss_top3:.4f}")
    print(f"  - Gloss top-5 accuracy: {gloss_top5:.4f}")
    print(f"  - Gloss Cohen's Kappa: {gloss_kappa:.4f}")
    print(f"  - Category Cohen's Kappa: {cat_kappa:.4f}")
    print(f"  - Gloss Matthews Correlation: {gloss_mcc:.4f}")
    print(f"  - Category Matthews Correlation: {cat_mcc:.4f}")
    
    return {
        'loss': avg_loss,
        'gloss_accuracy': gloss_accuracy,
        'cat_accuracy': cat_accuracy,
        'gloss_top3_accuracy': gloss_top3,
        'gloss_top5_accuracy': gloss_top5,
        'gloss_kappa': gloss_kappa,
        'cat_kappa': cat_kappa,
        'gloss_mcc': gloss_mcc,
        'cat_mcc': cat_mcc,
        'gloss_ci_lower': gloss_ci_lower,
        'gloss_ci_upper': gloss_ci_upper,
        'cat_ci_lower': cat_ci_lower,
        'cat_ci_upper': cat_ci_upper,
        'gloss_predictions': all_gloss_preds,
        'cat_predictions': all_cat_preds,
        'gloss_targets': all_gloss_targets,
        'cat_targets': all_cat_targets,
        'gloss_probabilities': all_gloss_probs,
        'cat_probabilities': all_cat_probs
    }

# Run evaluation with better error handling
if (config_valid and 'model' in locals() and model is not None and 
    'test_dataloader' in locals() and test_dataloader is not None):
    try:
        results = evaluate_model_comprehensive(
            model, test_dataloader, CONFIG['device'], CONFIG['model_type'], CONFIG
        )
        print("Model evaluation successful!")
    except Exception as e:
        print(f"Failed to evaluate model: {e}")
        print("Please check your model, data, and configuration")
        results = None
else:
    print("Skipping evaluation due to missing model or test data.")
    results = None

## Overall Performance Summary

Display key performance metrics in a clear summary format.


In [None]:
def create_performance_summary(results, model_type, checkpoint_info=None):
    """Create a comprehensive performance summary with statistical analysis"""
    
    # Basic metrics with confidence intervals
    summary_data = {
        'Metric': [
            'Test Loss',
            'Gloss Accuracy',
            'Category Accuracy', 
            'Gloss Top-3 Accuracy',
            'Gloss Top-5 Accuracy',
            'Gloss Cohen\'s Kappa',
            'Category Cohen\'s Kappa',
            'Gloss Matthews Correlation',
            'Category Matthews Correlation'
        ],
        'Value': [
            f"{results['loss']:.4f}",
            f"{results['gloss_accuracy']:.4f}",
            f"{results['cat_accuracy']:.4f}",
            f"{results['gloss_top3_accuracy']:.4f}",
            f"{results['gloss_top5_accuracy']:.4f}",
            f"{results['gloss_kappa']:.4f}",
            f"{results['cat_kappa']:.4f}",
            f"{results['gloss_mcc']:.4f}",
            f"{results['cat_mcc']:.4f}"
        ],
        'Percentage': [
            f"{results['loss']:.2f}",
            f"{results['gloss_accuracy']*100:.2f}%",
            f"{results['cat_accuracy']*100:.2f}%",
            f"{results['gloss_top3_accuracy']*100:.2f}%",
            f"{results['gloss_top5_accuracy']*100:.2f}%",
            f"{results['gloss_kappa']*100:.2f}%",
            f"{results['cat_kappa']*100:.2f}%",
            f"{results['gloss_mcc']*100:.2f}%",
            f"{results['cat_mcc']*100:.2f}%"
        ],
        'Confidence Interval': [
            'N/A',
            f"[{results['gloss_ci_lower']:.4f}, {results['gloss_ci_upper']:.4f}]",
            f"[{results['cat_ci_lower']:.4f}, {results['cat_ci_upper']:.4f}]",
            'N/A',
            'N/A',
            'N/A',
            'N/A',
            'N/A',
            'N/A'
        ]
    }
    
    summary_df = pd.DataFrame(summary_data)
    
    print("PERFORMANCE SUMMARY")
    print("=" * 80)
    print(f"Model Type: {model_type.upper()}")
    if checkpoint_info:
        print(f"Training Epoch: {checkpoint_info.get('epoch', 'Unknown')}")
        print(f"Best Validation Metric: {checkpoint_info.get('best_metric', 'Unknown'):.4f}")
    print("=" * 80)
    
    # Display table
    print(summary_df.to_string(index=False))
    
    # Additional insights
    print(f"\nKey Insights:")
    print(f"  â€¢ Gloss classification: {results['gloss_accuracy']*100:.1f}% accuracy")
    print(f"    - 95% CI: [{results['gloss_ci_lower']*100:.1f}%, {results['gloss_ci_upper']*100:.1f}%]")
    print(f"    - Cohen's Kappa: {results['gloss_kappa']:.3f} (agreement quality)")
    print(f"    - Matthews Correlation: {results['gloss_mcc']:.3f} (balanced accuracy)")
    
    print(f"  â€¢ Category classification: {results['cat_accuracy']*100:.1f}% accuracy")
    print(f"    - 95% CI: [{results['cat_ci_lower']*100:.1f}%, {results['cat_ci_upper']*100:.1f}%]")
    print(f"    - Cohen's Kappa: {results['cat_kappa']:.3f} (agreement quality)")
    print(f"    - Matthews Correlation: {results['cat_mcc']:.3f} (balanced accuracy)")
    
    print(f"  â€¢ Top-k performance:")
    print(f"    - Top-3: {results['gloss_top3_accuracy']*100:.1f}% (good for ranking)")
    print(f"    - Top-5: {results['gloss_top5_accuracy']*100:.1f}% (excellent for ranking)")
    
    # Performance assessment
    print(f"\nPerformance Assessment:")
    
    # Gloss performance
    if results['gloss_accuracy'] > 0.8:
        print(f"  â€¢ EXCELLENT gloss classification performance!")
    elif results['gloss_accuracy'] > 0.6:
        print(f"  â€¢ GOOD gloss classification performance")
    elif results['gloss_accuracy'] > 0.4:
        print(f"  â€¢ FAIR gloss classification performance")
    else:
        print(f"  â€¢ POOR gloss classification performance - needs improvement")
    
    # Category performance
    if results['cat_accuracy'] > 0.7:
        print(f"  â€¢ EXCELLENT category classification performance!")
    elif results['cat_accuracy'] > 0.5:
        print(f"  â€¢ GOOD category classification performance")
    elif results['cat_accuracy'] > 0.3:
        print(f"  â€¢ FAIR category classification performance")
    else:
        print(f"  â€¢ POOR category classification performance - needs improvement")
    
    # Agreement quality assessment
    if results['gloss_kappa'] > 0.8:
        print(f"  â€¢ EXCELLENT gloss agreement quality (Kappa > 0.8)")
    elif results['gloss_kappa'] > 0.6:
        print(f"  â€¢ GOOD gloss agreement quality (Kappa > 0.6)")
    elif results['gloss_kappa'] > 0.4:
        print(f"  â€¢ MODERATE gloss agreement quality (Kappa > 0.4)")
    else:
        print(f"  â€¢ POOR gloss agreement quality (Kappa < 0.4)")
    
    return summary_df

# Create and display summary
if 'results' in locals():
    summary_df = create_performance_summary(results, CONFIG['model_type'], checkpoint)
else:
    print("No results available for summary. Please run evaluation first.")


## Confusion Matrix Analysis

Visualize confusion matrices for both gloss and category classification to understand misclassification patterns.


In [None]:
def plot_confusion_matrices(results, show_plots=True, save_plots=False):
    """Plot confusion matrices for gloss and category classification with enhanced analysis"""
    
    # Calculate confusion matrices
    gloss_cm = confusion_matrix(results['gloss_targets'], results['gloss_predictions'])
    cat_cm = confusion_matrix(results['cat_targets'], results['cat_predictions'])
    
    # Create subplots with better layout
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    
    # Gloss confusion matrix (raw counts)
    sns.heatmap(gloss_cm, annot=False, fmt='d', cmap='Blues', ax=ax1)
    ax1.set_title('Gloss Classification Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Predicted Gloss Class')
    ax1.set_ylabel('True Gloss Class')
    
    # Gloss confusion matrix (normalized)
    gloss_cm_norm = gloss_cm.astype('float') / gloss_cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(gloss_cm_norm, annot=False, fmt='.3f', cmap='Blues', ax=ax2)
    ax2.set_title('Gloss Classification Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Predicted Gloss Class')
    ax2.set_ylabel('True Gloss Class')
    
    # Category confusion matrix (raw counts)
    sns.heatmap(cat_cm, annot=True, fmt='d', cmap='Greens', ax=ax3)
    ax3.set_title('Category Classification Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Predicted Category Class')
    ax3.set_ylabel('True Category Class')
    
    # Category confusion matrix (normalized)
    cat_cm_norm = cat_cm.astype('float') / cat_cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cat_cm_norm, annot=True, fmt='.3f', cmap='Greens', ax=ax4)
    ax4.set_title('Category Classification Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Predicted Category Class')
    ax4.set_ylabel('True Category Class')
    
    plt.tight_layout()
    
    if show_plots:
        plt.show()
    
    if save_plots:
        plt.savefig('confusion_matrices.png', dpi=300, bbox_inches='tight')
        print("Confusion matrices saved as 'confusion_matrices.png'")
    
    # Print confusion matrix statistics
    print("Confusion Matrix Statistics:")
    print(f"  Gloss matrix shape: {gloss_cm.shape}")
    print(f"  Category matrix shape: {cat_cm.shape}")
    print(f"  Gloss diagonal accuracy: {np.trace(gloss_cm) / np.sum(gloss_cm):.4f}")
    print(f"  Category diagonal accuracy: {np.trace(cat_cm) / np.sum(cat_cm):.4f}")
    
    # Additional analysis
    print(f"\nDetailed Analysis:")
    
    # Gloss analysis
    gloss_diag = np.diag(gloss_cm)
    gloss_row_sums = np.sum(gloss_cm, axis=1)
    gloss_col_sums = np.sum(gloss_cm, axis=0)
    
    print(f"  Gloss Classification:")
    print(f"    - Best predicted class: {np.argmax(gloss_col_sums)} ({np.max(gloss_col_sums)} predictions)")
    print(f"    - Most frequent true class: {np.argmax(gloss_row_sums)} ({np.max(gloss_row_sums)} samples)")
    print(f"    - Most accurate class: {np.argmax(gloss_diag / gloss_row_sums)} ({np.max(gloss_diag / gloss_row_sums):.3f} accuracy)")
    print(f"    - Least accurate class: {np.argmin(gloss_diag / gloss_row_sums)} ({np.min(gloss_diag / gloss_row_sums):.3f} accuracy)")
    
    # Category analysis
    cat_diag = np.diag(cat_cm)
    cat_row_sums = np.sum(cat_cm, axis=1)
    cat_col_sums = np.sum(cat_cm, axis=0)
    
    print(f"  Category Classification:")
    print(f"    - Best predicted class: {np.argmax(cat_col_sums)} ({np.max(cat_col_sums)} predictions)")
    print(f"    - Most frequent true class: {np.argmax(cat_row_sums)} ({np.max(cat_row_sums)} samples)")
    print(f"    - Most accurate class: {np.argmax(cat_diag / cat_row_sums)} ({np.max(cat_diag / cat_row_sums):.3f} accuracy)")
    print(f"    - Least accurate class: {np.argmin(cat_diag / cat_row_sums)} ({np.min(cat_diag / cat_row_sums):.3f} accuracy)")
    
    return gloss_cm, cat_cm

# Plot confusion matrices
if 'results' in locals():
    gloss_cm, cat_cm = plot_confusion_matrices(results, show_plots=CONFIG['show_plots'], save_plots=CONFIG['save_plots'])
else:
    print("No results available for confusion matrix analysis. Please run evaluation first.")


## Per-Class Performance Analysis

Analyze performance for individual classes to identify which classes are performing well and which need improvement.


In [None]:
def analyze_per_class_performance(results, num_gloss_classes, num_cat_classes):
    """Analyze per-class performance metrics with comprehensive statistics"""
    
    # Calculate per-class metrics for gloss
    gloss_report = classification_report(
        results['gloss_targets'], 
        results['gloss_predictions'], 
        output_dict=True,
        zero_division=0
    )
    
    # Calculate per-class metrics for category
    cat_report = classification_report(
        results['cat_targets'], 
        results['cat_predictions'], 
        output_dict=True,
        zero_division=0
    )
    
    # Extract per-class data for gloss
    gloss_classes = []
    gloss_precision = []
    gloss_recall = []
    gloss_f1 = []
    gloss_support = []
    
    for i in range(num_gloss_classes):
        if str(i) in gloss_report:
            gloss_classes.append(f"Gloss {i}")
            gloss_precision.append(gloss_report[str(i)]['precision'])
            gloss_recall.append(gloss_report[str(i)]['recall'])
            gloss_f1.append(gloss_report[str(i)]['f1-score'])
            gloss_support.append(gloss_report[str(i)]['support'])
        else:
            gloss_classes.append(f"Gloss {i}")
            gloss_precision.append(0.0)
            gloss_recall.append(0.0)
            gloss_f1.append(0.0)
            gloss_support.append(0)
    
    # Extract per-class data for category
    cat_classes = []
    cat_precision = []
    cat_recall = []
    cat_f1 = []
    cat_support = []
    
    for i in range(num_cat_classes):
        if str(i) in cat_report:
            cat_classes.append(f"Category {i}")
            cat_precision.append(cat_report[str(i)]['precision'])
            cat_recall.append(cat_report[str(i)]['recall'])
            cat_f1.append(cat_report[str(i)]['f1-score'])
            cat_support.append(cat_report[str(i)]['support'])
        else:
            cat_classes.append(f"Category {i}")
            cat_precision.append(0.0)
            cat_recall.append(0.0)
            cat_f1.append(0.0)
            cat_support.append(0)
    
    # Create DataFrames
    gloss_df = pd.DataFrame({
        'Class': gloss_classes,
        'Precision': gloss_precision,
        'Recall': gloss_recall,
        'F1-Score': gloss_f1,
        'Support': gloss_support
    })
    
    cat_df = pd.DataFrame({
        'Class': cat_classes,
        'Precision': cat_precision,
        'Recall': cat_recall,
        'F1-Score': cat_f1,
        'Support': cat_support
    })
    
    # Calculate additional statistics
    gloss_df['Balanced_Accuracy'] = (gloss_df['Precision'] + gloss_df['Recall']) / 2
    cat_df['Balanced_Accuracy'] = (cat_df['Precision'] + cat_df['Recall']) / 2
    
    # Display summary statistics
    print("PER-CLASS PERFORMANCE SUMMARY")
    print("=" * 80)
    
    # Overall statistics
    print(f"\nOverall Statistics:")
    print(f"  Gloss Classification:")
    print(f"    - Mean F1-Score: {gloss_df['F1-Score'].mean():.3f} Â± {gloss_df['F1-Score'].std():.3f}")
    print(f"    - Median F1-Score: {gloss_df['F1-Score'].median():.3f}")
    print(f"    - Min F1-Score: {gloss_df['F1-Score'].min():.3f}")
    print(f"    - Max F1-Score: {gloss_df['F1-Score'].max():.3f}")
    
    print(f"  Category Classification:")
    print(f"    - Mean F1-Score: {cat_df['F1-Score'].mean():.3f} Â± {cat_df['F1-Score'].std():.3f}")
    print(f"    - Median F1-Score: {cat_df['F1-Score'].median():.3f}")
    print(f"    - Min F1-Score: {cat_df['F1-Score'].min():.3f}")
    print(f"    - Max F1-Score: {cat_df['F1-Score'].max():.3f}")
    
    print(f"\nTop 10 Gloss Classes by F1-Score:")
    top_gloss = gloss_df.nlargest(10, 'F1-Score')
    print(top_gloss[['Class', 'Precision', 'Recall', 'F1-Score', 'Support']].to_string(index=False, float_format='%.3f'))
    
    print(f"\nCategory Classification Performance:")
    print(cat_df[['Class', 'Precision', 'Recall', 'F1-Score', 'Support']].to_string(index=False, float_format='%.3f'))
    
    # Identify problematic classes
    print(f"\nProblematic Gloss Classes (F1 < 0.5):")
    problematic_gloss = gloss_df[gloss_df['F1-Score'] < 0.5]
    if len(problematic_gloss) > 0:
        print(f"  Found {len(problematic_gloss)} classes with poor performance:")
        print(problematic_gloss[['Class', 'Precision', 'Recall', 'F1-Score', 'Support']].to_string(index=False, float_format='%.3f'))
    else:
        print("  None found - all classes performing well!")
    
    print(f"\nProblematic Category Classes (F1 < 0.5):")
    problematic_cat = cat_df[cat_df['F1-Score'] < 0.5]
    if len(problematic_cat) > 0:
        print(f"  Found {len(problematic_cat)} classes with poor performance:")
        print(problematic_cat[['Class', 'Precision', 'Recall', 'F1-Score', 'Support']].to_string(index=False, float_format='%.3f'))
    else:
        print("  None found - all classes performing well!")
    
    # Performance distribution analysis
    print(f"\nPerformance Distribution Analysis:")
    
    # Gloss performance categories
    excellent_gloss = len(gloss_df[gloss_df['F1-Score'] >= 0.8])
    good_gloss = len(gloss_df[(gloss_df['F1-Score'] >= 0.6) & (gloss_df['F1-Score'] < 0.8)])
    fair_gloss = len(gloss_df[(gloss_df['F1-Score'] >= 0.4) & (gloss_df['F1-Score'] < 0.6)])
    poor_gloss = len(gloss_df[gloss_df['F1-Score'] < 0.4])
    
    print(f"  Gloss Classification Performance Distribution:")
    print(f"    - Excellent (F1 â‰¥ 0.8): {excellent_gloss} classes ({excellent_gloss/len(gloss_df)*100:.1f}%)")
    print(f"    - Good (0.6 â‰¤ F1 < 0.8): {good_gloss} classes ({good_gloss/len(gloss_df)*100:.1f}%)")
    print(f"    - Fair (0.4 â‰¤ F1 < 0.6): {fair_gloss} classes ({fair_gloss/len(gloss_df)*100:.1f}%)")
    print(f"    - Poor (F1 < 0.4): {poor_gloss} classes ({poor_gloss/len(gloss_df)*100:.1f}%)")
    
    # Category performance categories
    excellent_cat = len(cat_df[cat_df['F1-Score'] >= 0.8])
    good_cat = len(cat_df[(cat_df['F1-Score'] >= 0.6) & (cat_df['F1-Score'] < 0.8)])
    fair_cat = len(cat_df[(cat_df['F1-Score'] >= 0.4) & (cat_df['F1-Score'] < 0.6)])
    poor_cat = len(cat_df[cat_df['F1-Score'] < 0.4])
    
    print(f"  Category Classification Performance Distribution:")
    print(f"    - Excellent (F1 â‰¥ 0.8): {excellent_cat} classes ({excellent_cat/len(cat_df)*100:.1f}%)")
    print(f"    - Good (0.6 â‰¤ F1 < 0.8): {good_cat} classes ({good_cat/len(cat_df)*100:.1f}%)")
    print(f"    - Fair (0.4 â‰¤ F1 < 0.6): {fair_cat} classes ({fair_cat/len(cat_df)*100:.1f}%)")
    print(f"    - Poor (F1 < 0.4): {poor_cat} classes ({poor_cat/len(cat_df)*100:.1f}%)")
    
    return gloss_df, cat_df

# Analyze per-class performance
if 'results' in locals():
    gloss_per_class, cat_per_class = analyze_per_class_performance(
        results, 
        CONFIG['num_gloss'], 
        CONFIG['num_cat']
    )
else:
    print("No results available for per-class analysis. Please run evaluation first.")


## Error Analysis

Analyze misclassification patterns to understand what types of errors the model is making.


In [None]:
def analyze_errors(results, top_n=10):
    """Analyze misclassification patterns and errors with comprehensive statistics"""
    
    print("ERROR ANALYSIS")
    print("=" * 80)
    
    # Gloss misclassifications
    gloss_errors = results['gloss_targets'] != results['gloss_predictions']
    gloss_error_rate = np.mean(gloss_errors)
    
    print(f"Gloss Classification Errors:")
    print(f"  - Error rate: {gloss_error_rate:.4f} ({gloss_error_rate*100:.2f}%)")
    print(f"  - Total errors: {np.sum(gloss_errors)} out of {len(gloss_errors)}")
    print(f"  - Correct predictions: {np.sum(~gloss_errors)} out of {len(gloss_errors)}")
    
    # Category misclassifications
    cat_errors = results['cat_targets'] != results['cat_predictions']
    cat_error_rate = np.mean(cat_errors)
    
    print(f"\nCategory Classification Errors:")
    print(f"  - Error rate: {cat_error_rate:.4f} ({cat_error_rate*100:.2f}%)")
    print(f"  - Total errors: {np.sum(cat_errors)} out of {len(cat_errors)}")
    print(f"  - Correct predictions: {np.sum(~cat_errors)} out of {len(cat_errors)}")
    
    # Most confused class pairs for gloss
    print(f"\nMost Confused Gloss Class Pairs:")
    gloss_confusion_pairs = {}
    for true_class, pred_class in zip(results['gloss_targets'][gloss_errors], results['gloss_predictions'][gloss_errors]):
        pair = (true_class, pred_class)
        gloss_confusion_pairs[pair] = gloss_confusion_pairs.get(pair, 0) + 1
    
    # Sort by frequency
    sorted_gloss_pairs = sorted(gloss_confusion_pairs.items(), key=lambda x: x[1], reverse=True)
    
    print(f"  Top {min(top_n, len(sorted_gloss_pairs))} most confused pairs:")
    for i, ((true_class, pred_class), count) in enumerate(sorted_gloss_pairs[:top_n]):
        print(f"    {i+1:2d}. True: {true_class:3d} â†’ Pred: {pred_class:3d} ({count:3d} times)")
    
    # Most confused class pairs for category
    print(f"\nMost Confused Category Class Pairs:")
    cat_confusion_pairs = {}
    for true_class, pred_class in zip(results['cat_targets'][cat_errors], results['cat_predictions'][cat_errors]):
        pair = (true_class, pred_class)
        cat_confusion_pairs[pair] = cat_confusion_pairs.get(pair, 0) + 1
    
    # Sort by frequency
    sorted_cat_pairs = sorted(cat_confusion_pairs.items(), key=lambda x: x[1], reverse=True)
    
    print(f"  Top {min(top_n, len(sorted_cat_pairs))} most confused pairs:")
    for i, ((true_class, pred_class), count) in enumerate(sorted_cat_pairs[:top_n]):
        print(f"    {i+1:2d}. True: {true_class:3d} â†’ Pred: {pred_class:3d} ({count:3d} times)")
    
    # Confidence analysis
    print(f"\nConfidence Analysis:")
    gloss_confidences = np.max(results['gloss_probabilities'], axis=1)
    cat_confidences = np.max(results['cat_probabilities'], axis=1)
    
    print(f"  Gloss confidence statistics:")
    print(f"    - Mean: {np.mean(gloss_confidences):.4f}")
    print(f"    - Median: {np.median(gloss_confidences):.4f}")
    print(f"    - Std: {np.std(gloss_confidences):.4f}")
    print(f"    - Min: {np.min(gloss_confidences):.4f}")
    print(f"    - Max: {np.max(gloss_confidences):.4f}")
    print(f"    - 25th percentile: {np.percentile(gloss_confidences, 25):.4f}")
    print(f"    - 75th percentile: {np.percentile(gloss_confidences, 75):.4f}")
    
    print(f"  Category confidence statistics:")
    print(f"    - Mean: {np.mean(cat_confidences):.4f}")
    print(f"    - Median: {np.median(cat_confidences):.4f}")
    print(f"    - Std: {np.std(cat_confidences):.4f}")
    print(f"    - Min: {np.min(cat_confidences):.4f}")
    print(f"    - Max: {np.max(cat_confidences):.4f}")
    print(f"    - 25th percentile: {np.percentile(cat_confidences, 25):.4f}")
    print(f"    - 75th percentile: {np.percentile(cat_confidences, 75):.4f}")
    
    # Confidence vs accuracy analysis
    print(f"\nConfidence vs Accuracy Analysis:")
    
    # Define confidence thresholds
    thresholds = [0.5, 0.6, 0.7, 0.8, 0.9]
    
    for threshold in thresholds:
        high_conf_gloss = gloss_confidences > threshold
        high_conf_cat = cat_confidences > threshold
        
        if np.sum(high_conf_gloss) > 0:
            high_conf_gloss_acc = np.mean(results['gloss_targets'][high_conf_gloss] == results['gloss_predictions'][high_conf_gloss])
            print(f"  Gloss accuracy for confidence > {threshold}: {high_conf_gloss_acc:.4f} ({np.sum(high_conf_gloss)} samples)")
        
        if np.sum(high_conf_cat) > 0:
            high_conf_cat_acc = np.mean(results['cat_targets'][high_conf_cat] == results['cat_predictions'][high_conf_cat])
            print(f"  Category accuracy for confidence > {threshold}: {high_conf_cat_acc:.4f} ({np.sum(high_conf_cat)} samples)")
    
    # Error pattern analysis
    print(f"\nError Pattern Analysis:")
    
    # Analyze errors by confidence level
    low_conf_gloss = gloss_confidences < 0.5
    high_conf_gloss = gloss_confidences > 0.8
    
    if np.sum(low_conf_gloss) > 0:
        low_conf_error_rate = np.mean(gloss_errors[low_conf_gloss])
        print(f"  Low confidence gloss error rate (<0.5): {low_conf_error_rate:.4f}")
    
    if np.sum(high_conf_gloss) > 0:
        high_conf_error_rate = np.mean(gloss_errors[high_conf_gloss])
        print(f"  High confidence gloss error rate (>0.8): {high_conf_error_rate:.4f}")
    
    # Analyze errors by class frequency
    unique_gloss_targets, gloss_counts = np.unique(results['gloss_targets'], return_counts=True)
    unique_cat_targets, cat_counts = np.unique(results['cat_targets'], return_counts=True)
    
    # Find rare classes (bottom 25% by frequency)
    gloss_rare_threshold = np.percentile(gloss_counts, 25)
    cat_rare_threshold = np.percentile(cat_counts, 25)
    
    rare_gloss_classes = unique_gloss_targets[gloss_counts <= gloss_rare_threshold]
    rare_cat_classes = unique_cat_targets[cat_counts <= cat_rare_threshold]
    
    if len(rare_gloss_classes) > 0:
        rare_gloss_mask = np.isin(results['gloss_targets'], rare_gloss_classes)
        rare_gloss_error_rate = np.mean(gloss_errors[rare_gloss_mask])
        print(f"  Rare gloss classes error rate: {rare_gloss_error_rate:.4f}")
    
    if len(rare_cat_classes) > 0:
        rare_cat_mask = np.isin(results['cat_targets'], rare_cat_classes)
        rare_cat_error_rate = np.mean(cat_errors[rare_cat_mask])
        print(f"  Rare category classes error rate: {rare_cat_error_rate:.4f}")
    
    return {
        'gloss_error_rate': gloss_error_rate,
        'cat_error_rate': cat_error_rate,
        'gloss_confusion_pairs': sorted_gloss_pairs,
        'cat_confusion_pairs': sorted_cat_pairs,
        'gloss_confidences': gloss_confidences,
        'cat_confidences': cat_confidences,
        'gloss_errors': gloss_errors,
        'cat_errors': cat_errors
    }

# Analyze errors
if 'results' in locals():
    error_analysis = analyze_errors(results)
else:
    print("No results available for error analysis. Please run evaluation first.")


## Interactive Visualizations

Create interactive plots to explore the results in detail.


In [None]:
def create_interactive_plots(results, error_analysis, gloss_per_class, cat_per_class, show_plots=True):
    """Create interactive visualizations using Plotly"""
    
    if not show_plots:
        print("Interactive plots disabled in configuration")
        return
    
    # 1. Performance Metrics Comparison
    fig1 = make_subplots(
        rows=1, cols=2,
        subplot_titles=('Gloss Classification', 'Category Classification'),
        specs=[[{"type": "bar"}, {"type": "bar"}]]
    )
    
    # Gloss metrics
    gloss_metrics = ['Accuracy', 'Top-3', 'Top-5']
    gloss_values = [
        results['gloss_accuracy'],
        results['gloss_top3_accuracy'], 
        results['gloss_top5_accuracy']
    ]
    
    fig1.add_trace(
        go.Bar(x=gloss_metrics, y=gloss_values, name='Gloss', marker_color='lightblue'),
        row=1, col=1
    )
    
    # Category metrics
    cat_metrics = ['Accuracy']
    cat_values = [results['cat_accuracy']]
    
    fig1.add_trace(
        go.Bar(x=cat_metrics, y=cat_values, name='Category', marker_color='lightgreen'),
        row=1, col=2
    )
    
    fig1.update_layout(
        title="Model Performance Metrics",
        showlegend=True,
        height=400
    )
    
    fig1.show()
    
    # 2. Confidence Distribution
    fig2 = make_subplots(
        rows=1, cols=2,
        subplot_titles=('Gloss Confidence Distribution', 'Category Confidence Distribution')
    )
    
    # Gloss confidence histogram
    fig2.add_trace(
        go.Histogram(x=error_analysis['gloss_confidences'], nbinsx=30, name='Gloss Confidence'),
        row=1, col=1
    )
    
    # Category confidence histogram
    fig2.add_trace(
        go.Histogram(x=error_analysis['cat_confidences'], nbinsx=30, name='Category Confidence'),
        row=1, col=2
    )
    
    fig2.update_layout(
        title="Confidence Score Distributions",
        showlegend=True,
        height=400
    )
    
    fig2.show()
    
    # 3. Per-Class F1 Scores (Top 20 for gloss) - only if data is available
    if gloss_per_class is not None and not gloss_per_class.empty:
        top_gloss_classes = gloss_per_class.nlargest(20, 'F1-Score')
        
        fig3 = go.Figure()
        fig3.add_trace(go.Bar(
            x=top_gloss_classes['Class'],
            y=top_gloss_classes['F1-Score'],
            name='F1-Score',
            marker_color='lightcoral'
        ))
        
        fig3.update_layout(
            title="Top 20 Gloss Classes by F1-Score",
            xaxis_title="Class",
            yaxis_title="F1-Score",
            height=500
        )
        
        fig3.show()
    else:
        print("Skipping per-class F1 plot - data not available")
    
    # 4. Error Analysis - Most Confused Pairs
    if len(error_analysis['gloss_confusion_pairs']) > 0:
        top_confused = error_analysis['gloss_confusion_pairs'][:10]
        true_classes = [f"True: {pair[0][0]}" for pair in top_confused]
        pred_classes = [f"Pred: {pair[0][1]}" for pair in top_confused]
        counts = [pair[1] for pair in top_confused]
        
        fig4 = go.Figure()
        fig4.add_trace(go.Bar(
            x=[f"{t} â†’ {p}" for t, p in zip(true_classes, pred_classes)],
            y=counts,
            name='Confusion Count',
            marker_color='orange'
        ))
        
        fig4.update_layout(
            title="Top 10 Most Confused Gloss Class Pairs",
            xaxis_title="True â†’ Predicted",
            yaxis_title="Count",
            height=500,
            xaxis_tickangle=-45
        )
        
        fig4.show()
    
    print("Interactive plots created successfully!")
    print("Hover over the plots to see detailed information")
    print("Use the toolbar to zoom, pan, and download plots")

# Create interactive visualizations - now with proper parameter passing
if 'results' in locals() and 'error_analysis' in locals():
    # Check if per-class data is available
    gloss_per_class_data = gloss_per_class if 'gloss_per_class' in locals() else None
    cat_per_class_data = cat_per_class if 'cat_per_class' in locals() else None
    
    create_interactive_plots(
        results, 
        error_analysis, 
        gloss_per_class_data, 
        cat_per_class_data, 
        show_plots=CONFIG['show_plots']
    )
else:
    print("Skipping interactive plots - missing required data (results or error_analysis)")

## Model Comparison (Optional)

Compare multiple models side-by-side if you have multiple checkpoints available.


In [None]:
# Optional: Model Comparison
# Uncomment and modify this section if you want to compare multiple models

def compare_models(model_configs, test_dataloader, device):
    """Compare multiple models side-by-side"""
    
    print("ðŸ”„ Comparing multiple models...")
    
    comparison_results = []
    
    for config in model_configs:
        print(f"\nðŸ“Š Evaluating {config['name']}...")
        
        # Load model
        model, checkpoint = load_model(
            model_type=config['model_type'],
            model_path=ROOT / config['model_path'],
            device=device,
            **{k: v for k, v in config.items() if k in ['num_gloss', 'num_cat', 'hidden1', 'hidden2', 'dropout']}
        )
        
        # Evaluate model
        results = evaluate_model_comprehensive(model, test_dataloader, device, config['model_type'])
        
        # Store results
        comparison_results.append({
            'name': config['name'],
            'model_type': config['model_type'],
            'gloss_accuracy': results['gloss_accuracy'],
            'cat_accuracy': results['cat_accuracy'],
            'gloss_top3_accuracy': results['gloss_top3_accuracy'],
            'gloss_top5_accuracy': results['gloss_top5_accuracy'],
            'loss': results['loss']
        })
    
    # Create comparison DataFrame
    comparison_df = pd.DataFrame(comparison_results)
    
    print("\nðŸ“Š MODEL COMPARISON RESULTS")
    print("=" * 80)
    print(comparison_df.to_string(index=False, float_format='%.4f'))
    
    # Create comparison plot
    if CONFIG['show_plots']:
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Gloss Accuracy', 'Category Accuracy', 'Top-3 Accuracy', 'Loss'),
            specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "bar"}, {"type": "bar"}]]
        )
        
        # Gloss accuracy
        fig.add_trace(
            go.Bar(x=comparison_df['name'], y=comparison_df['gloss_accuracy'], name='Gloss Accuracy'),
            row=1, col=1
        )
        
        # Category accuracy
        fig.add_trace(
            go.Bar(x=comparison_df['name'], y=comparison_df['cat_accuracy'], name='Category Accuracy'),
            row=1, col=2
        )
        
        # Top-3 accuracy
        fig.add_trace(
            go.Bar(x=comparison_df['name'], y=comparison_df['gloss_top3_accuracy'], name='Top-3 Accuracy'),
            row=2, col=1
        )
        
        # Loss
        fig.add_trace(
            go.Bar(x=comparison_df['name'], y=comparison_df['loss'], name='Loss'),
            row=2, col=2
        )
        
        fig.update_layout(
            title="Model Comparison",
            showlegend=False,
            height=600
        )
        
        fig.show()
    
    return comparison_df

# Example model comparison (uncomment to use)
# model_configs = [
#     {
#         'name': 'Transformer Best',
#         'model_type': 'transformer',
#         'model_path': 'data/processed/SignTransformer_best.pt',
#         'num_gloss': 105,
#         'num_cat': 10
#     },
#     {
#         'name': 'IV3-GRU Best',
#         'model_type': 'iv3_gru',
#         'model_path': 'data/processed/InceptionV3GRU_best.pt',
#         'num_gloss': 105,
#         'num_cat': 10,
#         'hidden1': 16,
#         'hidden2': 12,
#         'dropout': 0.3
#     }
# ]
# 
# comparison_df = compare_models(model_configs, test_dataloader, CONFIG['device'])

print("ðŸ’¡ To compare multiple models, uncomment and modify the model_configs section above")


## Export Results

Export evaluation results to files for further analysis and reporting.


In [None]:
def export_results(results, summary_df, gloss_per_class, cat_per_class, error_analysis, config, output_dir='evaluation_results'):
    """Export all evaluation results to files"""
    
    import os
    from datetime import datetime
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Generate timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    print(f"Exporting results to {output_dir}/")
    
    # Export summary metrics
    summary_file = f"{output_dir}/summary_metrics_{timestamp}.csv"
    summary_df.to_csv(summary_file, index=False)
    print(f"  - Summary metrics: {summary_file}")
    
    # Export per-class performance
    gloss_per_class_file = f"{output_dir}/gloss_per_class_{timestamp}.csv"
    cat_per_class_file = f"{output_dir}/cat_per_class_{timestamp}.csv"
    gloss_per_class.to_csv(gloss_per_class_file, index=False)
    cat_per_class.to_csv(cat_per_class_file, index=False)
    print(f"  - Gloss per-class: {gloss_per_class_file}")
    print(f"  - Category per-class: {cat_per_class_file}")
    
    # Export detailed results
    detailed_results = {
        'config': config,
        'timestamp': timestamp,
        'model_type': config['model_type'],
        'test_samples': len(results['gloss_targets']),
        'gloss_accuracy': results['gloss_accuracy'],
        'cat_accuracy': results['cat_accuracy'],
        'gloss_top3_accuracy': results['gloss_top3_accuracy'],
        'gloss_top5_accuracy': results['gloss_top5_accuracy'],
        'gloss_kappa': results['gloss_kappa'],
        'cat_kappa': results['cat_kappa'],
        'gloss_mcc': results['gloss_mcc'],
        'cat_mcc': results['cat_mcc'],
        'gloss_ci_lower': results['gloss_ci_lower'],
        'gloss_ci_upper': results['gloss_ci_upper'],
        'cat_ci_lower': results['cat_ci_lower'],
        'cat_ci_upper': results['cat_ci_upper'],
        'gloss_error_rate': error_analysis['gloss_error_rate'],
        'cat_error_rate': error_analysis['cat_error_rate'],
        'gloss_confusion_pairs': error_analysis['gloss_confusion_pairs'][:20],  # Top 20
        'cat_confusion_pairs': error_analysis['cat_confusion_pairs'][:20],  # Top 20
    }
    
    import json
    results_file = f"{output_dir}/detailed_results_{timestamp}.json"
    with open(results_file, 'w') as f:
        json.dump(detailed_results, f, indent=2, default=str)
    print(f"  - Detailed results: {results_file}")
    
    # Export predictions
    predictions_df = pd.DataFrame({
        'gloss_target': results['gloss_targets'],
        'gloss_prediction': results['gloss_predictions'],
        'cat_target': results['cat_targets'],
        'cat_prediction': results['cat_predictions'],
        'gloss_confidence': np.max(results['gloss_probabilities'], axis=1),
        'cat_confidence': np.max(results['cat_probabilities'], axis=1),
        'gloss_correct': results['gloss_targets'] == results['gloss_predictions'],
        'cat_correct': results['cat_targets'] == results['cat_predictions']
    })
    
    predictions_file = f"{output_dir}/predictions_{timestamp}.csv"
    predictions_df.to_csv(predictions_file, index=False)
    print(f"  - Predictions: {predictions_file}")
    
    # Create a comprehensive report
    report_file = f"{output_dir}/evaluation_report_{timestamp}.txt"
    with open(report_file, 'w') as f:
        f.write("MODEL EVALUATION REPORT\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Model Type: {config['model_type'].upper()}\n")
        f.write(f"Evaluation Date: {timestamp}\n")
        f.write(f"Test Samples: {len(results['gloss_targets'])}\n\n")
        
        f.write("PERFORMANCE METRICS\n")
        f.write("-" * 30 + "\n")
        f.write(f"Gloss Accuracy: {results['gloss_accuracy']:.4f} [{results['gloss_ci_lower']:.4f}, {results['gloss_ci_upper']:.4f}]\n")
        f.write(f"Category Accuracy: {results['cat_accuracy']:.4f} [{results['cat_ci_lower']:.4f}, {results['cat_ci_upper']:.4f}]\n")
        f.write(f"Gloss Top-3 Accuracy: {results['gloss_top3_accuracy']:.4f}\n")
        f.write(f"Gloss Top-5 Accuracy: {results['gloss_top5_accuracy']:.4f}\n")
        f.write(f"Gloss Cohen's Kappa: {results['gloss_kappa']:.4f}\n")
        f.write(f"Category Cohen's Kappa: {results['cat_kappa']:.4f}\n")
        f.write(f"Gloss Matthews Correlation: {results['gloss_mcc']:.4f}\n")
        f.write(f"Category Matthews Correlation: {results['cat_mcc']:.4f}\n\n")
        
        f.write("ERROR ANALYSIS\n")
        f.write("-" * 30 + "\n")
        f.write(f"Gloss Error Rate: {error_analysis['gloss_error_rate']:.4f}\n")
        f.write(f"Category Error Rate: {error_analysis['cat_error_rate']:.4f}\n")
        f.write(f"Total Gloss Errors: {np.sum(error_analysis['gloss_errors'])}\n")
        f.write(f"Total Category Errors: {np.sum(error_analysis['cat_errors'])}\n\n")
        
        f.write("TOP CONFUSED CLASS PAIRS\n")
        f.write("-" * 30 + "\n")
        f.write("Gloss Classification:\n")
        for i, ((true_class, pred_class), count) in enumerate(error_analysis['gloss_confusion_pairs'][:10]):
            f.write(f"  {i+1:2d}. True: {true_class:3d} â†’ Pred: {pred_class:3d} ({count:3d} times)\n")
        
        f.write("\nCategory Classification:\n")
        for i, ((true_class, pred_class), count) in enumerate(error_analysis['cat_confusion_pairs'][:10]):
            f.write(f"  {i+1:2d}. True: {true_class:3d} â†’ Pred: {pred_class:3d} ({count:3d} times)\n")
    
    print(f"  - Evaluation report: {report_file}")
    
    print(f"\nExport completed! All files saved to {output_dir}/")
    return output_dir

# Export results
if 'results' in locals() and 'summary_df' in locals():
    export_dir = export_results(
        results, 
        summary_df, 
        gloss_per_class, 
        cat_per_class, 
        error_analysis, 
        CONFIG
    )
else:
    print("No results available for export. Please run evaluation first.")


## Summary and Next Steps

This notebook provides comprehensive evaluation of your trained FSLR models with statistical rigor and detailed analysis.

### What we evaluated:
- **Overall performance**: Accuracy, loss, top-k accuracy with confidence intervals
- **Statistical analysis**: Cohen's Kappa, Matthews Correlation Coefficient, bootstrap confidence intervals
- **Per-class analysis**: Precision, recall, F1-score with performance distribution analysis
- **Confusion matrices**: Both raw counts and normalized versions with detailed analysis
- **Comprehensive error analysis**: Misclassification patterns, confidence analysis, and error rate by class frequency
- **Interactive visualizations**: Detailed plots for exploration and analysis
- **Export functionality**: Complete results export for further analysis and reporting

### Key insights you can gain:
- Model performance on test data with statistical confidence
- Which classes are performing well/poorly with detailed metrics
- Common misclassification patterns and their frequencies
- Confidence score distributions and their relationship to accuracy
- Performance distribution across all classes
- Error patterns by confidence level and class frequency

### Next steps you might consider:
1. **Improve problematic classes**: Focus on classes with low F1-scores or high error rates
2. **Data augmentation**: Add more training data for confused class pairs
3. **Model architecture**: Try different architectures or hyperparameters
4. **Ensemble methods**: Combine multiple models for better performance
5. **Error analysis**: Investigate specific misclassified examples using the exported predictions
6. **Statistical validation**: Use the confidence intervals to assess model reliability

### Tips for using this enhanced notebook:
- Change the configuration cell to evaluate different models
- Use the model comparison section to compare multiple models
- Export results automatically for further analysis
- Modify the analysis functions to focus on specific aspects
- Use the exported CSV files for additional analysis in other tools
- Check the evaluation report for a comprehensive summary

### Files generated:
- `summary_metrics_[timestamp].csv`: Main performance metrics
- `gloss_per_class_[timestamp].csv`: Per-class gloss performance
- `cat_per_class_[timestamp].csv`: Per-class category performance
- `predictions_[timestamp].csv`: All predictions with confidence scores
- `detailed_results_[timestamp].json`: Complete results in JSON format
- `evaluation_report_[timestamp].txt`: Human-readable summary report

**Happy evaluating! ðŸŽ‰**
