# CJE (Causal Judge Evaluation) Demo

**Interactive demo using Arena 10K sample data**

This notebook demonstrates how to use CJE to evaluate LLM policies using judge scores and oracle labels. We'll walk through:

1. **Setup**: Install CJE and download sample data
2. **Understanding Modes**: IPS, DR, and Direct evaluation
3. **Policy Comparison**: Find the best policy with confidence intervals
4. **Diagnostics**: Check reliability with ESS and other metrics

---

## What is CJE?

CJE turns LLM-as-judge scores into causally interpretable estimates. Instead of naively averaging judge scores, CJE:
- **Calibrates** judge scores to match oracle labels (AutoCal-R)
- **Stabilizes** importance weights for off-policy evaluation (SIMCal)
- **Reports** confidence intervals that account for all sources of uncertainty

**Key insight**: Judge scores are correlational. CJE makes them causal.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cimo-labs/cje/blob/main/examples/cje_arena_demo.ipynb)

## Step 1: Setup

Install CJE and download the Arena sample data.

**Note:** Colab comes with numpy 2.3+ pre-installed, which breaks scipy. We force-reinstall numpy 2.0.x first.

In [None]:
# Colab comes with numpy 2.3+ which breaks scipy
# Force install compatible numpy first
!pip install -q 'numpy>=2.0,<2.1' --force-reinstall

# Install CJE from PyPI
!pip install -q cje-eval

# Verify installation
import cje
print(f"✓ CJE version {cje.__version__} installed")

# Check numpy version
import numpy as np
print(f"✓ NumPy version {np.__version__}")

In [None]:
# Download Arena sample data from GitHub
import os
import urllib.request
from pathlib import Path

# Create data directory
DATA_DIR = Path("arena_sample")
DATA_DIR.mkdir(exist_ok=True)
(DATA_DIR / "fresh_draws").mkdir(exist_ok=True)

# Base URL for raw files on GitHub
BASE_URL = "https://raw.githubusercontent.com/cimo-labs/cje/main/examples/arena_sample"

# Download logged data
print("Downloading logged_data.jsonl...")
urllib.request.urlretrieve(
    f"{BASE_URL}/logged_data.jsonl",
    DATA_DIR / "logged_data.jsonl"
)
print("✓ Downloaded logged_data.jsonl")

# Download fresh draws (note: actual filenames have _responses.jsonl suffix)
fresh_draw_files = {
    "clone": "clone_responses.jsonl",
    "parallel_universe_prompt": "parallel_universe_prompt_responses.jsonl",
    "unhelpful": "unhelpful_responses.jsonl"
}

for policy, filename in fresh_draw_files.items():
    print(f"Downloading fresh_draws/{filename}...")
    urllib.request.urlretrieve(
        f"{BASE_URL}/fresh_draws/{filename}",
        DATA_DIR / "fresh_draws" / filename
    )
    print(f"✓ Downloaded {filename}")

print("\n✓ All data downloaded successfully!")
print(f"\nData location: {DATA_DIR.absolute()}")
print(f"Policies available: {list(fresh_draw_files.keys())}")

## Step 2: Inspect the Data

Let's look at what we're working with.

In [None]:
import json
import pandas as pd

# Load and inspect logged data
with open(DATA_DIR / "logged_data.jsonl") as f:
    logged_samples = [json.loads(line) for line in f]

print(f"Logged data: {len(logged_samples)} samples")
print(f"\nTarget policies: {list(logged_samples[0]['target_policy_logprobs'].keys())}")
print(f"\nExample sample:")
sample = logged_samples[0]
print(f"  Prompt: {sample['prompt'][:100]}...")
print(f"  Response: {sample['response'][:150]}...")
print(f"  Judge score: {sample['judge_score']}")
print(f"  Oracle label: {sample['oracle_label']}")
print(f"  Has logprobs: ✓")

# Check oracle coverage
n_with_oracle = sum(1 for s in logged_samples if s.get('oracle_label') is not None)
coverage = n_with_oracle / len(logged_samples)
print(f"\nOracle label coverage: {n_with_oracle}/{len(logged_samples)} ({coverage:.1%})")
print(f"→ {coverage:.1%} coverage enables {'IPS/DR modes' if coverage >= 0.5 else 'calibration'}")

## Step 3: Mode 1 - IPS (Importance Sampling)

**Use case**: Reuse logged data to estimate counterfactual performance

**Estimand**: "What would the KPI be if we deployed policy π' instead of π₀?"

**How it works**:
1. Calibrate judge scores → oracle-scale rewards (AutoCal-R)
2. Compute importance weights: W = π'(a|x) / π₀(a|x)
3. Stabilize weights with monotone projection (SIMCal)
4. Estimate: V̂(π') = (1/n) Σ W_i · R_i

In [None]:
from cje import analyze_dataset

# IPS mode: Logged data only (auto-selects calibrated-ips estimator)
results_ips = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    estimator="auto",  # Auto-detects IPS mode
    verbose=True,
)

