# CJE (Causal Judge Evaluation) Demo

**Interactive demo using Arena 5K sample data (1000 prompts)**

This notebook demonstrates how to use CJE to evaluate LLM policies using judge scores and oracle labels. We'll walk through:

1. **Setup**: Install CJE and download sample data
2. **Inspect Data**: Understand the dataset structure
3. **Direct Mode**: Simple policy comparison (no logprobs needed!)
4. **IPS Mode**: Counterfactual estimates from logged data
5. **DR Mode**: Maximum accuracy with both logged + fresh data
6. **Policy Selection**: Statistical comparison with confidence intervals

---

## What is CJE?

CJE turns LLM-as-judge scores into causally interpretable estimates. Instead of naively averaging judge scores, CJE:
- **Calibrates** judge scores to match oracle labels (AutoCal-R)
- **Stabilizes** importance weights for off-policy evaluation (SIMCal)
- **Reports** confidence intervals that account for all sources of uncertainty

**Key insight**: Judge scores are correlational. CJE makes them causal.

**Three modes, increasing complexity:**
- **Direct**: Compare policies on eval set (simplest - no logprobs!)
- **IPS**: Counterfactual estimates from logged data (reuse existing logs)
- **DR**: Best of both worlds (doubly robust - most accurate)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cimo-labs/cje/blob/main/examples/cje_tutorial.ipynb)

## Step 1: Setup

Install CJE and download the Arena sample data.

**Note:** Colab comes with numpy 2.3+ pre-installed, which breaks scipy. We force-reinstall numpy 2.0.x first.

In [None]:
# Colab comes with numpy 2.3+ which breaks scipy
# Force install compatible numpy first
!pip install -q 'numpy>=2.0,<2.1' --force-reinstall

# Install latest CJE from PyPI with cache busting
!pip install --no-cache-dir --upgrade cje-eval

# Verify installation
import cje
print(f"✓ CJE version {cje.__version__} installed")

# Check numpy version
import numpy as np
print(f"✓ NumPy version {np.__version__}")

In [None]:
# Download Arena sample data from GitHub
import os
import urllib.request
from pathlib import Path

# Create data directory
DATA_DIR = Path("arena_sample")
DATA_DIR.mkdir(exist_ok=True)
(DATA_DIR / "fresh_draws").mkdir(exist_ok=True)

# Base URL for raw files on GitHub
BASE_URL = "https://raw.githubusercontent.com/cimo-labs/cje/main/examples/arena_sample"

# Download logged data
print("Downloading logged_data.jsonl...")
urllib.request.urlretrieve(
    f"{BASE_URL}/logged_data.jsonl",
    DATA_DIR / "logged_data.jsonl"
)
print("✓ Downloaded logged_data.jsonl")

# Download fresh draws (note: actual filenames have _responses.jsonl suffix)
fresh_draw_files = {
    "clone": "clone_responses.jsonl",
    "parallel_universe_prompt": "parallel_universe_prompt_responses.jsonl",
    "unhelpful": "unhelpful_responses.jsonl"
}

for policy, filename in fresh_draw_files.items():
    print(f"Downloading fresh_draws/{filename}...")
    urllib.request.urlretrieve(
        f"{BASE_URL}/fresh_draws/{filename}",
        DATA_DIR / "fresh_draws" / filename
    )
    print(f"✓ Downloaded {filename}")

print("\n✓ All data downloaded successfully!")
print(f"\nData location: {DATA_DIR.absolute()}")
print(f"Policies available: {list(fresh_draw_files.keys())}")

## Step 2: Inspect the Data

Let's look at what we're working with.

In [None]:
import json
import pandas as pd

# Load and inspect logged data
with open(DATA_DIR / "logged_data.jsonl") as f:
    logged_samples = [json.loads(line) for line in f]

print(f"Logged data: {len(logged_samples)} samples")
print(f"\nTarget policies: {list(logged_samples[0]['target_policy_logprobs'].keys())}")
print(f"\nExample sample:")
sample = logged_samples[0]
print(f"  Prompt: {sample['prompt'][:100]}...")
print(f"  Response: {sample['response'][:150]}...")
print(f"  Judge score: {sample['judge_score']}")
oracle_label = sample.get('oracle_label', 'N/A')
print(f"  Oracle label: {oracle_label}")
print(f"  Has logprobs: ✓")

