[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jman4162/rankcal/blob/main/examples/tutorial.ipynb)

# rankcal Tutorial: Calibration for Ranking Systems

This tutorial demonstrates how to use **rankcal** to calibrate ranking scores and make better decisions. You'll learn:

1. Why calibration matters for ranking
2. How to detect and visualize miscalibration
3. How to choose and apply the right calibrator
4. Why top-k calibration is different from overall calibration
5. How to use calibrated scores for decision-making

---

## Quick Start (TL;DR)

```python
from rankcal import IsotonicCalibrator, ece

# 1. Fit calibrator on held-out data (NOT training data!)
calibrator = IsotonicCalibrator()
calibrator.fit(validation_scores, validation_labels)

# 2. Apply to new data
calibrated = calibrator(test_scores)

# 3. Verify improvement
print(f"ECE before: {ece(test_scores, labels):.4f}")
print(f"ECE after:  {ece(calibrated, labels):.4f}")
```

That's it! Read on for the full story.

---

## Installation (Colab)

If you're running this in Google Colab, uncomment and run the cell below to install rankcal:

In [None]:
# !pip install rankcal

## Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set up beautiful plots
sns.set_theme(style="whitegrid", palette="husl")
plt.rcParams["figure.figsize"] = (10, 6)
plt.rcParams["font.size"] = 11
plt.rcParams["axes.titlesize"] = 13
plt.rcParams["axes.labelsize"] = 11

# Color palette for consistency
COLORS = {
    "uncalibrated": "#e74c3c",  # red
    "calibrated": "#27ae60",    # green
    "perfect": "#2c3e50",       # dark gray
    "positive": "#3498db",      # blue
    "negative": "#e67e22",      # orange
    "highlight": "#9b59b6",     # purple
}

import rankcal
print(f"rankcal version: {rankcal.__version__}")

---
## 1. What is Calibration and Why Does It Matter?

A ranking model produces **scores** that rank items. But can you trust these scores as **probabilities**?

**Calibration** means: if your model outputs 0.7 for many items, about 70% of them should actually be relevant.

This matters when you need to:
- Set a threshold ("show items with score > 0.5")
- Estimate expected outcomes ("how many relevant items in top 100?")
- Combine scores from multiple models

Let's see what miscalibration looks like:

In [None]:
from rankcal import generate_miscalibrated_data, generate_calibrated_data

# Generate well-calibrated data
calibrated_scores, calibrated_labels = generate_calibrated_data(n_samples=5000, seed=42)

# Generate overconfident (miscalibrated) data - common in neural networks
overconfident_scores, overconfident_labels = generate_miscalibrated_data(
    n_samples=5000, temperature=0.5, seed=42  # temperature < 1 = overconfident
)

# Generate underconfident data
underconfident_scores, underconfident_labels = generate_miscalibrated_data(
    n_samples=5000, temperature=2.0, seed=42  # temperature > 1 = underconfident
)

print("Data generated!")
print(f"  - Calibrated: {len(calibrated_scores)} samples")
print(f"  - Overconfident: {len(overconfident_scores)} samples")
print(f"  - Underconfident: {len(underconfident_scores)} samples")

### Reliability Diagrams: Visualizing Calibration

A **reliability diagram** shows predicted confidence vs actual accuracy. Perfect calibration = diagonal line.

In [None]:
from rankcal.metrics.ece import calibration_error_per_bin

def to_np(tensor):
    """Convert tensor to numpy, handling gradients."""
    if hasattr(tensor, 'detach'):
        return tensor.detach().numpy()
    return tensor.numpy()

def plot_reliability_diagram(ax, scores, labels, title, color, n_bins=10):
    """Create a beautiful reliability diagram."""
    # Detach scores if they have gradients
    if hasattr(scores, 'detach'):
        scores = scores.detach()
    
    centers, accs, confs, counts = calibration_error_per_bin(scores, labels, n_bins=n_bins)
    mask = counts > 0
    
    # Perfect calibration line
    ax.plot([0, 1], [0, 1], '--', color=COLORS["perfect"], linewidth=2, 
            label="Perfect calibration", zorder=1)
    
    # Confidence bars
    bar_width = 0.08
    bars = ax.bar(to_np(confs[mask]), to_np(accs[mask]), width=bar_width,
                  color=color, alpha=0.8, edgecolor="white", linewidth=1.5, zorder=2)
    
    # Gap visualization (shaded area showing miscalibration)
    for conf, acc in zip(to_np(confs[mask]), to_np(accs[mask])):
        if acc != conf:
            ax.fill_between([conf - bar_width/2, conf + bar_width/2], 
                          [min(acc, conf), min(acc, conf)],
                          [max(acc, conf), max(acc, conf)],
                          color="red", alpha=0.2, zorder=1)
    
    ax.set_xlabel("Mean Predicted Probability")
    ax.set_ylabel("Fraction of Positives")
    ax.set_title(title, fontweight="bold")
    ax.set_xlim(-0.02, 1.02)
    ax.set_ylim(-0.02, 1.02)
    ax.legend(loc="upper left", framealpha=0.9)
    ax.set_aspect("equal")

