# Your AI Metrics Are Lying to You: CJE Core Demo

This notebook demonstrates the **Causal Judge Evaluation (CJE)** concepts from first principles.

**The problem**: What scores high on quick metrics can predict low value on what actually matters. "You're absolutely right!" scored high on politeness but tanked developer productivity.

**The fix**: Learn how your cheap metrics (S) predict real outcomes (Y) using a calibration slice.

## The Deliberation Ladder

```
Y* │ Idealized Deliberation Oracle (unobservable)
   │ What you'd decide with unlimited time & perfect information
   │
Y  │ Oracle / High-Rung Outcome  
   │ Expensive but practical labels (expert audits, task success)
   │
S  │ Cheap Surrogate
   │ Fast signals at scale (LLM-judge scores, clicks)
```

CJE calibrates S→Y so you can aim abundant S data at Y*.

---

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cimo-labs/cje/blob/main/examples/cje_core_demo.ipynb)

## Setup

In [None]:
# Install dependencies (uncomment for Colab)
# !pip install -q scikit-learn pandas numpy matplotlib seaborn

In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.isotonic import IsotonicRegression
from pathlib import Path
import urllib.request
import os

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("Setup complete!")

In [None]:
# Download Arena sample data from GitHub (for Colab)
DATA_DIR = Path("arena_sample")

# Check if data already exists locally
if (DATA_DIR / "fresh_draws" / "base_responses.jsonl").exists():
    print(f"Data already exists at {DATA_DIR.absolute()}")
else:
    print("Downloading data from GitHub...")
    DATA_DIR.mkdir(exist_ok=True)
    (DATA_DIR / "fresh_draws").mkdir(exist_ok=True)
    (DATA_DIR / "probe_slice").mkdir(exist_ok=True)
    
    BASE_URL = "https://raw.githubusercontent.com/cimo-labs/cje/main/examples/arena_sample"
    
    # Fresh draws
    fresh_files = ["base_responses.jsonl", "clone_responses.jsonl", 
                   "parallel_universe_prompt_responses.jsonl", "unhelpful_responses.jsonl"]
    for filename in fresh_files:
        print(f"  Downloading fresh_draws/{filename}...")
        urllib.request.urlretrieve(
            f"{BASE_URL}/fresh_draws/{filename}",
            DATA_DIR / "fresh_draws" / filename
        )
    
    # Probe slice (for transportability testing)
    probe_files = ["clone_probe.jsonl", "parallel_universe_prompt_probe.jsonl", "unhelpful_probe.jsonl"]
    for filename in probe_files:
        print(f"  Downloading probe_slice/{filename}...")
        urllib.request.urlretrieve(
            f"{BASE_URL}/probe_slice/{filename}",
            DATA_DIR / "probe_slice" / filename
        )
    print("Done!")

## 1. Load the Arena Sample Data

We have 4 policies evaluated on 1000 Chatbot Arena prompts:
- **base**: Llama 3.3 70B with standard prompt (logging policy)
- **clone**: Same as base, different seed (sanity check)
- **parallel_universe_prompt**: Modified system prompt
- **unhelpful**: Deliberately confusing responses (adversarial)

Each response has:
- **S (judge_score)**: GPT-4.1-nano score (cheap, 16x cheaper)
- **Y (oracle_label)**: GPT-5 score (expensive oracle)

Oracle coverage:
- **base**: ~48% (calibration training set)
- **other policies**: ~5% each (probe slice for transportability testing)

In [None]:
def load_policy_data(policy: str, data_dir: Path) -> pd.DataFrame:
    """Load response data for a policy."""
    path = data_dir / "fresh_draws" / f"{policy}_responses.jsonl"
    records = []
    with open(path) as f:
        for line in f:
            r = json.loads(line)
            records.append({
                'prompt_id': r['prompt_id'],
                'prompt': r['prompt'],
                'response': r['response'],
                'policy': policy,
                'judge_score': r['judge_score'],  # S
                'oracle_label': r.get('oracle_label'),  # Y (may be None)
                'response_length': len(r['response'])
            })
    return pd.DataFrame(records)