# Check oracle coverage
n_with_oracle = sum(1 for s in logged_samples if s.get('oracle_label') is not None)
coverage = n_with_oracle / len(logged_samples)
print(f"\nOracle label coverage: {n_with_oracle}/{len(logged_samples)} ({coverage:.1%})")
print(f"→ AutoCal-R will use these {n_with_oracle} oracle labels to calibrate judge scores")

## Step 3: Mode 1 - Direct (On-Policy Evaluation)

**Use case**: Compare policies on a specific evaluation set (simplest!)

**Estimand**: "Which policy performs best on *this* prompt set?"

**How it works**:
1. Generate fresh responses from each policy on same prompts
2. Calibrate judge scores using oracle labels (AutoCal-R)
3. Average: V̂(π) = (1/m) Σ R_πi

**Key advantage**: No logprobs needed! Simplest mode for quick policy comparisons.

In [None]:
from cje import analyze_dataset

# Direct mode: Fresh draws only (learns calibration from oracle labels in fresh draws)
results_direct = analyze_dataset(
    fresh_draws_dir=str(DATA_DIR / "fresh_draws"),  # No logged data!
    estimator="auto",  # Auto-detects Direct mode
    verbose=True,
)

print("\n" + "="*70)
print("Direct Mode Results")
print("="*70)
print(f"Mode: {results_direct.metadata['mode']}")
print(f"Estimator: {results_direct.metadata['estimator']}")
print(f"Calibration: {results_direct.metadata.get('calibration', 'none')}")
print(f"Oracle coverage: {results_direct.metadata.get('oracle_coverage', 0):.1%}")
print()

# Show estimates with confidence intervals
policies = results_direct.metadata['target_policies']
print(f"{'Policy':<30} {'Estimate':<12} {'Std Error':<12} {'95% CI':<20}")
print("-" * 74)
for i, policy in enumerate(policies):
    est = results_direct.estimates[i]
    se = results_direct.standard_errors[i]
    ci_low = est - 1.96 * se
    ci_high = est + 1.96 * se
    print(f"{policy:<30} {est:>6.3f}       {se:>6.3f}       [{ci_low:.3f}, {ci_high:.3f}]")

## Step 4: Mode 2 - IPS (Importance Sampling)

**Use case**: Reuse logged data to estimate counterfactual performance

**Estimand**: "What would the KPI be if we deployed policy π' instead of π₀?"

**How it works**:
1. Calibrate judge scores → oracle-scale rewards (AutoCal-R)
2. Compute importance weights: W = π'(a|x) / π₀(a|x)
3. Stabilize weights with monotone projection (SIMCal)
4. Estimate: V̂(π') = (1/n) Σ W_i · R_i

In [None]:
# IPS mode: Logged data only (auto-selects calibrated-ips estimator)
results_ips = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    estimator="auto",  # Auto-detects IPS mode
    verbose=True,
)

print("\n" + "="*70)
print("IPS Results")
print("="*70)
print(f"Mode: {results_ips.metadata['mode']}")
print(f"Estimator: {results_ips.metadata['estimator']}")
print(f"Logprob coverage: {results_ips.metadata['mode_selection']['logprob_coverage']:.1%}")
print()

# Show estimates with confidence intervals
print(f"{'Policy':<30} {'Estimate':<12} {'Std Error':<12} {'95% CI':<20}")
print("-" * 74)
for i, policy in enumerate(policies):
    est = results_ips.estimates[i]
    se = results_ips.standard_errors[i]
    ci_low = est - 1.96 * se
    ci_high = est + 1.96 * se
    print(f"{policy:<30} {est:>6.3f}       {se:>6.3f}       [{ci_low:.3f}, {ci_high:.3f}]")

### Check IPS Diagnostics

**ESS (Effective Sample Size)** is the key diagnostic for IPS reliability:
- ESS ≥ 50%: Excellent overlap
- ESS ∈ [10%, 50%): Moderate (DR recommended)
- ESS < 10%: Poor overlap (switch to DR or regenerate)

In [None]:
# Check diagnostics for each policy
print("IPS Diagnostics")
print("="*70)

