# Model Comparison Workflow

This notebook demonstrates how to compare model predictions, assess accuracy, create ensembles, and generate reports.

**Prerequisites**: Trained classifiers and prediction rasters for at least two models.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.config_schema import CLASS_SCHEMA, CLASS_COLORS

## 1. Accuracy Assessment

The `AccuracyAssessor` compares predictions against reference points, computing overall accuracy, kappa, F1, and per-class metrics.

In [None]:
from src.validation.accuracy_assessor import AccuracyAssessor

# Example usage (requires actual raster and reference data):
# assessor = AccuracyAssessor()
# results = assessor.assess(
#     prediction='../data/outputs/landcover_prithvi.tif',
#     reference='../data/validation/reference_points.gpkg',
#     class_field='LC_CLASS',
#     output_dir='../data/outputs/accuracy',
# )
# print(f"Overall accuracy: {results['overall_accuracy']:.4f}")
# print(f"Kappa: {results['kappa']:.4f}")

# Demo with synthetic data
print('AccuracyAssessor supports:')
print('  - Overall accuracy, kappa, macro F1')
print('  - Per-class producer/user accuracy and F1')
print('  - Confusion matrix export')
print('  - Reference point CRS reprojection')
print('  - Multi-model comparison and ranking')

## 2. Error Analysis

Spatial error analysis identifies where and why models fail.

In [None]:
from src.validation.error_analysis import (
    analyze_class_confusions,
    compute_spatial_error_density,
)

# Demo: analyze confusion patterns from a synthetic confusion matrix
confusion = np.array([
    [85, 5, 10],   # water: 85 correct, 5 confused with trees, 10 with built
    [3, 90, 7],    # trees: 90 correct
    [8, 2, 90],    # built: 90 correct
])
class_names = ['water', 'trees', 'built']

analysis = analyze_class_confusions(confusion, class_names)

print('Top confused class pairs:')
for pair in analysis['top_confusions'][:5]:
    print(f"  {pair['from']} -> {pair['to']}: {pair['count']} ({pair['rate']:.1%})")

print(f"\nPer-class omission rates:")
for name, rate in zip(class_names, analysis['omission_rates']):
    print(f"  {name}: {rate:.1%}")

## 3. Report Generation

Generate HTML reports with embedded plots for confusion matrices and per-class accuracy.

In [None]:
from src.validation.report_generator import (
    plot_confusion_matrix,
    plot_per_class_accuracy,
)

# Plot a confusion matrix
class_names = ['water', 'trees', 'built']
confusion = np.array([
    [85, 5, 10],
    [3, 90, 7],
    [8, 2, 90],
])

fig_data = plot_confusion_matrix(confusion, class_names, output_mode='figure')
plt.show()

# Plot per-class accuracy
per_class = {
    'water': {'producers_accuracy': 0.85, 'users_accuracy': 0.89, 'f1': 0.87},
    'trees': {'producers_accuracy': 0.90, 'users_accuracy': 0.93, 'f1': 0.91},
    'built': {'producers_accuracy': 0.90, 'users_accuracy': 0.84, 'f1': 0.87},
}

fig_data = plot_per_class_accuracy(per_class, output_mode='figure')
plt.show()

## 4. Model Comparison

Compare two classification maps pixel-by-pixel to measure spatial agreement.

In [None]:
from src.validation.comparison_metrics import compute_spatial_agreement

# Demo with synthetic classification arrays
np.random.seed(42)
size = (100, 100)
pred_a = np.random.randint(0, 7, size=size).astype(np.uint8)
pred_b = pred_a.copy()
# Introduce 20% disagreement
mask = np.random.random(size) < 0.2
pred_b[mask] = np.random.randint(0, 7, size=mask.sum()).astype(np.uint8)

agreement = compute_spatial_agreement(pred_a, pred_b, nodata=255)

print(f"Agreement: {agreement['agreement_pct']:.1f}%")
print(f"Valid pixels: {agreement['valid_pixels']:,}")

# Visualize agreement map
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(pred_a, cmap='tab10', vmin=0, vmax=6)
axes[0].set_title('Model A')
axes[1].imshow(pred_b, cmap='tab10', vmin=0, vmax=6)
axes[1].set_title('Model B')
agree_map = (pred_a == pred_b).astype(np.uint8)
axes[2].imshow(agree_map, cmap='RdYlGn', vmin=0, vmax=1)
axes[2].set_title(f'Agreement ({agreement["agreement_pct"]:.0f}%)')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