# Create comparison plot
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

plot_reliability_diagram(axes[0], calibrated_scores, calibrated_labels, 
                        "Well-Calibrated", COLORS["calibrated"])
plot_reliability_diagram(axes[1], overconfident_scores, overconfident_labels,
                        "Overconfident\n(predictions too extreme)", COLORS["uncalibrated"])
plot_reliability_diagram(axes[2], underconfident_scores, underconfident_labels,
                        "Underconfident\n(predictions too moderate)", COLORS["highlight"])

plt.tight_layout()
plt.show()

### Measuring Calibration with ECE

**Expected Calibration Error (ECE)** quantifies miscalibration as a single number (lower = better).

In [None]:
from rankcal import ece, adaptive_ece, mce

datasets = [
    ("Well-Calibrated", calibrated_scores, calibrated_labels),
    ("Overconfident", overconfident_scores, overconfident_labels),
    ("Underconfident", underconfident_scores, underconfident_labels),
]

print("Calibration Metrics Comparison")
print("=" * 55)
print(f"{'Dataset':<18} {'ECE':>10} {'Adaptive ECE':>14} {'MCE':>10}")
print("-" * 55)

for name, scores, labels in datasets:
    ece_val = ece(scores, labels).item()
    aece_val = adaptive_ece(scores, labels).item()
    mce_val = mce(scores, labels).item()
    print(f"{name:<18} {ece_val:>10.4f} {aece_val:>14.4f} {mce_val:>10.4f}")

---
## 2. Fixing Miscalibration: Comparing Calibrators

rankcal provides 4 calibrators with different tradeoffs:

| Calibrator | Differentiable | Parameters | Best For |
|------------|----------------|------------|----------|
| `IsotonicCalibrator` | No | 0 | Post-hoc, production |
| `TemperatureScaling` | Yes | 1 | Simple over/underconfidence |
| `PiecewiseLinearCalibrator` | Yes | 10-20 | Moderate miscalibration |
| `MonotonicNNCalibrator` | Yes | ~100+ | Complex patterns |

In [None]:
from rankcal import (
    IsotonicCalibrator,
    TemperatureScaling,
    PiecewiseLinearCalibrator,
    MonotonicNNCalibrator,
)

# Use overconfident data - split into train/test
n_train = 3000
train_scores, train_labels = overconfident_scores[:n_train], overconfident_labels[:n_train]
test_scores, test_labels = overconfident_scores[n_train:], overconfident_labels[n_train:]

# Fit all calibrators
calibrators = {
    "Isotonic": IsotonicCalibrator(),
    "Temperature": TemperatureScaling(),
    "Piecewise Linear": PiecewiseLinearCalibrator(n_knots=10),
    "Monotonic NN": MonotonicNNCalibrator(hidden_dims=(16, 16)),
}

calibrated_outputs = {}
for name, cal in calibrators.items():
    cal.fit(train_scores, train_labels)
    calibrated_outputs[name] = cal(test_scores)
    
print("All calibrators fitted!")

In [None]:
# Compare results
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Original uncalibrated
plot_reliability_diagram(axes[0], test_scores, test_labels,
                        f"Uncalibrated\nECE = {ece(test_scores, test_labels):.4f}",
                        COLORS["uncalibrated"])

# Each calibrator
cal_colors = [COLORS["calibrated"], COLORS["positive"], COLORS["highlight"], COLORS["negative"]]
for i, (name, cal_scores) in enumerate(calibrated_outputs.items()):
    ece_val = ece(cal_scores, test_labels)
    plot_reliability_diagram(axes[i+1], cal_scores, test_labels,
                            f"{name}\nECE = {ece_val:.4f}",
                            cal_colors[i])