for policy in policies:
    ess = results_ips.diagnostics.ess_per_policy.get(policy, 0.0)
    
    # Traffic light assessment
    if ess >= 0.5:
        status = "✓ EXCELLENT"
    elif ess >= 0.1:
        status = "⚠ MODERATE (consider DR)"
    else:
        status = "✗ POOR (use DR or regenerate)"
    
    print(f"\n{policy}:")
    print(f"  ESS: {ess:.1%} {status}")
    
    # Show weight statistics if available
    max_weight = results_ips.diagnostics.max_weight_per_policy.get(policy)
    if max_weight:
        print(f"  Max weight: {max_weight:.2f}")

## Step 5: Mode 3 - DR (Doubly Robust)

**Use case**: Most accurate counterfactual estimates when you have fresh draws

**Estimand**: Same as IPS, but with better accuracy

**How it works**:
1. Everything from IPS mode
2. Train outcome model ĝ(S) on fresh draws
3. Combine: V̂_DR(π') = (1/n) Σ [W_i · (R_i - ĝ(S_i)) + ĝ(S_i)]

**Double robustness**: Consistent if *either* weights or outcome model is correct.

In [None]:
# DR mode: Logged data + fresh draws (auto-selects stacked-dr estimator)
results_dr = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    fresh_draws_dir=str(DATA_DIR / "fresh_draws"),
    estimator="auto",  # Auto-detects DR mode
    estimator_config={"parallel": False},  # Disable parallel for Colab
    verbose=True,
)

print("\n" + "="*70)
print("DR Results")
print("="*70)
print(f"Mode: {results_dr.metadata['mode']}")
print(f"Estimator: {results_dr.metadata['estimator']}")
print()

# Compare all three modes side-by-side
print(f"{'Policy':<30} {'Direct':<10} {'IPS':<10} {'DR':<10}")
print("-" * 60)
for i, policy in enumerate(policies):
    direct_est = results_direct.estimates[i]
    ips_est = results_ips.estimates[i]
    dr_est = results_dr.estimates[i]
    print(f"{policy:<30} {direct_est:>6.3f}     {ips_est:>6.3f}     {dr_est:>6.3f}")

print("\n" + "="*70)
print("Standard Error Comparison (IPS vs DR)")
print("="*70)
print(f"{'Policy':<30} {'IPS SE':<12} {'DR SE':<12} {'Improvement':<15}")
print("-" * 69)
for i, policy in enumerate(policies):
    ips_se = results_ips.standard_errors[i]
    dr_se = results_dr.standard_errors[i]
    
    # Check if DR improved standard error
    if dr_se < ips_se:
        improvement = f"↓ {(1 - dr_se/ips_se)*100:.0f}% SE"
    else:
        improvement = "(similar)"
    
    print(f"{policy:<30} {ips_se:>6.3f}       {dr_se:>6.3f}       {improvement:<15}")

In [None]:
import numpy as np

# Use DR results (most accurate)
policies = results_dr.metadata['target_policies']
estimates = np.array(results_dr.estimates)
std_errors = np.array(results_dr.standard_errors)

# Find best policy using built-in method
best_idx = results_dr.best_policy()
best_policy = policies[best_idx]
best_est = estimates[best_idx]
best_se = std_errors[best_idx]

print("Policy Selection (DR Mode)")
print("="*70)
print(f"\n🏆 Best policy: {best_policy}")
print(f"   Estimate: {best_est:.3f} ± {best_se:.3f}")
print(f"   95% CI: [{best_est - 1.96*best_se:.3f}, {best_est + 1.96*best_se:.3f}]")

# Compare against baseline (clone) using built-in comparison
baseline_policy = 'clone'
baseline_idx = policies.index(baseline_policy)

print(f"\n📊 Comparison to baseline ({baseline_policy}):")
print(f"   Baseline: {estimates[baseline_idx]:.3f} ± {std_errors[baseline_idx]:.3f}")
print()
print(f"{'Policy':<30} {'Delta':<12} {'SE(Δ)':<12} {'p-value':<12} {'Significant?':<15}")
print("-" * 81)

