# Comprehensive Evaluation Suite

This notebook provides comprehensive evaluation tools for analyzing model performance across different domains.

## Features:
1. **Per-Class Metrics**: F1, precision, recall for each class in each domain
2. **Confusion Matrices**: Visual confusion matrices for each domain
3. **Feature Visualization**: t-SNE/UMAP visualization of feature space
4. **Domain Classification**: Analyze domain classification accuracy
5. **Failure Analysis**: Detailed analysis of misclassifications
6. **Comparison Tools**: Compare multiple models side-by-side

## Usage:
Load a trained model and evaluation results, then run the analysis cells.


In [11]:
# Imports and Setup
import os
from pathlib import Path
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    classification_report, confusion_matrix, 
    f1_score, precision_score, recall_score
)
from sklearn.manifold import TSNE
import torch
import torch.nn as nn

# Paths
current_dir = Path(os.getcwd())
if current_dir.name == "experiment_2":
    BASE_DIR = current_dir.parent
else:
    BASE_DIR = current_dir

METADATA_DIR = BASE_DIR / "metadata"
LABEL_MAPPING_PATH = METADATA_DIR / "label_mapping.json"
MODELS_DIR = BASE_DIR / "models"

# Load label mappings
with open(LABEL_MAPPING_PATH, "r") as f:
    label_mapping = json.load(f)

id_to_label = {c["id"]: c["canonical_label"] for c in label_mapping["classes"]}
label_to_id = {v: k for k, v in id_to_label.items()}
num_classes = len(label_mapping["classes"])

print(f"Loaded {num_classes} classes")
print(f"Base directory: {BASE_DIR}")


Loaded 39 classes
Base directory: d:\Programming\Seminar_Project\PLANT_LEAF_DISEASE_DETECTION


## Function 1: Per-Class Metrics Analysis


In [12]:
def compute_per_class_metrics(targets, preds, id_to_label):
    """Compute per-class F1, precision, and recall."""
    class_report = classification_report(
        targets, preds,
        labels=list(range(len(id_to_label))),
        target_names=[id_to_label[i] for i in range(len(id_to_label))],
        output_dict=True,
        zero_division=0
    )
    
    metrics = []
    for class_id in range(len(id_to_label)):
        class_name = id_to_label[class_id]
        if class_name in class_report:
            metrics.append({
                'class_id': class_id,
                'class_name': class_name,
                'precision': class_report[class_name]['precision'],
                'recall': class_report[class_name]['recall'],
                'f1': class_report[class_name]['f1-score'],
                'support': class_report[class_name]['support']
            })
        else:
            metrics.append({
                'class_id': class_id,
                'class_name': class_name,
                'precision': 0.0,
                'recall': 0.0,
                'f1': 0.0,
                'support': 0
            })
    
    return pd.DataFrame(metrics)

def print_per_class_summary(df_metrics, dataset_name):
    """Print summary of per-class metrics."""
    print(f"\n{'='*70}")
    print(f"PER-CLASS METRICS: {dataset_name.upper()}")
    print(f"{'='*70}")
    
    # Filter classes with support > 0
    df_active = df_metrics[df_metrics['support'] > 0].copy()
    
    if len(df_active) == 0:
        print("No samples found for any class.")
        return
    
    print(f"\nTotal classes with samples: {len(df_active)}")
    print(f"Average F1: {df_active['f1'].mean():.4f}")
    print(f"Average Precision: {df_active['precision'].mean():.4f}")
    print(f"Average Recall: {df_active['recall'].mean():.4f}")
    
    print("\nTop 10 Best Performing Classes:")
    top10 = df_active.nlargest(10, 'f1')[['class_name', 'f1', 'precision', 'recall', 'support']]
    print(top10.to_string(index=False))
    
    print("\nTop 10 Worst Performing Classes:")
    worst10 = df_active.nsmallest(10, 'f1')[['class_name', 'f1', 'precision', 'recall', 'support']]
    print(worst10.to_string(index=False))

print("✓ Per-class metrics functions defined")


✓ Per-class metrics functions defined


## Function 2: Confusion Matrix Visualization