# Hide last subplot
axes[5].axis("off")

# Add summary text
summary_text = "Summary:\n\n"
summary_text += f"{'Method':<20} {'ECE':>8}\n"
summary_text += "-" * 30 + "\n"
summary_text += f"{'Uncalibrated':<20} {ece(test_scores, test_labels):>8.4f}\n"
for name, cal_scores in calibrated_outputs.items():
    summary_text += f"{name:<20} {ece(cal_scores, test_labels):>8.4f}\n"

axes[5].text(0.1, 0.5, summary_text, transform=axes[5].transAxes, 
            fontsize=12, fontfamily="monospace", verticalalignment="center")

plt.suptitle("Calibration Comparison: Overconfident Model", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

### Visualizing the Calibration Functions

Each calibrator learns a different transformation from raw scores to calibrated probabilities:

In [None]:
# Plot calibration functions
fig, ax = plt.subplots(figsize=(10, 8))

# Input range
x = torch.linspace(0.01, 0.99, 200)

# Identity line
ax.plot(x.numpy(), x.numpy(), '--', color=COLORS["perfect"], linewidth=2,
        label="No calibration (identity)", zorder=1)

# Plot each calibrator's function
for (name, cal), color in zip(calibrators.items(), cal_colors):
    y = cal(x).detach()
    ax.plot(x.numpy(), y.numpy(), linewidth=2.5, label=name, color=color)

ax.set_xlabel("Raw Score", fontsize=12)
ax.set_ylabel("Calibrated Probability", fontsize=12)
ax.set_title("Learned Calibration Functions", fontsize=14, fontweight="bold")
ax.legend(loc="lower right", fontsize=11)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect("equal")

# Add annotation
ax.annotate("Overconfident model:\nhigh scores pulled down",
           xy=(0.85, 0.6), fontsize=10, ha="center",
           bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.8))

plt.tight_layout()
plt.show()

---
## 3. The Key Insight: Top-k Calibration

**In ranking, you often only care about the top items.** A model might be well-calibrated overall but poorly calibrated in the top-k where decisions happen.

rankcal provides `ece_at_k` to measure calibration at specific positions.

In [None]:
from rankcal import ece_at_k, adaptive_ece_at_k, mce_at_k

# Use isotonic calibration
iso_calibrated = calibrated_outputs["Isotonic"]

# Measure ECE at different k values
k_values = [10, 25, 50, 100, 200, 500, 1000, len(test_scores)]

uncal_ece = [ece_at_k(test_scores, test_labels, k=k).item() for k in k_values]
cal_ece = [ece_at_k(iso_calibrated, test_labels, k=k).item() for k in k_values]

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

x_pos = range(len(k_values))
width = 0.35

bars1 = ax.bar([p - width/2 for p in x_pos], uncal_ece, width, 
               label="Uncalibrated", color=COLORS["uncalibrated"], alpha=0.8)
bars2 = ax.bar([p + width/2 for p in x_pos], cal_ece, width,
               label="Calibrated (Isotonic)", color=COLORS["calibrated"], alpha=0.8)

ax.set_xlabel("k (position cutoff)", fontsize=12)
ax.set_ylabel("ECE@k", fontsize=12)
ax.set_title("Calibration Error at Different Ranking Depths", fontsize=14, fontweight="bold")
ax.set_xticks(x_pos)
ax.set_xticklabels([str(k) if k < len(test_scores) else "all" for k in k_values])
ax.legend()
ax.set_ylim(0, max(uncal_ece) * 1.2)

# Add improvement annotations
for i, (u, c) in enumerate(zip(uncal_ece, cal_ece)):
    if u > 0:
        improvement = (u - c) / u * 100
        ax.annotate(f"{improvement:.0f}%↓", xy=(i, max(u, c) + 0.01),
                   ha="center", fontsize=9, color=COLORS["calibrated"])

plt.tight_layout()
plt.show()

### Side-by-side: Overall vs Top-100 Reliability

In [None]:
def plot_topk_reliability(ax, scores, labels, k, title, color):
    """Plot reliability diagram for top-k items only."""
    # Detach if needed
    if hasattr(scores, 'detach'):
        scores = scores.detach()
    
    # Get top-k
    _, top_k_idx = torch.topk(scores, k)
    top_scores = scores[top_k_idx]
    top_labels = labels[top_k_idx]
    
    plot_reliability_diagram(ax, top_scores, top_labels, title, color)

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