for i, policy in enumerate(policies):
    if i == baseline_idx:
        print(f"{policy:<30} {'(baseline)':<12} {'':<12} {'':<12} {'':<15}")
        continue
    
    # Use built-in comparison (uses influence functions for paired variance)
    comp = results_dr.compare_policies(i, baseline_idx)
    
    sig_text = "✓ Yes (p<0.05)" if comp['significant'] else "No"
    method_text = "(paired)" if comp['used_influence'] else "(indep.)"
    
    print(f"{policy:<30} {comp['difference']:+.3f}       "
          f"{comp['se_difference']:.3f} {method_text:<6} "
          f"{comp['p_value']:.3f}      {sig_text:<15}")

print("\n💡 Note: Using influence functions for paired comparisons when available.")
print("   This properly accounts for correlation between estimates on the same prompts.")

In [None]:
# Just evaluate the result to see a nice HTML table
results_dr

### Jupyter Auto-Display

In Jupyter notebooks (including Colab), results automatically display as formatted HTML tables when you evaluate them:

In [None]:
# Import matplotlib for plotting
import matplotlib.pyplot as plt

# Use the convenience method
fig = results_dr.plot_estimates()
plt.show()

print("\n💡 The .plot_estimates() method automatically extracts data from the result object")

### Quick Plotting with Convenience Method

You can also use the convenience method on `EstimationResult` for quick plotting:

In [None]:
# Import visualization functions from cje
from cje import plot_policy_estimates

# Plot DR results with all three policies
# Extract estimates as dictionary
estimates_dict = {policy: float(results_dr.estimates[i]) for i, policy in enumerate(policies)}
ses_dict = {policy: float(results_dr.standard_errors[i]) for i, policy in enumerate(policies)}

# Create forest plot with confidence intervals
fig = plot_policy_estimates(
    estimates=estimates_dict,
    standard_errors=ses_dict,
    figsize=(10, 5)
)

# Display
import matplotlib.pyplot as plt
plt.tight_layout()
plt.show()

print("\n✓ Forest plot shows point estimates with 95% confidence intervals")
print("  Green dot = best policy, Gray square would be baseline if provided")

## Step 7: Visualization

CJE provides built-in visualization functions to help understand and communicate results. You can import them directly from the main `cje` namespace:

**Available plot functions:**
- `plot_policy_estimates` - Forest plots with confidence intervals
- `plot_calibration_comparison` - Judge → oracle calibration curves  
- `plot_weight_dashboard_summary` - Weight diagnostics (ESS, concentration)
- `plot_weight_dashboard_detailed` - Per-policy weight analysis
- `plot_dr_dashboard` - Doubly robust diagnostics
- `plot_transport_audit` - Calibrator transportability tests (advanced)
- `plot_transport_comparison` - Multi-policy transport comparison (advanced)

Below we show a simple example with `plot_policy_estimates`.

## Summary: When to Use Each Mode

| Mode | Data Required | Estimand | Use When |
|------|--------------|----------|----------|
| **Direct** | Fresh draws (+ optional oracle labels) | Performance on eval set | Just want on-policy comparison, simplest! |
| **IPS** | Logged data + logprobs | Counterfactual deployment value | Have logged data, want off-policy estimates |
| **DR** | Logged data + fresh draws | Counterfactual (most accurate) | Have both logged and fresh data |

### Key Diagnostics to Check

- **Direct mode**: Oracle coverage (5-10% is often enough for calibration)
- **IPS mode**: ESS ≥ 10% (prefer ≥ 50%)
- **DR mode**: Standard errors typically lower than IPS

### Pro Tips

1. **Start with Direct mode**: Simplest, no logprobs needed, great for quick comparisons
2. **Use `estimator="auto"`**: CJE selects the best mode and estimator automatically
3. **Check diagnostics first**: Don't trust estimates if ESS < 10%
4. **Report confidence intervals**: CJE's CIs account for all uncertainty sources (oracle, calibration, sampling)
5. **DR > IPS when possible**: Doubly robust is more accurate and robust

---

## Next Steps

- **Documentation**: [CJE README](https://github.com/cimo-labs/cje)
- **Technical Details**: [Arena Experiment - Full Benchmarking](https://www.cimolabs.com/blog/arena-experiment)
- **Install locally**: `pip install cje-eval`
- **More examples**: See [examples/](https://github.com/cimo-labs/cje/tree/main/examples)

Questions? Open an issue on [GitHub](https://github.com/cimo-labs/cje/issues)!