# Model Evaluation Notebook

**Federated Learning for Skin Cancer Classification with DSCATNet**

This notebook provides comprehensive evaluation tools for trained models, including:

- **Performance Metrics**: Accuracy, Balanced Accuracy, F1, AUC-ROC
- **Confusion Matrix**: Visual analysis of predictions
- **Per-Class Analysis**: Sensitivity/Specificity for each skin lesion type
- **ROC Curves**: Multi-class ROC curves with AUC
- **Prediction Confidence**: Distribution analysis

---

## 1. Setup and Imports

In [None]:
# Standard library
import sys
import json
from pathlib import Path
from datetime import datetime

# Data science
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

# Scikit-learn metrics
from sklearn.metrics import (
    accuracy_score, balanced_accuracy_score,
    precision_score, recall_score, f1_score,
    roc_auc_score, roc_curve, auc,
    confusion_matrix, classification_report
)

# Add project root to path
project_root = Path().resolve().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Project imports
from src.models.dscatnet import create_dscatnet
from src.data.datasets import UNIFIED_CLASSES
from src.data.preprocessing import get_val_transforms
from src.evaluation.metrics import ModelEvaluator, EvaluationResults

# Settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')
%matplotlib inline

print(f"PyTorch: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

Configure the evaluation by setting the checkpoint path and dataset.

In [None]:
# =============================================================================
# CONFIGURATION - Modify these paths for your experiment
# =============================================================================

# Path to trained model checkpoint
CHECKPOINT_PATH = project_root / "outputs" / "federated_YYYYMMDD_HHMMSS" / "checkpoints" / "best_model.pt"

# Dataset for evaluation
DATASET_NAME = "ISIC2019"  # Options: HAM10000, ISIC2018, ISIC2019, ISIC2020, PAD-UFES-20
DATA_ROOT = project_root / "data"

# Model configuration (must match training)
MODEL_VARIANT = "small"  # tiny, small, base
NUM_CLASSES = 7
IMAGE_SIZE = 224

# Evaluation settings
BATCH_SIZE = 32
NUM_WORKERS = 4

# Class names
CLASS_NAMES = list(UNIFIED_CLASSES.values())
print(f"Classes: {CLASS_NAMES}")

## 3. Load Model and Data

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model
model = create_dscatnet(
    variant=MODEL_VARIANT,
    num_classes=NUM_CLASSES,
    pretrained=False
)

# Load checkpoint
if CHECKPOINT_PATH.exists():
    print(f"Loading checkpoint: {CHECKPOINT_PATH}")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    
    # Handle different checkpoint formats
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
        epoch = checkpoint.get("epoch", "unknown")
        metrics = checkpoint.get("metrics", {})
        print(f"Loaded from epoch/round: {epoch}")
        if metrics:
            print(f"Checkpoint metrics: {metrics}")
    else:
        model.load_state_dict(checkpoint)
        print("Loaded raw state dict")
else:
    print(f"‚ö†Ô∏è Checkpoint not found: {CHECKPOINT_PATH}")
    print("Please update CHECKPOINT_PATH to point to your trained model.")

model = model.to(device)
model.eval()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters: {total_params:,} total, {trainable_params:,} trainable")

In [None]:
# Load dataset
from src.data.datasets import (
    HAM10000Dataset, ISIC2018Dataset, ISIC2019Dataset, 
    ISIC2020Dataset, PADUFES20Dataset
)

# Dataset mapping
DATASET_CLASSES = {
    "HAM10000": HAM10000Dataset,
    "ISIC2018": ISIC2018Dataset,
    "ISIC2019": ISIC2019Dataset,
    "ISIC2020": ISIC2020Dataset,
    "PAD-UFES-20": PADUFES20Dataset,
}

# Dataset paths (adjust if different)
DATASET_PATHS = {
    "HAM10000": {
        "root_dir": DATA_ROOT / "HAM10000",
        "csv_path": DATA_ROOT / "HAM10000" / "HAM10000_metadata.csv"
    },
    "ISIC2018": {
        "root_dir": DATA_ROOT / "ISIC2018" / "ISIC2018_Task3_Training_Input",
        "csv_path": DATA_ROOT / "ISIC2018" / "ISIC2018_Task3_Training_GroundTruth.csv"
    },
    "ISIC2019": {
        "root_dir": DATA_ROOT / "ISIC2019" / "ISIC_2019_Training_Input",
        "csv_path": DATA_ROOT / "ISIC2019" / "ISIC_2019_Training_GroundTruth.csv"
    },
    "ISIC2020": {
        "root_dir": DATA_ROOT / "ISIC2020" / "train",
        "csv_path": DATA_ROOT / "ISIC2020" / "train.csv"
    },
    "PAD-UFES-20": {
        "root_dir": DATA_ROOT / "PAD-UFES-20",
        "csv_path": DATA_ROOT / "PAD-UFES-20" / "metadata.csv"
    },
}

# Validation transforms
val_transform = get_val_transforms(img_size=IMAGE_SIZE)

# Load dataset
dataset_class = DATASET_CLASSES[DATASET_NAME]
paths = DATASET_PATHS[DATASET_NAME]

try:
    test_dataset = dataset_class(
        root_dir=str(paths["root_dir"]),
        csv_path=str(paths["csv_path"]),
        transform=val_transform
    )
    print(f"‚úì Loaded {DATASET_NAME}: {len(test_dataset)} samples")
except Exception as e:
    print(f"‚úó Failed to load {DATASET_NAME}: {e}")
    print("\nPlease check that the dataset is properly downloaded and extracted.")

In [None]:
# Create DataLoader
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"DataLoader: {len(test_loader)} batches of size {BATCH_SIZE}")

