# Multi-Agent Experiment (M1–M6)

This notebook runs **multi-agent and ablation configurations** on the same stratified
CUAD sample used in `03_baseline_calibration.ipynb`, then performs statistical
comparison against baselines (B1/B4).

**Configurations:**
- `M1`: Full multi-agent system (orchestrator + 3 specialists + validation via LangGraph)
- `M2`–`M5`: Reserved for ablation studies (not yet implemented)
- `M6`: Combined specialist prompts in a single agent (critical ablation: architecture vs prompting)

**Key hypotheses tested here:**

| ID | Hypothesis | Test |
|----|-----------|------|
| H1 | Multi-agent beats single-agent baselines | F2(M1) > F2(B1), McNemar p < 0.05 |
| H2 | Specialists help rare categories most | ΔF2_rare > ΔF2_common |
| H3 | Architecture matters, not just prompts | M1 > M6 significantly |
| H4 | Multi-agent produces auditable reasoning | Trace completeness > 90% |

**Pipeline:** Same as notebook 03 (crash-safe JSONL, resume support, full traceability).

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

MODEL_KEY = "claude-sonnet-4"       # Model key (see src/models/config.py)
EXPERIMENT_TYPE = "M1"              # M1=full-multiagent, M6=combined-prompts, M2-M5=reserved
SAMPLES_PER_TIER = 5               # Must match baseline runs for fair comparison
INCLUDE_NEGATIVE_SAMPLES = True
MAX_CONTRACT_CHARS = 100_000
TEMPERATURE = 0.0
MAX_TOKENS = 4096

# Path to baseline results for statistical comparison
BASELINE_RESULTS_DIR = "../experiments/results"

In [None]:
import sys, os, time, json, datetime
from collections import defaultdict
from pathlib import Path

sys.path.insert(0, "..")

from dotenv import load_dotenv
load_dotenv("../.env")

from src.models.config import get_model_config, ModelProvider

config = get_model_config(MODEL_KEY)

experiment_labels = {
    "M1": "multiagent",
    "M2": "ablation_no_validation",
    "M3": "ablation_single_specialist",
    "M4": "ablation_no_routing",
    "M5": "ablation_no_specialist_prompts",
    "M6": "combined_prompts",
}
assert EXPERIMENT_TYPE in experiment_labels, (
    f"Unknown EXPERIMENT_TYPE={EXPERIMENT_TYPE!r}. "
    f"Valid options: {list(experiment_labels)}"
)
experiment_label = experiment_labels[EXPERIMENT_TYPE]

print(f"Model:      {config.name} ({config.model_id})")
print(f"Provider:   {config.provider.value}")
print(f"Experiment: {EXPERIMENT_TYPE} ({experiment_label})")
print(f"Context:    {config.context_window:,} tokens")

# Verify provider connectivity
if config.provider == ModelProvider.OLLAMA:
    import urllib.request
    try:
        urllib.request.urlopen(f"{config.base_url or 'http://localhost:11434/v1'}/models")
        print("Ollama:     connected")
    except Exception as e:
        print(f"WARNING:    Ollama not reachable — {e}")
elif config.provider == ModelProvider.ANTHROPIC:
    assert os.getenv("ANTHROPIC_API_KEY"), "ANTHROPIC_API_KEY not set"
    print("API key:    set")
elif config.provider == ModelProvider.OPENAI:
    assert os.getenv("OPENAI_API_KEY"), "OPENAI_API_KEY not set"
    print("API key:    set")
elif config.provider == ModelProvider.GOOGLE:
    assert os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY"), "GEMINI_API_KEY not set"
    print("API key:    set")

## 1. Load and Sample CUAD Data

Identical stratified sampling as notebook 03 (`random.seed(42)`, same tier counts)
to ensure fair comparison against baselines.

In [None]:
from src.data.cuad_loader import CUADDataLoader, CATEGORY_TIERS
import random

random.seed(42)
loader = CUADDataLoader()
loader.load()
all_samples = list(loader)

print(f"Total samples: {len(all_samples):,}")
print(f"Contracts:     {len(loader.get_contracts())}")
print()

by_tier: dict[str, list] = defaultdict(list)
for s in all_samples:
    if len(s.contract_text) <= MAX_CONTRACT_CHARS:
        by_tier[s.tier].append(s)

for tier in ["common", "moderate", "rare"]:
    pos = sum(1 for s in by_tier[tier] if s.has_clause)
    neg = len(by_tier[tier]) - pos
    print(f"{tier:10s}: {len(by_tier[tier]):,} samples ({pos} pos, {neg} neg)")