print("\n" + "="*70)
print("IPS Results")
print("="*70)
print(f"Mode: {results_ips.metadata['mode']}")
print(f"Estimator: {results_ips.metadata['estimator']}")
print(f"Logprob coverage: {results_ips.metadata['mode_selection']['logprob_coverage']:.1%}")
print()

# Show estimates with confidence intervals
policies = results_ips.metadata['target_policies']
print(f"{'Policy':<30} {'Estimate':<12} {'Std Error':<12} {'95% CI':<20}")
print("-" * 74)
for i, policy in enumerate(policies):
    est = results_ips.estimates[i]
    se = results_ips.standard_errors[i]
    ci_low = est - 1.96 * se
    ci_high = est + 1.96 * se
    print(f"{policy:<30} {est:>6.3f}       {se:>6.3f}       [{ci_low:.3f}, {ci_high:.3f}]")

### Check IPS Diagnostics

**ESS (Effective Sample Size)** is the key diagnostic for IPS reliability:
- ESS ≥ 50%: Excellent overlap
- ESS ∈ [10%, 50%): Moderate (DR recommended)
- ESS < 10%: Poor overlap (switch to DR or regenerate)

In [None]:
# Check diagnostics for each policy
print("IPS Diagnostics")
print("="*70)

for policy in policies:
    diag = results_ips.diagnostics.get(policy, {})
    ess = diag.get('ess_fraction', 0.0)
    
    # Traffic light assessment
    if ess >= 0.5:
        status = "✓ EXCELLENT"
    elif ess >= 0.1:
        status = "⚠ MODERATE (consider DR)"
    else:
        status = "✗ POOR (use DR or regenerate)"
    
    print(f"\n{policy}:")
    print(f"  ESS: {ess:.1%} {status}")
    
    # Show weight statistics
    if 'weights' in diag:
        weights_info = diag['weights']
        print(f"  Weight stats (after SIMCal):")
        print(f"    Min:    {weights_info.get('min', 0):.3f}")
        print(f"    Median: {weights_info.get('median', 0):.3f}")
        print(f"    Max:    {weights_info.get('max', 0):.3f}")

## Step 4: Mode 2 - DR (Doubly Robust)

**Use case**: Most accurate counterfactual estimates when you have fresh draws

**Estimand**: Same as IPS, but with better accuracy

**How it works**:
1. Everything from IPS mode
2. Train outcome model ĝ(S) on fresh draws
3. Combine: V̂_DR(π') = (1/n) Σ [W_i · (R_i - ĝ(S_i)) + ĝ(S_i)]

**Double robustness**: Consistent if *either* weights or outcome model is correct.

In [None]:
# DR mode: Logged data + fresh draws (auto-selects stacked-dr estimator)
results_dr = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    fresh_draws_dir=str(DATA_DIR / "fresh_draws"),
    estimator="auto",  # Auto-detects DR mode
    estimator_config={"parallel": False},  # Disable parallel for Colab
    verbose=True,
)

print("\n" + "="*70)
print("DR Results")
print("="*70)
print(f"Mode: {results_dr.metadata['mode']}")
print(f"Estimator: {results_dr.metadata['estimator']}")
print()

# Compare IPS vs DR
print(f"{'Policy':<30} {'IPS Est':<12} {'DR Est':<12} {'Improvement':<15}")
print("-" * 69)
for i, policy in enumerate(policies):
    ips_est = results_ips.estimates[i]
    dr_est = results_dr.estimates[i]
    
    ips_se = results_ips.standard_errors[i]
    dr_se = results_dr.standard_errors[i]
    
    # Check if DR improved standard error
    if dr_se < ips_se:
        improvement = f"↓ {(1 - dr_se/ips_se)*100:.0f}% SE"
    else:
        improvement = "(similar)"
    
    print(f"{policy:<30} {ips_est:>6.3f}       {dr_est:>6.3f}       {improvement:<15}")

### Check DR Orthogonality

**Orthogonality score**: E[W · (R - ĝ)] should be ≈ 0
- If CI contains 0: ✓ Good
- If CI excludes 0: ⚠ Weights or outcome model may be poor

In [None]:
print("DR Orthogonality Diagnostics")
print("="*70)

for policy in policies:
    diag = results_dr.diagnostics.get(policy, {})
    
    # Check orthogonality
    if 'orthogonality' in diag:
        ortho = diag['orthogonality']
        score = ortho.get('score', 0)
        ci_low = ortho.get('ci_lower', 0)
        ci_high = ortho.get('ci_upper', 0)
        
        # Check if CI contains zero
        contains_zero = ci_low <= 0 <= ci_high
        status = "✓ PASS" if contains_zero else "⚠ CHECK"
        
        print(f"\n{policy}:")
        print(f"  Orthogonality: {score:.4f} [{ci_low:.4f}, {ci_high:.4f}] {status}")
        
        if not contains_zero:
            print(f"  → CI does not contain 0. Consider:")
            print(f"    • Improving outcome model")
            print(f"    • Adding more fresh draws")
            print(f"    • Revisiting SIMCal settings")