## 4. Run Evaluation

In [None]:
# Initialize evaluator
evaluator = ModelEvaluator(
    model=model,
    device=device,
    num_classes=NUM_CLASSES,
    class_names=CLASS_NAMES
)

# Run evaluation
print("Running evaluation...")
results = evaluator.evaluate(test_loader, compute_auc=True)
print("‚úì Evaluation complete!")

## 5. Performance Metrics Summary

### Key Performance Indicators (KPIs)

In [None]:
# Create metrics summary DataFrame
metrics_summary = pd.DataFrame({
    "Metric": [
        "Accuracy",
        "Balanced Accuracy",
        "Precision (macro)",
        "Recall (macro)",
        "F1-Score (macro)",
        "F1-Score (weighted)",
        "AUC-ROC (macro)"
    ],
    "Value": [
        results.accuracy,
        results.balanced_accuracy,
        results.precision_macro,
        results.recall_macro,
        results.f1_macro,
        results.f1_weighted,
        results.auc_macro if results.auc_macro else 0.0
    ]
})

# Style the DataFrame
def highlight_kpi(val):
    if val >= 0.90:
        return 'background-color: #90EE90'  # Light green
    elif val >= 0.80:
        return 'background-color: #FFFFE0'  # Light yellow
    elif val >= 0.70:
        return 'background-color: #FFD700'  # Gold
    else:
        return 'background-color: #FFB6C1'  # Light red

styled_metrics = metrics_summary.style.format({"Value": "{:.4f}"}).applymap(
    highlight_kpi, subset=["Value"]
)

print("\n" + "="*60)
print("üìä MODEL PERFORMANCE SUMMARY")
print("="*60)
display(styled_metrics)

In [None]:
# Visual KPI Dashboard
fig, axes = plt.subplots(2, 4, figsize=(16, 8))

kpis = [
    ("Accuracy", results.accuracy),
    ("Balanced\nAccuracy", results.balanced_accuracy),
    ("Precision", results.precision_macro),
    ("Recall", results.recall_macro),
    ("F1 (macro)", results.f1_macro),
    ("F1 (weighted)", results.f1_weighted),
    ("AUC-ROC", results.auc_macro if results.auc_macro else 0.0),
    ("Total\nSamples", len(results.labels))
]

colors = ['#3498db', '#2ecc71', '#9b59b6', '#e74c3c', 
          '#f39c12', '#1abc9c', '#e67e22', '#34495e']