## 5. Ensemble Methods

Combine predictions from multiple models to improve accuracy.

In [None]:
from src.classification.ensemble import EnsembleClassifier

# Create 3 synthetic model predictions
np.random.seed(42)
size = (50, 50)
base = np.random.randint(0, 7, size=size).astype(np.uint8)

# Models mostly agree but have some noise
predictions = {}
for name in ['prithvi', 'satlas', 'ssl4eo']:
    pred = base.copy()
    noise = np.random.random(size) < 0.15
    pred[noise] = np.random.randint(0, 7, size=noise.sum()).astype(np.uint8)
    predictions[name] = pred

ec = EnsembleClassifier()

# Majority vote
result = ec.majority_vote(predictions)
print(f"Majority vote: {result['valid_pixels']} valid pixels")
print(f"Mean agreement: {result['mean_agreement']:.3f}")

# Weighted vote
weights = {'prithvi': 0.5, 'satlas': 0.3, 'ssl4eo': 0.2}
w_result = ec.weighted_vote(predictions, weights)
print(f"\nWeighted vote: mean confidence = {w_result['mean_confidence']:.3f}")

# Agreement map
agreement = ec.compute_agreement_map(predictions)
print(f"\nFull agreement pixels: {agreement['full_agreement_pct']:.1f}%")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(result['classification'], cmap='tab10', vmin=0, vmax=6)
axes[0].set_title('Ensemble (Majority Vote)')
axes[1].imshow(result['agreement'], cmap='YlGn', vmin=0, vmax=1)
axes[1].set_title('Agreement Map')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

## 6. Hierarchical Fusion

Merge predictions from different resolution sources.

In [None]:
from src.classification.ensemble import HierarchicalFusion

# Simulate base (10m, Sentinel-2) and refinement (1m, NAIP) predictions
np.random.seed(42)
size = (50, 50)
base_pred = np.random.randint(0, 7, size=size).astype(np.uint8)
refinement_pred = base_pred.copy()

# Refinement has some nodata (255) areas
refinement_pred[40:, :] = 255
# But is more detailed where available
refinement_pred[:40, :40] = 3  # grass

fusion = HierarchicalFusion(nodata=255)
result = fusion.fuse(
    base_pred, refinement_pred,
    strategy='high_res_priority',
)

print(f"Base pixels: {result['base_pct']:.1f}%")
print(f"Refinement pixels: {result['refinement_pct']:.1f}%")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(14, 4))
axes[0].imshow(base_pred, cmap='tab10', vmin=0, vmax=6)
axes[0].set_title('Base (Sentinel-2)')

ref_display = refinement_pred.copy().astype(float)
ref_display[refinement_pred == 255] = np.nan
axes[1].imshow(ref_display, cmap='tab10', vmin=0, vmax=6)
axes[1].set_title('Refinement (NAIP)')

axes[2].imshow(result['classification'], cmap='tab10', vmin=0, vmax=6)
axes[2].set_title('Fused Result')

for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

## 7. Resolution Matching

When combining rasters of different resolutions, the `ResolutionMatcher` aligns them to a common grid.

In [None]:
from src.processing.resolution_matcher import ResolutionMatcher

print('ResolutionMatcher capabilities:')
print('  - Auto-detect finest resolution across inputs')
print('  - Compute intersection or union extents')
print('  - Reproject across CRS boundaries')
print('  - Nearest neighbor (classification) or bilinear (probability) resampling')
print()
print('Example usage:')
print('  matcher = ResolutionMatcher()')
print('  aligned = matcher.align_rasters({')
print('      "sentinel": "10m_prediction.tif",')
print('      "naip": "1m_prediction.tif",')
print('  }, output_dir="aligned/")')
print('  arrays, meta = matcher.load_aligned_arrays(aligned_paths)')

## Summary

This notebook demonstrated:

1. **Accuracy assessment** against reference data
2. **Error analysis** to identify confusion patterns
3. **Report generation** with matplotlib plots
4. **Model comparison** with spatial agreement maps
5. **Ensemble methods** (majority vote, weighted vote)
6. **Hierarchical fusion** for multi-resolution data
7. **Resolution matching** for raster alignment

For the full CLI workflow, see `docs/workflow.md`.