# Document Quality Assessment - Model Evaluation

This notebook provides interactive evaluation of the document quality assessment model.

## Features:
- Load trained model and test data
- Calculate comprehensive metrics
- Generate visualizations
- Test on individual images with explanations
- Analyze errors and edge cases

## Setup

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import yaml
from torch.utils.data import DataLoader

# Import project modules
from src.data.dataset import DocumentDataset
from src.models.model import DocumentQualityModel, ModelConfig
from src.evaluation.metrics import evaluate_model_comprehensive
from src.evaluation.visualize import (
    plot_calibration_curve,
    plot_confusion_matrix,
    plot_roc_curves,
    plot_prediction_distribution,
    plot_issue_analysis,
)
from src.evaluation.explainability import visualize_explanation, get_target_layer

# Configure plotting
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 8)
sns.set_style("whitegrid")

print("✓ Imports successful")

## 1. Load Model and Configuration

In [None]:
# Configuration
MODEL_PATH = project_root / "models" / "best_model.pth"
CONFIG_PATH = project_root / "config" / "best_training.yaml"

# Load config
with open(CONFIG_PATH, "r") as f:
    config = yaml.safe_load(f)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create model config
model_config = ModelConfig(
    backbone=config.get("model", {}).get("backbone", "efficientnet_b0"),
    num_quality_classes=config.get("model", {}).get("num_quality_classes", 5),
    num_issue_classes=config.get("model", {}).get("num_issue_classes", 10),
    hidden_dim=config.get("model", {}).get("hidden_dim", 256),
    dropout_rate=config.get("model", {}).get("dropout_rate", 0.5),
    use_attention=config.get("model", {}).get("use_attention", True),
)

# Load model
model = DocumentQualityModel(model_config)
checkpoint = torch.load(MODEL_PATH, map_location=device)

if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    model.load_state_dict(checkpoint["model_state_dict"])
else:
    model.load_state_dict(checkpoint)

model = model.to(device)
model.eval()

print(f"✓ Model loaded from {MODEL_PATH}")
print(f"  Backbone: {model_config.backbone}")
print(f"  Quality classes: {model_config.num_quality_classes}")
print(f"  Issue classes: {model_config.num_issue_classes}")

## 2. Load Test Dataset

In [None]:
# Find test metadata
datasets_dir = project_root / "datasets"
test_metadata = datasets_dir / "default" / "test" / "test_metadata.csv"

if not test_metadata.exists():
    print(f"❌ Test metadata not found at {test_metadata}")
else:
    # Create dataset
    test_dataset = DocumentDataset(
        metadata_file=str(test_metadata),
        transform="test",
        use_advanced_augmentations=False,
    )

    # Create dataloader
    test_loader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    print(f"✓ Test dataset loaded: {len(test_dataset)} images")
    print(f"  Batches: {len(test_loader)}")

## 3. Run Comprehensive Evaluation

In [None]:
# Run evaluation
print("Running evaluation... (this may take a few minutes)")
evaluation_results = evaluate_model_comprehensive(
    model=model,
    dataloader=test_loader,
    device=device,
    acceptance_threshold=0.5,
)

print("\n" + "="*80)
print("EVALUATION RESULTS")
print("="*80)

# Regression metrics
reg = evaluation_results["regression"]
print("\n📊 Regression Metrics (Quality Scores):")
print(f"  MAE:  {reg['mae']:.4f}")
print(f"  RMSE: {reg['rmse']:.4f}")
print(f"  R²:   {reg['r2']:.4f}")
print(f"  Within 10%: {reg['within_10pct']:.2%}")

# Classification metrics
cls = evaluation_results["classification"]
print("\n🎯 Classification Metrics (Quality Classes):")
print(f"  Accuracy:          {cls['accuracy']:.4f}")
print(f"  Balanced Accuracy: {cls['balanced_accuracy']:.4f}")
print(f"  Weighted F1:       {cls['weighted_f1']:.4f}")