## Step 5: Mode 3 - Direct (On-Policy Evaluation)

**Use case**: Compare policies on a specific evaluation set (non-counterfactual)

**Estimand**: "Which policy performs best on *this* prompt set?"

**How it works**:
1. Generate fresh responses from each policy on same prompts
2. Calibrate judge scores using oracle labels in fresh draws
3. Average: V̂(π) = (1/m) Σ R_πi

**Key difference**: No importance weighting (on-policy), not counterfactual.

In [None]:
# Direct mode: Fresh draws only (learns calibration from oracle labels in fresh draws)
results_direct = analyze_dataset(
    fresh_draws_dir=str(DATA_DIR / "fresh_draws"),  # No logged data!
    estimator="auto",  # Auto-detects Direct mode
    verbose=True,
)

print("\n" + "="*70)
print("Direct Mode Results")
print("="*70)
print(f"Mode: {results_direct.metadata['mode']}")
print(f"Estimator: {results_direct.metadata['estimator']}")
print(f"Calibration: {results_direct.metadata.get('calibration', 'none')}")
print(f"Oracle coverage: {results_direct.metadata.get('oracle_coverage', 0):.1%}")
print()

# Show all three modes side-by-side
print(f"{'Policy':<30} {'IPS':<10} {'DR':<10} {'Direct':<10}")
print("-" * 60)
for i, policy in enumerate(policies):
    print(f"{policy:<30} {results_ips.estimates[i]:>6.3f}     {results_dr.estimates[i]:>6.3f}     {results_direct.estimates[i]:>6.3f}")

## Step 6: Policy Selection and Comparison

Find the best policy and compare against a baseline using proper statistical inference.

In [None]:
import numpy as np

# Use DR results (most accurate)
estimates = np.array(results_dr.estimates)
std_errors = np.array(results_dr.standard_errors)

# Find best policy
best_idx = np.argmax(estimates)
best_policy = policies[best_idx]
best_est = estimates[best_idx]
best_se = std_errors[best_idx]

print("Policy Selection (DR Mode)")
print("="*70)
print(f"\n🏆 Best policy: {best_policy}")
print(f"   Estimate: {best_est:.3f} ± {best_se:.3f}")
print(f"   95% CI: [{best_est - 1.96*best_se:.3f}, {best_est + 1.96*best_se:.3f}]")

# Compare against baseline (clone)
baseline_idx = policies.index('clone')
baseline_est = estimates[baseline_idx]
baseline_se = std_errors[baseline_idx]

print(f"\n📊 Comparison to baseline ({policies[baseline_idx]}):")
print(f"   Baseline: {baseline_est:.3f} ± {baseline_se:.3f}")
print()
print(f"{'Policy':<30} {'Delta':<12} {'Std Error':<12} {'Significant?':<15}")
print("-" * 69)

for i, policy in enumerate(policies):
    if i == baseline_idx:
        print(f"{policy:<30} {'(baseline)':<12} {'':<12} {'':<15}")
        continue
    
    delta = estimates[i] - baseline_est
    # Standard error of difference (assuming independence)
    delta_se = np.sqrt(std_errors[i]**2 + baseline_se**2)
    
    # Z-test
    z_score = delta / delta_se
    significant = abs(z_score) > 1.96
    
    sig_text = "✓ Yes (p<0.05)" if significant else "No"
    delta_text = f"{delta:+.3f}"
    
    print(f"{policy:<30} {delta_text:<12} {delta_se:>6.3f}       {sig_text:<15}")

print("\nNote: Assuming independence between policies (conservative).")
print("For paired comparisons, use CJE's built-in contrast computation.")

## Summary: When to Use Each Mode

| Mode | Data Required | Estimand | Use When |
|------|--------------|----------|----------|
| **IPS** | Logged data + logprobs | Counterfactual deployment value | Have logged data, want off-policy estimates |
| **DR** | Logged data + fresh draws | Counterfactual (most accurate) | Have both logged and fresh data |
| **Direct** | Fresh draws (+ optional calibration data) | Performance on eval set | Just want on-policy comparison |

### Key Diagnostics to Check

- **IPS mode**: ESS ≥ 10% (prefer ≥ 50%)
- **DR mode**: Orthogonality CI contains 0
- **All modes**: Oracle coverage ≥ 85% for reliable calibration

### Pro Tips

1. **Use `estimator="auto"`**: CJE selects the best mode and estimator automatically
2. **Check diagnostics first**: Don't trust estimates if ESS < 10% or orthogonality fails
3. **Report confidence intervals**: CJE's CIs account for all uncertainty sources (oracle, calibration, sampling)
4. **DR > IPS when possible**: Doubly robust is more accurate and robust

---

## Next Steps

- **Documentation**: [CJE README](https://github.com/cimo-labs/cje)
- **Paper**: Coming soon with full technical details
- **Install locally**: `pip install cje-eval`
- **More examples**: See [examples/](https://github.com/cimo-labs/cje/tree/main/examples)

Questions? Open an issue on [GitHub](https://github.com/cimo-labs/cje/issues)!