# Baseline Calibration — Full Pipeline Traceability

This notebook runs single-agent baselines on a stratified CUAD sample
with **full pipeline traceability**: every step from input to evaluation is recorded.

**What gets recorded per sample:**
- **Input**: system prompt, user message (contract + question injected)
- **Raw output**: the complete model response before any parsing
- **Parsed output**: extracted clauses after baseline-specific parsing
- **Ground truth**: the labeled answer spans from CUAD
- **Evaluation**: TP/FP/FN/TN classification, Jaccard similarity, grounding rate
- **Usage**: token counts, latency, estimated cost

**Pipeline flow:**
```
Config → Load CUAD → Sample → Build prompt → Call model → Parse response → Evaluate → Save JSONL
```

**Crash-safe**: Each result is appended to a JSONL file immediately.
Re-running the cell skips already-completed samples (resume).

**Baselines:**
- `B1`: Zero-shot (ContractEval exact replication)
- `B4`: Chain-of-Thought

For multi-agent configurations (M1–M6), see `04_multiagent_experiment.ipynb`.

In [19]:
# ============================================================
# CONFIGURATION — Change these to switch model / baseline / sample size
# ============================================================

MODEL_KEY = "gpt-4.1-mini"       # Model key (see src/models/config.py for all options)
BASELINE_TYPE = "B4"                # B1=zero-shot, B4=chain-of-thought
SAMPLES_PER_TIER = 10               # Samples per tier (common/moderate/rare)
INCLUDE_NEGATIVE_SAMPLES = True    # Include samples where ground truth is empty
MAX_CONTRACT_CHARS = 100_000       # Skip contracts longer than this
TEMPERATURE = 0.0                  # Generation temperature
MAX_TOKENS = 4096                  # Max output tokens

In [20]:
import sys, os, time, json, datetime
from collections import defaultdict
from pathlib import Path

sys.path.insert(0, "..")

from dotenv import load_dotenv
load_dotenv("../.env")

from src.models.config import get_model_config, ModelProvider

config = get_model_config(MODEL_KEY)
baseline_labels = {"B1": "zero_shot", "B4": "cot"}
baseline_label = baseline_labels[BASELINE_TYPE]

print(f"Model:    {config.name} ({config.model_id})")
print(f"Provider: {config.provider.value}")
print(f"Baseline: {BASELINE_TYPE} ({baseline_label})")
print(f"Context:  {config.context_window:,} tokens")

# Verify provider connectivity
if config.provider == ModelProvider.OLLAMA:
    import urllib.request
    try:
        urllib.request.urlopen(f"{config.base_url or 'http://localhost:11434/v1'}/models")
        print("Ollama:   connected")
    except Exception as e:
        print(f"WARNING:  Ollama not reachable — {e}")
elif config.provider == ModelProvider.ANTHROPIC:
    assert os.getenv("ANTHROPIC_API_KEY"), "ANTHROPIC_API_KEY not set"
    print("API key:  set")
elif config.provider == ModelProvider.OPENAI:
    assert os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY not set"
    print("API key:  set")
elif config.provider == ModelProvider.GOOGLE:
    assert os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"), "GEMINI_API_KEY not set"
    print("API key:  set")

Model:    GPT 4.1 Mini (gpt-4.1-mini)
Provider: openai
Baseline: B4 (cot)
Context:  1,047,576 tokens
API key:  set


## 1. Load and Sample CUAD Data

Stratified sampling: a few from each tier (common / moderate / rare),
including both positive (has clause) and negative (no clause) samples.

In [21]:
from src.data.cuad_loader import CUADDataLoader, CATEGORY_TIERS
import random

random.seed(42)
loader = CUADDataLoader()
loader.load()
all_samples = list(loader)

print(f"Total samples: {len(all_samples):,}")
print(f"Contracts:     {len(loader.get_contracts())}")
print()

by_tier: dict[str, list] = defaultdict(list)
for s in all_samples:
    if len(s.contract_text) <= MAX_CONTRACT_CHARS:
        by_tier[s.tier].append(s)

