In [1]:
#!/usr/bin/env python3
"""
Per-Class Accuracy Analysis Script
Analyzes the per-class accuracy for each model across all test folds
"""

import numpy as np
import pandas as pd
import os

# Configuration
OUTPUT_ROOT = 'output'
METRICS_DIR = os.path.join(OUTPUT_ROOT, 'metrics')
N_FOLDS = 5  # Number of CV folds
N_CLASSES = 7  # Number of classes
MODEL_SET = ['FS_PCA_NN', 'FS_PCA_SVM', 'RF', 'FS_PCA_QDA', 'FS_PCA_KNN', 'FS_PCA_LR']

# Model display names for better formatting
MODEL_DISPLAY_NAMES = {
    'FS_PCA_NN': 'NN',
    'FS_PCA_SVM': 'SVM', 
    'RF': 'RF',
    'FS_PCA_QDA': 'QDA',
    'FS_PCA_KNN': 'KNN',
    'FS_PCA_LR': 'LR'
}

def load_per_class_accuracies():
    """
    Load per-class accuracies (recall) for all models and folds
    Note: Per-class recall = per-class accuracy (TP / (TP + FN))
    """
    all_accuracies = {}
    
    for model_name in MODEL_SET:
        model_accuracies = []
        
        for fold_idx in range(1, N_FOLDS + 1):
            # Load per-class recall which is equivalent to per-class accuracy
            filename = f"per_class_recall_{model_name}_{fold_idx}.npy"
            filepath = os.path.join(METRICS_DIR, filename)
            
            if os.path.exists(filepath):
                per_class_acc = np.load(filepath)
                model_accuracies.append(per_class_acc)
            else:
                print(f"Warning: File not found - {filename}")
        
        if model_accuracies:
            # Stack all folds: shape (n_folds, n_classes)
            all_accuracies[model_name] = np.array(model_accuracies)
        else:
            print(f"Error: No data found for model {model_name}")
    
    return all_accuracies

def calculate_statistics(all_accuracies):
    """
    Calculate mean and std for each model across all folds
    """
    statistics = {}
    
    for model_name, accuracies in all_accuracies.items():
        # accuracies shape: (n_folds, n_classes)
        mean_acc = np.mean(accuracies, axis=0)  # Shape: (n_classes,)
        std_acc = np.std(accuracies, axis=0)    # Shape: (n_classes,)
        
        statistics[model_name] = {
            'mean': mean_acc,
            'std': std_acc
        }
        
        # Save as numpy arrays
        np.save(f"{model_name}_per_class_acc_mean.npy", mean_acc)
        np.save(f"{model_name}_per_class_acc_std.npy", std_acc)
        print(f"Saved statistics for {model_name}")
    
    return statistics

def create_formatted_table(statistics):
    """
    Create a formatted table showing mean ± std for each class and model
    """
    # Create empty dataframe
    table_data = []
    
    # For each class (row)
    for class_idx in range(N_CLASSES):
        row_data = {'Class': f'Class {class_idx + 1}'}
        
        # For each model (column)
        for model_name in MODEL_SET:
            mean_val = statistics[model_name]['mean'][class_idx]
            std_val = statistics[model_name]['std'][class_idx]
            
            # Format as "mean ± std" with 3 decimal places
            formatted_value = f"{mean_val:.3f} ± {std_val:.3f}"
            
            # Use display name for column
            display_name = MODEL_DISPLAY_NAMES[model_name]
            row_data[display_name] = formatted_value
        
        table_data.append(row_data)
    
    # Create DataFrame
    df = pd.DataFrame(table_data)
    
    return df

def print_detailed_statistics(statistics):
    """
    Print detailed statistics for verification
    """
    print("\n" + "="*80)
    print("DETAILED PER-CLASS ACCURACY STATISTICS")
    print("="*80)
    
    for model_name in MODEL_SET:
        print(f"\n{model_name}:")
        print("-" * 40)
        
        mean_acc = statistics[model_name]['mean']
        std_acc = statistics[model_name]['std']
        
        for class_idx in range(N_CLASSES):
            print(f"  Class {class_idx + 1}: {mean_acc[class_idx]:.4f} ± {std_acc[class_idx]:.4f}")
        
        # Overall statistics
        print(f"  Overall mean: {np.mean(mean_acc):.4f}")
        print(f"  Overall std: {np.mean(std_acc):.4f}")

def save_latex_table(df, filename='per_class_accuracy_table.tex'):
    """
    Save the table in LaTeX format
    """
    latex_table = df.to_latex(index=False, escape=False)
    with open(filename, 'w') as f:
        f.write(latex_table)
    print(f"\nLaTeX table saved to {filename}")