def load_probe_slice(policy: str, data_dir: Path) -> pd.DataFrame:
    """Load probe slice with oracle labels for transportability testing."""
    path = data_dir / "probe_slice" / f"{policy}_probe.jsonl"
    if not path.exists():
        return pd.DataFrame()
    records = []
    with open(path) as f:
        for line in f:
            r = json.loads(line)
            records.append({
                'prompt_id': r['prompt_id'],
                'prompt': r['prompt'],
                'response': r['response'],
                'policy': policy,
                'judge_score': r['judge_score'],
                'oracle_label': r['oracle_label'],
                'response_length': len(r['response'])
            })
    return pd.DataFrame(records)

# Load data for all policies
POLICIES = ['base', 'clone', 'parallel_universe_prompt', 'unhelpful']

# Fresh draws (for estimation)
all_data = pd.concat([load_policy_data(p, DATA_DIR) for p in POLICIES], ignore_index=True)

# Probe slice (for transportability testing - separate from calibration training)
probe_data = pd.concat([load_probe_slice(p, DATA_DIR) for p in POLICIES if p != 'base'], ignore_index=True)

print(f"Loaded {len(all_data):,} fresh draw samples across {all_data['policy'].nunique()} policies")
print(f"Loaded {len(probe_data):,} probe slice samples for transportability testing")

# Show oracle coverage
oracle_coverage = all_data.groupby('policy')['oracle_label'].apply(lambda x: x.notna().mean())
print(f"\nOracle coverage (fresh_draws):")
print(oracle_coverage.round(2))

In [None]:
# Summary statistics
all_data.groupby('policy')[['judge_score', 'oracle_label']].agg(['mean', 'std', 'count']).round(3)

## 2. The Problem: Naive S ≠ Y

If we just use judge scores (S) directly, we might get the wrong answer. Let's see how S and Y differ across policies.

In [None]:
# Compare S vs Y by policy (only samples with oracle labels)
has_oracle = all_data[all_data['oracle_label'].notna()].copy()

policy_stats = has_oracle.groupby('policy').agg({
    'judge_score': 'mean',
    'oracle_label': 'mean'
}).round(3)
policy_stats['gap'] = (policy_stats['judge_score'] - policy_stats['oracle_label']).round(3)
policy_stats = policy_stats.sort_values('oracle_label', ascending=False)

print("Policy Rankings: Judge (S) vs Oracle (Y)")
print("="*50)
print(policy_stats)
print("\nNote the gap between judge and oracle for each policy!")

In [None]:
# Visualize S vs Y correlation
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Overall correlation (samples with oracle)
ax = axes[0]
sample = has_oracle.sample(min(2000, len(has_oracle)), random_state=42)
ax.scatter(sample['judge_score'], sample['oracle_label'], alpha=0.3, s=10)
ax.plot([0, 1], [0, 1], 'r--', label='Perfect calibration')
ax.set_xlabel('Judge Score (S)', fontsize=12)
ax.set_ylabel('Oracle Label (Y)', fontsize=12)
ax.set_title('S vs Y: Overall Correlation', fontsize=14)
ax.legend()

# Right: By policy
ax = axes[1]
colors = {'base': 'C0', 'clone': 'C1', 'parallel_universe_prompt': 'C3', 'unhelpful': 'red'}
for policy in POLICIES:
    data = has_oracle[has_oracle['policy'] == policy]
    if len(data) > 0:
        ax.scatter(data['judge_score'].mean(), data['oracle_label'].mean(), 
                   s=200, label=policy, c=colors[policy], edgecolors='black', linewidth=2)
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlabel('Mean Judge Score (S)', fontsize=12)
ax.set_ylabel('Mean Oracle Label (Y)', fontsize=12)
ax.set_title('Policy Means: S vs Y', fontsize=14)
ax.legend(loc='lower right')

plt.tight_layout()
plt.show()

print("\nThe 'unhelpful' policy (red) shows the biggest gap - the judge is fooled!")

## 3. The Solution: Calibrate S→Y

We learn a calibration function on the **base policy** (our oracle slice), then apply it to all policies.

**AutoCal-R**: Isotonic regression ensures predictions are monotone in S and mean-preserving.

In [None]:
# Train calibrator on base policy only (samples with oracle labels)
base_with_oracle = all_data[(all_data['policy'] == 'base') & (all_data['oracle_label'].notna())].copy()
S_train = base_with_oracle['judge_score'].values
Y_train = base_with_oracle['oracle_label'].values

# Fit isotonic regression (AutoCal-R)
calibrator = IsotonicRegression(out_of_bounds='clip')
calibrator.fit(S_train, Y_train)