# Uncalibrated
plot_reliability_diagram(axes[0, 0], test_scores, test_labels,
                        f"Uncalibrated - All\nECE = {ece(test_scores, test_labels):.4f}",
                        COLORS["uncalibrated"])
plot_topk_reliability(axes[0, 1], test_scores, test_labels, 100,
                     f"Uncalibrated - Top 100\nECE@100 = {ece_at_k(test_scores, test_labels, 100):.4f}",
                     COLORS["uncalibrated"])

# Calibrated
plot_reliability_diagram(axes[1, 0], iso_calibrated, test_labels,
                        f"Calibrated - All\nECE = {ece(iso_calibrated, test_labels):.4f}",
                        COLORS["calibrated"])
plot_topk_reliability(axes[1, 1], iso_calibrated, test_labels, 100,
                     f"Calibrated - Top 100\nECE@100 = {ece_at_k(iso_calibrated, test_labels, 100):.4f}",
                     COLORS["calibrated"])

plt.suptitle("Why Top-k Calibration Matters", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

---
## 4. New in v0.2.0: Adaptive ECE and MCE

Standard ECE uses equal-width bins, which can be problematic when scores aren't uniformly distributed.

**Adaptive ECE** uses equal-mass (quantile) bins - each bin has the same number of samples.

**MCE (Maximum Calibration Error)** reports the worst-case bin error, important for safety-critical applications.

In [None]:
# Create skewed score distribution (common in practice)
torch.manual_seed(42)
# Most scores clustered near 0.8-0.9 (overconfident model)
skewed_scores = torch.clamp(0.85 + 0.1 * torch.randn(2000), 0.01, 0.99)
skewed_labels = (torch.rand(2000) < skewed_scores * 0.7).float()  # Actual rate lower

# Visualize the score distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Score histogram
ax = axes[0]
ax.hist(skewed_scores.numpy(), bins=30, color=COLORS["positive"], alpha=0.7, edgecolor="white")
ax.set_xlabel("Score")
ax.set_ylabel("Count")
ax.set_title("Skewed Score Distribution", fontweight="bold")
ax.axvline(x=skewed_scores.mean().item(), color=COLORS["uncalibrated"], linestyle="--", 
          label=f"Mean: {skewed_scores.mean():.2f}")
ax.legend()

# Standard ECE bins (equal-width)
ax = axes[1]
_, _, _, counts_std = calibration_error_per_bin(skewed_scores, skewed_labels, n_bins=10)
bin_edges = torch.linspace(0, 1, 11)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
ax.bar(bin_centers.numpy(), counts_std.numpy(), width=0.09, color=COLORS["uncalibrated"], 
       alpha=0.7, edgecolor="white")
ax.set_xlabel("Bin Center")
ax.set_ylabel("Samples per Bin")
ax.set_title("Standard ECE: Equal-Width Bins", fontweight="bold")
ax.annotate("Many empty bins!", xy=(0.2, 50), fontsize=10,
           bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8))

# Adaptive ECE bins (equal-mass)
ax = axes[2]
# For visualization, show that adaptive bins have equal samples
quantiles = torch.linspace(0, 1, 11)
adaptive_boundaries = torch.quantile(skewed_scores, quantiles)
adaptive_centers = (adaptive_boundaries[:-1] + adaptive_boundaries[1:]) / 2
# Each bin has ~200 samples (2000/10)
ax.bar(adaptive_centers.numpy(), [200]*10, width=0.05, color=COLORS["calibrated"],
       alpha=0.7, edgecolor="white")
ax.set_xlabel("Bin Center (data-driven)")
ax.set_ylabel("Samples per Bin")
ax.set_title("Adaptive ECE: Equal-Mass Bins", fontweight="bold")
ax.annotate("All bins well-populated!", xy=(0.75, 220), fontsize=10,
           bbox=dict(boxstyle="round", facecolor="lightgreen", alpha=0.8))

plt.tight_layout()
plt.show()

# Compare metrics
print("\nMetrics on Skewed Distribution:")
print(f"  Standard ECE:  {ece(skewed_scores, skewed_labels):.4f}")
print(f"  Adaptive ECE:  {adaptive_ece(skewed_scores, skewed_labels):.4f}")
print(f"  MCE (worst):   {mce(skewed_scores, skewed_labels):.4f}")

