# CJE: Off-Policy Evaluation Methods

**IPS and DR modes for counterfactual policy evaluation**

This notebook covers:
1. **IPS Mode**: Importance sampling with logged data
2. **DR Mode**: Doubly robust estimation (IPS + outcome models)
3. **Diagnostics**: ESS, overlap, orthogonality
4. **When to use each method**

**Prerequisites**: Complete `cje_direct_mode_intro.ipynb` first!

## What is Off-Policy Evaluation?

**Direct Mode** answers: "Which policy is best *on this eval set*?"

**OPE** answers: "What would our *production metrics* be if we deployed policy π'?"

Key difference:
- Direct Mode: On-policy comparison (simple average)
- OPE: Counterfactual inference (reweight logged data)

**Use OPE when:**
- You have logged data from production/baseline
- You want to estimate deployment value before rolling out
- Your eval set may not match production distribution

## Step 1: Install and Download Data

**About the Arena Dataset**

We use a sample from the [LMSYS Chatbot Arena](https://huggingface.co/datasets/agie-ai/lmsys-chatbot_arena_conversations) dataset - real user conversations with human preference judgments. This dataset was used to validate CJE in our ablation studies.

**The Three Policies:**

1. **`clone`** (Baseline/Logging Policy)
   - The original Arena responses
   - This is π₀ - the policy that generated our logged data
   - Mean reward: ~0.76

2. **`parallel_universe_prompt`** (Target Policy)
   - Modified system prompt: "You are a helpful assistant from a parallel universe"
   - Tests whether a quirky prompt helps or hurts
   - Mean reward: ~0.76 (similar to baseline)
   - Good ESS (~40-50%) - reasonable overlap with baseline

3. **`unhelpful`** (Adversarial Target)
   - System prompt: "You are a very unhelpful assistant"
   - Intentionally bad responses for testing
   - Mean reward: ~0.14 (much worse, as expected)
   - Poor ESS (~5-10%) - demonstrates low overlap challenges

**Why this is a good OPE test case:**
- Real production-like data (Arena conversations)
- Mix of good overlap (parallel_universe) and poor overlap (unhelpful)
- Demonstrates when IPS works (parallel_universe) vs when DR is needed (unhelpful)
- Has oracle labels (human preferences) for calibration

**Data structure:**
- **Logged data**: Responses from `clone` policy with logprobs for all three policies
- **Fresh draws**: New responses generated from each target policy on the same prompts
- **Oracle labels**: Human preference judgments (subset of samples)

In [None]:
# Colab fix: Install compatible numpy first
!pip install -q 'numpy>=2.0,<2.1' --force-reinstall

# Install CJE
!pip install -q cje-eval

import cje
import numpy as np
print(f"✓ CJE version {cje.__version__}")
print(f"✓ NumPy version {np.__version__}")

In [None]:
import urllib.request
from pathlib import Path

# Create directories
DATA_DIR = Path("arena_sample")
FRESH_DRAWS_DIR = DATA_DIR / "fresh_draws"
FRESH_DRAWS_DIR.mkdir(parents=True, exist_ok=True)

BASE_URL = "https://raw.githubusercontent.com/cimo-labs/cje/main/examples/arena_sample"

# Download logged data (needed for OPE!)
print("Downloading logged_data.jsonl...")
urllib.request.urlretrieve(
    f"{BASE_URL}/logged_data.jsonl",
    DATA_DIR / "logged_data.jsonl"
)
print("✓ Downloaded logged_data.jsonl")

# Download fresh draws (for DR mode)
policies = {
    "clone": "clone_responses.jsonl",
    "parallel_universe_prompt": "parallel_universe_prompt_responses.jsonl",
    "unhelpful": "unhelpful_responses.jsonl"
}

for policy, filename in policies.items():
    print(f"Downloading {filename}...")
    urllib.request.urlretrieve(
        f"{BASE_URL}/fresh_draws/{filename}",
        FRESH_DRAWS_DIR / filename
    )
    print(f"✓ Downloaded {filename}")

print(f"\n✓ All data downloaded!")

## Step 2: Understand Logged Data Structure

For OPE, we need logged data with **importance weights**.

In [None]:
import json

# Load logged data
with open(DATA_DIR / "logged_data.jsonl") as f:
    logged_samples = [json.loads(line) for line in f]

print(f"Logged samples: {len(logged_samples)}")
print(f"\nExample logged sample:")
print(json.dumps(logged_samples[0], indent=2))

### Key Fields for OPE

**Required for IPS/DR:**
- `base_policy_logprob`: Log P(response | prompt, π₀)
  - This is the **logging policy** (what generated the data)
- `target_policy_logprobs`: {policy_name: log P(response | prompt, π')}
  - These are the **target policies** we want to evaluate

**Why logprobs?**
- Importance weight: W = P(a|x, π') / P(a|x, π₀) = exp(log π' - log π₀)
- Reweights logged data to match target policy distribution

**Also used:**
- `judge_score`: For calibration (judge → oracle)
- `oracle_label`: Ground truth for calibration

In [None]:
# Check logprob coverage
n_with_logprobs = sum(
    1 for s in logged_samples 
    if s.get('base_policy_logprob') is not None
    and s.get('target_policy_logprobs')
)

coverage = n_with_logprobs / len(logged_samples)
print(f"Logprob coverage: {n_with_logprobs}/{len(logged_samples)} ({coverage:.1%})")
print(f"Target policies: {list(logged_samples[0]['target_policy_logprobs'].keys())}")

# Check oracle coverage
n_with_oracle = sum(1 for s in logged_samples if s.get('oracle_label') is not None)
oracle_coverage = n_with_oracle / len(logged_samples)
print(f"\nOracle coverage: {n_with_oracle}/{len(logged_samples)} ({oracle_coverage:.1%})")
print(f"→ Used for calibrating judge scores to oracle scale")

## Step 3: IPS Mode (Importance Sampling)

**What IPS does:**
1. Compute importance weights: W_i = π'(a_i|x_i) / π₀(a_i|x_i)
2. Calibrate judge scores: S_i → R_i (AutoCal-R)
3. Stabilize weights: W_i → W_i^{cal} (SIMCal)
4. Estimate: V̂(π') = Σ W_i^{cal} · R_i / Σ W_i^{cal}

**Key assumption**: Overlap - logged policy must have positive probability where target policy does.

In [None]:
from cje import analyze_dataset

# IPS mode: logged data only
results_ips = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    estimator="auto",  # Auto-detects IPS mode
    verbose=True,
)

print("\n" + "="*70)
print("IPS Mode Results")
print("="*70)

In [None]:
# Display IPS estimates
policies = results_ips.metadata['target_policies']
estimates = results_ips.estimates
std_errors = results_ips.standard_errors

print(f"{'Policy':<35} {'Estimate':<12} {'Std Error':<12} {'95% CI':<20}")
print("-" * 79)
for i, policy in enumerate(policies):
    est = estimates[i]
    se = std_errors[i]
    ci_low = est - 1.96 * se
    ci_high = est + 1.96 * se
    print(f"{policy:<35} {est:>6.3f}       {se:>6.3f}       [{ci_low:.3f}, {ci_high:.3f}]")

### Check IPS Diagnostics: ESS

**ESS (Effective Sample Size)** is the most important IPS diagnostic.

ESS = (Σ W)² / Σ W²

**Interpretation:**
- ESS = 100%: Perfect overlap (all samples equally useful)
- ESS = 50%: Good overlap
- ESS = 10%: Minimum for reliable IPS
- ESS < 10%: Poor overlap - switch to DR or regenerate

**Why ESS matters:**
- Low ESS → high variance → wide confidence intervals
- A few large weights dominate the estimate
- Results become unreliable

In [None]:
print("IPS Diagnostics: Effective Sample Size")
print("="*70)

for policy in policies:
    ess = results_ips.diagnostics.ess_per_policy.get(policy, 0.0)
    max_weight = results_ips.diagnostics.max_weight_per_policy.get(policy, 0.0)
    
    # Status assessment
    if ess >= 0.5:
        status = "✓ EXCELLENT"
        advice = "IPS is reliable"
    elif ess >= 0.1:
        status = "⚠ MODERATE"
        advice = "Consider DR for better accuracy"
    else:
        status = "✗ POOR"
        advice = "Use DR or regenerate data"
    
    print(f"\n{policy}:")
    print(f"  ESS: {ess:.1%} {status}")
    print(f"  Max weight: {max_weight:.3f}")
    print(f"  → {advice}")

print("\n💡 ESS tells you how many 'effective' samples you have after reweighting.")
print("   ESS ≥ 50% is excellent, ESS ≥ 10% is acceptable, ESS < 10% is risky.")

## Step 4: DR Mode (Doubly Robust)

**What DR adds to IPS:**
- Trains outcome model: ĝ(S) predicts reward from judge score
- Uses fresh draws to fit model
- Combines IPS weights with outcome predictions

**DR estimator:**
V̂_DR(π') = (1/n) Σ [W_i · (R_i - ĝ(S_i)) + ĝ(S_i)]

**Why DR is better:**
1. **Double robustness**: Consistent if *either* weights or model is correct
2. **Variance reduction**: Model predictions reduce noise
3. **Better with poor overlap**: Works even when ESS is low

In [None]:
# DR mode: logged data + fresh draws
results_dr = analyze_dataset(
    logged_data_path=str(DATA_DIR / "logged_data.jsonl"),
    fresh_draws_dir=str(FRESH_DRAWS_DIR),
    estimator="auto",  # Auto-detects DR mode (uses stacked-dr)
    estimator_config={"parallel": False},  # Disable parallel for Colab
    verbose=True,
)

print("\n" + "="*70)
print("DR Mode Results")
print("="*70)

### Compare IPS vs DR

Let's see how DR improves on IPS.

In [None]:
print("Estimate Comparison: IPS vs DR")
print("="*70)
print(f"{'Policy':<35} {'IPS Est':<12} {'DR Est':<12} {'Difference'}")
print("-" * 71)
for i, policy in enumerate(policies):
    ips_est = results_ips.estimates[i]
    dr_est = results_dr.estimates[i]
    diff = dr_est - ips_est
    print(f"{policy:<35} {ips_est:>6.3f}       {dr_est:>6.3f}       {diff:+.4f}")

print("\n" + "="*70)
print("Standard Error Comparison: IPS vs DR")
print("="*70)
print(f"{'Policy':<35} {'IPS SE':<12} {'DR SE':<12} {'Improvement'}")
print("-" * 72)
for i, policy in enumerate(policies):
    ips_se = results_ips.standard_errors[i]
    dr_se = results_dr.standard_errors[i]
    
    if dr_se < ips_se:
        improvement = f"↓ {(1 - dr_se/ips_se)*100:.0f}%"
    else:
        improvement = "(similar)"
    
    print(f"{policy:<35} {ips_se:>6.4f}       {dr_se:>6.4f}       {improvement:<12}")

print("\n💡 DR typically has lower standard errors → narrower confidence intervals")
print("   The outcome model reduces variance, making estimates more precise.")

### Check DR Diagnostics: Orthogonality

**Orthogonality score**: E[W · (R - ĝ)]

This tests if the DR correction term has mean zero.

**What it means:**
- CI contains 0: ✓ Good - weights and model are compatible
- CI excludes 0: ⚠ Problem - weights or model may be misspecified

**Why it matters:**
- Orthogonality ensures √n-consistency
- Non-orthogonal DR can be biased

In [None]:
print("DR Orthogonality Diagnostics")
print("="*70)

if 'orthogonality_scores' in results_dr.metadata:
    ortho_scores = results_dr.metadata['orthogonality_scores']
    
    for policy in policies:
        if policy in ortho_scores:
            ortho = ortho_scores[policy]
            score = ortho.get('score', 0)
            ci_low = ortho.get('ci_lower', 0)
            ci_high = ortho.get('ci_upper', 0)
            
            # Check if CI contains zero
            contains_zero = ci_low <= 0 <= ci_high
            status = "✓ PASS" if contains_zero else "⚠ CHECK"
            
            print(f"\n{policy}:")
            print(f"  Score: {score:.5f}")
            print(f"  95% CI: [{ci_low:.5f}, {ci_high:.5f}]")
            print(f"  Status: {status}")
            
            if not contains_zero:
                print(f"  → CI excludes 0. Consider:")
                print(f"    • Checking outcome model fit (R² below)")
                print(f"    • Adding more fresh draws")
                print(f"    • Reviewing weight calibration")
else:
    print("\nOrthogonality scores not available for this estimator.")
    print("(Stacked-DR may not report per-base-estimator orthogonality)")

### Outcome Model Quality

Check how well the outcome model predicts rewards.

In [None]:
print("Outcome Model R² (out-of-fold)")
print("="*70)

diag = results_dr.diagnostics

if hasattr(diag, 'outcome_r2_range') and diag.outcome_r2_range:
    r2_min, r2_max = diag.outcome_r2_range
    print(f"\nR² range across policies: [{r2_min:.3f}, {r2_max:.3f}]")
    
    if r2_max >= 0.5:
        print("✓ Good outcome model fit")
    elif r2_max >= 0.2:
        print("⚠ Moderate fit - DR still helps but may not be optimal")
    else:
        print("✗ Poor fit - outcome model not capturing reward structure well")
else:
    print("\nOutcome model R² not available in diagnostics.")

print("\n💡 R² measures how much variance the outcome model explains.")
print("   Higher R² → better variance reduction → narrower confidence intervals.")

## Step 5: Method Comparison Summary

Let's compare all three methods side-by-side.

In [None]:
# Load Direct mode results for comparison
results_direct = analyze_dataset(
    fresh_draws_dir=str(FRESH_DRAWS_DIR),
    estimator="auto",
    verbose=False,
)

print("Method Comparison: Direct vs IPS vs DR")
print("="*90)
print(f"{'Policy':<30} {'Direct':<12} {'IPS':<12} {'DR':<12} {'Best Method'}")
print("-" * 90)

for i, policy in enumerate(policies):
    direct_est = results_direct.estimates[i]
    ips_est = results_ips.estimates[i]
    dr_est = results_dr.estimates[i]
    
    # Find lowest SE (best precision)
    direct_se = results_direct.standard_errors[i]
    ips_se = results_ips.standard_errors[i]
    dr_se = results_dr.standard_errors[i]
    
    best_method = ["Direct", "IPS", "DR"][np.argmin([direct_se, ips_se, dr_se])]
    
    print(f"{policy:<30} {direct_est:>6.3f}       {ips_est:>6.3f}       "
          f"{dr_est:>6.3f}       {best_method}")

print("\n" + "="*90)
print("Standard Errors (Lower is Better)")
print("="*90)
print(f"{'Policy':<30} {'Direct SE':<12} {'IPS SE':<12} {'DR SE':<12}")
print("-" * 66)

for i, policy in enumerate(policies):
    direct_se = results_direct.standard_errors[i]
    ips_se = results_ips.standard_errors[i]
    dr_se = results_dr.standard_errors[i]
    
    print(f"{policy:<30} {direct_se:>6.4f}       {ips_se:>6.4f}       {dr_se:>6.4f}")

### Interpretation

**Direct Mode:**
- Estimates performance *on the eval set*
- Not counterfactual (doesn't estimate production value)
- Good for: Policy comparison on fixed prompt distribution

**IPS Mode:**
- Estimates counterfactual deployment value
- Reweights logged data to match target policy
- Good for: When ESS ≥ 10-50% and no fresh draws available

**DR Mode:**
- Also estimates counterfactual deployment value
- Uses outcome model to reduce variance
- Good for: When you have fresh draws (almost always best!)

## Summary: Choosing Your OPE Method

### Decision Tree

```
Do you need counterfactual estimates?
├─ No → Use Direct Mode
│         (on-policy evaluation, simplest)
│
└─ Yes → Do you have logged data with logprobs?
          ├─ No → Use Direct Mode
          │        (can't do OPE without logprobs)
          │
          └─ Yes → Do you have fresh draws?
                   ├─ No → Use IPS Mode
                   │        (check ESS ≥ 10%!)
                   │
                   └─ Yes → Use DR Mode ✓
                            (best accuracy + robustness)
```

### Data Requirements

| Method | Logged Data | Logprobs | Fresh Draws | Oracle Labels |
|--------|------------|----------|-------------|---------------|
| **Direct** | ✗ | ✗ | ✓ | Optional (5-10%) |
| **IPS** | ✓ | ✓ | ✗ | Recommended (5-10%) |
| **DR** | ✓ | ✓ | ✓ | Recommended (5-10%) |

### Key Diagnostics Checklist

**IPS Mode:**
- [ ] ESS ≥ 10% for all policies (prefer ≥ 50%)
- [ ] Max weight not too large (< 100)
- [ ] Calibration RMSE reasonable

**DR Mode:**
- [ ] Orthogonality CI contains 0
- [ ] Outcome model R² ≥ 0.2 (prefer ≥ 0.5)
- [ ] Standard errors lower than IPS

### Common Issues

**Low ESS (< 10%):**
- Problem: Poor overlap between logging and target policies
- Solutions:
  1. Switch to DR mode (most robust)
  2. Regenerate data with better overlap
  3. Use a closer baseline policy for logging

**Non-orthogonal DR:**
- Problem: Outcome model or weights misspecified
- Solutions:
  1. Add more fresh draws (improves outcome model)
  2. Check calibration diagnostics
  3. Try different DR variant (e.g., oc-dr-cpo)

**High variance:**
- Problem: Not enough data or poor overlap
- Solutions:
  1. Collect more samples
  2. Use DR instead of IPS
  3. Use weight calibration (SIMCal, enabled by default)

### Pro Tips

1. **Always use `estimator="auto"`** - CJE picks the best method
2. **Check diagnostics first** - Don't trust estimates with bad diagnostics
3. **DR > IPS when possible** - More robust and accurate
4. **5-10% oracle coverage is enough** - Don't need labels for all samples
5. **Report confidence intervals** - CJE's CIs account for all uncertainty

---

## Next Steps

- **Documentation**: [GitHub README](https://github.com/cimo-labs/cje)
- **API Reference**: See docstrings in `analyze_dataset()`
- **Paper**: Coming soon with full technical details
- **Questions?**: [Open an issue](https://github.com/cimo-labs/cje/issues)

Happy evaluating! 🎉