# Binary classification
binary = evaluation_results["binary_classification"]
print("\n✅/❌ Binary Classification (Accept/Reject):")
print(f"  F1 Score: {binary['f1']:.4f}")
print(f"  Precision: {binary['precision']:.4f}")
print(f"  Recall: {binary['recall']:.4f}")
print(f"  ROC AUC: {binary['roc_auc']:.4f}")
print(f"  False Reject Rate: {binary['fpr']:.4f}")

# Calibration
cal = evaluation_results["calibration"]
print("\n🎚️  Calibration:")
print(f"  ECE (Expected Calibration Error): {cal['ece']:.4f}")
print(f"  {'✓ Well calibrated' if cal['ece'] < 0.1 else '⚠ Needs calibration'}")

# Issue detection
issues = evaluation_results["issue_detection"]
print("\n🔍 Issue Detection:")
print(f"  Macro F1: {issues['macro_f1']:.4f}")
print(f"  Micro F1: {issues['micro_f1']:.4f}")

## 4. Visualizations

### 4.1 Calibration Curve

In [None]:
plot_calibration_curve(evaluation_results["calibration"])
plt.show()

### 4.2 Confusion Matrix

In [None]:
cm = np.array(evaluation_results["classification"]["confusion_matrix"])
plot_confusion_matrix(cm)
plt.show()

### 4.3 ROC Curve

In [None]:
quality_scores = np.array(evaluation_results["raw_predictions"]["quality_scores"])
quality_targets = np.array(evaluation_results["raw_predictions"]["quality_targets"])
binary_targets = (quality_targets >= 0.5).astype(int)

plot_roc_curves(quality_scores, binary_targets)
plt.show()

### 4.4 Prediction Distribution

In [None]:
plot_prediction_distribution(quality_scores, quality_targets)
plt.show()

### 4.5 Issue Detection Performance

In [None]:
plot_issue_analysis(evaluation_results["issue_detection"])
plt.show()

## 5. Test on Individual Images with Explanations

In [None]:
# Get target layer for Grad-CAM
target_layer = get_target_layer(model, model_config.backbone)

# Test on a few sample images
num_samples = 5
sample_indices = np.random.choice(len(test_dataset), num_samples, replace=False)

for idx in sample_indices:
    # Load image
    img_path = Path(test_dataset.image_paths[idx])
    image = Image.open(img_path).convert("RGB")
    
    # Generate explanation
    predictions, vis_image = visualize_explanation(
        image=image,
        model=model,
        transform=test_dataset.transform,
        target_layer=target_layer,
        save_path=None,
    )
    
    # Display
    plt.figure(figsize=(20, 12))
    plt.imshow(vis_image)
    plt.axis("off")
    plt.title(f"Sample {idx + 1}: {img_path.name}", fontsize=16, fontweight="bold")
    plt.tight_layout()
    plt.show()
    
    print("\n" + "-"*80 + "\n")

## 6. Analyze Worst Predictions

In [None]:
# Find samples with largest errors
errors = np.abs(quality_scores - quality_targets)
worst_indices = np.argsort(errors)[-5:][::-1]

print("📉 Top 5 Worst Predictions:\n")
for rank, idx in enumerate(worst_indices, 1):
    print(f"{rank}. Index {idx}:")
    print(f"   Predicted: {quality_scores[idx]:.3f}")
    print(f"   Actual:    {quality_targets[idx]:.3f}")
    print(f"   Error:     {errors[idx]:.3f}")
    print()

# Visualize worst predictions
print("\nVisualizing worst predictions...\n")
for idx in worst_indices[:3]:  # Show top 3
    img_path = Path(test_dataset.image_paths[idx])
    image = Image.open(img_path).convert("RGB")
    
    predictions, vis_image = visualize_explanation(
        image=image,
        model=model,
        transform=test_dataset.transform,
        target_layer=target_layer,
    )
    
    plt.figure(figsize=(20, 12))
    plt.imshow(vis_image)
    plt.axis("off")
    plt.title(
        f"Error: {errors[idx]:.3f} | Pred: {quality_scores[idx]:.3f} | True: {quality_targets[idx]:.3f}",
        fontsize=16,
        fontweight="bold",
    )
    plt.tight_layout()
    plt.show()