print(f"Calibrator trained on {len(S_train):,} base policy samples with oracle labels")
print(f"Mean S: {S_train.mean():.3f}, Mean Y: {Y_train.mean():.3f}")

In [None]:
# Visualize the calibration curve
fig, ax = plt.subplots(figsize=(10, 6))

# Scatter of training data
ax.scatter(S_train, Y_train, alpha=0.2, s=10, label='Training data (base with oracle)')

# Calibration curve
S_grid = np.linspace(0, 1, 100)
Y_pred = calibrator.predict(S_grid)
ax.plot(S_grid, Y_pred, 'r-', linewidth=3, label='Isotonic calibration')

# Perfect calibration line
ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration (S=Y)')

ax.set_xlabel('Judge Score (S)', fontsize=12)
ax.set_ylabel('Oracle Label (Y) / Calibrated Prediction', fontsize=12)
ax.set_title('AutoCal-R: Learning S→Y Mapping', fontsize=14)
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

print("\nThe calibration curve shows how to map judge scores to oracle-scale predictions.")
print("Notice the curve is monotone (higher S -> higher predicted Y).")

## 4. Apply Calibration & Compute Residuals

Now we apply the calibrator to ALL policies and compute residuals (for samples with oracle labels):

**Residual = Y - Y_hat = Oracle - Calibrated_Prediction**

- Residual < 0: Calibrator **over-predicted** (predicted quality is higher than actual)
- Residual > 0: Calibrator **under-predicted** (predicted quality is lower than actual)

In [None]:
# Apply calibration to all data
all_data['calibrated_pred'] = calibrator.predict(all_data['judge_score'].values)

# Load probe slice and run transportability audit using CJE canonical interface
from cje.diagnostics import audit_transportability

print("Transportability Audit (CJE Canonical Interface)")
print("="*70)

transport_results = {}
for policy in ['clone', 'parallel_universe_prompt', 'unhelpful']:
    # Load probe directly as dicts - no boilerplate needed!
    probe_path = DATA_DIR / "probe_slice" / f"{policy}_probe.jsonl"
    probe = [json.loads(line) for line in open(probe_path)]
    
    # Run audit
    diag = audit_transportability(calibrator, probe, group_label=f"policy:{policy}")
    transport_results[policy] = diag
    print(diag.summary())

# Also compute base residuals for comparison
base_with_oracle['calibrated_pred'] = calibrator.predict(base_with_oracle['judge_score'].values)
base_with_oracle['residual'] = base_with_oracle['oracle_label'] - base_with_oracle['calibrated_pred']
print(f"\nBase policy: mean residual = {base_with_oracle['residual'].mean():.4f} (training set, ~0 by construction)")

In [None]:
# Transportability visualization using canonical CJE interface
from cje.diagnostics import plot_transport_comparison

# Plot comparison of all policies
fig = plot_transport_comparison(transport_results, title="Transportability Audit: Does Calibration Transfer?")
plt.show()

# Also show detailed view for unhelpful (the failing policy)
print("\nDetailed view: unhelpful policy (FAIL)")
fig = transport_results['unhelpful'].plot()
plt.show()

print("\nKey insight: The unhelpful policy shows systematic overestimation (negative residuals)")
print("across ALL score deciles - the judge is fooled regardless of how confident it is.")

## 5. Deep Dive: Why Does 'unhelpful' Fail?

The transportability test showed `unhelpful` has a large negative mean residual - the calibrator **overestimates** its quality.

Let's look at **individual samples** with the biggest residuals to understand what's fooling the calibrator.

In [None]:
# Compute residuals for probe data (for visualization)
probe_data['calibrated_pred'] = calibrator.predict(probe_data['judge_score'].values)
probe_data['residual'] = probe_data['oracle_label'] - probe_data['calibrated_pred']

# Get unhelpful policy data from probe slice
unhelpful_data = probe_data[probe_data['policy'] == 'unhelpful'].copy()
unhelpful_data['abs_residual'] = unhelpful_data['residual'].abs()