### MCE: Finding Worst-Case Calibration

MCE tells you the worst calibration error in any bin. This is crucial when you can't afford failures.

In [None]:
# Compare ECE vs MCE across calibrators
metrics_data = []
methods = ["Uncalibrated"] + list(calibrated_outputs.keys())
all_outputs = {"Uncalibrated": test_scores, **calibrated_outputs}

for method in methods:
    scores = all_outputs[method]
    metrics_data.append({
        "Method": method,
        "ECE": ece(scores, test_labels).item(),
        "MCE": mce(scores, test_labels).item(),
    })

# Plot
fig, ax = plt.subplots(figsize=(10, 6))

x = range(len(methods))
width = 0.35

ece_vals = [d["ECE"] for d in metrics_data]
mce_vals = [d["MCE"] for d in metrics_data]

bars1 = ax.bar([i - width/2 for i in x], ece_vals, width, label="ECE (average)",
               color=COLORS["positive"], alpha=0.8)
bars2 = ax.bar([i + width/2 for i in x], mce_vals, width, label="MCE (worst-case)",
               color=COLORS["negative"], alpha=0.8)

ax.set_xlabel("Method")
ax.set_ylabel("Calibration Error")
ax.set_title("ECE vs MCE: Average vs Worst-Case", fontsize=14, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(methods, rotation=15, ha="right")
ax.legend()

# Add "MCE >= ECE always" annotation
ax.annotate("MCE ≥ ECE always\n(max ≥ mean)", xy=(4, 0.35), fontsize=10, ha="center",
           bbox=dict(boxstyle="round", facecolor="lightyellow", alpha=0.9))

plt.tight_layout()
plt.show()

---
## 5. Decision Analysis: Using Calibrated Scores

The real power of calibration is enabling **optimal decisions**.

In [None]:
from rankcal import (
    risk_coverage_curve,
    utility_curve,
    optimal_threshold,
    utility_budget_curve,
)

# Use calibrated scores for decision analysis
cal_scores = iso_calibrated

### Risk-Coverage Curve

Shows the tradeoff: to reduce risk (error rate), you must reduce coverage (fraction shown).

In [None]:
coverage, risk = risk_coverage_curve(cal_scores, test_labels, n_thresholds=200)

fig, ax = plt.subplots(figsize=(10, 6))

ax.fill_between(coverage.numpy(), risk.numpy(), alpha=0.3, color=COLORS["positive"])
ax.plot(coverage.numpy(), risk.numpy(), linewidth=2.5, color=COLORS["positive"])

ax.set_xlabel("Coverage (fraction of items shown)", fontsize=12)
ax.set_ylabel("Risk (error rate on shown items)", fontsize=12)
ax.set_title("Risk-Coverage Tradeoff", fontsize=14, fontweight="bold")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Add operating points
for target_cov in [0.2, 0.5, 0.8]:
    idx = (coverage - target_cov).abs().argmin()
    ax.plot(coverage[idx], risk[idx], 'o', markersize=10, color=COLORS["highlight"])
    ax.annotate(f"  {coverage[idx]:.0%} coverage\n  {risk[idx]:.1%} error",
               xy=(coverage[idx], risk[idx]), fontsize=9)

plt.tight_layout()
plt.show()

### Utility Optimization

Find the optimal threshold given a cost/benefit tradeoff.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Different cost scenarios
scenarios = [
    (1, 1, "Equal (b=1, c=1)"),
    (1, 3, "Conservative (b=1, c=3)"),
    (3, 1, "Aggressive (b=3, c=1)"),
]

# Utility curves
ax = axes[0]
colors = [COLORS["positive"], COLORS["uncalibrated"], COLORS["calibrated"]]

for (benefit, cost, label), color in zip(scenarios, colors):
    thresholds, utility = utility_curve(cal_scores, test_labels, benefit=benefit, cost=cost)
    ax.plot(thresholds.numpy(), utility.numpy(), linewidth=2, label=label, color=color)
    
    # Mark optimal
    opt_thresh, opt_util = optimal_threshold(cal_scores, test_labels, benefit=benefit, cost=cost)
    ax.plot(opt_thresh.item(), opt_util.item(), 'o', markersize=10, color=color)

ax.set_xlabel("Threshold", fontsize=12)
ax.set_ylabel("Utility", fontsize=12)
ax.set_title("Utility vs Threshold", fontsize=13, fontweight="bold")
ax.legend(loc="best")
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)