def main():
    """
    Main analysis function
    """
    print("Per-Class Accuracy Analysis")
    print("=" * 50)
    
    # Load per-class accuracies
    print("\nLoading per-class accuracies from output folder...")
    all_accuracies = load_per_class_accuracies()
    
    if not all_accuracies:
        print("Error: No data loaded. Please check the output folder.")
        return
    
    # Calculate statistics
    print("\nCalculating mean and standard deviation across folds...")
    statistics = calculate_statistics(all_accuracies)
    
    # Create formatted table
    print("\nCreating formatted table...")
    table_df = create_formatted_table(statistics)
    
    # Display table
    print("\n" + "="*80)
    print("PER-CLASS ACCURACY TABLE (mean ± std across {} folds)".format(N_FOLDS))
    print("="*80)
    print(table_df.to_string(index=False))
    
    # Save table to CSV
    csv_filename = 'per_class_accuracy_table.csv'
    table_df.to_csv(csv_filename, index=False)
    print(f"\nTable saved to {csv_filename}")
    
    # Save LaTeX version
    save_latex_table(table_df)
    
    # Print detailed statistics
    print_detailed_statistics(statistics)
    
    # Additional analysis
    print("\n" + "="*80)
    print("SUMMARY ANALYSIS")
    print("="*80)
    
    # Find best model for each class
    print("\nBest model for each class (based on mean accuracy):")
    for class_idx in range(N_CLASSES):
        best_acc = -1
        best_model = ""
        
        for model_name in MODEL_SET:
            mean_acc = statistics[model_name]['mean'][class_idx]
            if mean_acc > best_acc:
                best_acc = mean_acc
                best_model = MODEL_DISPLAY_NAMES[model_name]
        
        print(f"  Class {class_idx + 1}: {best_model} ({best_acc:.3f})")
    
    # Find most stable model (lowest average std)
    print("\nModel stability ranking (based on average std across classes):")
    stability_scores = []
    
    for model_name in MODEL_SET:
        avg_std = np.mean(statistics[model_name]['std'])
        stability_scores.append((MODEL_DISPLAY_NAMES[model_name], avg_std))
    
    stability_scores.sort(key=lambda x: x[1])
    
    for rank, (model, avg_std) in enumerate(stability_scores, 1):
        print(f"  {rank}. {model}: {avg_std:.4f}")
    
    # Class difficulty analysis
    print("\nClass difficulty ranking (based on average accuracy across all models):")
    class_difficulties = []
    
    for class_idx in range(N_CLASSES):
        class_accs = []
        for model_name in MODEL_SET:
            class_accs.append(statistics[model_name]['mean'][class_idx])
        avg_acc = np.mean(class_accs)
        class_difficulties.append((f"Class {class_idx + 1}", avg_acc))
    
    class_difficulties.sort(key=lambda x: x[1], reverse=True)
    
    for rank, (class_name, avg_acc) in enumerate(class_difficulties, 1):
        print(f"  {rank}. {class_name}: {avg_acc:.3f}")
    
    print("\n" + "="*80)
    print("Analysis complete!")
    print("="*80)

if __name__ == "__main__":
    main()

Per-Class Accuracy Analysis

Loading per-class accuracies from output folder...

Calculating mean and standard deviation across folds...
Saved statistics for FS_PCA_NN
Saved statistics for FS_PCA_SVM
Saved statistics for RF
Saved statistics for FS_PCA_QDA
Saved statistics for FS_PCA_KNN
Saved statistics for FS_PCA_LR

Creating formatted table...

PER-CLASS ACCURACY TABLE (mean ± std across 5 folds)
  Class            NN           SVM            RF           QDA           KNN            LR
Class 1 0.859 ± 0.029 0.941 ± 0.053 0.894 ± 0.101 0.918 ± 0.060 0.941 ± 0.053 0.847 ± 0.080
Class 2 0.812 ± 0.044 0.812 ± 0.069 0.812 ± 0.058 0.718 ± 0.044 0.788 ± 0.088 0.765 ± 0.064
Class 3 0.824 ± 0.105 0.882 ± 0.098 0.859 ± 0.109 0.894 ± 0.108 0.788 ± 0.096 0.824 ± 0.112
Class 4 0.859 ± 0.029 0.929 ± 0.044 0.918 ± 0.071 0.800 ± 0.080 0.918 ± 0.029 0.882 ± 0.037
Class 5 0.953 ± 0.044 0.965 ± 0.047 0.953 ± 0.024 0.976 ± 0.029 0.965 ± 0.029 0.941 ± 0.037
Class 6 0.847 ± 0.103 0.824 ± 0.098 0.788 ± 0.