if len(unhelpful_data) > 0:
    # Summary stats
    print("UNHELPFUL Policy Residual Analysis (from probe slice)")
    print("="*50)
    print(f"Samples with oracle labels: {len(unhelpful_data)}")
    print(f"Mean residual: {unhelpful_data['residual'].mean():.4f}")
    print(f"Std residual:  {unhelpful_data['residual'].std():.4f}")
    print(f"Min residual:  {unhelpful_data['residual'].min():.4f} (biggest overestimate)")
    print(f"Max residual:  {unhelpful_data['residual'].max():.4f} (biggest underestimate)")
    print(f"\nSamples with |residual| > 0.3: {(unhelpful_data['abs_residual'] > 0.3).sum():,}")
    print(f"Samples with |residual| > 0.5: {(unhelpful_data['abs_residual'] > 0.5).sum():,}")
else:
    print("No probe slice data for unhelpful policy")

In [None]:
# Show top 5 biggest overestimates (where calibrator was most fooled)
if len(unhelpful_data) > 0:
    worst_overestimates = unhelpful_data.nsmallest(5, 'residual')

    print("\n" + "="*80)
    print("TOP 5 BIGGEST OVERESTIMATES (Calibrator was most fooled)")
    print("="*80)

    for i, (_, row) in enumerate(worst_overestimates.iterrows(), 1):
        print(f"\n--- #{i} | Residual: {row['residual']:.3f}")
        print(f"    Judge: {row['judge_score']:.2f} -> Calibrated: {row['calibrated_pred']:.2f} | Oracle: {row['oracle_label']:.2f}")
        print(f"    Prompt: {row['prompt'][:100]}...")
        print(f"    Response: {row['response'][:200]}...")

In [None]:
# Visualize residual distribution by policy
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Combine base (training) and probe (test) data for visualization
with_oracle = pd.concat([base_with_oracle, probe_data], ignore_index=True)

# Left: Histogram of residuals
ax = axes[0]
for policy in ['base', 'clone', 'unhelpful']:
    data = with_oracle[with_oracle['policy'] == policy]['residual']
    if len(data) > 0:
        ax.hist(data, bins=30, alpha=0.5, label=f"{policy} (n={len(data)})", density=True)
ax.axvline(x=0, color='black', linestyle='--', linewidth=2)
ax.set_xlabel('Residual (Y - Ŷ)', fontsize=12)
ax.set_ylabel('Density', fontsize=12)
ax.set_title('Residual Distribution by Policy', fontsize=14)
ax.legend()

# Right: Residual vs Judge Score for unhelpful
ax = axes[1]
if len(unhelpful_data) > 0:
    scatter = ax.scatter(unhelpful_data['judge_score'], unhelpful_data['residual'], 
                         c=unhelpful_data['oracle_label'], cmap='RdYlGn', alpha=0.6, s=30)
    ax.axhline(y=0, color='black', linestyle='--', linewidth=2)
    ax.set_xlabel('Judge Score (S)', fontsize=12)
    ax.set_ylabel('Residual (Y - Ŷ)', fontsize=12)
    ax.set_title('Unhelpful Policy: Residual vs Judge Score', fontsize=14)
    plt.colorbar(scatter, ax=ax, label='Oracle Label')

plt.tight_layout()
plt.show()

print("\nKey insight: High judge scores (S > 0.6) often have negative residuals.")
print("The judge gives high scores to confident nonsense; the oracle sees through it.")

## 6. What Patterns Fool the Judge?

Let's analyze what characteristics are associated with large overestimates (negative residuals).

In [None]:
# Analyze response length vs residual
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Response length vs residual for unhelpful
ax = axes[0]
if len(unhelpful_data) > 0:
    ax.scatter(unhelpful_data['response_length'], unhelpful_data['residual'], alpha=0.3, s=20)
    ax.axhline(y=0, color='red', linestyle='--', linewidth=2)

    # Add trend line
    z = np.polyfit(unhelpful_data['response_length'], unhelpful_data['residual'], 1)
    p = np.poly1d(z)
    x_line = np.linspace(unhelpful_data['response_length'].min(), unhelpful_data['response_length'].max(), 100)
    ax.plot(x_line, p(x_line), 'orange', linewidth=2, label=f'Trend: slope={z[0]:.2e}')

    ax.set_xlabel('Response Length (chars)', fontsize=12)
    ax.set_ylabel('Residual (Y - Y_hat)', fontsize=12)
    ax.set_title('Unhelpful: Does Length Predict Residual?', fontsize=14)
    ax.legend()

# Right: Compare response lengths by policy
ax = axes[1]
all_data.boxplot(column='response_length', by='policy', ax=ax)
ax.set_xlabel('Policy', fontsize=12)
ax.set_ylabel('Response Length (chars)', fontsize=12)
ax.set_title('Response Length by Policy', fontsize=14)
plt.suptitle('')  # Remove auto-title