# Utility vs budget
ax = axes[1]
budgets, utilities = utility_budget_curve(cal_scores, test_labels, max_budget=len(test_labels))

ax.fill_between(budgets.numpy(), utilities.numpy(), alpha=0.3, color=COLORS["calibrated"])
ax.plot(budgets.numpy(), utilities.numpy(), linewidth=2, color=COLORS["calibrated"])

# Find break-even and optimal
breakeven_idx = (utilities > 0).nonzero()[0][-1].item() if (utilities > 0).any() else 0
optimal_idx = utilities.argmax().item()

ax.axvline(x=budgets[optimal_idx], color=COLORS["highlight"], linestyle="--", alpha=0.7)
ax.plot(budgets[optimal_idx], utilities[optimal_idx], 'o', markersize=12, 
        color=COLORS["highlight"], label=f"Optimal: {budgets[optimal_idx]} items")

ax.set_xlabel("Budget (# items to review)", fontsize=12)
ax.set_ylabel("Utility", fontsize=12)
ax.set_title("Utility vs Budget", fontsize=13, fontweight="bold")
ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
ax.legend()

plt.tight_layout()
plt.show()

---
## 6. Complete Workflow Example

Here's the recommended pattern for using rankcal in practice:

In [None]:
# Simulate a real workflow
torch.manual_seed(123)

# 1. Your model produces raw scores (simulated here)
n_total = 10000
raw_scores = torch.sigmoid(torch.randn(n_total) * 1.5)  # Overconfident model
true_labels = (torch.rand(n_total) < raw_scores * 0.8).float()  # True relevance

# 2. Split: 70% train, 15% calibration, 15% test
n_train = int(0.70 * n_total)
n_cal = int(0.15 * n_total)

workflow_train_scores = raw_scores[:n_train]
workflow_train_labels = true_labels[:n_train]

workflow_cal_scores = raw_scores[n_train:n_train+n_cal]
workflow_cal_labels = true_labels[n_train:n_train+n_cal]

workflow_test_scores = raw_scores[n_train+n_cal:]
workflow_test_labels = true_labels[n_train+n_cal:]

print(f"Split: {n_train} train / {n_cal} calibration / {len(workflow_test_scores)} test")

# 3. Fit calibrator on calibration set (NOT training set!)
workflow_calibrator = IsotonicCalibrator()
workflow_calibrator.fit(workflow_cal_scores, workflow_cal_labels)

# 4. Apply to test set
workflow_calibrated = workflow_calibrator(workflow_test_scores)

# 5. Evaluate
print(f"\nTest Set Results:")
print(f"  ECE (before):     {ece(workflow_test_scores, workflow_test_labels):.4f}")
print(f"  ECE (after):      {ece(workflow_calibrated, workflow_test_labels):.4f}")
print(f"  ECE@50 (before):  {ece_at_k(workflow_test_scores, workflow_test_labels, k=50):.4f}")
print(f"  ECE@50 (after):   {ece_at_k(workflow_calibrated, workflow_test_labels, k=50):.4f}")

In [None]:
# Final visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Before
plot_reliability_diagram(axes[0], workflow_test_scores, workflow_test_labels,
                        f"Before Calibration\nECE = {ece(workflow_test_scores, workflow_test_labels):.4f}",
                        COLORS["uncalibrated"])

# After
plot_reliability_diagram(axes[1], workflow_calibrated, workflow_test_labels,
                        f"After Calibration\nECE = {ece(workflow_calibrated, workflow_test_labels):.4f}",
                        COLORS["calibrated"])

# Score distributions
ax = axes[2]
ax.hist(workflow_test_scores.numpy(), bins=30, alpha=0.5, label="Before", 
        color=COLORS["uncalibrated"], density=True)
ax.hist(workflow_calibrated.detach().numpy(), bins=30, alpha=0.5, label="After",
        color=COLORS["calibrated"], density=True)
ax.set_xlabel("Score")
ax.set_ylabel("Density")
ax.set_title("Score Distribution Shift", fontweight="bold")
ax.legend()