## 7. Per-Class Analysis

In [None]:
# Analyze performance by quality class
class_names = ["High", "Good", "Moderate", "Poor", "Very Poor"]
quality_class_targets = np.array(evaluation_results["raw_predictions"]["quality_class_targets"])

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for class_idx, class_name in enumerate(class_names):
    # Get predictions for this class
    mask = quality_class_targets == class_idx
    class_scores = quality_scores[mask]
    class_targets = quality_targets[mask]
    
    if len(class_scores) > 0:
        ax = axes[class_idx]
        ax.scatter(class_targets, class_scores, alpha=0.5, s=20)
        ax.plot([0, 1], [0, 1], 'r--', lw=2)
        ax.set_xlim([0, 1])
        ax.set_ylim([0, 1])
        ax.set_xlabel("True Score")
        ax.set_ylabel("Predicted Score")
        ax.set_title(f"{class_name}\n(n={len(class_scores)})")
        ax.grid(True, alpha=0.3)
        
        # Calculate MAE for this class
        class_mae = np.mean(np.abs(class_scores - class_targets))
        ax.text(0.05, 0.95, f"MAE: {class_mae:.3f}", 
               transform=ax.transAxes, va="top", fontsize=10,
               bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5))

plt.tight_layout()
plt.suptitle("Prediction Performance by Quality Class", fontsize=16, y=1.02)
plt.show()

## 8. Summary and Recommendations

In [None]:
print("="*80)
print("MODEL EVALUATION SUMMARY")
print("="*80)

# Overall assessment
print("\n🎯 Overall Performance:")
if reg["r2"] > 0.8 and cls["balanced_accuracy"] > 0.8 and cal["ece"] < 0.1:
    print("  ✅ EXCELLENT - Model is performing well across all metrics")
elif reg["r2"] > 0.6 and cls["balanced_accuracy"] > 0.7:
    print("  ⚠️  GOOD - Model is usable but has room for improvement")
else:
    print("  ❌ NEEDS IMPROVEMENT - Model requires retraining")

# Specific recommendations
print("\n📋 Recommendations:")

if cal["ece"] > 0.1:
    print("  • Model is poorly calibrated - consider temperature scaling")

if binary["fpr"] > 0.1:
    print(f"  • High false reject rate ({binary['fpr']:.2%}) - may frustrate users")

if binary["fnr"] > 0.1:
    print(f"  • High false accept rate ({binary['fnr']:.2%}) - quality control issue")

if cls["balanced_accuracy"] < 0.7:
    print("  • Poor class balance - use weighted sampling or class weights")

if reg["within_10pct"] < 0.7:
    print("  • Many predictions are off by >10% - model needs more training")

# Check if any class has very poor performance
for class_name, metrics in cls["per_class"].items():
    if metrics["f1"] < 0.5:
        print(f"  • '{class_name}' class has F1={metrics['f1']:.3f} - needs more training data")

print("\n💡 Next Steps:")
print("  1. Review misclassified examples (Section 6)")
print("  2. Analyze per-class performance (Section 7)")
print("  3. Collect more real-world training data if needed")
print("  4. Consider model calibration techniques")
print("  5. Test on production data before deployment")
print("\n" + "="*80)

## 9. Export Results

Save evaluation results for later analysis or reporting.

In [None]:
import json
from datetime import datetime

# Create results directory
results_dir = project_root / "evaluation_results" / datetime.now().strftime("%Y%m%d_%H%M%S")
results_dir.mkdir(parents=True, exist_ok=True)

# Save metrics as JSON
serializable_results = {
    k: v for k, v in evaluation_results.items() if k != "raw_predictions"
}

with open(results_dir / "evaluation_metrics.json", "w") as f:
    json.dump(serializable_results, f, indent=2)

print(f"✓ Results saved to {results_dir}")