for tier in ["common", "moderate", "rare"]:
    pos = sum(1 for s in by_tier[tier] if s.has_clause)
    neg = len(by_tier[tier]) - pos
    print(f"{tier:10s}: {len(by_tier[tier]):,} samples ({pos} pos, {neg} neg)")

Total samples: 20,910
Contracts:     510

common    : 2,574 samples (2278 pos, 296 neg)
moderate  : 7,722 samples (2050 pos, 5672 neg)
rare      : 7,293 samples (812 pos, 6481 neg)


In [22]:
selected = []
for tier in ["common", "moderate", "rare"]:
    tier_samples = by_tier[tier]
    positive = [s for s in tier_samples if s.has_clause]
    negative = [s for s in tier_samples if not s.has_clause]

    n_pos = min(SAMPLES_PER_TIER, len(positive))
    selected.extend(random.sample(positive, n_pos))

    if INCLUDE_NEGATIVE_SAMPLES and negative:
        n_neg = min(max(1, SAMPLES_PER_TIER // 2), len(negative))
        selected.extend(random.sample(negative, n_neg))

print(f"Selected {len(selected)} samples:\n")
for s in selected:
    info = f"{s.num_spans} spans" if s.has_clause else "no clause"
    print(f"  [{s.tier:8s}] {s.category:40s} ({info}) | {len(s.contract_text):,} chars")

Selected 45 samples:

  [common  ] Agreement Date                           (1 spans) | 32,321 chars
  [common  ] Expiration Date                          (1 spans) | 6,341 chars
  [common  ] Parties                                  (7 spans) | 62,272 chars
  [common  ] Document Name                            (1 spans) | 48,253 chars
  [common  ] Parties                                  (4 spans) | 29,724 chars
  [common  ] Parties                                  (2 spans) | 33,577 chars
  [common  ] Governing Law                            (2 spans) | 70,631 chars
  [common  ] Agreement Date                           (1 spans) | 63,376 chars
  [common  ] Effective Date                           (1 spans) | 26,753 chars
  [common  ] Document Name                            (1 spans) | 59,532 chars
  [common  ] Governing Law                            (no clause) | 6,341 chars
  [common  ] Effective Date                           (no clause) | 6,341 chars
  [common  ] Expiration Date 

## 2. Run Extraction with Full Traceability

Each sample goes through:
1. **Build prompt** — baseline-specific system prompt + user message (with contract text and question injected)
2. **Call model** — raw API call capturing response text + token usage
3. **Parse response** — baseline-specific parser extracts clauses from raw response
4. **Evaluate** — classify as TP/FP/FN/TN, compute Jaccard, check grounding
5. **Save** — append full record to JSONL immediately (crash-safe)

The JSONL file enables **resume**: re-running skips already-completed samples.

In [23]:
from src.models import invoke_model as model_invoke
from src.models.diagnostics import ModelDiagnostics, TokenUsage
from src.evaluation.metrics import span_overlap, compute_jaccard, compute_grounding_rate

# Import baseline prompts and parsers
from src.baselines.zero_shot import CONTRACTEVAL_PROMPT, ZeroShotBaseline
from src.baselines.chain_of_thought import COT_PROMPT, ChainOfThoughtBaseline


def build_messages(sample, baseline_type):
    """Build (system_prompt, user_message) for the selected baseline.

    This is exactly what each baseline's extract() method does internally,
    but exposed here so we can capture the raw input/output.
    """
    if baseline_type == "B1":
        system_prompt = CONTRACTEVAL_PROMPT
        user_msg = f"Context:\n{sample.contract_text}\n\nQuestion:\n{sample.question}"
        return system_prompt, user_msg
    elif baseline_type == "B4":
        system_prompt = None
        user_msg = COT_PROMPT.format(
            contract_text=sample.contract_text,
            question=sample.question,
        )
        return system_prompt, user_msg
    else:
        raise ValueError(f"Unknown baseline type: {baseline_type}")


# Create parser instances (just for their parse methods, not for extraction)
_parsers = {
    "B1": ZeroShotBaseline(),
    "B4": ChainOfThoughtBaseline(),
}


def parse_response(raw_response, category, baseline_type):
    """Parse raw model response using the baseline-specific parser."""
    parser = _parsers[baseline_type]
    result = parser.parse_response(raw_response)
    result.category = category
    return result


# ── Run ID and file setup ──
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = f"{baseline_label}_{MODEL_KEY}_{timestamp}"

output_dir = Path("../experiments/results")
output_dir.mkdir(parents=True, exist_ok=True)
intermediate_path = output_dir / f"{run_id}_intermediate.jsonl"

print(f"Run ID:       {run_id}")
print(f"Intermediate: {intermediate_path}")

# ── Resume: load existing completed samples ──
results = []
completed_ids = set()
if intermediate_path.exists():
    with open(intermediate_path) as f:
        for line in f:
            if line.strip():
                rec = json.loads(line)
                completed_ids.add(rec["sample_id"])
                results.append(rec)
    print(f"Resuming:     {len(completed_ids)} samples already completed")
print()

# ── Diagnostics tracker ──
diagnostics = ModelDiagnostics(experiment_id=run_id)

# ── Extraction loop ──
total = len(selected)
start_time = time.time()

for i, sample in enumerate(selected):
    if sample.id in completed_ids:
        print(f"[{i+1}/{total}] {sample.category} — SKIPPED (already done)")
        continue

    print(f"[{i+1}/{total}] {sample.category} ({sample.tier})...", end=" ", flush=True)

    try:
        t0 = time.time()

        # Build prompt → call model → parse
        system_prompt, user_message = build_messages(sample, BASELINE_TYPE)
        messages = [{"role": "user", "content": user_message}]

        raw_response, usage = await model_invoke(
            model_key=MODEL_KEY,
            messages=messages,
            system=system_prompt,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            diagnostics=diagnostics,
            agent_name=baseline_label,
            category=sample.category,
        )

        result = parse_response(raw_response, sample.category, BASELINE_TYPE)
        elapsed = time.time() - t0

        # 4. Evaluate
        predicted_text = " ".join(result.extracted_clauses)
        has_prediction = len(result.extracted_clauses) > 0

        if sample.has_clause:
            if has_prediction:
                covers = any(
                    span_overlap(predicted_text, gt)
                    for gt in sample.ground_truth_spans
                )
                classification = "TP" if covers else "FN"
            else:
                classification = "FN"
        else:
            classification = "FP" if has_prediction else "TN"

        jacc = (
            compute_jaccard(predicted_text, sample.ground_truth)
            if sample.has_clause and has_prediction
            else (1.0 if not sample.has_clause and not has_prediction else 0.0)
        )
        grounding = (
            compute_grounding_rate(result.extracted_clauses, sample.contract_text)
            if has_prediction else 1.0
        )

        # 5. Build full traceable record
        record = {
            "sample_id": sample.id,
            "run_id": run_id,
            "timestamp": datetime.datetime.now().isoformat(),
            "model_key": MODEL_KEY,
            "model_id": config.model_id,
            "baseline_type": BASELINE_TYPE,
            "baseline_label": baseline_label,
            "category": sample.category,
            "tier": sample.tier,
            "contract_title": sample.contract_title,
            "contract_chars": len(sample.contract_text),
            "input": {
                "system_prompt": system_prompt,
                "user_message_length": len(user_message),
                "question": sample.question,
            },
            "output": {
                "raw_response": raw_response,
                "parsed_clauses": result.extracted_clauses,
                "num_clauses": len(result.extracted_clauses),
                "reasoning": result.reasoning,
                "confidence": result.confidence,
            },
            "ground_truth": {
                "has_clause": sample.has_clause,
                "spans": sample.ground_truth_spans,
                "full_text": sample.ground_truth,
                "num_spans": sample.num_spans,
            },
            "evaluation": {
                "classification": classification,
                "jaccard": jacc,
                "grounding_rate": grounding,
            },
            "usage": {
                "input_tokens": usage.input_tokens,
                "output_tokens": usage.output_tokens,
                "cache_read_tokens": getattr(usage, "cache_read_tokens", 0),
                "cache_creation_tokens": getattr(usage, "cache_creation_tokens", 0),
                "latency_s": round(elapsed, 2),
            },
        }

        # 6. Append to JSONL immediately (crash-safe)
        with open(intermediate_path, "a") as f:
            f.write(json.dumps(record, default=str) + "\n")

        results.append(record)
        print(f"-> {classification} | {len(result.extracted_clauses)} clause(s) | J={jacc:.3f} | {elapsed:.1f}s")

    except Exception as e:
        print(f"-> ERROR: {e}")
        import traceback; traceback.print_exc()

total_time = time.time() - start_time
print(f"\nCompleted: {len(results)} total ({len(completed_ids)} resumed)")
print(f"Intermediate saved to: {intermediate_path}")
print(f"Total wall time: {total_time:.1f}s")

Run ID:       cot_gpt-4.1-mini_20260216_012937
Intermediate: ../experiments/results/cot_gpt-4.1-mini_20260216_012937_intermediate.jsonl

[1/45] Agreement Date (common)... -> FN | 12 clause(s) | J=0.022 | 7.8s
[2/45] Expiration Date (common)... -> TP | 7 clause(s) | J=0.102 | 4.6s
[3/45] Parties (common)... -> TP | 26 clause(s) | J=0.005 | 31.8s
[4/45] Document Name (common)... -> TP | 13 clause(s) | J=0.010 | 16.2s
[5/45] Parties (common)... -> TP | 20 clause(s) | J=0.010 | 18.9s
[6/45] Parties (common)... -> TP | 20 clause(s) | J=0.004 | 11.0s
[7/45] Governing Law (common)... -> TP | 10 clause(s) | J=0.147 | 6.5s
[8/45] Agreement Date (common)... -> TP | 9 clause(s) | J=0.016 | 8.5s
[9/45] Effective Date (common)... -> TP | 15 clause(s) | J=0.183 | 4.6s
[10/45] Document Name (common)... -> TP | 25 clause(s) | J=0.002 | 18.0s
[11/45] Governing Law (common)... -> FP | 5 clause(s) | J=0.000 | 2.3s
[12/45] Effective Date (common)... -> FP | 7 clause(s) | J=0.000 | 5.8s
[13/45] Expiration 

## 3. Evaluation Metrics

Metrics follow ContractEval definitions:
- **TP**: Label not empty AND prediction covers ground truth span
- **TN**: Label empty AND model predicts nothing / "no related clause"
- **FP**: Label empty BUT model predicts non-empty clause
- **FN**: Label not empty BUT model misses (no prediction or doesn't cover span)
- **Laziness**: FN where model produced 0 clauses (said "no related clause" when one exists)

In [24]:
from src.evaluation.metrics import compute_f1, compute_f2, compute_precision, compute_recall

tp = sum(1 for r in results if r["evaluation"]["classification"] == "TP")
fp = sum(1 for r in results if r["evaluation"]["classification"] == "FP")
fn = sum(1 for r in results if r["evaluation"]["classification"] == "FN")
tn = sum(1 for r in results if r["evaluation"]["classification"] == "TN")

total_positive = tp + fn
laziness_count = sum(
    1 for r in results
    if r["evaluation"]["classification"] == "FN"
    and r["output"]["num_clauses"] == 0
)

precision = compute_precision(tp, fp)
recall = compute_recall(tp, fn)
f1 = compute_f1(tp, fp, fn)
f2 = compute_f2(tp, fp, fn)

jaccard_scores = [r["evaluation"]["jaccard"] for r in results if r["ground_truth"]["has_clause"]]
avg_jaccard = sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0
laziness_rate = laziness_count / total_positive if total_positive > 0 else 0

print(f"{'='*60}")
print(f"  {BASELINE_TYPE} {baseline_label} — {MODEL_KEY}")
print(f"{'='*60}")
print(f"  Samples:       {len(results)}")
print(f"  TP: {tp}  FP: {fp}  FN: {fn}  TN: {tn}")
print()
print(f"  Precision:     {precision:.3f}")
print(f"  Recall:        {recall:.3f}")
print(f"  F1:            {f1:.3f}")
print(f"  F2:            {f2:.3f}")
print(f"  Avg Jaccard:   {avg_jaccard:.3f}")
print(f"  Laziness rate: {laziness_rate:.1%} ({laziness_count}/{total_positive})")
print()
print(f"  ContractEval reference (GPT-4.1):")
print(f"  F1=0.641  F2=0.678  Jaccard=0.472  Laziness=7.1%")

# Per-tier breakdown
print(f"\n{'='*70}")
print(f"  Per-Tier Breakdown")
print(f"{'='*70}")
print(f"  {'Tier':<10} {'TP':>4} {'FP':>4} {'FN':>4} {'TN':>4} {'F1':>7} {'F2':>7} {'Jaccard':>8}")
print(f"  {'-'*60}")

for tier in ["common", "moderate", "rare"]:
    tr = [r for r in results if r["tier"] == tier]
    t_tp = sum(1 for r in tr if r["evaluation"]["classification"] == "TP")
    t_fp = sum(1 for r in tr if r["evaluation"]["classification"] == "FP")
    t_fn = sum(1 for r in tr if r["evaluation"]["classification"] == "FN")
    t_tn = sum(1 for r in tr if r["evaluation"]["classification"] == "TN")
    t_f1 = compute_f1(t_tp, t_fp, t_fn)
    t_f2 = compute_f2(t_tp, t_fp, t_fn)
    t_jaccs = [r["evaluation"]["jaccard"] for r in tr if r["ground_truth"]["has_clause"]]
    t_jacc = sum(t_jaccs) / len(t_jaccs) if t_jaccs else 0
    print(f"  {tier:<10} {t_tp:>4} {t_fp:>4} {t_fn:>4} {t_tn:>4} {t_f1:>7.3f} {t_f2:>7.3f} {t_jacc:>8.3f}")

  B4 cot — gpt-4.1-mini
  Samples:       45
  TP: 26  FP: 15  FN: 4  TN: 0

  Precision:     0.634
  Recall:        0.867
  F1:            0.732
  F2:            0.807
  Avg Jaccard:   0.123
  Laziness rate: 0.0% (0/30)

  ContractEval reference (GPT-4.1):
  F1=0.641  F2=0.678  Jaccard=0.472  Laziness=7.1%

  Per-Tier Breakdown
  Tier         TP   FP   FN   TN      F1      F2  Jaccard
  ------------------------------------------------------------
  common        9    5    1    0   0.750   0.833    0.050
  moderate     10    5    0    0   0.800   0.909    0.208
  rare          7    5    3    0   0.636   0.673    0.111


In [25]:
print(f"\n{'='*90}")
print(f"  Per-Sample Results")
print(f"{'='*90}")

for i, r in enumerate(results):
    cls = r["evaluation"]["classification"]
    ok = cls in ("TP", "TN")

    print(f"\n  [{i+1}] {'PASS' if ok else 'FAIL'} {cls} | {r['category']} ({r['tier']})")
    print(f"      Contract: {r['contract_title'][:60]}")
    print(f"      Question: {r['input']['question'][:80]}...")

    if r["ground_truth"]["has_clause"]:
        gt = r["ground_truth"]["full_text"][:120]
        print(f"      GT:   {gt}...")

    if r["output"]["num_clauses"] > 0:
        pred = r["output"]["parsed_clauses"][0][:120]
        print(f"      Pred: {pred}...")
    else:
        print(f"      Pred: (no clause extracted)")

    print(f"      Jaccard: {r['evaluation']['jaccard']:.3f} | "
          f"Grounding: {r['evaluation']['grounding_rate']:.1%} | "
          f"Tokens: {r['usage']['input_tokens']:,} in / {r['usage']['output_tokens']:,} out | "
          f"Time: {r['usage']['latency_s']:.1f}s")

    # Raw response preview
    raw = r["output"]["raw_response"][:200].replace("\n", " ")
    print(f"      Raw:  {raw}...")


  Per-Sample Results

  [1] FAIL FN | Agreement Date (common)
      Contract: TodosMedicalLtd_20190328_20-F_EX-4.10_11587157_EX-4.10_Marke
      Question: Highlight the parts (if any) of this contract related to "Agreement Date" that s...
      GT:   20t h day of December 2018...
      Pred: Step 1: Identify key concepts in the Question  
- "Agreement Date"  
- "date of the contract"  
- Any references to the ...
      Jaccard: 0.022 | Grounding: 0.0% | Tokens: 6,865 in / 676 out | Time: 7.8s
      Raw:  Step 1: Identify key concepts in the Question   - "Agreement Date"   - "date of the contract"   - Any references to the effective date or signing date of the Agreement  Step 2: Scan the Context for se...

  [2] PASS TP | Expiration Date (common)
      Contract: Freecook_20180605_S-1_EX-10.3_11233807_EX-10.3_Hosting Agree
      Question: Highlight the parts (if any) of this contract related to "Expiration Date" that ...
      GT:   Terms of the project: 12 weeks from February 8, 2018 t

## 4. Model Diagnostics & Cost

In [26]:
diag_summary = diagnostics.summary()

print(f"Model Diagnostics ({MODEL_KEY})")
print("=" * 50)
print(f"API calls:       {diag_summary['total_calls']}")
print(f"Success rate:    {diag_summary['success_rate']:.0%}")
print(f"Input tokens:    {diag_summary['total_input_tokens']:,}")
print(f"Output tokens:   {diag_summary['total_output_tokens']:,}")
print(f"Total tokens:    {diag_summary['total_tokens']:,}")
print(f"Estimated cost:  ${diag_summary['total_cost_usd']:.4f}")
print(f"Avg latency:     {diag_summary['avg_latency_ms']:.0f} ms")
print(f"Total time:      {diag_summary['duration_seconds']:.1f} s")

if diag_summary["total_calls"] > 0:
    avg_in = diag_summary["total_input_tokens"] / diag_summary["total_calls"]
    avg_out = diag_summary["total_output_tokens"] / diag_summary["total_calls"]
    print(f"\nAvg tokens/call: {avg_in:,.0f} in / {avg_out:,.0f} out")

Model Diagnostics (gpt-4.1-mini)
API calls:       45
Success rate:    100%
Input tokens:    336,410
Output tokens:   44,295
Total tokens:    380,705
Estimated cost:  $0.2054
Avg latency:     11285 ms
Total time:      508.0 s

Avg tokens/call: 7,476 in / 984 out


## 5. Save Summary

Saves three files:
1. **Intermediate JSONL** — one full record per sample (already saved during extraction)
2. **Summary JSON** — config, prompt, aggregate metrics, per-tier, compact per-sample view
3. **Diagnostics JSON** — raw API call log from ModelDiagnostics

In [27]:
summary = {
    "run_id": run_id,
    "timestamp": datetime.datetime.now().isoformat(),
    "config": {
        "model_key": MODEL_KEY,
        "model_id": config.model_id,
        "provider": config.provider.value,
        "baseline_type": BASELINE_TYPE,
        "baseline_label": baseline_label,
        "samples_per_tier": SAMPLES_PER_TIER,
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
        "max_contract_chars": MAX_CONTRACT_CHARS,
        "include_negative": INCLUDE_NEGATIVE_SAMPLES,
    },
    "prompt": {
        "system_prompt": results[0]["input"]["system_prompt"] if results else None,
        "template_name": baseline_label,
    },
    "metrics": {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "avg_jaccard": avg_jaccard,
        "laziness_rate": laziness_rate,
        "tp": tp, "fp": fp, "fn": fn, "tn": tn,
    },
    "per_tier": {},
    "samples": [],
    "diagnostics": diag_summary,
    "intermediate_file": str(intermediate_path),
}

# Per-tier metrics
for tier in ["common", "moderate", "rare"]:
    tr = [r for r in results if r["tier"] == tier]
    t_tp = sum(1 for r in tr if r["evaluation"]["classification"] == "TP")
    t_fp = sum(1 for r in tr if r["evaluation"]["classification"] == "FP")
    t_fn = sum(1 for r in tr if r["evaluation"]["classification"] == "FN")
    t_tn = sum(1 for r in tr if r["evaluation"]["classification"] == "TN")
    t_jaccs = [r["evaluation"]["jaccard"] for r in tr if r["ground_truth"]["has_clause"]]
    summary["per_tier"][tier] = {
        "tp": t_tp, "fp": t_fp, "fn": t_fn, "tn": t_tn,
        "f1": compute_f1(t_tp, t_fp, t_fn),
        "f2": compute_f2(t_tp, t_fp, t_fn),
        "avg_jaccard": sum(t_jaccs) / len(t_jaccs) if t_jaccs else 0,
    }

# Compact per-sample view (full data is in intermediate JSONL)
for r in results:
    summary["samples"].append({
        "id": r["sample_id"],
        "category": r["category"],
        "tier": r["tier"],
        "classification": r["evaluation"]["classification"],
        "jaccard": r["evaluation"]["jaccard"],
        "grounding_rate": r["evaluation"]["grounding_rate"],
        "num_clauses_predicted": r["output"]["num_clauses"],
        "num_gt_spans": r["ground_truth"]["num_spans"],
        "input_tokens": r["usage"]["input_tokens"],
        "output_tokens": r["usage"]["output_tokens"],
        "latency_s": r["usage"]["latency_s"],
    })

# Save summary
summary_path = output_dir / f"{run_id}_summary.json"
with open(summary_path, "w") as f:
    json.dump(summary, f, indent=2, default=str)
print(f"Summary saved:      {summary_path}")

# Save diagnostics
diag_dir = Path("../experiments/diagnostics")
diag_dir.mkdir(parents=True, exist_ok=True)
diag_path = diag_dir / f"{run_id}_diagnostics.json"
diagnostics.export(diag_path)
print(f"Diagnostics saved:  {diag_path}")

# Remind about intermediate
print(f"Intermediate saved: {intermediate_path}")
print(f"\nTo inspect a single record:")
print(f"  head -1 {intermediate_path} | python -m json.tool")

Summary saved:      ../experiments/results/cot_gpt-4.1-mini_20260216_012937_summary.json
Diagnostics saved:  ../experiments/diagnostics/cot_gpt-4.1-mini_20260216_012937_diagnostics.json
Intermediate saved: ../experiments/results/cot_gpt-4.1-mini_20260216_012937_intermediate.jsonl

To inspect a single record:
  head -1 ../experiments/results/cot_gpt-4.1-mini_20260216_012937_intermediate.jsonl | python -m json.tool


## Next Steps

**Switch model** — change `MODEL_KEY` in the config cell:
```python
# Local models (Ollama)
MODEL_KEY = "qwen3-4b"          # Qwen3 4B
MODEL_KEY = "qwen3-8b"          # Qwen3 8B
MODEL_KEY = "llama-3.1-8b"      # LLaMA 3.1 8B

# Proprietary (need API keys)
MODEL_KEY = "claude-sonnet-4"    # Claude Sonnet 4
MODEL_KEY = "gpt-4.1"           # GPT 4.1
MODEL_KEY = "gpt-4.1-mini"      # GPT 4.1 Mini
MODEL_KEY = "gemini-2.5-pro"    # Gemini 2.5 Pro
```

**Switch baseline** — change `BASELINE_TYPE`:
```python
BASELINE_TYPE = "B1"  # Zero-shot (ContractEval replication)
BASELINE_TYPE = "B4"  # Chain-of-Thought
```

For multi-agent configurations (M1–M6), see `04_multiagent_experiment.ipynb`.

**Scale up** — increase `SAMPLES_PER_TIER` or run full test set via `scripts/run_experiment.py`.

**Output files:**
- `experiments/results/{run_id}_intermediate.jsonl` — full per-sample records
- `experiments/results/{run_id}_summary.json` — config + metrics + compact results
- `experiments/diagnostics/{run_id}_diagnostics.json` — raw API call log