for idx, ((name, value), color) in enumerate(zip(kpis, colors)):
    ax = axes.flat[idx]
    
    if idx < 7:  # Metrics (0-1 scale)
        # Create a gauge-like visualization
        ax.barh([0], [value], color=color, height=0.5, alpha=0.8)
        ax.barh([0], [1], color='lightgray', height=0.5, alpha=0.3)
        ax.set_xlim(0, 1)
        ax.set_yticks([])
        ax.text(0.5, 0.7, f"{value:.2%}", ha='center', va='bottom', 
                fontsize=24, fontweight='bold', transform=ax.transAxes)
    else:  # Sample count
        ax.text(0.5, 0.5, f"{int(value):,}", ha='center', va='center',
                fontsize=28, fontweight='bold', color=color, transform=ax.transAxes)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
    
    ax.set_title(name, fontsize=12, fontweight='bold', pad=10)
    if idx < 7:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)

plt.suptitle(f"Model Performance Dashboard - {DATASET_NAME}", fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

## 6. Confusion Matrix Analysis

In [None]:
# Plot confusion matrix (normalized and absolute)
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Normalized confusion matrix
cm_norm = results.confusion_matrix.astype('float') / results.confusion_matrix.sum(axis=1)[:, np.newaxis]
cm_norm = np.nan_to_num(cm_norm)

sns.heatmap(
    cm_norm, 
    annot=True, 
    fmt='.2%', 
    cmap='Blues',
    xticklabels=CLASS_NAMES,
    yticklabels=CLASS_NAMES,
    ax=axes[0],
    cbar_kws={'label': 'Proportion'}
)
axes[0].set_xlabel('Predicted Label', fontsize=12)
axes[0].set_ylabel('True Label', fontsize=12)
axes[0].set_title('Normalized Confusion Matrix', fontsize=14, fontweight='bold')

# Absolute confusion matrix
sns.heatmap(
    results.confusion_matrix, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=CLASS_NAMES,
    yticklabels=CLASS_NAMES,
    ax=axes[1],
    cbar_kws={'label': 'Count'}
)
axes[1].set_xlabel('Predicted Label', fontsize=12)
axes[1].set_ylabel('True Label', fontsize=12)
axes[1].set_title('Absolute Confusion Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(project_root / 'outputs' / 'confusion_matrix_evaluation.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Per-Class Performance Analysis

In [None]:
# Per-class metrics DataFrame
per_class_df = pd.DataFrame(results.per_class_metrics).T
per_class_df['Class'] = per_class_df.index
per_class_df = per_class_df[['Class', 'accuracy', 'precision', 'recall', 'support']]
per_class_df.columns = ['Class', 'Accuracy', 'Precision', 'Recall', 'Support']

# Add F1-score per class
per_class_df['F1-Score'] = 2 * (per_class_df['Precision'] * per_class_df['Recall']) / \
                           (per_class_df['Precision'] + per_class_df['Recall'] + 1e-8)

# Sort by support (most samples first)
per_class_df = per_class_df.sort_values('Support', ascending=False)

print("\n" + "="*70)
print("üìã PER-CLASS PERFORMANCE")
print("="*70)
display(per_class_df.style.format({
    'Accuracy': '{:.4f}',
    'Precision': '{:.4f}',
    'Recall': '{:.4f}',
    'F1-Score': '{:.4f}',
    'Support': '{:,.0f}'
}).background_gradient(subset=['Accuracy', 'Precision', 'Recall', 'F1-Score'], cmap='RdYlGn'))

In [None]:
# Visualize per-class metrics
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart of metrics per class
x = np.arange(len(CLASS_NAMES))
width = 0.2

metrics_to_plot = ['Precision', 'Recall', 'F1-Score']
colors = ['#3498db', '#2ecc71', '#e74c3c']

for i, (metric, color) in enumerate(zip(metrics_to_plot, colors)):
    values = [per_class_df[per_class_df['Class'] == c][metric].values[0] 
              for c in CLASS_NAMES]
    axes[0].bar(x + i*width, values, width, label=metric, color=color, alpha=0.8)

axes[0].set_xlabel('Class', fontsize=12)
axes[0].set_ylabel('Score', fontsize=12)
axes[0].set_title('Per-Class Metrics', fontsize=14, fontweight='bold')
axes[0].set_xticks(x + width)
axes[0].set_xticklabels(CLASS_NAMES, rotation=45, ha='right')
axes[0].legend(loc='lower right')
axes[0].set_ylim(0, 1.1)
axes[0].grid(True, alpha=0.3, axis='y')

# Class distribution (support)
supports = [per_class_df[per_class_df['Class'] == c]['Support'].values[0] 
            for c in CLASS_NAMES]
colors_pie = plt.cm.Set3(np.linspace(0, 1, len(CLASS_NAMES)))

wedges, texts, autotexts = axes[1].pie(
    supports, 
    labels=CLASS_NAMES, 
    autopct='%1.1f%%',
    colors=colors_pie,
    explode=[0.02]*len(CLASS_NAMES)
)
axes[1].set_title('Class Distribution in Test Set', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

## 8. ROC Curves and AUC Analysis

In [None]:
# Compute ROC curves for each class
from sklearn.preprocessing import label_binarize

# Binarize labels for one-vs-rest
y_true_bin = label_binarize(results.labels, classes=range(NUM_CLASSES))
y_score = results.probabilities

# Compute ROC curve and AUC for each class
fpr = {}
tpr = {}
roc_auc = {}

for i, class_name in enumerate(CLASS_NAMES):
    if y_true_bin[:, i].sum() > 0:  # Only if class exists in test set
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    else:
        fpr[i], tpr[i], roc_auc[i] = [0, 1], [0, 1], 0.5

# Plot ROC curves
fig, ax = plt.subplots(figsize=(10, 8))

colors = plt.cm.Set1(np.linspace(0, 1, NUM_CLASSES))

for i, (class_name, color) in enumerate(zip(CLASS_NAMES, colors)):
    ax.plot(
        fpr[i], tpr[i],
        color=color,
        lw=2,
        label=f'{class_name} (AUC = {roc_auc[i]:.3f})'
    )

# Plot diagonal (random classifier)
ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Random (AUC = 0.500)')

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('Multi-Class ROC Curves (One-vs-Rest)', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(project_root / 'outputs' / 'roc_curves_evaluation.png', dpi=150, bbox_inches='tight')
plt.show()

# Print AUC summary
print("\n" + "="*50)
print("AUC-ROC per Class")
print("="*50)
for i, class_name in enumerate(CLASS_NAMES):
    print(f"  {class_name:15s}: {roc_auc[i]:.4f}")
print(f"\n  {'Macro Average':15s}: {np.mean(list(roc_auc.values())):.4f}")

## 9. Prediction Confidence Analysis

In [None]:
# Analyze prediction confidence
confidence = np.max(results.probabilities, axis=1)
correct = results.predictions == results.labels

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Overall confidence distribution
axes[0].hist(confidence, bins=50, color='steelblue', edgecolor='black', alpha=0.7)
axes[0].axvline(confidence.mean(), color='red', linestyle='--', lw=2, label=f'Mean: {confidence.mean():.3f}')
axes[0].set_xlabel('Prediction Confidence', fontsize=12)
axes[0].set_ylabel('Count', fontsize=12)
axes[0].set_title('Confidence Distribution', fontsize=14, fontweight='bold')
axes[0].legend()

# Confidence: Correct vs Incorrect
axes[1].hist(confidence[correct], bins=30, alpha=0.7, label=f'Correct (n={correct.sum():,})', color='green')
axes[1].hist(confidence[~correct], bins=30, alpha=0.7, label=f'Incorrect (n={(~correct).sum():,})', color='red')
axes[1].set_xlabel('Prediction Confidence', fontsize=12)
axes[1].set_ylabel('Count', fontsize=12)
axes[1].set_title('Confidence by Correctness', fontsize=14, fontweight='bold')
axes[1].legend()

# Accuracy vs Confidence bins
confidence_bins = np.linspace(0, 1, 11)
bin_indices = np.digitize(confidence, confidence_bins)
bin_accuracies = []
bin_counts = []

for i in range(1, len(confidence_bins)):
    mask = bin_indices == i
    if mask.sum() > 0:
        bin_accuracies.append(correct[mask].mean())
        bin_counts.append(mask.sum())
    else:
        bin_accuracies.append(0)
        bin_counts.append(0)

bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2

ax3 = axes[2]
ax3.bar(bin_centers, bin_accuracies, width=0.08, color='steelblue', alpha=0.7, edgecolor='black')
ax3.plot([0, 1], [0, 1], 'r--', lw=2, label='Perfect Calibration')
ax3.set_xlabel('Confidence', fontsize=12)
ax3.set_ylabel('Accuracy', fontsize=12)
ax3.set_title('Calibration: Confidence vs Accuracy', fontsize=14, fontweight='bold')
ax3.set_xlim(0, 1)
ax3.set_ylim(0, 1.05)
ax3.legend()
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Confidence statistics
print("\n" + "="*50)
print("Confidence Statistics")
print("="*50)
print(f"  Mean Confidence:     {confidence.mean():.4f}")
print(f"  Median Confidence:   {np.median(confidence):.4f}")
print(f"  Std Confidence:      {confidence.std():.4f}")
print(f"  Correct Mean Conf:   {confidence[correct].mean():.4f}")
print(f"  Incorrect Mean Conf: {confidence[~correct].mean():.4f}")

## 10. Misclassification Analysis

In [None]:
# Find most common misclassifications
misclass_pairs = []
for i in range(len(results.labels)):
    if results.predictions[i] != results.labels[i]:
        true_class = CLASS_NAMES[results.labels[i]]
        pred_class = CLASS_NAMES[results.predictions[i]]
        conf = results.probabilities[i, results.predictions[i]]
        misclass_pairs.append((true_class, pred_class, conf))

# Count misclassification patterns
from collections import Counter
pattern_counts = Counter([(t, p) for t, p, _ in misclass_pairs])
top_misclass = pattern_counts.most_common(10)

print("\n" + "="*60)
print("üîç TOP 10 MISCLASSIFICATION PATTERNS")
print("="*60)
print(f"{'True Class':<15} {'Predicted As':<15} {'Count':>10} {'%':>8}")
print("-"*60)
total_errors = len(misclass_pairs)
for (true_cls, pred_cls), count in top_misclass:
    pct = count / total_errors * 100 if total_errors > 0 else 0
    print(f"{true_cls:<15} {pred_cls:<15} {count:>10,} {pct:>7.1f}%")
print("-"*60)
print(f"{'Total Errors':<31} {total_errors:>10,}")

## 11. Export Results

In [None]:
# Export results to JSON
export_results = {
    "evaluation_timestamp": datetime.now().isoformat(),
    "dataset": DATASET_NAME,
    "model_variant": MODEL_VARIANT,
    "checkpoint_path": str(CHECKPOINT_PATH),
    "num_samples": len(results.labels),
    "metrics": {
        "accuracy": float(results.accuracy),
        "balanced_accuracy": float(results.balanced_accuracy),
        "precision_macro": float(results.precision_macro),
        "recall_macro": float(results.recall_macro),
        "f1_macro": float(results.f1_macro),
        "f1_weighted": float(results.f1_weighted),
        "auc_macro": float(results.auc_macro) if results.auc_macro else None
    },
    "per_class_metrics": results.per_class_metrics,
    "per_class_auc": {CLASS_NAMES[i]: float(roc_auc[i]) for i in range(NUM_CLASSES)},
    "confusion_matrix": results.confusion_matrix.tolist(),
    "confidence_stats": {
        "mean": float(confidence.mean()),
        "std": float(confidence.std()),
        "correct_mean": float(confidence[correct].mean()),
        "incorrect_mean": float(confidence[~correct].mean()) if (~correct).sum() > 0 else None
    }
}

# Save to file
output_path = project_root / 'outputs' / f'evaluation_results_{DATASET_NAME}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
with open(output_path, 'w') as f:
    json.dump(export_results, f, indent=2)

print(f"\n‚úì Results exported to: {output_path}")

---

## Summary

This notebook provides comprehensive evaluation metrics for your trained DSCATNet model:

| Metric | Value | Interpretation |
|--------|-------|----------------|
| **Accuracy** | Standard overall accuracy | May be misleading with imbalanced data |
| **Balanced Accuracy** | Mean recall across classes | Better for imbalanced datasets |
| **F1 (macro)** | Unweighted mean F1 | Treats all classes equally |
| **F1 (weighted)** | Weighted by class support | Accounts for class imbalance |
| **AUC-ROC** | Multi-class one-vs-rest | Measures ranking ability |

**Key insights:**
- Check per-class metrics for underperforming classes (often minority classes)
- Analyze confusion matrix for systematic misclassification patterns
- Review confidence calibration for deployment decisions

---