In [None]:
selected = []
for tier in ["common", "moderate", "rare"]:
    tier_samples = by_tier[tier]
    positive = [s for s in tier_samples if s.has_clause]
    negative = [s for s in tier_samples if not s.has_clause]

    n_pos = min(SAMPLES_PER_TIER, len(positive))
    selected.extend(random.sample(positive, n_pos))

    if INCLUDE_NEGATIVE_SAMPLES and negative:
        n_neg = min(max(1, SAMPLES_PER_TIER // 2), len(negative))
        selected.extend(random.sample(negative, n_neg))

print(f"Selected {len(selected)} samples:\n")
for s in selected:
    info = f"{s.num_spans} spans" if s.has_clause else "no clause"
    print(f"  [{s.tier:8s}] {s.category:40s} ({info}) | {len(s.contract_text):,} chars")

## 2. Run Extraction

Extraction path depends on `EXPERIMENT_TYPE`:

- **M1**: Orchestrator routes to specialist agent (risk/temporal/IP) via LangGraph,
  then validation layer checks grounding. Full trace captured for H4.
- **M6**: Single agent with combined specialist prompts (all domain knowledge in one prompt).
  Tests whether multi-agent architecture provides benefit beyond prompt engineering.
- **M2–M5**: Reserved for ablation studies (raise error if selected).

Same crash-safe JSONL + resume logic as notebook 03.

In [None]:
from src.models import invoke_model as model_invoke
from src.models.diagnostics import ModelDiagnostics, TokenUsage
from src.evaluation.metrics import span_overlap, compute_jaccard, compute_grounding_rate

# ── M1: Multi-Agent Setup ──
_orchestrator = None
_m1_diagnostics = None

if EXPERIMENT_TYPE == "M1":
    from src.agents.base import AgentConfig
    from src.agents import (
        Orchestrator,
        RiskLiabilityAgent,
        TemporalRenewalAgent,
        IPCommercialAgent,
    )
    _m1_diagnostics = ModelDiagnostics()
    _risk_config = AgentConfig(name="risk_liability", model_key=MODEL_KEY, prompt_name="risk_liability")
    _temporal_config = AgentConfig(name="temporal_renewal", model_key=MODEL_KEY, prompt_name="temporal_renewal")
    _ip_config = AgentConfig(name="ip_commercial", model_key=MODEL_KEY, prompt_name="ip_commercial")
    _specialists = {
        "risk_liability": RiskLiabilityAgent(config=_risk_config, diagnostics=_m1_diagnostics),
        "temporal_renewal": TemporalRenewalAgent(config=_temporal_config, diagnostics=_m1_diagnostics),
        "ip_commercial": IPCommercialAgent(config=_ip_config, diagnostics=_m1_diagnostics),
    }
    _orchestrator = Orchestrator(
        specialists=_specialists,
        validation_agent=None,
        config=AgentConfig(name="orchestrator", model_key=MODEL_KEY),
    )
    print(f"M1 Orchestrator ready with {len(_specialists)} specialists (validation=None)")

# ── M6: Combined Prompts Setup ──
_m6_baseline = None

if EXPERIMENT_TYPE == "M6":
    from src.agents.base import AgentConfig
    from src.baselines.combined_prompts import COMBINED_PROMPT, CombinedPromptsBaseline
    _m6_baseline = CombinedPromptsBaseline(
        config=AgentConfig(name="combined_prompts", model_key=MODEL_KEY),
        diagnostics=None,  # We create a separate diagnostics below
    )
    print(f"M6 Combined Prompts baseline ready")

# ── M2–M5: Not yet implemented ──
if EXPERIMENT_TYPE in ("M2", "M3", "M4", "M5"):
    raise NotImplementedError(
        f"{EXPERIMENT_TYPE} ablation is not yet implemented. "
        f"Currently available: M1 (full multi-agent), M6 (combined prompts). "
        f"See CLAUDE.md for planned ablation definitions."
    )

# ── Run ID and file setup ──
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
run_id = f"{experiment_label}_{MODEL_KEY}_{timestamp}"

output_dir = Path("../experiments/results")
output_dir.mkdir(parents=True, exist_ok=True)
intermediate_path = output_dir / f"{run_id}_intermediate.jsonl"

print(f"Run ID:       {run_id}")
print(f"Intermediate: {intermediate_path}")

# ── Resume: load existing completed samples ──
results = []
completed_ids = set()
if intermediate_path.exists():
    with open(intermediate_path) as f:
        for line in f:
            if line.strip():
                rec = json.loads(line)
                completed_ids.add(rec["sample_id"])
                results.append(rec)
    print(f"Resuming:     {len(completed_ids)} samples already completed")
print()

# ── Diagnostics tracker ──
diagnostics = _m1_diagnostics if EXPERIMENT_TYPE == "M1" else ModelDiagnostics(experiment_id=run_id)
if _m6_baseline is not None:
    _m6_baseline.diagnostics = diagnostics

# ── Extraction loop ──
total = len(selected)
start_time = time.time()

for i, sample in enumerate(selected):
    if sample.id in completed_ids:
        print(f"[{i+1}/{total}] {sample.category} — SKIPPED (already done)")
        continue

    print(f"[{i+1}/{total}] {sample.category} ({sample.tier})...", end=" ", flush=True)

    try:
        t0 = time.time()
        trace_nodes = []  # For H4 trace completeness

        if EXPERIMENT_TYPE == "M1":
            # ── M1: Use Orchestrator (LangGraph) ──
            n_calls_before = len(diagnostics.calls)
            result = await _orchestrator.extract(
                contract_text=sample.contract_text,
                category=sample.category,
                question=sample.question,
            )
            raw_response = result.reasoning
            system_prompt = "M1 multi-agent (orchestrator → specialist → validation)"
            user_message = f"Category: {sample.category}\nQuestion: {sample.question}"

            # Aggregate usage from all calls made during this extraction
            recent_calls = diagnostics.calls[n_calls_before:]
            agg_input = sum(c.usage.input_tokens for c in recent_calls)
            agg_output = sum(c.usage.output_tokens for c in recent_calls)
            usage = type("Usage", (), {
                "input_tokens": agg_input,
                "output_tokens": agg_output,
                "cache_read_tokens": 0,
                "cache_creation_tokens": 0,
            })()

            # Capture trace nodes for H4
            trace_nodes = [c.agent_name for c in recent_calls]

        elif EXPERIMENT_TYPE == "M6":
            # ── M6: Combined Prompts (single-agent ablation) ──
            from src.baselines.combined_prompts import COMBINED_PROMPT
            system_prompt = None
            user_message = COMBINED_PROMPT.format(
                category=sample.category,
                contract_text=sample.contract_text,
                question=sample.question,
            )
            messages = [{"role": "user", "content": user_message}]

            raw_response, usage = await model_invoke(
                model_key=MODEL_KEY,
                messages=messages,
                system=system_prompt,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                diagnostics=diagnostics,
                agent_name="combined_prompts",
                category=sample.category,
            )

            # Parse M6 response (JSON or plaintext fallback)
            data = _m6_baseline.parse_json_response(raw_response)
            if data and "extracted_clauses" in data:
                result = _m6_baseline.result_from_dict(data, sample.category)
            else:
                result = _m6_baseline._parse_plaintext(raw_response, sample.category)

        elapsed = time.time() - t0

        # 4. Evaluate
        predicted_text = " ".join(result.extracted_clauses)
        has_prediction = len(result.extracted_clauses) > 0

        if sample.has_clause:
            if has_prediction:
                covers = any(
                    span_overlap(predicted_text, gt)
                    for gt in sample.ground_truth_spans
                )
                classification = "TP" if covers else "FN"
            else:
                classification = "FN"
        else:
            classification = "FP" if has_prediction else "TN"

        jacc = (
            compute_jaccard(predicted_text, sample.ground_truth)
            if sample.has_clause and has_prediction
            else (1.0 if not sample.has_clause and not has_prediction else 0.0)
        )
        grounding = (
            compute_grounding_rate(result.extracted_clauses, sample.contract_text)
            if has_prediction else 1.0
        )

        # 5. Build full traceable record
        record = {
            "sample_id": sample.id,
            "run_id": run_id,
            "timestamp": datetime.datetime.now().isoformat(),
            "model_key": MODEL_KEY,
            "model_id": config.model_id,
            "experiment_type": EXPERIMENT_TYPE,
            "experiment_label": experiment_label,
            "category": sample.category,
            "tier": sample.tier,
            "contract_title": sample.contract_title,
            "contract_chars": len(sample.contract_text),
            "input": {
                "system_prompt": system_prompt,
                "user_message_length": len(user_message),
                "question": sample.question,
            },
            "output": {
                "raw_response": raw_response,
                "parsed_clauses": result.extracted_clauses,
                "num_clauses": len(result.extracted_clauses),
                "reasoning": result.reasoning,
                "confidence": result.confidence,
            },
            "ground_truth": {
                "has_clause": sample.has_clause,
                "spans": sample.ground_truth_spans,
                "full_text": sample.ground_truth,
                "num_spans": sample.num_spans,
            },
            "evaluation": {
                "classification": classification,
                "jaccard": jacc,
                "grounding_rate": grounding,
            },
            "usage": {
                "input_tokens": getattr(usage, "input_tokens", 0),
                "output_tokens": getattr(usage, "output_tokens", 0),
                "cache_read_tokens": getattr(usage, "cache_read_tokens", 0),
                "cache_creation_tokens": getattr(usage, "cache_creation_tokens", 0),
                "latency_s": round(elapsed, 2),
            },
            "trace": {
                "nodes_visited": trace_nodes,
                "num_llm_calls": len(trace_nodes) if trace_nodes else 1,
            },
        }

        # 6. Append to JSONL immediately (crash-safe)
        with open(intermediate_path, "a") as f:
            f.write(json.dumps(record, default=str) + "\n")

        results.append(record)
        print(f"-> {classification} | {len(result.extracted_clauses)} clause(s) | J={jacc:.3f} | {elapsed:.1f}s")

    except Exception as e:
        print(f"-> ERROR: {e}")
        import traceback; traceback.print_exc()

total_time = time.time() - start_time
print(f"\nCompleted: {len(results)} total ({len(completed_ids)} resumed)")
print(f"Intermediate saved to: {intermediate_path}")
print(f"Total wall time: {total_time:.1f}s")

## 3. Evaluation Metrics

Same ContractEval definitions as notebook 03.

In [None]:
from src.evaluation.metrics import compute_f1, compute_f2, compute_precision, compute_recall

tp = sum(1 for r in results if r["evaluation"]["classification"] == "TP")
fp = sum(1 for r in results if r["evaluation"]["classification"] == "FP")
fn = sum(1 for r in results if r["evaluation"]["classification"] == "FN")
tn = sum(1 for r in results if r["evaluation"]["classification"] == "TN")

total_positive = tp + fn
laziness_count = sum(
    1 for r in results
    if r["evaluation"]["classification"] == "FN"
    and r["output"]["num_clauses"] == 0
)

precision = compute_precision(tp, fp)
recall = compute_recall(tp, fn)
f1 = compute_f1(tp, fp, fn)
f2 = compute_f2(tp, fp, fn)

jaccard_scores = [r["evaluation"]["jaccard"] for r in results if r["ground_truth"]["has_clause"]]
avg_jaccard = sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0
laziness_rate = laziness_count / total_positive if total_positive > 0 else 0

print(f"{'='*60}")
print(f"  {EXPERIMENT_TYPE} {experiment_label} — {MODEL_KEY}")
print(f"{'='*60}")
print(f"  Samples:       {len(results)}")
print(f"  TP: {tp}  FP: {fp}  FN: {fn}  TN: {tn}")
print()
print(f"  Precision:     {precision:.3f}")
print(f"  Recall:        {recall:.3f}")
print(f"  F1:            {f1:.3f}")
print(f"  F2:            {f2:.3f}")
print(f"  Avg Jaccard:   {avg_jaccard:.3f}")
print(f"  Laziness rate: {laziness_rate:.1%} ({laziness_count}/{total_positive})")
print()
print(f"  ContractEval reference (GPT-4.1):")
print(f"  F1=0.641  F2=0.678  Jaccard=0.472  Laziness=7.1%")

# Per-tier breakdown
print(f"\n{'='*70}")
print(f"  Per-Tier Breakdown")
print(f"{'='*70}")
print(f"  {'Tier':<10} {'TP':>4} {'FP':>4} {'FN':>4} {'TN':>4} {'F1':>7} {'F2':>7} {'Jaccard':>8}")
print(f"  {'-'*60}")

for tier in ["common", "moderate", "rare"]:
    tr = [r for r in results if r["tier"] == tier]
    t_tp = sum(1 for r in tr if r["evaluation"]["classification"] == "TP")
    t_fp = sum(1 for r in tr if r["evaluation"]["classification"] == "FP")
    t_fn = sum(1 for r in tr if r["evaluation"]["classification"] == "FN")
    t_tn = sum(1 for r in tr if r["evaluation"]["classification"] == "TN")
    t_f1 = compute_f1(t_tp, t_fp, t_fn)
    t_f2 = compute_f2(t_tp, t_fp, t_fn)
    t_jaccs = [r["evaluation"]["jaccard"] for r in tr if r["ground_truth"]["has_clause"]]
    t_jacc = sum(t_jaccs) / len(t_jaccs) if t_jaccs else 0
    print(f"  {tier:<10} {t_tp:>4} {t_fp:>4} {t_fn:>4} {t_tn:>4} {t_f1:>7.3f} {t_f2:>7.3f} {t_jacc:>8.3f}")

In [None]:
print(f"\n{'='*90}")
print(f"  Per-Sample Results")
print(f"{'='*90}")

for i, r in enumerate(results):
    cls = r["evaluation"]["classification"]
    ok = cls in ("TP", "TN")

    print(f"\n  [{i+1}] {'PASS' if ok else 'FAIL'} {cls} | {r['category']} ({r['tier']})")
    print(f"      Contract: {r['contract_title'][:60]}")
    print(f"      Question: {r['input']['question'][:80]}...")

    if r["ground_truth"]["has_clause"]:
        gt = r["ground_truth"]["full_text"][:120]
        print(f"      GT:   {gt}...")

    if r["output"]["num_clauses"] > 0:
        pred = r["output"]["parsed_clauses"][0][:120]
        print(f"      Pred: {pred}...")
    else:
        print(f"      Pred: (no clause extracted)")

    print(f"      Jaccard: {r['evaluation']['jaccard']:.3f} | "
          f"Grounding: {r['evaluation']['grounding_rate']:.1%} | "
          f"Tokens: {r['usage']['input_tokens']:,} in / {r['usage']['output_tokens']:,} out | "
          f"Time: {r['usage']['latency_s']:.1f}s")

    # Trace info (M1 only)
    if r.get("trace", {}).get("nodes_visited"):
        print(f"      Trace: {' -> '.join(r['trace']['nodes_visited'])}")

## 4. Model Diagnostics & Cost

In [None]:
diag_summary = diagnostics.summary()

print(f"Model Diagnostics ({MODEL_KEY})")
print("=" * 50)
print(f"API calls:       {diag_summary['total_calls']}")
print(f"Success rate:    {diag_summary['success_rate']:.0%}")
print(f"Input tokens:    {diag_summary['total_input_tokens']:,}")
print(f"Output tokens:   {diag_summary['total_output_tokens']:,}")
print(f"Total tokens:    {diag_summary['total_tokens']:,}")
print(f"Estimated cost:  ${diag_summary['total_cost_usd']:.4f}")
print(f"Avg latency:     {diag_summary['avg_latency_ms']:.0f} ms")
print(f"Total time:      {diag_summary['duration_seconds']:.1f} s")

if diag_summary["total_calls"] > 0:
    avg_in = diag_summary["total_input_tokens"] / diag_summary["total_calls"]
    avg_out = diag_summary["total_output_tokens"] / diag_summary["total_calls"]
    print(f"\nAvg tokens/call: {avg_in:,.0f} in / {avg_out:,.0f} out")

# Per-agent breakdown (useful for M1 to see specialist distribution)
if diag_summary.get("by_agent"):
    print(f"\nCalls by agent:")
    for agent, count in sorted(diag_summary["by_agent"].items()):
        print(f"  {agent:25s}: {count}")

## 5. Statistical Comparison Against Baselines

Load baseline results (B1/B4 from notebook 03) and run hypothesis tests:

- **H1**: McNemar + Cohen's d → M-variant vs B1 (F2 improvement)
- **H2**: Per-tier F2 comparison → rare vs common improvement
- **H3**: McNemar → M1 vs M6 (architecture vs prompts) — requires both M1 and M6 runs
- **H4**: Trace completeness from M1 records (check trace nodes present)

In [None]:
# ── Load baseline results ──
baseline_dir = Path(BASELINE_RESULTS_DIR)

def load_latest_run(label_prefix: str) -> tuple[dict | None, list[dict]]:
    """Find the most recent summary + intermediate files for a given baseline label."""
    summaries = sorted(baseline_dir.glob(f"{label_prefix}_*_summary.json"), reverse=True)
    if not summaries:
        return None, []
    summary_path = summaries[0]
    with open(summary_path) as f:
        summary = json.load(f)

    # Load per-sample intermediate
    inter_path = summary_path.with_name(summary_path.name.replace("_summary.json", "_intermediate.jsonl"))
    records = []
    if inter_path.exists():
        with open(inter_path) as f:
            for line in f:
                if line.strip():
                    records.append(json.loads(line))
    return summary, records

b1_summary, b1_records = load_latest_run("zero_shot")
b4_summary, b4_records = load_latest_run("cot")
m6_summary, m6_records = load_latest_run("combined_prompts")

# Also try loading M1 results if we're running M6 (for H3 comparison)
m1_summary, m1_records = load_latest_run("multiagent")

print("Loaded baseline results:")
for label, summ, recs in [
    ("B1 (zero_shot)", b1_summary, b1_records),
    ("B4 (cot)", b4_summary, b4_records),
    ("M6 (combined)", m6_summary, m6_records),
    ("M1 (multiagent)", m1_summary, m1_records),
]:
    if summ:
        m = summ["metrics"]
        print(f"  {label:20s}: {len(recs)} samples | F1={m['f1']:.3f} F2={m['f2']:.3f} J={m['avg_jaccard']:.3f}")
    else:
        print(f"  {label:20s}: not found")

In [None]:
from src.evaluation.statistical import (
    bootstrap_ci, mcnemar_test, wilcoxon_test,
    benjamini_hochberg, cohens_d, format_result,
)

# ── Helper: build paired outcome vectors from two sets of records ──
def build_paired_outcomes(records_a: list[dict], records_b: list[dict]):
    """Align two sets of records by sample_id and return paired binary outcomes."""
    by_id_a = {r["sample_id"]: r for r in records_a}
    by_id_b = {r["sample_id"]: r for r in records_b}
    shared_ids = sorted(set(by_id_a) & set(by_id_b))

    correct_a = [by_id_a[sid]["evaluation"]["classification"] in ("TP", "TN") for sid in shared_ids]
    correct_b = [by_id_b[sid]["evaluation"]["classification"] in ("TP", "TN") for sid in shared_ids]
    jaccard_a = [by_id_a[sid]["evaluation"]["jaccard"] for sid in shared_ids]
    jaccard_b = [by_id_b[sid]["evaluation"]["jaccard"] for sid in shared_ids]
    return correct_a, correct_b, jaccard_a, jaccard_b, shared_ids

print("=" * 70)
print("  HYPOTHESIS TESTS")
print("=" * 70)

p_values = []  # Collect for BH correction

# ── H1: Multi-agent vs B1 baseline ──
print(f"\n--- H1: {EXPERIMENT_TYPE} vs B1 (zero-shot baseline) ---")
if b1_records and results:
    correct_exp, correct_b1, jacc_exp, jacc_b1, shared = build_paired_outcomes(results, b1_records)
    print(f"  Paired samples: {len(shared)}")

    # McNemar test on binary correctness
    chi2, p_val = mcnemar_test(correct_exp, correct_b1)
    d = cohens_d(jacc_exp, jacc_b1)
    p_values.append(p_val)

    exp_acc = sum(correct_exp) / len(correct_exp)
    b1_acc = sum(correct_b1) / len(correct_b1)
    ci = bootstrap_ci([1 if c else 0 for c in correct_exp])

    print(format_result("Accuracy", exp_acc, ci=ci, baseline_value=b1_acc, p_value=p_val, effect_size=d))

    # Wilcoxon on Jaccard
    w_stat, w_p = wilcoxon_test(jacc_exp, jacc_b1)
    p_values.append(w_p)
    print(f"  Jaccard Wilcoxon: W={w_stat:.1f}, p={w_p:.4f}")
else:
    print("  SKIPPED: B1 results not found")

# ── H1b: Multi-agent vs B4 baseline ──
print(f"\n--- H1b: {EXPERIMENT_TYPE} vs B4 (chain-of-thought) ---")
if b4_records and results:
    correct_exp, correct_b4, jacc_exp, jacc_b4, shared = build_paired_outcomes(results, b4_records)
    print(f"  Paired samples: {len(shared)}")

    chi2, p_val = mcnemar_test(correct_exp, correct_b4)
    d = cohens_d(jacc_exp, jacc_b4)
    p_values.append(p_val)

    exp_acc = sum(correct_exp) / len(correct_exp)
    b4_acc = sum(correct_b4) / len(correct_b4)
    ci = bootstrap_ci([1 if c else 0 for c in correct_exp])

    print(format_result("Accuracy", exp_acc, ci=ci, baseline_value=b4_acc, p_value=p_val, effect_size=d))
else:
    print("  SKIPPED: B4 results not found")

# ── H2: Per-tier improvement (rare vs common) ──
print(f"\n--- H2: Per-tier improvement ({EXPERIMENT_TYPE} vs B1) ---")
if b1_records and results:
    by_id_exp = {r["sample_id"]: r for r in results}
    by_id_b1 = {r["sample_id"]: r for r in b1_records}
    shared_ids = set(by_id_exp) & set(by_id_b1)

    for tier in ["common", "moderate", "rare"]:
        tier_ids = [sid for sid in shared_ids if by_id_exp[sid]["tier"] == tier]
        if not tier_ids:
            print(f"  {tier}: no paired samples")
            continue
        tier_correct_exp = [by_id_exp[sid]["evaluation"]["classification"] in ("TP", "TN") for sid in tier_ids]
        tier_correct_b1 = [by_id_b1[sid]["evaluation"]["classification"] in ("TP", "TN") for sid in tier_ids]
        exp_rate = sum(tier_correct_exp) / len(tier_correct_exp)
        b1_rate = sum(tier_correct_b1) / len(tier_correct_b1)
        delta = exp_rate - b1_rate
        print(f"  {tier:10s}: {EXPERIMENT_TYPE}={exp_rate:.1%}  B1={b1_rate:.1%}  Δ={delta:+.1%}  (n={len(tier_ids)})")
else:
    print("  SKIPPED: B1 results not found")

# ── H3: M1 vs M6 (architecture vs prompts) ──
print(f"\n--- H3: M1 vs M6 (architecture vs prompts) ---")
# Use stored M1/M6 results if available
h3_m1_records = m1_records if EXPERIMENT_TYPE != "M1" else results
h3_m6_records = m6_records if EXPERIMENT_TYPE != "M6" else results

if h3_m1_records and h3_m6_records:
    correct_m1, correct_m6, jacc_m1, jacc_m6, shared = build_paired_outcomes(h3_m1_records, h3_m6_records)
    print(f"  Paired samples: {len(shared)}")

    chi2, p_val = mcnemar_test(correct_m1, correct_m6)
    d = cohens_d(jacc_m1, jacc_m6)
    p_values.append(p_val)

    m1_acc = sum(correct_m1) / len(correct_m1)
    m6_acc = sum(correct_m6) / len(correct_m6)
    print(format_result("M1 Accuracy", m1_acc, baseline_value=m6_acc, p_value=p_val, effect_size=d))

    if p_val < 0.05 and m1_acc > m6_acc:
        print("  => Architecture provides genuine benefit beyond prompting")
    elif p_val >= 0.05:
        print("  => No significant difference: multi-agent overhead may not be justified")
    else:
        print("  => M6 outperforms M1: combined prompts sufficient")
else:
    missing = []
    if not h3_m1_records:
        missing.append("M1")
    if not h3_m6_records:
        missing.append("M6")
    print(f"  SKIPPED: Need both M1 and M6 results (missing: {', '.join(missing)})")

# ── H4: Trace completeness (M1 only) ──
print(f"\n--- H4: Trace completeness ---")
# Check records that have trace info (M1 records)
trace_records = h3_m1_records or []
if trace_records:
    with_trace = [r for r in trace_records if r.get("trace", {}).get("nodes_visited")]
    completeness = len(with_trace) / len(trace_records) if trace_records else 0
    print(f"  Records with trace: {len(with_trace)} / {len(trace_records)}")
    print(f"  Trace completeness: {completeness:.1%} (target: > 90%)")
    print(f"  {'PASS' if completeness > 0.9 else 'FAIL'}: {'Meets' if completeness > 0.9 else 'Below'} 90% threshold")
else:
    print("  SKIPPED: No M1 records with trace data available")

# ── Multiple comparison correction ──
if p_values:
    print(f"\n--- Benjamini-Hochberg Correction ---")
    significant = benjamini_hochberg(p_values, alpha=0.05)
    for i, (p, sig) in enumerate(zip(p_values, significant)):
        print(f"  Test {i+1}: p={p:.4f} {'*' if sig else 'ns'}")
    print(f"  Significant after BH correction: {sum(significant)} / {len(significant)}")

## 6. Summary Comparison Table

In [None]:
# Build summary table across all available configs
configs = []

# Current run
configs.append({
    "config": EXPERIMENT_TYPE,
    "f1": f1, "f2": f2, "precision": precision, "recall": recall,
    "jaccard": avg_jaccard, "laziness": laziness_rate,
    "samples": len(results),
})

# Loaded baselines
for label, summ in [("B1", b1_summary), ("B4", b4_summary), ("M6", m6_summary), ("M1", m1_summary)]:
    if summ and label != EXPERIMENT_TYPE:  # Don't duplicate current run
        m = summ["metrics"]
        configs.append({
            "config": label,
            "f1": m["f1"], "f2": m["f2"],
            "precision": m["precision"], "recall": m["recall"],
            "jaccard": m["avg_jaccard"], "laziness": m["laziness_rate"],
            "samples": sum(m[k] for k in ["tp", "fp", "fn", "tn"]),
        })

# Sort: baselines first, then M-variants
order = {"B1": 0, "B4": 1, "M6": 2, "M1": 3, "M2": 4, "M3": 5, "M4": 6, "M5": 7}
configs.sort(key=lambda c: order.get(c["config"], 99))

print(f"{'='*85}")
print(f"  Cross-Configuration Comparison")
print(f"{'='*85}")
print(f"  {'Config':<8} {'N':>4} {'Prec':>7} {'Rec':>7} {'F1':>7} {'F2':>7} {'Jaccard':>8} {'Lazy':>7}")
print(f"  {'-'*75}")
for c in configs:
    marker = " <--" if c["config"] == EXPERIMENT_TYPE else ""
    print(f"  {c['config']:<8} {c['samples']:>4} {c['precision']:>7.3f} {c['recall']:>7.3f} "
          f"{c['f1']:>7.3f} {c['f2']:>7.3f} {c['jaccard']:>8.3f} {c['laziness']:>6.1%}{marker}")

## 7. Save Results

In [None]:
summary = {
    "run_id": run_id,
    "timestamp": datetime.datetime.now().isoformat(),
    "config": {
        "model_key": MODEL_KEY,
        "model_id": config.model_id,
        "provider": config.provider.value,
        "experiment_type": EXPERIMENT_TYPE,
        "experiment_label": experiment_label,
        "samples_per_tier": SAMPLES_PER_TIER,
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
        "max_contract_chars": MAX_CONTRACT_CHARS,
        "include_negative": INCLUDE_NEGATIVE_SAMPLES,
    },
    "metrics": {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "f2": f2,
        "avg_jaccard": avg_jaccard,
        "laziness_rate": laziness_rate,
        "tp": tp, "fp": fp, "fn": fn, "tn": tn,
    },
    "per_tier": {},
    "samples": [],
    "diagnostics": diag_summary,
    "intermediate_file": str(intermediate_path),
}

# Per-tier metrics
for tier in ["common", "moderate", "rare"]:
    tr = [r for r in results if r["tier"] == tier]
    t_tp = sum(1 for r in tr if r["evaluation"]["classification"] == "TP")
    t_fp = sum(1 for r in tr if r["evaluation"]["classification"] == "FP")
    t_fn = sum(1 for r in tr if r["evaluation"]["classification"] == "FN")
    t_tn = sum(1 for r in tr if r["evaluation"]["classification"] == "TN")
    t_jaccs = [r["evaluation"]["jaccard"] for r in tr if r["ground_truth"]["has_clause"]]
    summary["per_tier"][tier] = {
        "tp": t_tp, "fp": t_fp, "fn": t_fn, "tn": t_tn,
        "f1": compute_f1(t_tp, t_fp, t_fn),
        "f2": compute_f2(t_tp, t_fp, t_fn),
        "avg_jaccard": sum(t_jaccs) / len(t_jaccs) if t_jaccs else 0,
    }

# Compact per-sample view
for r in results:
    summary["samples"].append({
        "id": r["sample_id"],
        "category": r["category"],
        "tier": r["tier"],
        "classification": r["evaluation"]["classification"],
        "jaccard": r["evaluation"]["jaccard"],
        "grounding_rate": r["evaluation"]["grounding_rate"],
        "num_clauses_predicted": r["output"]["num_clauses"],
        "num_gt_spans": r["ground_truth"]["num_spans"],
        "input_tokens": r["usage"]["input_tokens"],
        "output_tokens": r["usage"]["output_tokens"],
        "latency_s": r["usage"]["latency_s"],
    })

# Save summary
summary_path = output_dir / f"{run_id}_summary.json"
with open(summary_path, "w") as f:
    json.dump(summary, f, indent=2, default=str)
print(f"Summary saved:      {summary_path}")

# Save diagnostics
diag_dir = Path("../experiments/diagnostics")
diag_dir.mkdir(parents=True, exist_ok=True)
diag_path = diag_dir / f"{run_id}_diagnostics.json"
diagnostics.export(diag_path)
print(f"Diagnostics saved:  {diag_path}")

# Remind about intermediate
print(f"Intermediate saved: {intermediate_path}")
print(f"\nTo inspect a single record:")
print(f"  head -1 {intermediate_path} | python -m json.tool")

## Next Steps

**Switch experiment** — change `EXPERIMENT_TYPE`:
```python
EXPERIMENT_TYPE = "M1"  # Full multi-agent (orchestrator + 3 specialists)
EXPERIMENT_TYPE = "M6"  # Combined prompts (architecture ablation)
# M2–M5 reserved for future ablations
```

**Workflow for complete comparison:**
1. Run notebook 03 with `B1` and `B4` to establish baselines
2. Run this notebook with `M1` (core thesis contribution)
3. Run this notebook with `M6` (critical ablation)
4. Re-run the statistical comparison cell — it auto-loads all available results

**Output files:**
- `experiments/results/{run_id}_intermediate.jsonl` — full per-sample records
- `experiments/results/{run_id}_summary.json` — config + metrics + compact results
- `experiments/diagnostics/{run_id}_diagnostics.json` — raw API call log