plt.tight_layout()
plt.show()

# Correlation analysis
if len(unhelpful_data) > 0:
    print("\nCorrelation with Residual (unhelpful policy):")
    print(f"  Response length: {unhelpful_data['response_length'].corr(unhelpful_data['residual']):.3f}")
    print(f"  Judge score:     {unhelpful_data['judge_score'].corr(unhelpful_data['residual']):.3f}")

## 7. The Fix: Use Residuals to Improve Your Metrics

The CJE loop: **Calibrate -> Inspect Residuals -> Improve S -> Recalibrate**

What we learned from the unhelpful policy:
1. The judge is fooled by **confident-sounding nonsense**
2. **High judge scores with low oracle** = reward hacking
3. **Transportability fails** when response distribution shifts dramatically

### Potential Fixes:
- Add **response length** as a covariate (two-stage calibration)
- Use a **domain-specific judge prompt** that catches deliberate misinformation
- Add **factual verification** as an additional S signal
- **Reject transportability** for adversarial policies and require policy-specific calibration

In [None]:
# Summary: What to do with these insights
print("="*70)
print("SUMMARY: CJE Pipeline Insights")
print("="*70)

print("\nTransportability Results:")
print("  base: PASS (training set, residual ~0 by construction)")
for policy, diag in transport_results.items():
    status_emoji = "✓" if diag.status == "PASS" else "✗"
    print(f"  {policy}: {diag.status} {status_emoji} (δ̂={diag.delta_hat:+.3f})")

print("\nRECOMMENDED ACTIONS:")
print("  1. For PASS policies: Use calibrated estimates directly")
print("  2. For FAIL policies: Report rankings only, not absolute values")
print("  3. Investigate: What makes failing responses fool the judge?")
print("  4. Iterate: Improve S or collect policy-specific oracle labels")

# Show recommended action for failing policies
for policy, diag in transport_results.items():
    if diag.status == "FAIL" and diag.recommended_action:
        print(f"\n  → {policy}: {diag.recommended_action}")

## 8. Temporal Drift: When Calibration Goes Stale

In production, your calibration will **drift over time**:
- User behavior changes
- Model updates
- World events shift topics
- Judge/oracle relationship evolves

**The solution**: Periodically collect fresh (S, Y) pairs and run residual checks.

**Alert if**: Mean residual is significantly non-zero for 2+ consecutive checks.

In [None]:
# Simulate temporal drift
# We'll pretend we're monitoring the base policy over 8 weeks
# In weeks 1-4, calibration holds. In weeks 5-8, drift begins.

np.random.seed(42)

def simulate_drift_monitoring(base_data, calibrator, n_weeks=8, samples_per_week=50):
    """Simulate weekly monitoring with gradual drift starting at week 5."""
    
    # Get samples with oracle labels
    oracle_samples = base_data[base_data['oracle_label'].notna()].copy()
    
    weekly_results = []
    
    for week in range(1, n_weeks + 1):
        # Sample a batch for this week
        batch = oracle_samples.sample(n=min(samples_per_week, len(oracle_samples)), 
                                       replace=True, random_state=week)
        
        S = batch['judge_score'].values
        Y = batch['oracle_label'].values
        
        # Simulate drift: starting week 5, oracle values degrade
        # (model update made judge scores less predictive)
        if week >= 5:
            drift_amount = 0.03 * (week - 4)  # Growing drift
            # Simulate: high S now corresponds to lower Y than expected
            Y_drifted = Y - drift_amount * (S - 0.5)  # Larger effect for high S
            Y_drifted = np.clip(Y_drifted, 0, 1)
            Y = Y_drifted
        
        # Compute calibrated predictions and residuals
        Y_hat = calibrator.predict(S)
        residuals = Y - Y_hat
        
        weekly_results.append({
            'week': week,
            'mean_residual': np.mean(residuals),
            'std_residual': np.std(residuals),
            'n_samples': len(residuals),
            'se': np.std(residuals) / np.sqrt(len(residuals))
        })
    
    return pd.DataFrame(weekly_results)

# Run simulation
base_only = all_data[all_data['policy'] == 'base']
drift_results = simulate_drift_monitoring(base_only, calibrator)