In [13]:
def plot_confusion_matrix(targets, preds, id_to_label, dataset_name, save_path=None):
    """Plot confusion matrix with class labels."""
    cm = confusion_matrix(targets, preds, labels=list(range(len(id_to_label))))
    
    # Get class names for active classes only
    active_classes = [i for i in range(len(id_to_label)) if cm[i].sum() > 0 or cm[:, i].sum() > 0]
    if len(active_classes) == 0:
        print(f"No active classes found for {dataset_name}")
        return None
    
    # Filter confusion matrix to active classes
    cm_filtered = cm[np.ix_(active_classes, active_classes)]
    class_names = [id_to_label[i] for i in active_classes]
    
    # Normalize by row (true class)
    cm_normalized = cm_filtered.astype('float') / (cm_filtered.sum(axis=1)[:, np.newaxis] + 1e-8)
    
    # Create figure
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Plot raw counts
    sns.heatmap(cm_filtered, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                ax=ax1, cbar_kws={'label': 'Count'})
    ax1.set_title(f'Confusion Matrix (Counts) - {dataset_name}', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Predicted', fontsize=12)
    ax1.set_ylabel('True', fontsize=12)
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax1.get_yticklabels(), rotation=0)
    
    # Plot normalized
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                ax=ax2, cbar_kws={'label': 'Normalized'})
    ax2.set_title(f'Confusion Matrix (Normalized) - {dataset_name}', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Predicted', fontsize=12)
    ax2.set_ylabel('True', fontsize=12)
    plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
    plt.setp(ax2.get_yticklabels(), rotation=0)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved confusion matrix to: {save_path}")
    
    plt.show()
    return fig

print("✓ Confusion matrix visualization function defined")


✓ Confusion matrix visualization function defined


## Function 3: Feature Space Visualization (t-SNE)


In [14]:
def visualize_features_tsne(model, dataloader, id_to_label, dataset_name, device='cuda', n_samples=1000, save_path=None):
    """Visualize feature space using t-SNE."""
    model.eval()
    features_list = []
    labels_list = []
    
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(dataloader):
            if len(features_list) * dataloader.batch_size >= n_samples:
                break
            
            images = images.to(device)
            targets = targets.cpu().numpy()
            
            # Extract features (before classifier)
            # For EfficientNet/ViT, we need to get features from the model
            # This is model-specific - adjust based on your model architecture
            try:
                # Try to get features from model
                if hasattr(model, 'forward_features'):
                    features = model.forward_features(images)
                elif hasattr(model, 'forward_head'):
                    # For timm models, get features before classifier
                    features = model.forward_features(images)
                    features = model.global_pool(features)
                else:
                    # Fallback: use penultimate layer
                    # This requires model modification or hook
                    print("Warning: Cannot extract features automatically. Skipping t-SNE.")
                    return None
                
                # Flatten features
                features = features.cpu().numpy().reshape(features.size(0), -1)
                
                features_list.append(features)
                labels_list.append(targets)
            except Exception as e:
                print(f"Error extracting features: {e}")
                return None
    
    if len(features_list) == 0:
        print("No features extracted.")
        return None
    
    # Concatenate all features
    all_features = np.vstack(features_list)
    all_labels = np.concatenate(labels_list)
    
    # Limit to n_samples
    if len(all_features) > n_samples:
        indices = np.random.choice(len(all_features), n_samples, replace=False)
        all_features = all_features[indices]
        all_labels = all_labels[indices]
    
    # Apply t-SNE
    print(f"Applying t-SNE to {len(all_features)} samples...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    features_2d = tsne.fit_transform(all_features)
    
    # Plot
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], 
                         c=all_labels, cmap='tab20', alpha=0.6, s=20)
    plt.colorbar(scatter, label='Class ID')
    plt.title(f'Feature Space Visualization (t-SNE) - {dataset_name}', fontsize=14, fontweight='bold')
    plt.xlabel('t-SNE Component 1', fontsize=12)
    plt.ylabel('t-SNE Component 2', fontsize=12)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved t-SNE visualization to: {save_path}")
    
    plt.show()
    return features_2d, all_labels

print("✓ Feature visualization function defined")


✓ Feature visualization function defined


## Function 4: Failure Analysis


In [15]:
def analyze_failures(targets, preds, id_to_label, dataset_name, top_n=10):
    """Analyze failure patterns and misclassifications."""
    cm = confusion_matrix(targets, preds, labels=list(range(len(id_to_label))))
    
    print(f"\n{'='*70}")
    print(f"FAILURE ANALYSIS: {dataset_name.upper()}")
    print(f"{'='*70}\n")
    
    # Per-class accuracy
    print("Per-Class Accuracy:")
    for class_id in range(len(id_to_label)):
        if cm[class_id].sum() > 0:
            acc = cm[class_id, class_id] / cm[class_id].sum()
            class_name = id_to_label[class_id]
            print(f"  Class {class_id} ({class_name}): {acc:.3f}")
    
    # Most common misclassifications
    print(f"\nTop {top_n} Misclassifications:")
    misclass_pairs = []
    for i in range(len(cm)):
        for j in range(len(cm)):
            if i != j and cm[i, j] > 0:
                misclass_pairs.append((i, j, cm[i, j]))
    misclass_pairs.sort(key=lambda x: x[2], reverse=True)
    
    for i, j, count in misclass_pairs[:top_n]:
        true_name = id_to_label[i]
        pred_name = id_to_label[j]
        print(f"  {true_name} → {pred_name}: {count} times")
    
    # Overall statistics
    total_samples = len(targets)
    correct = (targets == preds).sum()
    accuracy = correct / total_samples if total_samples > 0 else 0.0
    
    print(f"\nOverall Statistics:")
    print(f"  Total samples: {total_samples}")
    print(f"  Correct predictions: {correct}")
    print(f"  Incorrect predictions: {total_samples - correct}")
    print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    return misclass_pairs

print("✓ Failure analysis function defined")


✓ Failure analysis function defined


## Usage Example

To use this evaluation suite:

1. Load your model and evaluation results (from previous notebooks)
2. Run the analysis functions on your results
3. Example:
   ```python
   # Assuming you have results from evaluate_on_all_datasets
   results = {...}  # Your evaluation results
   
   # Per-class metrics
   df_metrics = compute_per_class_metrics(
       results['main']['all_targets'],
       results['main']['all_preds'],
       id_to_label
   )
   print_per_class_summary(df_metrics, 'main')
   
   # Confusion matrix
   plot_confusion_matrix(
       results['main']['all_targets'],
       results['main']['all_preds'],
       id_to_label,
       'main',
       save_path='confusion_matrix_main.png'
   )
   
   # Failure analysis
   analyze_failures(
       results['plant_doc']['all_targets'],
       results['plant_doc']['all_preds'],
       id_to_label,
       'plant_doc'
   )
   ```