plt.suptitle("Complete Calibration Workflow Results", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

---
## 7. Bonus: Calibration Heatmap

Visualize how calibration varies across both score ranges and ranking positions:

In [None]:
def compute_calibration_heatmap(scores, labels, k_values, n_bins=5):
    """Compute calibration error for different k and score ranges."""
    # Detach if needed
    if hasattr(scores, 'detach'):
        scores = scores.detach()
    
    heatmap = np.zeros((len(k_values), n_bins))
    
    bin_edges = torch.linspace(0, 1, n_bins + 1)
    
    for i, k in enumerate(k_values):
        # Get top-k
        _, top_k_idx = torch.topk(scores, min(k, len(scores)))
        top_scores = scores[top_k_idx]
        top_labels = labels[top_k_idx]
        
        # Compute per-bin error
        for j in range(n_bins):
            mask = (top_scores >= bin_edges[j]) & (top_scores < bin_edges[j+1])
            if mask.sum() > 0:
                conf = top_scores[mask].mean()
                acc = top_labels[mask].mean()
                heatmap[i, j] = abs(conf - acc).item()
            else:
                heatmap[i, j] = np.nan
                
    return heatmap

# Compare uncalibrated vs calibrated
k_values = [50, 100, 200, 500, 1000, 2000]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Uncalibrated heatmap
heatmap_uncal = compute_calibration_heatmap(test_scores, test_labels, k_values)
im1 = axes[0].imshow(heatmap_uncal, cmap="RdYlGn_r", aspect="auto", vmin=0, vmax=0.4)
axes[0].set_yticks(range(len(k_values)))
axes[0].set_yticklabels(k_values)
axes[0].set_xticks(range(5))
axes[0].set_xticklabels(["0-0.2", "0.2-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0"])
axes[0].set_ylabel("Top-k")
axes[0].set_xlabel("Score Range")
axes[0].set_title("Uncalibrated: Calibration Error by Position & Score", fontweight="bold")

# Add values
for i in range(len(k_values)):
    for j in range(5):
        val = heatmap_uncal[i, j]
        if not np.isnan(val):
            axes[0].text(j, i, f"{val:.2f}", ha="center", va="center", 
                        color="white" if val > 0.2 else "black", fontsize=9)

# Calibrated heatmap
heatmap_cal = compute_calibration_heatmap(iso_calibrated, test_labels, k_values)
im2 = axes[1].imshow(heatmap_cal, cmap="RdYlGn_r", aspect="auto", vmin=0, vmax=0.4)
axes[1].set_yticks(range(len(k_values)))
axes[1].set_yticklabels(k_values)
axes[1].set_xticks(range(5))
axes[1].set_xticklabels(["0-0.2", "0.2-0.4", "0.4-0.6", "0.6-0.8", "0.8-1.0"])
axes[1].set_ylabel("Top-k")
axes[1].set_xlabel("Score Range")
axes[1].set_title("Calibrated: Calibration Error by Position & Score", fontweight="bold")

# Add values
for i in range(len(k_values)):
    for j in range(5):
        val = heatmap_cal[i, j]
        if not np.isnan(val):
            axes[1].text(j, i, f"{val:.2f}", ha="center", va="center",
                        color="white" if val > 0.2 else "black", fontsize=9)

# Colorbar
fig.colorbar(im2, ax=axes, label="Calibration Error", shrink=0.8)

plt.suptitle("Where is Miscalibration Worst?", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

---
## 8. Production Monitoring Dashboard

Track calibration metrics over time or across user segments:

In [None]:
# Simulate monitoring data over time (e.g., daily batches)
def simulate_monitoring_data(n_days=14):
    """Simulate calibration drift over time."""
    np.random.seed(42)
    data = []
    
    for day in range(n_days):
        # Simulate gradual drift (calibration gets worse over time)
        drift = day * 0.005
        n_samples = np.random.randint(800, 1200)
        
        scores = torch.rand(n_samples)
        # Add drift: model becomes overconfident over time
        labels = (torch.rand(n_samples) < (scores - drift).clamp(0, 1)).float()
        
        data.append({
            "day": day + 1,
            "ece": ece(scores, labels).item(),
            "mce": mce(scores, labels).item(),
            "ece_at_100": ece_at_k(scores, labels, k=100).item(),
            "n_samples": n_samples,
        })
    
    return data

monitoring_data = simulate_monitoring_data()

# Create monitoring dashboard
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

days = [d["day"] for d in monitoring_data]

# ECE over time
ax = axes[0, 0]
ece_vals = [d["ece"] for d in monitoring_data]
ax.plot(days, ece_vals, 'o-', color=COLORS["positive"], linewidth=2, markersize=8)
ax.axhline(y=0.05, color=COLORS["uncalibrated"], linestyle="--", alpha=0.7, label="Alert threshold")
ax.fill_between(days, 0, ece_vals, alpha=0.2, color=COLORS["positive"])
ax.set_xlabel("Day")
ax.set_ylabel("ECE")
ax.set_title("Daily ECE Trend", fontweight="bold")
ax.legend()
ax.set_ylim(0, max(ece_vals) * 1.2)

# ECE vs MCE comparison
ax = axes[0, 1]
mce_vals = [d["mce"] for d in monitoring_data]
ax.plot(days, ece_vals, 'o-', color=COLORS["positive"], linewidth=2, markersize=6, label="ECE (avg)")
ax.plot(days, mce_vals, 's-', color=COLORS["negative"], linewidth=2, markersize=6, label="MCE (worst)")
ax.set_xlabel("Day")
ax.set_ylabel("Calibration Error")
ax.set_title("ECE vs MCE Over Time", fontweight="bold")
ax.legend()

# ECE@100 (top-k monitoring)
ax = axes[1, 0]
ece_100 = [d["ece_at_100"] for d in monitoring_data]
ax.bar(days, ece_100, color=COLORS["highlight"], alpha=0.7, edgecolor="white")
ax.axhline(y=0.10, color=COLORS["uncalibrated"], linestyle="--", alpha=0.7, label="Alert threshold")
ax.set_xlabel("Day")
ax.set_ylabel("ECE@100")
ax.set_title("Top-100 Calibration (Where Decisions Happen)", fontweight="bold")
ax.legend()

# Sample count
ax = axes[1, 1]
samples = [d["n_samples"] for d in monitoring_data]
ax.bar(days, samples, color=COLORS["calibrated"], alpha=0.7, edgecolor="white")
ax.set_xlabel("Day")
ax.set_ylabel("Samples")
ax.set_title("Daily Sample Volume", fontweight="bold")
ax.axhline(y=np.mean(samples), color="gray", linestyle="--", alpha=0.7, label=f"Avg: {np.mean(samples):.0f}")
ax.legend()

plt.suptitle("Calibration Monitoring Dashboard", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()

# Alert summary
print("\n" + "="*50)
print("MONITORING SUMMARY")
print("="*50)
latest = monitoring_data[-1]
print(f"Latest ECE:     {latest['ece']:.4f}  {'⚠️  ALERT' if latest['ece'] > 0.05 else '✓'}")
print(f"Latest MCE:     {latest['mce']:.4f}  {'⚠️  ALERT' if latest['mce'] > 0.15 else '✓'}")
print(f"Latest ECE@100: {latest['ece_at_100']:.4f}  {'⚠️  ALERT' if latest['ece_at_100'] > 0.10 else '✓'}")
print(f"\nTrend: ECE increased {(ece_vals[-1] - ece_vals[0])/ece_vals[0]*100:.1f}% over {len(days)} days")
if ece_vals[-1] > ece_vals[0] * 1.5:
    print("⚠️  Consider recalibrating the model!")

---
## Summary

### Key Takeaways

1. **Calibration matters** when you use scores for decisions, not just ranking
2. **Reliability diagrams** visualize calibration quality
3. **ECE** measures average miscalibration; **MCE** measures worst-case
4. **Top-k calibration** often differs from overall - use `ece_at_k`
5. **Adaptive ECE** handles skewed score distributions better
6. **Never calibrate on training data** - always use a held-out set
7. **Monitor calibration** in production - it drifts over time!

### Choosing a Calibrator

```
Need differentiable? ─No──→ IsotonicCalibrator (recommended default)
         │
        Yes
         │
         ├─ Simple scaling needed? ──→ TemperatureScaling
         │
         └─ Complex pattern? ──→ PiecewiseLinearCalibrator or MonotonicNNCalibrator
```

### Metrics Cheat Sheet

| Metric | Use When | Interpretation |
|--------|----------|----------------|
| `ece` | General calibration check | < 0.05 is good |
| `ece_at_k` | Ranking systems | Focus on top-k where decisions happen |
| `adaptive_ece` | Skewed score distributions | More robust than standard ECE |
| `mce` | Safety-critical applications | Worst-case bin error |

### Learn More

- **Conceptual Guide**: `docs/guide.md`
- **API Reference**: Docstrings for all functions
- **Examples**: `examples/` directory

---

*Tutorial created for rankcal v0.2.0*