print("Weekly Residual Monitoring (Simulated)")
print("="*60)
print(drift_results.to_string(index=False))
print("\nNote: Drift injected starting week 5")

In [None]:
# Visualize drift detection
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Time series of mean residual with CI
ax = axes[0]
weeks = drift_results['week']
means = drift_results['mean_residual']
ses = drift_results['se']

ax.errorbar(weeks, means, yerr=1.96*ses, fmt='o-', capsize=5, 
            capthick=2, markersize=8, linewidth=2, color='C0')

# Color points by alert status
alert_threshold = 0.04  # Example threshold
for i, (w, m, se) in enumerate(zip(weeks, means, ses)):
    ci_lower, ci_upper = m - 1.96*se, m + 1.96*se
    if ci_upper < -alert_threshold or ci_lower > alert_threshold:
        color = 'red'  # ALERT
    elif ci_upper < 0 or ci_lower > 0:
        color = 'orange'  # WARNING
    else:
        color = 'green'  # OK
    ax.scatter(w, m, s=150, c=color, zorder=5, edgecolors='black', linewidth=2)

ax.axhline(y=0, color='black', linestyle='--', linewidth=2, label='Perfect calibration')
ax.axhspan(-alert_threshold, alert_threshold, alpha=0.2, color='green', label='Acceptable range')
ax.axvline(x=4.5, color='red', linestyle=':', linewidth=2, alpha=0.7, label='Drift starts')

ax.set_xlabel('Week', fontsize=12)
ax.set_ylabel('Mean Residual (Y - Y_hat)', fontsize=12)
ax.set_title('Drift Monitoring: Weekly Residual Checks', fontsize=14)
ax.legend(loc='lower left')
ax.set_xticks(range(1, 9))

# Right: Decision logic
ax = axes[1]
ax.axis('off')

decision_text = """
DRIFT DETECTION LOGIC
=====================

For each monitoring window:
1. Collect ~50 samples with oracle labels
2. Compute residuals: Y - calibrator(S)
3. Test H0: mean(residual) = 0

ALERT LEVELS:
  GREEN  = CI includes 0
           Calibration still valid
           
  ORANGE = CI excludes 0 but small
           Monitor closely
           
  RED    = CI excludes 0 significantly
           Recalibrate immediately!

RULE OF THUMB:
  - 1 orange: Watch
  - 2 consecutive orange: Investigate  
  - 1 red: Recalibrate now
  
ACTIONS:
  1. Collect fresh oracle labels
  2. Refit calibrator on recent data
  3. Document what changed (model/users/world)
"""

ax.text(0.1, 0.95, decision_text, transform=ax.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print("\nIn this simulation:")
print("  Weeks 1-4: Calibration holds (green)")
print("  Week 5-6:  Drift begins, warnings appear (orange)")
print("  Week 7-8:  Significant drift, alerts triggered (red)")

In [None]:
# Demonstrate recalibration after drift detection

def recalibrate_with_recent_data(old_calibrator, recent_S, recent_Y, blend_weight=0.5):
    """
    Recalibrate using a blend of old predictions and new data.
    
    In practice, you might:
    - Use only recent data (if drift is severe)
    - Blend old and new (for gradual adaptation)
    - Use exponential decay weighting by time
    """
    # Option 1: Full recalibration on recent data only
    new_calibrator = IsotonicRegression(out_of_bounds='clip')
    new_calibrator.fit(recent_S, recent_Y)
    
    return new_calibrator

# Simulate: we detected drift at week 6, collect 100 fresh labels
np.random.seed(123)
oracle_samples = base_only[base_only['oracle_label'].notna()]
recent_batch = oracle_samples.sample(n=100, replace=True)

# Apply simulated drift to this batch (as if collected during drift period)
S_recent = recent_batch['judge_score'].values
Y_recent = recent_batch['oracle_label'].values
drift_amount = 0.06  # Week 6 drift
Y_recent_drifted = Y_recent - drift_amount * (S_recent - 0.5)
Y_recent_drifted = np.clip(Y_recent_drifted, 0, 1)

# Recalibrate
new_calibrator = recalibrate_with_recent_data(calibrator, S_recent, Y_recent_drifted)

# Compare old vs new calibration curves
fig, ax = plt.subplots(figsize=(10, 6))

S_grid = np.linspace(0, 1, 100)
Y_old = calibrator.predict(S_grid)
Y_new = new_calibrator.predict(S_grid)

ax.plot(S_grid, Y_old, 'b-', linewidth=2, label='Old calibration (pre-drift)')
ax.plot(S_grid, Y_new, 'r-', linewidth=2, label='New calibration (post-drift)')
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x')

ax.scatter(S_recent, Y_recent_drifted, alpha=0.3, s=20, c='red', label='Recent samples (drifted)')

ax.set_xlabel('Judge Score (S)', fontsize=12)
ax.set_ylabel('Calibrated Prediction', fontsize=12)
ax.set_title('Recalibration After Drift Detection', fontsize=14)
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

print("After recalibration:")
print("  - New calibrator learned the drifted S->Y relationship")
print("  - High S values now map to lower Y (reflecting new reality)")
print("  - Residual monitoring should return to green")

## 9. Using the CJE Library

Everything above was from scratch. For production, use `cje-eval` which provides:

- **AutoCal-R**: Isotonic/two-stage calibration with proper cross-fitting
- **OUA inference**: Confidence intervals that account for calibration uncertainty
- **Statistical tests**: Proper paired comparisons between policies
- **Forest plots**: Publication-ready visualizations

In [None]:
# Install CJE (uncomment for Colab)
# !pip install cje-eval

from cje import analyze_dataset

# Run CJE analysis - one line does everything we did manually above
results = analyze_dataset(
    fresh_draws_dir=str(DATA_DIR / "fresh_draws"),
    estimator="auto",
    verbose=False,
)

# Compare our manual estimates vs CJE library
print("Comparison: Manual vs CJE Library")
print("="*60)
print(f"{'Policy':<30} {'Manual':<12} {'CJE Library':<12}")
print("-"*60)

cje_estimates = dict(zip(results.metadata['target_policies'], results.estimates))
for policy in POLICIES:
    # Get our manual calibrated mean
    policy_data = all_data[all_data['policy'] == policy]
    manual_est = policy_data['calibrated_pred'].mean()
    cje_est = cje_estimates.get(policy, float('nan'))
    print(f"{policy:<30} {manual_est:.3f}       {cje_est:.3f}")

print("\nThe CJE library handles cross-fitting, uncertainty, and edge cases.")

In [None]:
# Forest plot with confidence intervals
fig = results.plot_estimates(figsize=(10, 6))
plt.title("Policy Comparison: CJE Calibrated Estimates with 95% CI", fontsize=14)
plt.tight_layout()
plt.show()

print("\nThe forest plot shows estimates with confidence intervals.")
print("Overlapping CIs suggest no significant difference between policies.")

In [None]:
# Statistical comparisons between policies
target_policies = results.metadata['target_policies']
best_idx = results.best_policy()
best_policy = target_policies[best_idx]

print(f"Best policy: {best_policy}")
print("="*65)

# Compare all to base
if 'base' in target_policies:
    base_idx = target_policies.index('base')
    print(f"\nComparisons vs base:\n")
    print(f"{'Policy':<30} {'Difference':<12} {'p-value':<10} {'Significant?'}")
    print("-"*65)
    
    for i, policy in enumerate(target_policies):
        if i == base_idx:
            print(f"{policy:<30} {'(baseline)':<12}")
            continue
        comp = results.compare_policies(i, base_idx)
        sig = "Yes" if comp['significant'] else "No"
        print(f"{policy:<30} {comp['difference']:+.3f}       {comp['p_value']:.3f}      {sig}")

## Summary

This notebook demonstrated the core CJE workflow:

1. **Load data** with S (cheap) and Y (expensive) labels
2. **Fit calibration** on an oracle slice (base policy)
3. **Compute residuals** to check transportability
4. **Identify failures** (e.g., policies where calibration doesn't transfer)
5. **Analyze patterns** (what fools the judge?)
6. **Detect drift** over time with residual monitoring
7. **Recalibrate** when drift is detected

### Key insight:
**Ship value, not memes.** Calibrate your cheap metrics against real outcomes, catch inversions early, and stop shipping behavior that scores high on vibes but tanks on outcomes.

### Learn more:
- [CJE Advanced](cje_advanced.ipynb) - IPS and DR modes for off-policy evaluation
- [GitHub](https://github.com/cimo-labs/cje) - Full documentation
- [Blog: Your AI Metrics Are Lying to You](https://www.cimolabs.com/blog/your-ai-metrics-are-lying-to-you) - The full story