# MirageBench Colab: Endogenous Pivot Theory on Real LLMs

This notebook operationalizes the **Tropical Endogenous Context Semiring** prediction on real LLM behavior:

- under context compression, models can keep answers coherent,
- while silently substituting the endogenous pivot,
- creating a **Validity Mirage**.

The notebook is structured for **Google Colab Pro (A100, 50GB+ RAM)** and includes:

1. Natural-language MirageBench task construction (12 tasks across incident/investment/narrative)
2. Black-box model evaluation (HF open models + API placeholders)
3. KV-cache surgery and neural pivot tracking
4. Divergence scaling probe up to `n = 1_000_000`
5. MirageBench packaging and export artifacts


## Notebook Roadmap

- **0. Setup & Installation**
- **1. Task Construction**
- **2. Black-Box Evaluation**
- **3. KV-Cache Surgery (GPU)**
- **4. Divergence Scaling**
- **5. MirageBench Packaging**
- **6. Summary & Success Criteria**

All figures and tables are saved under `/content/results/`.


In [5]:
# 0. Setup & Installation
# Colab Pro install cell. If packages are already present, this is fast.

!pip -q install -U     numpy pandas matplotlib seaborn scipy scikit-learn tqdm     transformers accelerate sentence-transformers datasets     openai anthropic huggingface_hub


In [6]:
# 0.2 Imports, reproducibility, and filesystem layout
from __future__ import annotations

import json
import math
import os
import random
import re
import textwrap
import time
import warnings
import zipfile
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial.distance import cosine
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.auto import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

sns.set_theme(style="whitegrid")
plt.rcParams["figure.dpi"] = 120
plt.rcParams["savefig.dpi"] = 180

RESULTS_ROOT = Path("/content/results")
RAW_DIR = RESULTS_ROOT / "raw"
FIG_DIR = RESULTS_ROOT / "figures"
for p in [RESULTS_ROOT, RAW_DIR, FIG_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print(f"Results root: {RESULTS_ROOT}")
print(f"Raw dir: {RAW_DIR}")
print(f"Figure dir: {FIG_DIR}")


Results root: /content/results
Raw dir: /content/results/raw
Figure dir: /content/results/figures


## 0.3 Runtime Check

This notebook is designed for Colab Pro. For the KV-cache section, GPU is strongly recommended.


In [8]:
# 0.3 Runtime check
try:
    import torch
    print("Torch:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("GPU:", torch.cuda.get_device_name(0))
        print("GPU memory (GB):", round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2))
except Exception as exc:
    print("Torch check failed:", exc)


Torch: 2.9.0+cpu
CUDA available: False


In [None]:
# 0.4 Control Panel (edit only this cell for session planning)
# Preconfigured for Part 1 black-box evaluation on Llama 3.1 8B.

CONTROL = {
    # Part 2: Black-box evaluation
    "RUN_HF_MODELS": True,
    "RUN_API_MODELS": False,
    "HF_MODELS_TO_RUN": ["llama-3.1-8b-instruct"],  # keys from MODEL_SPECS
    "MAX_TASKS_PER_MODEL": 12,  # set 4 for a faster smoke test
    "COMPRESSION_LEVELS": [0.4, 0.5, 0.6],

    # Part 3: KV-cache surgery (disabled for now)
    "AUTO_LOAD_KV_MODEL": False,
    "KV_MODEL_NAME": "meta-llama/Llama-3.1-8B-Instruct",
    "KV_MAX_NEW_TOKENS": 160,
    "RUN_KV_EXPERIMENT": False,
    "KV_TASK_INDEX": 0,
    "KV_EVICTION_RATIO": 0.5,
    "RUN_KV_SWEEP": False,
    "KV_SWEEP_LEVELS": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],

    # Part 4: Divergence scaling (disabled for now)
    "RUN_DIVERGENCE": False,
    "NS": [1_000, 5_000, 10_000, 50_000, 100_000, 500_000, 1_000_000],
    "EPS": [0.3, 0.5, 0.7],
    "KS": [1, 2, 3, 5],
    "SEEDS_PER_CELL": 200,
    "CHUNK_SIZE": 10,
}

globals().update(CONTROL)

print("Control panel loaded. Active toggles:")
for key in [
    "RUN_HF_MODELS", "RUN_API_MODELS", "AUTO_LOAD_KV_MODEL",
    "RUN_KV_EXPERIMENT", "RUN_KV_SWEEP", "RUN_DIVERGENCE",
]:
    print(f"  {key} = {globals()[key]}")


# 1. Task Construction

We build **12 MirageBench tasks** (4 per category):

- **Category A: Incident triage**
- **Category B: Investment analysis**
- **Category C: Narrative analysis**

Each task contains:

- `full_context` (long context; target ~4k-8k tokens)
- `compressed_context` (40-60% token drop, biased toward low-salience setup)
- `pivot_ground_truth`
- `answer_ground_truth`
- `decoy_pivot`
- `decoy_answer`


In [None]:
# 1.1 Core data model + embedded tropical semiring core (from validated repo logic)
NEG_INF = float("-inf")

@dataclass(frozen=True)
class Event:
    eid: int
    timestamp: float
    weight: float
    actor: str
    is_focal: bool


@dataclass
class TropicalContext:
    k: int
    W: np.ndarray
    d_total: int

    @classmethod
    def empty(cls, k: int) -> "TropicalContext":
        return cls(k=k, W=np.full(k + 1, NEG_INF, dtype=float), d_total=0)

    @classmethod
    def from_event(cls, event: Event, k: int) -> "TropicalContext":
        W = np.full(k + 1, NEG_INF, dtype=float)
        if event.is_focal:
            W[0] = event.weight
            d_total = 0
        else:
            d_total = 1
        return cls(k=k, W=W, d_total=d_total)


def compose_tropical(left: TropicalContext, right: TropicalContext) -> TropicalContext:
    if left.k != right.k:
        raise ValueError("Cannot compose contexts with different k")

    k = left.k
    d_total = left.d_total + right.d_total
    W_new = np.full(k + 1, NEG_INF, dtype=float)

    # Left pivots keep slot index.
    np.maximum(W_new, left.W, out=W_new)

    # Right pivots shift by available development budget from left block.
    for x_b, w in enumerate(right.W):
        if np.isneginf(w):
            continue
        x_new = min(k, x_b + left.d_total)
        if w > W_new[x_new]:
            W_new[x_new] = w

    return TropicalContext(k=k, W=W_new, d_total=d_total)


def build_tropical_context(events: Sequence[Event], k: int) -> TropicalContext:
    acc = TropicalContext.empty(k)
    for event in events:
        acc = compose_tropical(acc, TropicalContext.from_event(event, k))
    return acc


@dataclass
class MirageBenchTask:
    task_id: str
    category: str  # "incident", "investment", "narrative"
    full_context: str
    compressed_context: str
    question: str
    pivot_ground_truth: str
    answer_ground_truth: str
    decoy_pivot: str
    decoy_answer: str
    k: int
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class MirageBenchResult:
    task_id: str
    model_name: str
    full_answer: str
    compressed_answer: str
    raw_validity: float  # compressed validity, 0-1
    pivot_preserved: bool
    semantic_regret: float  # 0-1
    compression_level: float = 0.5
    category: str = ""
    full_pivot: str = ""
    compressed_pivot: str = ""
    raw_validity_full: float = 0.0
    raw_validity_compressed: float = 0.0
    full_pivot_matches_ground_truth: bool = False
    compressed_pivot_matches_ground_truth: bool = False
    pivot_outcome: str = ""
    high_validity_flag: bool = False
    true_mirage_flag: bool = False
    rescue_flag: bool = False
    instability_flag: bool = False


In [None]:
# 1.2 Task synthesis helpers

def _long_note(rng: np.random.Generator, role: str, domain: str) -> str:
    opening = {
        "setup": [
            "operators logged a routine control-plane adjustment",
            "the team recorded a mundane baseline calibration",
            "an unremarkable handoff completed between on-call rotations",
            "a low-visibility maintenance annotation was filed",
        ],
        "pivot": [
            "cross-system coupling became explicit and destabilizing",
            "multiple subsystems synchronized into a high-impact transition",
            "latent dependencies converged into a decisive break",
            "the system crossed a threshold where all priors inverted",
        ],
        "decoy": [
            "the event looked dramatic on dashboards but was structurally secondary",
            "alerts peaked visually, yet this remained derivative",
            "stakeholders focused on this visible shock, although it was downstream",
            "this incident dominated discussion despite depending on prior setup",
        ],
        "routine": [
            "logs show repetitive housekeeping and status chatter",
            "the update was processed as ordinary background activity",
            "nothing in isolation appeared strategically dominant",
            "local checks passed with only minor variance",
        ],
    }
    mids = [
        "and the note references dependencies that only matter when read with the full timeline",
        "with commentary that seems low salience unless one tracks the global score rule",
        "while preserving details that become causal prerequisites for later interpretation",
        "and includes procedural details that summarizers often collapse away",
    ]
    tails = [
        "This detail is intentionally plain-language and easy to discard during compression.",
        "In local context it seems boring, but globally it changes which pivot is admissible.",
        "Its narrative footprint is small even though it shifts downstream feasibility.",
        "This is a structural breadcrumb rather than a headline event.",
    ]
    return f"{rng.choice(opening[role])} in {domain}, {rng.choice(mids)}. {rng.choice(tails)}"


def _render_context(
    preamble: str,
    records: List[Dict[str, Any]],
    rule_block: str,
    appendix_target_words: int,
    rng: np.random.Generator,
) -> Tuple[str, List[Dict[str, Any]], str]:
    # Render long context and record char spans for each event line.
    text = preamble.rstrip() + "\n\n"
    spans: List[Dict[str, Any]] = []

    for rec in records:
        line = rec["line"].rstrip() + "\n"
        start = len(text)
        text += line
        end = len(text)
        spans.append({
            "marker": rec["marker"],
            "role": rec["role"],
            "start": start,
            "end": end,
            "word_count": len(line.split()),
        })

    text += "\n" + rule_block.strip() + "\n"

    # Add a long low-salience appendix so contexts are genuinely long and vulnerable to truncation/summarization.
    appendix_lines: List[str] = []
    words_now = len(text.split())
    appendix_idx = 1
    while words_now < appendix_target_words:
        line = (
            f"Appendix note {appendix_idx:03d}: A review clerk documented peripheral status drift, "
            "ticket routing metadata, and coordination chatter that reads as low priority in isolation "
            "but can encode prerequisite state for later high-consequence events."
        )
        appendix_lines.append(line)
        words_now += len(line.split())
        appendix_idx += 1

    appendix_text = "\n".join(appendix_lines)
    if appendix_text:
        text += "\nOperational Appendix (intentionally low salience):\n"
        text += appendix_text + "\n"

    return text, spans, appendix_text


def _compress_records_to_target(
    task_layout: Dict[str, Any],
    target_drop_fraction: float,
    rng: np.random.Generator,
) -> Tuple[str, float, List[str]]:
    # Compression simulator: drop low-salience setup/routine records first, remove appendix by default.
    records = task_layout["records"]
    preamble = task_layout["preamble"]
    rule_block = task_layout["rule_block"]
    appendix_text = task_layout.get("appendix_text", "")

    full_text, _, _ = _render_context(
        preamble=preamble,
        records=records,
        rule_block=rule_block,
        appendix_target_words=len((preamble + rule_block).split()),
        rng=rng,
    )
    # Re-add appendix exactly as original for accurate denominator.
    if appendix_text:
        full_text += "\nOperational Appendix (intentionally low salience):\n" + appendix_text + "\n"

    full_words = len(full_text.split())
    target_words = int(round(full_words * (1.0 - target_drop_fraction)))

    protected = set(task_layout.get("protected_markers", []))

    role_priority = {
        "setup": 0,
        "routine": 1,
        "support": 2,
        "decoy_setup": 3,
        "candidate": 4,
        "pivot": 5,
    }

    candidates = [r for r in records if r["marker"] not in protected]
    rng.shuffle(candidates)
    candidates.sort(key=lambda r: role_priority.get(r["role"], 2))

    keep = {r["marker"]: r for r in records}

    # Start by dropping all appendix text (what many summarizers effectively do).
    compressed_words = len((preamble + "\n" + rule_block).split()) + sum(len(r["line"].split()) for r in records)

    for rec in candidates:
        if compressed_words <= target_words:
            break
        if rec["marker"] not in keep:
            continue
        del keep[rec["marker"]]
        compressed_words -= len(rec["line"].split())

    kept_records = [r for r in records if r["marker"] in keep]

    compressed_text, _, _ = _render_context(
        preamble=preamble,
        records=kept_records,
        rule_block=rule_block,
        appendix_target_words=len((preamble + rule_block).split()),
        rng=rng,
    )

    # No appendix in compressed variant.
    actual_drop = 1.0 - (len(compressed_text.split()) / max(1, full_words))
    removed_markers = [r["marker"] for r in records if r["marker"] not in keep]
    return compressed_text, actual_drop, removed_markers


def _build_question(category: str) -> str:
    if category == "incident":
        return (
            "The root cause is the incident with the highest composite severity score (severity * scope). "
            "Identify the root cause and list prerequisite events that enabled it. "
            "Begin the answer with `PIVOT_ID=<event_id>`."
        )
    if category == "investment":
        return (
            "The anchor position is the entry with the highest cumulative return. "
            "Analyze what market conditions preceded and enabled this position's success. "
            "Begin the answer with `PIVOT_ID=<event_id>`."
        )
    return (
        "Identify the story turning point: the single action with the highest consequence score. "
        "Explain what earlier events made it possible. "
        "Begin the answer with `PIVOT_ID=<event_id>`."
    )


In [None]:
# 1.3 Category builders (A incident, B investment, C narrative)

def build_incident_task(task_num: int, k: int = 3, target_words: int = 3600) -> MirageBenchTask:
    rng = np.random.default_rng(1000 + task_num)
    task_id = f"A{task_num:02d}"
    n_events = 88

    pivot_idx = int(rng.integers(54, 66))
    decoy_idx = int(min(n_events - 8, pivot_idx + rng.integers(8, 14)))
    true_setup_idx = [pivot_idx - 8, pivot_idx - 6, pivot_idx - 4, pivot_idx - 2]
    decoy_setup_idx = [decoy_idx - 6, decoy_idx - 4, decoy_idx - 2]

    services = ["auth", "ledger", "cache", "queue", "api-gateway", "billing", "search"]
    records: List[Dict[str, Any]] = []

    for i in range(n_events):
        marker = f"{task_id}-E{i+1:03d}"
        ts = f"2026-05-{(i % 27) + 1:02d} {8 + (i % 11):02d}:{(i * 7) % 60:02d}"
        service = services[i % len(services)]

        role = "routine"
        sev = int(rng.integers(2, 6))
        scope = int(rng.integers(2, 8))
        event_type = "Routine telemetry"

        if i in true_setup_idx:
            role = "setup"
            sev = 1
            scope = int(rng.integers(1, 3))
            event_type = "Config baseline update"
        elif i in decoy_setup_idx:
            role = "decoy_setup"
            sev = int(rng.integers(2, 4))
            scope = int(rng.integers(2, 5))
            event_type = "Regional failover rehearsal"
        elif i == pivot_idx:
            role = "pivot"
            sev = 9
            scope = 12
            event_type = "Cascade failure across service mesh"
        elif i == decoy_idx:
            role = "candidate"
            sev = 10
            scope = 8
            event_type = "Network partition across availability zones"

        composite = sev * scope
        note_role = "pivot" if role == "pivot" else "decoy" if role in {"candidate", "decoy_setup"} else role
        note = _long_note(rng, note_role if note_role in {"setup", "pivot", "decoy", "routine"} else "routine", "incident operations")

        line = (
            f"[{marker}] {ts} | Service={service} | Event={event_type} | "
            f"Severity={sev} | Scope={scope} | Composite={composite} | {note}"
        )
        records.append({"marker": marker, "role": role, "line": line, "sev": sev, "scope": scope, "composite": composite})

    pivot_marker = f"{task_id}-E{pivot_idx+1:03d}"
    decoy_marker = f"{task_id}-E{decoy_idx+1:03d}"
    true_setup_markers = [f"{task_id}-E{i+1:03d}" for i in true_setup_idx]
    decoy_setup_markers = [f"{task_id}-E{i+1:03d}" for i in decoy_setup_idx]

    preamble = (
        f"Incident Triage Dossier {task_id}\n"
        "You are reviewing a long forensic timeline. The timeline includes high-salience alerts and low-salience setup events.\n"
        "Interpretation rule: root cause selection is endogenous and depends on the global argmax over composite severity."
    )
    rule_block = "Rule reminder: root cause = event with max(Severity * Scope) over the entire timeline."

    full_context, spans, appendix_text = _render_context(
        preamble=preamble,
        records=records,
        rule_block=rule_block,
        appendix_target_words=target_words,
        rng=rng,
    )

    layout = {
        "preamble": preamble,
        "records": records,
        "rule_block": rule_block,
        "appendix_text": appendix_text,
        "spans": spans,
        "protected_markers": [pivot_marker, decoy_marker, *decoy_setup_markers],
        "candidate_markers": [pivot_marker, decoy_marker],
        "candidate_requirements": {
            pivot_marker: true_setup_markers,
            decoy_marker: decoy_setup_markers,
        },
        "pivot_setup_markers": true_setup_markers,
        "decoy_setup_markers": decoy_setup_markers,
    }

    compressed_context, actual_drop, removed_markers = _compress_records_to_target(layout, target_drop_fraction=0.50, rng=rng)

    question = _build_question("incident")
    answer_gt = (
        f"PIVOT_ID={pivot_marker}. The root cause is {pivot_marker}, because its composite severity is maximal. "
        f"Prerequisite chain: {true_setup_markers[0]} -> {true_setup_markers[1]} -> {true_setup_markers[2]} -> {pivot_marker}."
    )
    decoy_answer = (
        f"PIVOT_ID={decoy_marker}. A plausible but wrong chain is "
        f"{decoy_setup_markers[0]} -> {decoy_setup_markers[1]} -> {decoy_marker}."
    )

    layout["compression_default_drop"] = actual_drop
    layout["removed_markers_default"] = removed_markers

    return MirageBenchTask(
        task_id=task_id,
        category="incident",
        full_context=full_context,
        compressed_context=compressed_context,
        question=question,
        pivot_ground_truth=pivot_marker,
        answer_ground_truth=answer_gt,
        decoy_pivot=decoy_marker,
        decoy_answer=decoy_answer,
        k=k,
        metadata=layout,
    )


def build_investment_task(task_num: int, k: int = 3, target_words: int = 3600) -> MirageBenchTask:
    rng = np.random.default_rng(2000 + task_num)
    task_id = f"B{task_num:02d}"
    n_events = 84

    pivot_idx = int(rng.integers(50, 62))
    decoy_idx = int(min(n_events - 6, pivot_idx + rng.integers(7, 13)))
    true_setup_idx = [pivot_idx - 9, pivot_idx - 6, pivot_idx - 3, pivot_idx - 1]
    decoy_setup_idx = [decoy_idx - 5, decoy_idx - 3, decoy_idx - 1]

    positions = [
        "NorthRiver Utilities Carry",
        "Aurelia AI Semiconductor Basket",
        "Helios Grid Infrastructure",
        "BlueHarbor Treasury Arbitrage",
        "Cinder Logistics Credit",
    ]
    pivot_position_name = "Helios Grid Infrastructure"
    decoy_position_name = "Aurelia AI Semiconductor Basket"

    records: List[Dict[str, Any]] = []
    cumulative = {p: 0.0 for p in positions}
    pivot_ceiling: Optional[float] = None
    pivot_peer_margin = 0.8
    pivot_self_margin = 0.1

    for i in range(n_events):
        marker = f"{task_id}-E{i+1:03d}"
        wk = f"Week-{i+1:02d}"
        position = positions[i % len(positions)]

        role = "routine"
        weekly = float(rng.normal(0.8, 0.9))

        if i in true_setup_idx:
            role = "setup"
            weekly = float(rng.normal(0.2, 0.2))
        elif i in decoy_setup_idx:
            role = "decoy_setup"
            weekly = float(rng.normal(0.6, 0.3))
        elif i == pivot_idx:
            role = "pivot"
            position = pivot_position_name
            weekly = 5.4
        elif i == decoy_idx:
            role = "candidate"
            position = decoy_position_name
            weekly = 4.8

        cumulative[position] += weekly

        if i == pivot_idx:
            cumulative[position] = max(cumulative.values()) + 6.0
            pivot_ceiling = cumulative[position]
        if i == decoy_idx and pivot_ceiling is not None:
            cumulative[position] = min(
                max(v for k2, v in cumulative.items() if k2 != position) + 1.2,
                pivot_ceiling - pivot_peer_margin,
            )

        # Hard clamp post-pivot cumulative values so later entries cannot overtake the pivot.
        if pivot_ceiling is not None and i > pivot_idx:
            cap = pivot_ceiling - (pivot_self_margin if position == pivot_position_name else pivot_peer_margin)
            cumulative[position] = min(cumulative[position], cap)

        cum_val = cumulative[position]
        regime = int(rng.integers(1, 6))

        note_role = "pivot" if role == "pivot" else "decoy" if role in {"candidate", "decoy_setup"} else role
        note = _long_note(rng, note_role if note_role in {"setup", "pivot", "decoy", "routine"} else "routine", "portfolio research")

        line = (
            f"[{marker}] {wk} | Position={position} | WeeklyReturn={weekly:+.2f}% | "
            f"CumulativeReturn={cum_val:.2f}% | RegimeScore={regime} | {note}"
        )
        records.append({"marker": marker, "role": role, "line": line, "position": position, "cum": cum_val})

    pivot_marker = f"{task_id}-E{pivot_idx+1:03d}"
    decoy_marker = f"{task_id}-E{decoy_idx+1:03d}"
    true_setup_markers = [f"{task_id}-E{i+1:03d}" for i in true_setup_idx]
    decoy_setup_markers = [f"{task_id}-E{i+1:03d}" for i in decoy_setup_idx]

    max_cum_marker = max(records, key=lambda r: float(r["cum"]))["marker"]
    if max_cum_marker != pivot_marker:
        raise RuntimeError(
            f"Investment task {task_id} invalid: max cumulative marker {max_cum_marker} != pivot {pivot_marker}."
        )

    pivot_position = next(r["position"] for r in records if r["marker"] == pivot_marker)
    decoy_position = next(r["position"] for r in records if r["marker"] == decoy_marker)

    preamble = (
        f"Investment Committee Timeline {task_id}\n"
        "The portfolio diary contains noisy market commentary and low-salience condition markers.\n"
        "Interpretation rule: anchor analysis must follow the highest cumulative-return position."
    )
    rule_block = "Rule reminder: anchor position = entry with max(CumulativeReturn) in the full timeline."

    full_context, spans, appendix_text = _render_context(
        preamble=preamble,
        records=records,
        rule_block=rule_block,
        appendix_target_words=target_words,
        rng=rng,
    )

    layout = {
        "preamble": preamble,
        "records": records,
        "rule_block": rule_block,
        "appendix_text": appendix_text,
        "spans": spans,
        "protected_markers": [pivot_marker, decoy_marker, *decoy_setup_markers],
        "candidate_markers": [pivot_marker, decoy_marker],
        "candidate_requirements": {
            pivot_marker: true_setup_markers,
            decoy_marker: decoy_setup_markers,
        },
        "pivot_setup_markers": true_setup_markers,
        "decoy_setup_markers": decoy_setup_markers,
        "pivot_position": pivot_position,
        "decoy_position": decoy_position,
    }

    compressed_context, actual_drop, removed_markers = _compress_records_to_target(layout, target_drop_fraction=0.50, rng=rng)

    question = _build_question("investment")
    answer_gt = (
        f"PIVOT_ID={pivot_marker}. Anchor position is {pivot_position}. "
        f"Prerequisite market conditions are encoded in {true_setup_markers[0]}, {true_setup_markers[1]}, {true_setup_markers[2]} before {pivot_marker}."
    )
    decoy_answer = (
        f"PIVOT_ID={decoy_marker}. A coherent but wrong narrative centers {decoy_position} and cites "
        f"{decoy_setup_markers[0]}, {decoy_setup_markers[1]} as enabling conditions."
    )

    layout["compression_default_drop"] = actual_drop
    layout["removed_markers_default"] = removed_markers

    return MirageBenchTask(
        task_id=task_id,
        category="investment",
        full_context=full_context,
        compressed_context=compressed_context,
        question=question,
        pivot_ground_truth=pivot_marker,
        answer_ground_truth=answer_gt,
        decoy_pivot=decoy_marker,
        decoy_answer=decoy_answer,
        k=k,
        metadata=layout,
    )

def build_narrative_task(task_num: int, k: int = 3, target_words: int = 3600) -> MirageBenchTask:
    rng = np.random.default_rng(3000 + task_num)
    task_id = f"C{task_num:02d}"
    n_events = 80

    pivot_idx = int(rng.integers(48, 60))
    decoy_idx = int(min(n_events - 6, pivot_idx + rng.integers(8, 14)))
    true_setup_idx = [pivot_idx - 10, pivot_idx - 7, pivot_idx - 4, pivot_idx - 1]
    decoy_setup_idx = [decoy_idx - 6, decoy_idx - 3, decoy_idx - 1]

    characters = ["Mira", "Jonas", "Elio", "Sana", "Iris", "Cato"]
    places = ["market ward", "canal archive", "north gate", "assembly atrium", "river embankment"]

    records: List[Dict[str, Any]] = []

    for i in range(n_events):
        marker = f"{task_id}-E{i+1:03d}"
        scene = f"Scene-{i+1:02d}"
        actor = characters[i % len(characters)]
        place = places[i % len(places)]

        role = "routine"
        consequence = int(rng.integers(2, 12))
        action = "exchanged routine updates"

        if i in true_setup_idx:
            role = "setup"
            consequence = int(rng.integers(1, 4))
            action = "shared mundane logistical details"
        elif i in decoy_setup_idx:
            role = "decoy_setup"
            consequence = int(rng.integers(3, 7))
            action = "staged a visible confrontation"
        elif i == pivot_idx:
            role = "pivot"
            consequence = 24
            action = "released the sealed ledger proving council fraud"
        elif i == decoy_idx:
            role = "candidate"
            consequence = 19
            action = "challenged the council in a public square"

        note_role = "pivot" if role == "pivot" else "decoy" if role in {"candidate", "decoy_setup"} else role
        note = _long_note(rng, note_role if note_role in {"setup", "pivot", "decoy", "routine"} else "routine", "character dynamics")

        line = (
            f"[{marker}] {scene} | Actor={actor} | Location={place} | Action={action} | "
            f"ConsequenceScore={consequence} | {note}"
        )
        records.append({"marker": marker, "role": role, "line": line, "actor": actor, "consequence": consequence})

    pivot_marker = f"{task_id}-E{pivot_idx+1:03d}"
    decoy_marker = f"{task_id}-E{decoy_idx+1:03d}"
    true_setup_markers = [f"{task_id}-E{i+1:03d}" for i in true_setup_idx]
    decoy_setup_markers = [f"{task_id}-E{i+1:03d}" for i in decoy_setup_idx]

    pivot_actor = next(r["actor"] for r in records if r["marker"] == pivot_marker)
    decoy_actor = next(r["actor"] for r in records if r["marker"] == decoy_marker)

    preamble = (
        f"Narrative Consequence Ledger {task_id}\n"
        "This story timeline mixes dramatic beats and mundane setup scenes.\n"
        "Interpretation rule: turning point is the action with highest consequence score across the full story."
    )
    rule_block = "Rule reminder: turning point = argmax ConsequenceScore over all scenes."

    full_context, spans, appendix_text = _render_context(
        preamble=preamble,
        records=records,
        rule_block=rule_block,
        appendix_target_words=target_words,
        rng=rng,
    )

    layout = {
        "preamble": preamble,
        "records": records,
        "rule_block": rule_block,
        "appendix_text": appendix_text,
        "spans": spans,
        "protected_markers": [pivot_marker, decoy_marker, *decoy_setup_markers],
        "candidate_markers": [pivot_marker, decoy_marker],
        "candidate_requirements": {
            pivot_marker: true_setup_markers,
            decoy_marker: decoy_setup_markers,
        },
        "pivot_setup_markers": true_setup_markers,
        "decoy_setup_markers": decoy_setup_markers,
        "pivot_actor": pivot_actor,
        "decoy_actor": decoy_actor,
    }

    compressed_context, actual_drop, removed_markers = _compress_records_to_target(layout, target_drop_fraction=0.50, rng=rng)

    question = _build_question("narrative")
    answer_gt = (
        f"PIVOT_ID={pivot_marker}. Turning point is {pivot_marker} by {pivot_actor}. "
        f"Enabling setup beats are {true_setup_markers[0]}, {true_setup_markers[1]}, and {true_setup_markers[2]}."
    )
    decoy_answer = (
        f"PIVOT_ID={decoy_marker}. A plausible but wrong reading centers {decoy_actor}'s action at {decoy_marker}, "
        f"supported by {decoy_setup_markers[0]} and {decoy_setup_markers[1]}."
    )

    layout["compression_default_drop"] = actual_drop
    layout["removed_markers_default"] = removed_markers

    return MirageBenchTask(
        task_id=task_id,
        category="narrative",
        full_context=full_context,
        compressed_context=compressed_context,
        question=question,
        pivot_ground_truth=pivot_marker,
        answer_ground_truth=answer_gt,
        decoy_pivot=decoy_marker,
        decoy_answer=decoy_answer,
        k=k,
        metadata=layout,
    )


def _validate_investment_ground_truth(tasks: Sequence[MirageBenchTask]) -> None:
    bad: List[str] = []
    for task in tasks:
        if task.category != "investment":
            continue
        records = task.metadata.get("records", [])
        if not records:
            bad.append(f"{task.task_id}: missing records")
            continue
        max_marker = max(records, key=lambda r: float(r.get("cum", float("-inf"))))["marker"]
        if max_marker != task.pivot_ground_truth:
            bad.append(
                f"{task.task_id}: pivot_ground_truth={task.pivot_ground_truth}, max_cum_marker={max_marker}"
            )
    if bad:
        raise RuntimeError("Investment ground-truth validation failed: " + "; ".join(bad))


def build_miragebench_v01() -> List[MirageBenchTask]:
    tasks: List[MirageBenchTask] = []
    for i in range(1, 5):
        tasks.append(build_incident_task(i))
    for i in range(1, 5):
        tasks.append(build_investment_task(i))
    for i in range(1, 5):
        tasks.append(build_narrative_task(i))
    _validate_investment_ground_truth(tasks)
    return tasks


def render_compressed_variant(task: MirageBenchTask, drop_fraction: float, seed: int = 0) -> str:
    layout = task.metadata
    rng = np.random.default_rng(seed + int(drop_fraction * 1000))
    compressed_text, _, _ = _compress_records_to_target(layout, target_drop_fraction=drop_fraction, rng=rng)
    return compressed_text


In [None]:
# 1.4 Build MirageBench task suite
miragebench_tasks = build_miragebench_v01()

stats_rows = []
for task in miragebench_tasks:
    full_words = len(task.full_context.split())
    comp_words = len(task.compressed_context.split())
    stats_rows.append(
        {
            "task_id": task.task_id,
            "category": task.category,
            "full_words": full_words,
            "compressed_words": comp_words,
            "drop_pct": round(100 * (1 - comp_words / max(1, full_words)), 2),
            "pivot_gt": task.pivot_ground_truth,
            "decoy": task.decoy_pivot,
        }
    )

TASK_STATS_DF = pd.DataFrame(stats_rows)
TASK_STATS_DF


In [None]:
# 1.5 Quick verification: counts per category and compression range
print(TASK_STATS_DF.groupby("category").size())
print()
print("Drop % summary:")
print(TASK_STATS_DF["drop_pct"].describe())

TASK_STATS_DF.to_csv(RAW_DIR / "miragebench_task_stats.csv", index=False)
print("Saved:", RAW_DIR / "miragebench_task_stats.csv")


In [None]:
# 1.6 Optional: inspect one full + compressed pair
preview_task = miragebench_tasks[0]
print("Task:", preview_task.task_id, "Category:", preview_task.category)
print("\n--- Question ---\n")
print(preview_task.question)
print("\n--- Ground truth answer ---\n")
print(preview_task.answer_ground_truth)
print("\n--- Decoy answer ---\n")
print(preview_task.decoy_answer)
print("\n--- Full context excerpt ---\n")
print(preview_task.full_context[:3000])
print("\n[...truncated...]\n")
print("\n--- Compressed context excerpt ---\n")
print(preview_task.compressed_context[:2200])


# 2. Black-Box Evaluation

This section runs each task in two conditions:

- `full_context + question`
- `compressed_context + question`

Metrics:

- **Raw validity** (coherence / answerability)
- **Pivot preservation** (same pivot as full-context answer)
- **Semantic regret** (distance from full answer)


In [None]:
# 2.1 Model adapters (HF local + API placeholders)

MODEL_SPECS = {
    "llama-3.1-8b-instruct": "meta-llama/Llama-3.1-8B-Instruct",
    "mistral-7b-instruct-v0.3": "mistralai/Mistral-7B-Instruct-v0.3",
    "qwen2.5-7b-instruct": "Qwen/Qwen2.5-7B-Instruct",
}


def make_prompt(context: str, question: str) -> str:
    return (
        "You are a precise analyst. Follow the scoring rule in the prompt exactly.\n\n"
        + context.strip()
        + "\n\nQuestion:\n"
        + question.strip()
        + "\n\nAnswer:" 
    )


def load_hf_generator(model_name: str, max_new_tokens: int = 220):
    import torch
    from transformers import AutoTokenizer, pipeline

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token

    pipe = pipeline(
        "text-generation",
        model=model_name,
        tokenizer=tokenizer,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    def _generate(prompt: str) -> str:
        out = pipe(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            return_full_text=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        return out[0]["generated_text"].strip()

    return _generate


def get_openai_generator(model: str = "gpt-4o-mini", max_output_tokens: int = 280):
    from openai import OpenAI

    api_key = os.getenv("OPENAI_API_KEY", "")
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY is not set.")

    client = OpenAI(api_key=api_key)

    def _generate(prompt: str) -> str:
        resp = client.responses.create(
            model=model,
            input=prompt,
            max_output_tokens=max_output_tokens,
        )
        return (resp.output_text or "").strip()

    return _generate


def get_anthropic_generator(model: str = "claude-3-5-sonnet-latest", max_tokens: int = 280):
    import anthropic

    api_key = os.getenv("ANTHROPIC_API_KEY", "")
    if not api_key:
        raise RuntimeError("ANTHROPIC_API_KEY is not set.")

    client = anthropic.Anthropic(api_key=api_key)

    def _generate(prompt: str) -> str:
        msg = client.messages.create(
            model=model,
            max_tokens=max_tokens,
            messages=[{"role": "user", "content": prompt}],
        )
        chunks = []
        for block in msg.content:
            if hasattr(block, "text"):
                chunks.append(block.text)
        return "".join(chunks).strip()

    return _generate


In [None]:
# 2.2 Metrics: pivot extraction, raw validity, semantic regret

PIVOT_REGEX = re.compile(r"PIVOT_ID\s*=\s*([A-Z]\d{2}-E\d{3})")
MARKER_REGEX = re.compile(r"([A-Z]\d{2}-E\d{3})")


def extract_pivot_id(text: str, fallback_candidates: Optional[List[str]] = None) -> str:
    if not text:
        return ""
    m = PIVOT_REGEX.search(text)
    if m:
        return m.group(1)

    markers = MARKER_REGEX.findall(text)
    if markers:
        if fallback_candidates:
            for c in fallback_candidates:
                if c in markers:
                    return c
        return markers[0]

    return ""


def raw_validity_score(answer: str, task: MirageBenchTask) -> float:
    if not answer or not answer.strip():
        return 0.0

    words = answer.split()
    marker_hits = len(set(MARKER_REGEX.findall(answer)))
    has_causal_language = any(
        kw in answer.lower()
        for kw in ["because", "led", "enabled", "prerequisite", "therefore", "causal", "preceded"]
    )

    score = 0.0
    score += 0.35 if len(words) >= 45 else (0.2 if len(words) >= 20 else 0.0)
    score += 0.3 if extract_pivot_id(answer, [task.pivot_ground_truth, task.decoy_pivot]) else 0.0
    score += min(0.25, 0.08 * marker_hits)
    score += 0.1 if has_causal_language else 0.0
    return float(min(1.0, score))


_semantic_embedder = None


def _get_semantic_embedder():
    global _semantic_embedder
    if _semantic_embedder is not None:
        return _semantic_embedder

    try:
        from sentence_transformers import SentenceTransformer
        _semantic_embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    except Exception as exc:
        warnings.warn(f"SentenceTransformer unavailable; falling back to TF-IDF semantic proxy. Error: {exc}")
        _semantic_embedder = None
    return _semantic_embedder


def semantic_regret(full_answer: str, compressed_answer: str) -> float:
    full_answer = (full_answer or "").strip()
    compressed_answer = (compressed_answer or "").strip()
    if not full_answer and not compressed_answer:
        return 0.0
    if not full_answer or not compressed_answer:
        return 1.0

    embedder = _get_semantic_embedder()
    if embedder is not None:
        embs = embedder.encode([full_answer, compressed_answer], normalize_embeddings=True)
        sim = float(np.dot(embs[0], embs[1]))
    else:
        vec = TfidfVectorizer(min_df=1, ngram_range=(1, 2))
        X = vec.fit_transform([full_answer, compressed_answer]).toarray()
        sim = 1.0 - float(cosine(X[0], X[1]))
        if not np.isfinite(sim):
            sim = 0.0

    sim = max(-1.0, min(1.0, sim))
    regret = 1.0 - ((sim + 1.0) / 2.0)
    return float(np.clip(regret, 0.0, 1.0))


In [None]:
# 2.3 Evaluation harness

def classify_pivot_outcome(task: MirageBenchTask, full_pivot: str, compressed_pivot: str) -> str:
    full_correct = bool(full_pivot and full_pivot == task.pivot_ground_truth)
    comp_correct = bool(compressed_pivot and compressed_pivot == task.pivot_ground_truth)

    if not full_pivot or not compressed_pivot:
        return "unresolved"
    if full_pivot == compressed_pivot:
        return "stable_correct" if full_correct else "stable_wrong"
    if full_correct and not comp_correct:
        return "true_mirage"
    if (not full_correct) and comp_correct:
        return "rescue"
    if (not full_correct) and (not comp_correct):
        return "instability"
    return "other"


def evaluate_task_pair(
    task: MirageBenchTask,
    model_name: str,
    generate_fn: Callable[[str], str],
    compression_level: float,
    full_answer_override: Optional[str] = None,
) -> MirageBenchResult:
    full_prompt = make_prompt(task.full_context, task.question)
    comp_context = render_compressed_variant(task, drop_fraction=compression_level, seed=SEED)
    comp_prompt = make_prompt(comp_context, task.question)

    full_answer = full_answer_override if full_answer_override is not None else generate_fn(full_prompt)
    compressed_answer = generate_fn(comp_prompt)

    full_pivot = extract_pivot_id(full_answer, [task.pivot_ground_truth, task.decoy_pivot])
    comp_pivot = extract_pivot_id(compressed_answer, [task.pivot_ground_truth, task.decoy_pivot])

    raw_validity_full = raw_validity_score(full_answer, task)
    raw_validity_compressed = raw_validity_score(compressed_answer, task)
    pivot_outcome = classify_pivot_outcome(task, full_pivot, comp_pivot)
    high_validity = raw_validity_compressed >= 0.70

    result = MirageBenchResult(
        task_id=task.task_id,
        model_name=model_name,
        full_answer=full_answer,
        compressed_answer=compressed_answer,
        raw_validity=raw_validity_compressed,
        pivot_preserved=bool(full_pivot and comp_pivot and full_pivot == comp_pivot),
        semantic_regret=semantic_regret(full_answer, compressed_answer),
        compression_level=float(compression_level),
        category=task.category,
        full_pivot=full_pivot,
        compressed_pivot=comp_pivot,
        raw_validity_full=raw_validity_full,
        raw_validity_compressed=raw_validity_compressed,
        full_pivot_matches_ground_truth=bool(full_pivot == task.pivot_ground_truth),
        compressed_pivot_matches_ground_truth=bool(comp_pivot == task.pivot_ground_truth),
        pivot_outcome=pivot_outcome,
        high_validity_flag=bool(high_validity),
        true_mirage_flag=bool((pivot_outcome == "true_mirage") and high_validity),
        rescue_flag=bool((pivot_outcome == "rescue") and high_validity),
        instability_flag=bool((pivot_outcome == "instability") and high_validity),
    )
    return result


def run_blackbox_eval(
    tasks: List[MirageBenchTask],
    model_generators: Dict[str, Callable[[str], str]],
    compression_levels: Sequence[float] = (0.4, 0.5, 0.6),
    max_tasks_per_model: Optional[int] = None,
) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []

    for model_name, gen_fn in model_generators.items():
        sub_tasks = tasks if max_tasks_per_model is None else tasks[:max_tasks_per_model]
        for task in tqdm(sub_tasks, desc=f"Evaluating {model_name}"):
            full_prompt = make_prompt(task.full_context, task.question)
            try:
                # Generate once per (task, model); reuse for all compression levels.
                full_answer = gen_fn(full_prompt)
                full_pivot = extract_pivot_id(full_answer, [task.pivot_ground_truth, task.decoy_pivot])
                raw_validity_full = raw_validity_score(full_answer, task)
            except Exception as exc:
                for lvl in compression_levels:
                    rows.append(
                        {
                            "task_id": task.task_id,
                            "model_name": model_name,
                            "compression_level": lvl,
                            "category": task.category,
                            "error": str(exc),
                        }
                    )
                continue

            for lvl in compression_levels:
                try:
                    comp_context = render_compressed_variant(task, drop_fraction=lvl, seed=SEED)
                    comp_prompt = make_prompt(comp_context, task.question)
                    compressed_answer = gen_fn(comp_prompt)

                    compressed_pivot = extract_pivot_id(
                        compressed_answer,
                        [task.pivot_ground_truth, task.decoy_pivot],
                    )
                    raw_validity_compressed = raw_validity_score(compressed_answer, task)

                    pivot_outcome = classify_pivot_outcome(task, full_pivot, compressed_pivot)
                    high_validity = int(raw_validity_compressed >= 0.70)

                    row = {
                        "task_id": task.task_id,
                        "model_name": model_name,
                        "full_answer": full_answer,
                        "compressed_answer": compressed_answer,
                        "raw_validity": raw_validity_compressed,
                        "raw_validity_full": raw_validity_full,
                        "raw_validity_compressed": raw_validity_compressed,
                        "pivot_preserved": int(bool(full_pivot and compressed_pivot and full_pivot == compressed_pivot)),
                        "semantic_regret": semantic_regret(full_answer, compressed_answer),
                        "compression_level": float(lvl),
                        "category": task.category,
                        "full_pivot": full_pivot,
                        "compressed_pivot": compressed_pivot,
                        "full_pivot_matches_ground_truth": int(full_pivot == task.pivot_ground_truth),
                        "pivot_matches_ground_truth": int(compressed_pivot == task.pivot_ground_truth),
                        "pivot_outcome": pivot_outcome,
                        "high_validity_flag": high_validity,
                        "true_mirage_flag": int((pivot_outcome == "true_mirage") and high_validity),
                        "rescue_flag": int((pivot_outcome == "rescue") and high_validity),
                        "instability_flag": int((pivot_outcome == "instability") and high_validity),
                    }
                    # Backward-compatible legacy field.
                    row["mirage_flag"] = row["true_mirage_flag"]
                    rows.append(row)
                except Exception as exc:
                    rows.append(
                        {
                            "task_id": task.task_id,
                            "model_name": model_name,
                            "compression_level": lvl,
                            "category": task.category,
                            "error": str(exc),
                        }
                    )

    return pd.DataFrame(rows)


In [None]:
# 2.4 Configure models and run black-box experiments
# NOTE: Loading all open models can take significant VRAM/time.
# Start with 1 model + 3 tasks, then scale up.

RUN_HF_MODELS = bool(globals().get("RUN_HF_MODELS", False))
RUN_API_MODELS = bool(globals().get("RUN_API_MODELS", False))
MAX_TASKS_PER_MODEL = globals().get("MAX_TASKS_PER_MODEL", 3)
COMPRESSION_LEVELS = list(globals().get("COMPRESSION_LEVELS", [0.4, 0.5, 0.6]))
HF_MODELS_TO_RUN = list(globals().get("HF_MODELS_TO_RUN", MODEL_SPECS.keys()))

selected_model_specs = {
    short_name: repo_name
    for short_name, repo_name in MODEL_SPECS.items()
    if short_name in HF_MODELS_TO_RUN
}
if not selected_model_specs:
    selected_model_specs = MODEL_SPECS.copy()

model_generators: Dict[str, Callable[[str], str]] = {}

if RUN_HF_MODELS:
    for short_name, repo_name in selected_model_specs.items():
        try:
            print(f"Loading HF model: {repo_name}")
            model_generators[short_name] = load_hf_generator(repo_name)
        except Exception as exc:
            print(f"Skipping {repo_name}: {exc}")

if RUN_API_MODELS:
    try:
        model_generators["gpt-4o-mini"] = get_openai_generator(model="gpt-4o-mini")
    except Exception as exc:
        print("OpenAI setup not ready:", exc)

    try:
        model_generators["claude-sonnet"] = get_anthropic_generator(model="claude-3-5-sonnet-latest")
    except Exception as exc:
        print("Anthropic setup not ready:", exc)

if model_generators:
    blackbox_results_df = run_blackbox_eval(
        tasks=miragebench_tasks,
        model_generators=model_generators,
        compression_levels=COMPRESSION_LEVELS,
        max_tasks_per_model=MAX_TASKS_PER_MODEL,
    )
    blackbox_results_df.to_csv(RAW_DIR / "miragebench_blackbox_results.csv", index=False)
    print("Saved:", RAW_DIR / "miragebench_blackbox_results.csv")
else:
    blackbox_results_df = pd.DataFrame()
    print("No model generators configured. Set RUN_HF_MODELS or RUN_API_MODELS to True.")

blackbox_results_df.head()


In [None]:
# 2.5 Aggregate metrics and mirage plot
if blackbox_results_df.empty:
    print("No black-box results yet.")
else:
    usable = blackbox_results_df[~blackbox_results_df.columns.isin(["error"])].copy()
    if "error" in blackbox_results_df.columns:
        usable = blackbox_results_df[blackbox_results_df["error"].isna()] if blackbox_results_df["error"].notna().any() else blackbox_results_df.copy()

    numeric_cols = ["raw_validity", "semantic_regret", "pivot_preserved", "pivot_matches_ground_truth", "mirage_flag"]
    for c in numeric_cols:
        if c in usable.columns:
            usable[c] = pd.to_numeric(usable[c], errors="coerce")

    summary = (
        usable.groupby(["model_name", "category", "compression_level"], as_index=False)
        .agg(
            raw_validity=("raw_validity", "mean"),
            pivot_preservation=("pivot_preserved", "mean"),
            semantic_regret=("semantic_regret", "mean"),
            mirage_rate=("mirage_flag", "mean"),
            n=("task_id", "count"),
        )
        .sort_values(["model_name", "category", "compression_level"])
    )

    display(summary)
    summary.to_csv(RAW_DIR / "miragebench_blackbox_summary.csv", index=False)

    # Mirage plot: raw validity vs pivot preservation across compression levels.
    fig, ax = plt.subplots(figsize=(9, 6))
    for model_name, sub in summary.groupby("model_name"):
        xy = sub.groupby("compression_level", as_index=False).agg(
            raw_validity=("raw_validity", "mean"),
            pivot_preservation=("pivot_preservation", "mean"),
        )
        ax.plot(xy["raw_validity"], xy["pivot_preservation"], marker="o", label=model_name)
        for _, row in xy.iterrows():
            ax.text(row["raw_validity"] + 0.003, row["pivot_preservation"] + 0.003, f"c={row['compression_level']:.1f}", fontsize=8)

    ax.set_xlabel("Raw validity (higher is better)")
    ax.set_ylabel("Pivot preservation (higher is better)")
    ax.set_title("Mirage Plot: Validity vs Pivot Preservation")
    ax.set_xlim(0, 1.02)
    ax.set_ylim(0, 1.02)
    ax.legend(loc="lower left")
    fig.tight_layout()

    mirage_plot_path = FIG_DIR / "mirage_plot_blackbox.png"
    fig.savefig(mirage_plot_path)
    plt.show()
    print("Saved:", mirage_plot_path)


In [None]:
# 2.6 Flag high-validity / low-pivot-preservation mirage cases
if blackbox_results_df.empty:
    print("No black-box results yet.")
else:
    tmp = blackbox_results_df.copy()
    if "error" in tmp.columns:
        tmp = tmp[tmp["error"].isna()] if tmp["error"].notna().any() else tmp

    tmp["raw_validity"] = pd.to_numeric(tmp.get("raw_validity"), errors="coerce")
    tmp["pivot_preserved"] = pd.to_numeric(tmp.get("pivot_preserved"), errors="coerce")

    mirage_cases = tmp[(tmp["raw_validity"] >= 0.75) & (tmp["pivot_preserved"] < 0.5)].copy()
    mirage_cases = mirage_cases.sort_values(["model_name", "compression_level", "raw_validity"], ascending=[True, True, False])

    print(f"Mirage cases found: {len(mirage_cases)}")
    display(mirage_cases[[
        "task_id", "category", "model_name", "compression_level", "raw_validity", "pivot_preserved", "semantic_regret", "full_pivot", "compressed_pivot"
    ]].head(30))

    mirage_cases.to_csv(RAW_DIR / "mirage_cases.csv", index=False)
    print("Saved:", RAW_DIR / "mirage_cases.csv")


# 3. KV-Cache Surgery (GPU)

Goal: probe the neural mechanism behind pivot substitution.

For a selected task and model:

1. Build prefix KV-cache from full context.
2. Identify token spans for pivot/setup events.
3. Apply eviction strategies:
   - random
   - attention-based (H2O-style)
   - setup-targeted
   - contract-guarded
4. Regenerate and compare coherence, pivot preservation, and semantic regret.
5. Plot attention redistribution and per-layer pivot tracking.


In [None]:
# 3.1 Load model for KV experiments
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

KV_MODEL_NAME = str(globals().get("KV_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct"))
KV_MAX_NEW_TOKENS = int(globals().get("KV_MAX_NEW_TOKENS", 160))
AUTO_LOAD_KV_MODEL = bool(globals().get("AUTO_LOAD_KV_MODEL", False))

kv_tokenizer = None
kv_model = None


def load_kv_model(model_name: str = KV_MODEL_NAME):
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="eager",  # explicit for attention/KV introspection
    )
    model.eval()
    return tokenizer, model


if AUTO_LOAD_KV_MODEL:
    kv_tokenizer, kv_model = load_kv_model(KV_MODEL_NAME)
    print("Loaded:", KV_MODEL_NAME)
else:
    print("Set AUTO_LOAD_KV_MODEL=True in the control panel to auto-load KV model.")


In [None]:
# 3.2 KV helpers: token mapping, cache pruning, generation from cached prefix

def to_legacy_past(past_key_values):
    if hasattr(past_key_values, "to_legacy_cache"):
        return past_key_values.to_legacy_cache()
    return past_key_values


def make_kv_prompt(task: MirageBenchTask, context: str) -> str:
    return make_prompt(context=context, question=task.question)


def _tokenize_with_offsets(tokenizer, text: str):
    enc = tokenizer(
        text,
        return_tensors="pt",
        return_offsets_mapping=True,
        add_special_tokens=False,
    )
    input_ids = enc["input_ids"]
    offsets = enc["offset_mapping"][0].tolist()
    return input_ids, offsets


def spans_to_token_positions(offsets: List[Tuple[int, int]], spans: List[Dict[str, Any]], max_token: Optional[int] = None):
    event_token_map: Dict[str, List[int]] = {}
    role_map: Dict[str, str] = {}

    upper = len(offsets) if max_token is None else min(len(offsets), max_token)
    for sp in spans:
        marker = sp["marker"]
        s0, e0 = int(sp["start"]), int(sp["end"])
        pos = []
        for idx in range(upper):
            s, e = offsets[idx]
            if e <= s0 or s >= e0:
                continue
            pos.append(idx)
        if pos:
            event_token_map[marker] = pos
            role_map[marker] = sp.get("role", "routine")

    return event_token_map, role_map


def prune_past_key_values(past_key_values, keep_positions: List[int]):
    keep = sorted(set(int(x) for x in keep_positions if x >= 0))
    if not keep:
        raise ValueError("Cannot prune to empty cache; keep_positions is empty.")

    new_past = []
    for layer in past_key_values:
        key, value = layer[0], layer[1]
        keep_idx = torch.tensor(keep, dtype=torch.long, device=key.device)
        new_key = key.index_select(2, keep_idx)
        new_val = value.index_select(2, keep_idx)

        if len(layer) == 2:
            new_past.append((new_key, new_val))
        else:
            new_past.append((new_key, new_val, *layer[2:]))

    return tuple(new_past)


def decode_from_prefix_cache(
    model,
    tokenizer,
    prefix_past,
    last_token_id,
    max_new_tokens: int = 160,
):
    # Continue generation given prefix cache and last prompt token.
    # Returns generated text + first-step attentions for layer-wise analysis.
    past = prefix_past
    input_token = last_token_id

    generated_ids: List[int] = []
    first_step_attn = None

    with torch.no_grad():
        out = model(
            input_ids=input_token,
            past_key_values=past,
            use_cache=True,
            output_attentions=True,
        )
        past = to_legacy_past(out.past_key_values)

        if out.attentions is not None:
            first_step_attn = torch.stack([
                att[0, :, 0, :].mean(dim=0).detach().float().cpu()
                for att in out.attentions
            ])

        next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        generated_ids.append(int(next_token.item()))
        input_token = next_token

        eos = tokenizer.eos_token_id
        for _ in range(max_new_tokens - 1):
            out = model(
                input_ids=input_token,
                past_key_values=past,
                use_cache=True,
                output_attentions=False,
            )
            past = to_legacy_past(out.past_key_values)
            next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
            tid = int(next_token.item())
            generated_ids.append(tid)
            input_token = next_token
            if eos is not None and tid == eos:
                break

    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return text.strip(), first_step_attn


def run_prompt_with_cache(model, tokenizer, prompt: str, max_new_tokens: int = KV_MAX_NEW_TOKENS):
    input_ids, offsets = _tokenize_with_offsets(tokenizer, prompt)
    input_ids = input_ids.to(model.device)

    if input_ids.shape[1] < 2:
        raise ValueError("Prompt too short for cache split.")

    prefix_ids = input_ids[:, :-1]
    last_token = input_ids[:, -1:]

    with torch.no_grad():
        prefix_out = model(prefix_ids, use_cache=True, output_attentions=False)

    prefix_past = to_legacy_past(prefix_out.past_key_values)
    text, first_step_attn = decode_from_prefix_cache(
        model=model,
        tokenizer=tokenizer,
        prefix_past=prefix_past,
        last_token_id=last_token,
        max_new_tokens=max_new_tokens,
    )

    return {
        "prompt": prompt,
        "input_ids": input_ids,
        "offsets": offsets,
        "prefix_len": int(prefix_ids.shape[1]),
        "prefix_past": prefix_past,
        "last_token": last_token,
        "answer": text,
        "first_step_attn": first_step_attn,
    }


In [None]:
# 3.3 Eviction strategies (random, attention-H2O, setup-targeted, contract-guarded)

def _marker_scores_from_attention(event_token_map: Dict[str, List[int]], token_importance: np.ndarray) -> Dict[str, float]:
    scores = {}
    for marker, positions in event_token_map.items():
        vals = [token_importance[p] for p in positions if p < len(token_importance)]
        scores[marker] = float(np.mean(vals)) if vals else 0.0
    return scores


def _candidate_event_order(
    markers: List[str],
    role_map: Dict[str, str],
    marker_scores: Dict[str, float],
    strategy: str,
    rng: np.random.Generator,
) -> List[str]:
    if strategy == "random":
        out = markers[:]
        rng.shuffle(out)
        return out

    if strategy == "attention_h2o":
        # H2O-like: evict lowest attention first.
        return sorted(markers, key=lambda m: marker_scores.get(m, 0.0))

    if strategy == "setup_targeted":
        role_rank = {"setup": 0, "routine": 1, "support": 2, "decoy_setup": 3, "candidate": 4, "pivot": 5}
        return sorted(markers, key=lambda m: (role_rank.get(role_map.get(m, "routine"), 2), marker_scores.get(m, 0.0)))

    if strategy == "contract_guarded":
        role_rank = {"setup": 0, "routine": 1, "support": 2, "decoy_setup": 3, "candidate": 4, "pivot": 5}
        return sorted(markers, key=lambda m: (role_rank.get(role_map.get(m, "routine"), 2), marker_scores.get(m, 0.0)))

    raise ValueError(f"Unknown strategy: {strategy}")


def choose_eviction_markers(
    task: MirageBenchTask,
    event_token_map: Dict[str, List[int]],
    role_map: Dict[str, str],
    token_importance: np.ndarray,
    strategy: str,
    eviction_ratio: float,
    seed: int = 0,
) -> Tuple[List[str], int]:
    rng = np.random.default_rng(seed)

    all_markers = list(event_token_map.keys())
    protected = set(task.metadata.get("candidate_markers", []))

    removable = [m for m in all_markers if m not in protected]
    marker_scores = _marker_scores_from_attention(event_token_map, token_importance)

    total_tokens = sum(len(event_token_map[m]) for m in removable)
    target_remove = int(round(total_tokens * eviction_ratio))

    ordered = _candidate_event_order(removable, role_map, marker_scores, strategy=strategy, rng=rng)

    removed: List[str] = []
    removed_tokens = 0

    requirements: Dict[str, List[str]] = task.metadata.get("candidate_requirements", {})
    k = int(task.k)

    for marker in ordered:
        if removed_tokens >= target_remove:
            break

        if strategy == "contract_guarded":
            # Contract: removing marker cannot reduce surviving setup markers below k for any candidate pivot.
            violates = False
            current_removed = set(removed)
            for cand, req_markers in requirements.items():
                req_set = set(req_markers)
                surviving = len(req_set - current_removed)
                if marker in req_set and (surviving - 1) < k:
                    violates = True
                    break
            if violates:
                continue

        removed.append(marker)
        removed_tokens += len(event_token_map[marker])

    return removed, target_remove


def markers_to_token_positions(markers: List[str], event_token_map: Dict[str, List[int]], max_prefix_len: int) -> List[int]:
    positions: List[int] = []
    for m in markers:
        for p in event_token_map.get(m, []):
            if p < max_prefix_len:
                positions.append(int(p))
    return sorted(set(positions))


def kv_strategy_run(
    task: MirageBenchTask,
    model,
    tokenizer,
    strategy: str,
    eviction_ratio: float = 0.5,
    seed: int = 0,
) -> Dict[str, Any]:
    full_prompt = make_kv_prompt(task, task.full_context)
    base = run_prompt_with_cache(model, tokenizer, full_prompt, max_new_tokens=KV_MAX_NEW_TOKENS)

    # Token map for context events (use only context token range).
    context_ids = tokenizer(task.full_context, add_special_tokens=False, return_tensors="pt")["input_ids"]
    context_len = int(context_ids.shape[1])

    spans = task.metadata.get("spans", [])
    event_token_map, role_map = spans_to_token_positions(base["offsets"], spans, max_token=context_len)

    if base["first_step_attn"] is None:
        raise RuntimeError("No attentions returned; make sure model supports output_attentions.")

    token_importance = base["first_step_attn"].mean(dim=0).numpy()
    token_importance = token_importance[:context_len]

    removed_markers, target_remove = choose_eviction_markers(
        task=task,
        event_token_map=event_token_map,
        role_map=role_map,
        token_importance=token_importance,
        strategy=strategy,
        eviction_ratio=eviction_ratio,
        seed=seed,
    )

    remove_positions = markers_to_token_positions(removed_markers, event_token_map, max_prefix_len=base["prefix_len"])
    keep_positions = [i for i in range(base["prefix_len"]) if i not in set(remove_positions)]

    if len(keep_positions) < 8:
        raise RuntimeError("Too many tokens removed; insufficient cache length remains.")

    pruned_past = prune_past_key_values(base["prefix_past"], keep_positions)

    pruned_answer, pruned_attn = decode_from_prefix_cache(
        model=model,
        tokenizer=tokenizer,
        prefix_past=pruned_past,
        last_token_id=base["last_token"],
        max_new_tokens=KV_MAX_NEW_TOKENS,
    )

    full_pivot = extract_pivot_id(base["answer"], [task.pivot_ground_truth, task.decoy_pivot])
    pruned_pivot = extract_pivot_id(pruned_answer, [task.pivot_ground_truth, task.decoy_pivot])

    return {
        "task_id": task.task_id,
        "strategy": strategy,
        "eviction_ratio": eviction_ratio,
        "target_remove_tokens": target_remove,
        "removed_markers": removed_markers,
        "removed_tokens": len(remove_positions),
        "full_answer": base["answer"],
        "pruned_answer": pruned_answer,
        "full_pivot": full_pivot,
        "pruned_pivot": pruned_pivot,
        "pivot_preserved": bool(full_pivot and pruned_pivot and full_pivot == pruned_pivot),
        "raw_validity": raw_validity_score(pruned_answer, task),
        "semantic_regret": semantic_regret(base["answer"], pruned_answer),
        "full_first_step_attn": base["first_step_attn"],
        "pruned_first_step_attn": pruned_attn,
        "event_token_map": event_token_map,
        "context_len": context_len,
    }


In [None]:
# 3.4 Run KV-cache surgery on one task
RUN_KV_EXPERIMENT = bool(globals().get("RUN_KV_EXPERIMENT", False))
KV_TASK_INDEX = int(globals().get("KV_TASK_INDEX", 0))
KV_EVICTION_RATIO = float(globals().get("KV_EVICTION_RATIO", 0.5))

kv_strategy_results = []

if RUN_KV_EXPERIMENT:
    if kv_model is None or kv_tokenizer is None:
        kv_tokenizer, kv_model = load_kv_model(KV_MODEL_NAME)

    task = miragebench_tasks[KV_TASK_INDEX]
    strategies = ["random", "attention_h2o", "setup_targeted", "contract_guarded"]

    for strat in strategies:
        print(f"Running strategy: {strat}")
        try:
            out = kv_strategy_run(
                task=task,
                model=kv_model,
                tokenizer=kv_tokenizer,
                strategy=strat,
                eviction_ratio=KV_EVICTION_RATIO,
                seed=SEED,
            )
            kv_strategy_results.append(out)
            print(
                f"  pivot_preserved={out['pivot_preserved']} | raw_validity={out['raw_validity']:.3f} | "
                f"semantic_regret={out['semantic_regret']:.3f}"
            )
        except Exception as exc:
            print(f"  Failed for {strat}: {exc}")

    kv_results_df = pd.DataFrame(
        [
            {
                "task_id": r["task_id"],
                "strategy": r["strategy"],
                "eviction_ratio": r["eviction_ratio"],
                "removed_tokens": r["removed_tokens"],
                "pivot_preserved": r["pivot_preserved"],
                "raw_validity": r["raw_validity"],
                "semantic_regret": r["semantic_regret"],
                "full_pivot": r["full_pivot"],
                "pruned_pivot": r["pruned_pivot"],
            }
            for r in kv_strategy_results
        ]
    )
    kv_results_df.to_csv(RAW_DIR / "kv_surgery_results_single_task.csv", index=False)
    display(kv_results_df)
else:
    kv_results_df = pd.DataFrame()
    print("Set RUN_KV_EXPERIMENT=True in the control panel to execute KV-cache surgery.")


In [None]:
# 3.5 Attention visualization + per-layer pivot tracking

def _downsample_vector(v: np.ndarray, bins: int = 140) -> np.ndarray:
    if len(v) <= bins:
        return v
    edges = np.linspace(0, len(v), bins + 1).astype(int)
    out = []
    for i in range(bins):
        s, e = edges[i], edges[i + 1]
        if e <= s:
            out.append(0.0)
        else:
            out.append(float(v[s:e].mean()))
    return np.array(out)


def _layer_group_attention(first_step_attn: torch.Tensor, positions: List[int]) -> np.ndarray:
    if first_step_attn is None:
        return np.array([])
    if not positions:
        return np.zeros(first_step_attn.shape[0], dtype=float)
    arr = first_step_attn.numpy()
    pos = [p for p in positions if p < arr.shape[1]]
    if not pos:
        return np.zeros(arr.shape[0], dtype=float)
    return arr[:, pos].sum(axis=1)


if RUN_KV_EXPERIMENT and kv_strategy_results:
    # Compare baseline vs setup-targeted attention redistribution.
    baseline = next((r for r in kv_strategy_results if r["strategy"] == "attention_h2o"), kv_strategy_results[0])
    setup_targeted = next((r for r in kv_strategy_results if r["strategy"] == "setup_targeted"), kv_strategy_results[0])

    full_vec = baseline["full_first_step_attn"].mean(dim=0).numpy()
    if setup_targeted["pruned_first_step_attn"] is not None:
        pruned_vec = setup_targeted["pruned_first_step_attn"].mean(dim=0).numpy()
    else:
        pruned_vec = np.zeros_like(full_vec)

    full_ds = _downsample_vector(full_vec, bins=160)
    pruned_ds = _downsample_vector(pruned_vec, bins=160)

    heat = np.vstack([full_ds, pruned_ds])

    fig, ax = plt.subplots(figsize=(12, 2.8))
    im = ax.imshow(heat, aspect="auto", cmap="magma")
    ax.set_yticks([0, 1])
    ax.set_yticklabels(["Full", "Setup-targeted eviction"])
    ax.set_title("First-step attention heatmap (downsampled)")
    ax.set_xlabel("Context position bins")
    fig.colorbar(im, ax=ax, shrink=0.8)
    fig.tight_layout()
    path = FIG_DIR / "kv_attention_heatmap_full_vs_setup_targeted.png"
    fig.savefig(path)
    plt.show()
    print("Saved:", path)

    # Per-layer pivot/setup tracking (full vs compressed-context baseline).
    task = miragebench_tasks[KV_TASK_INDEX]
    comp_prompt = make_kv_prompt(task, task.compressed_context)
    comp_run = run_prompt_with_cache(kv_model, kv_tokenizer, comp_prompt, max_new_tokens=KV_MAX_NEW_TOKENS)

    full_prompt = make_kv_prompt(task, task.full_context)
    full_run = run_prompt_with_cache(kv_model, kv_tokenizer, full_prompt, max_new_tokens=KV_MAX_NEW_TOKENS)

    full_ids, full_offsets = _tokenize_with_offsets(kv_tokenizer, full_prompt)
    comp_ids, comp_offsets = _tokenize_with_offsets(kv_tokenizer, comp_prompt)

    full_event_map, _ = spans_to_token_positions(full_offsets, task.metadata.get("spans", []), max_token=full_ids.shape[1])

    pivot_positions = full_event_map.get(task.pivot_ground_truth, [])
    setup_positions = []
    for m in task.metadata.get("pivot_setup_markers", []):
        setup_positions.extend(full_event_map.get(m, []))

    full_pivot_curve = _layer_group_attention(full_run["first_step_attn"], pivot_positions)
    full_setup_curve = _layer_group_attention(full_run["first_step_attn"], setup_positions)

    # For compressed context, remap via marker search from spans derived on compressed context rendering.
    comp_context_len = kv_tokenizer(task.compressed_context, add_special_tokens=False, return_tensors="pt")["input_ids"].shape[1]
    comp_records = []
    kept_markers = set(re.findall(r"[A-Z]\d{2}-E\d{3}", task.compressed_context))
    for sp in task.metadata.get("spans", []):
        if sp["marker"] in kept_markers:
            comp_records.append(sp)
    comp_event_map, _ = spans_to_token_positions(comp_offsets, comp_records, max_token=comp_context_len)

    comp_pivot_positions = comp_event_map.get(task.pivot_ground_truth, [])
    comp_setup_positions = []
    for m in task.metadata.get("pivot_setup_markers", []):
        comp_setup_positions.extend(comp_event_map.get(m, []))

    comp_pivot_curve = _layer_group_attention(comp_run["first_step_attn"], comp_pivot_positions)
    comp_setup_curve = _layer_group_attention(comp_run["first_step_attn"], comp_setup_positions)

    layers = np.arange(len(full_pivot_curve))
    fig, ax = plt.subplots(figsize=(9, 5))
    ax.plot(layers, full_pivot_curve, label="Full: pivot attention", linewidth=2)
    ax.plot(layers, full_setup_curve, label="Full: setup attention", linewidth=2)
    ax.plot(layers, comp_pivot_curve, "--", label="Compressed: pivot attention", linewidth=2)
    ax.plot(layers, comp_setup_curve, "--", label="Compressed: setup attention", linewidth=2)
    ax.set_xlabel("Layer")
    ax.set_ylabel("Attention mass")
    ax.set_title("Per-layer pivot/setup tracking (first answer transition)")
    ax.legend()
    fig.tight_layout()
    path = FIG_DIR / "kv_per_layer_pivot_setup_tracking.png"
    fig.savefig(path)
    plt.show()
    print("Saved:", path)
else:
    print("Run KV experiment first to generate attention visuals.")


In [None]:
# 3.6 Neural mirage sweep across eviction levels
RUN_KV_SWEEP = bool(globals().get("RUN_KV_SWEEP", False))
KV_SWEEP_LEVELS = list(globals().get("KV_SWEEP_LEVELS", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]))

if RUN_KV_SWEEP:
    if kv_model is None or kv_tokenizer is None:
        kv_tokenizer, kv_model = load_kv_model(KV_MODEL_NAME)

    task = miragebench_tasks[KV_TASK_INDEX]
    rows = []

    for strat in ["random", "attention_h2o", "setup_targeted", "contract_guarded"]:
        for lvl in KV_SWEEP_LEVELS:
            try:
                out = kv_strategy_run(task, kv_model, kv_tokenizer, strat, eviction_ratio=lvl, seed=SEED)
                rows.append(
                    {
                        "strategy": strat,
                        "eviction_ratio": lvl,
                        "raw_validity": out["raw_validity"],
                        "pivot_preserved": float(out["pivot_preserved"]),
                        "semantic_regret": out["semantic_regret"],
                    }
                )
            except Exception as exc:
                rows.append(
                    {
                        "strategy": strat,
                        "eviction_ratio": lvl,
                        "error": str(exc),
                    }
                )

    kv_sweep_df = pd.DataFrame(rows)
    kv_sweep_df.to_csv(RAW_DIR / "kv_neural_mirage_sweep.csv", index=False)

    ok = kv_sweep_df[kv_sweep_df.get("error").isna()] if "error" in kv_sweep_df.columns else kv_sweep_df

    fig, ax = plt.subplots(figsize=(9, 6))
    for strat, sub in ok.groupby("strategy"):
        sub = sub.sort_values("eviction_ratio")
        ax.plot(sub["raw_validity"], sub["pivot_preserved"], marker="o", label=strat)
        for _, row in sub.iterrows():
            ax.text(row["raw_validity"] + 0.002, row["pivot_preserved"] + 0.002, f"e={row['eviction_ratio']:.1f}", fontsize=7)

    ax.set_xlabel("Raw generation quality")
    ax.set_ylabel("Pivot consistency")
    ax.set_title("Neural Mirage Plot across KV eviction levels")
    ax.set_xlim(0, 1.02)
    ax.set_ylim(0, 1.02)
    ax.legend()
    fig.tight_layout()
    path = FIG_DIR / "kv_neural_mirage_plot.png"
    fig.savefig(path)
    plt.show()
    print("Saved:", path)
else:
    kv_sweep_df = pd.DataFrame()
    print("Set RUN_KV_SWEEP=True in the control panel to run eviction-level sweep.")


In [None]:
# 3.7 Contract-guarded vs H2O direct comparison
if 'kv_sweep_df' in globals() and not kv_sweep_df.empty:
    ok = kv_sweep_df[kv_sweep_df.get("error").isna()] if "error" in kv_sweep_df.columns else kv_sweep_df
    cmp_df = (
        ok[ok["strategy"].isin(["attention_h2o", "contract_guarded"])]
        .groupby("strategy", as_index=False)
        .agg(
            mean_pivot_preservation=("pivot_preserved", "mean"),
            mean_raw_validity=("raw_validity", "mean"),
            mean_semantic_regret=("semantic_regret", "mean"),
        )
    )
    display(cmp_df)

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(cmp_df["strategy"], cmp_df["mean_pivot_preservation"], color=["#4c78a8", "#72b7b2"])
    ax.set_ylim(0, 1.0)
    ax.set_ylabel("Pivot preservation")
    ax.set_title("Contract-guarded vs H2O (mean over eviction sweep)")
    fig.tight_layout()
    path = FIG_DIR / "kv_contract_vs_h2o_bar.png"
    fig.savefig(path)
    plt.show()
    print("Saved:", path)
else:
    print("Run KV sweep first to compare contract-guarded vs H2O.")


# 4. Divergence Scaling

Extended probe:

- `n ∈ {1K, 5K, 10K, 50K, 100K, 500K, 1M}`
- `ε ∈ {0.3, 0.5, 0.7}`
- `k ∈ {1, 2, 3, 5}`
- `200` seeds per cell

We fit `cost ~ a * n^b` and report bootstrap CIs for `b`.


In [None]:
# 4.1 Vectorized divergence simulator

def interleaved_focal_mask(n: int, focal_fraction: float = 0.5) -> np.ndarray:
    n_focal = int(round(n * focal_fraction))
    n_focal = max(1, min(n - 1, n_focal))
    idx = (np.arange(n) * n_focal) // n
    mask = np.zeros(n, dtype=bool)
    prev = -1
    for i, b in enumerate(idx):
        if b != prev and b < n_focal:
            mask[i] = True
        prev = b
    return mask


def simulate_divergence_cell(
    n: int,
    epsilon: float,
    k: int,
    seeds: int = 200,
    chunk_size: int = 10,
    focal_fraction: float = 0.5,
    base_seed: int = 123,
) -> pd.DataFrame:
    # Returns per-seed costs with vectorized record detection.
    # effective_cost is k-sensitive: only record shifts with sub-k support contribute.
    focal = interleaved_focal_mask(n, focal_fraction=focal_fraction)
    nonfocal_prefix = np.cumsum(~focal).astype(np.int64)

    high_zone = (np.arange(n) / max(1, n - 1)) <= float(epsilon)
    positions = np.arange(n, dtype=np.int64)

    rows = []
    seed_vals = np.arange(base_seed, base_seed + seeds, dtype=np.int64)

    for start in range(0, seeds, chunk_size):
        end = min(seeds, start + chunk_size)
        m = end - start

        rng = np.random.default_rng(int(seed_vals[start]))
        low = rng.uniform(0.1, 5.0, size=(m, n)).astype(np.float32)
        high = rng.uniform(5.0, 20.0, size=(m, n)).astype(np.float32)
        w = np.where(high_zone[None, :], high, low)
        w[:, ~focal] = -np.inf

        runmax = np.maximum.accumulate(w, axis=1)
        shifted = np.concatenate([
            np.full((m, 1), -np.inf, dtype=np.float32),
            runmax[:, :-1],
        ], axis=1)
        is_record = np.isfinite(w) & (w > shifted)

        n_records = is_record.sum(axis=1)
        rec_pos_sum = (is_record * positions[None, :]).sum(axis=1)
        first_pos = np.where(n_records > 0, np.argmax(is_record, axis=1), 0)
        total_cost = rec_pos_sum - first_pos

        effective_cost = np.zeros(m, dtype=np.float64)
        for r in range(m):
            rec_pos = np.flatnonzero(is_record[r])
            if rec_pos.size <= 1:
                continue
            prev = rec_pos[:-1]
            curr = rec_pos[1:]
            gaps = nonfocal_prefix[curr] - nonfocal_prefix[prev]
            effective_cost[r] = float(curr[gaps < k].sum())

        for local_idx in range(m):
            rows.append(
                {
                    "n": n,
                    "epsilon": epsilon,
                    "k": k,
                    "seed": int(seed_vals[start + local_idx]),
                    "n_records": int(n_records[local_idx]),
                    "total_cost": float(total_cost[local_idx]),
                    "effective_cost": float(effective_cost[local_idx]),
                }
            )

    return pd.DataFrame(rows)


def fit_power_law(x: np.ndarray, y: np.ndarray) -> Dict[str, float]:
    mask = (x > 0) & (y > 0)
    x = x[mask]
    y = y[mask]
    lx = np.log(x)
    ly = np.log(y)

    if len(lx) < 2:
        return {"a": np.nan, "b": np.nan}

    b, intercept = np.polyfit(lx, ly, 1)
    return {"a": float(np.exp(intercept)), "b": float(b)}


def bootstrap_power_law_b(
    df: pd.DataFrame,
    eps: float,
    k: int,
    n_boot: int = 400,
    seed: int = 123,
    cost_col: str = "effective_cost",
) -> Dict[str, float]:
    sub = df[(df["epsilon"] == eps) & (df["k"] == k)].copy()
    ns = sorted(sub["n"].unique())

    means = []
    for n in ns:
        means.append(sub[sub["n"] == n][cost_col].mean())
    fit = fit_power_law(np.array(ns, dtype=float), np.array(means, dtype=float))

    rng = np.random.default_rng(seed)
    b_samples = []
    by_n = {n: sub[sub["n"] == n][cost_col].to_numpy() for n in ns}

    for _ in range(n_boot):
        boot_means = []
        for n in ns:
            arr = by_n[n]
            if len(arr) == 0:
                boot_means.append(np.nan)
            else:
                sampled = rng.choice(arr, size=len(arr), replace=True)
                boot_means.append(float(np.mean(sampled)))
        boot_means = np.array(boot_means, dtype=float)
        f = fit_power_law(np.array(ns, dtype=float), boot_means)
        if np.isfinite(f["b"]):
            b_samples.append(f["b"])

    if b_samples:
        lo, hi = np.percentile(b_samples, [2.5, 97.5])
    else:
        lo, hi = np.nan, np.nan

    return {
        "epsilon": eps,
        "k": k,
        "a": fit["a"],
        "b": fit["b"],
        "b_ci_low": float(lo),
        "b_ci_high": float(hi),
    }


In [None]:
# 4.2 Run extended divergence sweep
RUN_DIVERGENCE = bool(globals().get("RUN_DIVERGENCE", False))

NS = list(globals().get("NS", [1_000, 5_000, 10_000, 50_000, 100_000, 500_000, 1_000_000]))
EPS = list(globals().get("EPS", [0.3, 0.5, 0.7]))
KS = list(globals().get("KS", [1, 2, 3, 5]))
SEEDS_PER_CELL = int(globals().get("SEEDS_PER_CELL", 200))
CHUNK_SIZE = int(globals().get("CHUNK_SIZE", 10))  # memory-safe for large n

if RUN_DIVERGENCE:
    rows = []
    for eps in EPS:
        for k in KS:
            for n in NS:
                print(f"Simulating n={n}, eps={eps}, k={k} ...")
                cell_df = simulate_divergence_cell(
                    n=n,
                    epsilon=eps,
                    k=k,
                    seeds=SEEDS_PER_CELL,
                    chunk_size=CHUNK_SIZE,
                    focal_fraction=0.5,
                    base_seed=SEED + int(100 * eps) + 17 * k,
                )
                rows.append(cell_df)

    divergence_raw_df = pd.concat(rows, ignore_index=True)
    divergence_raw_df.to_csv(RAW_DIR / "divergence_extended_raw.csv", index=False)
    print("Saved:", RAW_DIR / "divergence_extended_raw.csv")
else:
    divergence_raw_df = pd.DataFrame()
    print("Set RUN_DIVERGENCE=True in the control panel to run full divergence sweep.")

divergence_raw_df.head()


In [None]:
# 4.3 Fit exponents + stability analysis
if divergence_raw_df.empty:
    print("No divergence data yet.")
    divergence_fit_df = pd.DataFrame()
    divergence_stability_df = pd.DataFrame()
else:
    # Bootstrap fits per (epsilon, k)
    fit_rows = []
    for eps in EPS:
        for k in KS:
            fit_rows.append(bootstrap_power_law_b(divergence_raw_df, eps=eps, k=k, n_boot=300, seed=SEED))

    divergence_fit_df = pd.DataFrame(fit_rows)
    divergence_fit_df.to_csv(RAW_DIR / "divergence_powerlaw_fits.csv", index=False)
    display(divergence_fit_df.sort_values(["epsilon", "k"]))

    # Exponent stability as n_max increases
    stability_rows = []
    for eps in EPS:
        for k in KS:
            sub = divergence_raw_df[(divergence_raw_df["epsilon"] == eps) & (divergence_raw_df["k"] == k)]
            for n_max in NS[2:]:
                s2 = sub[sub["n"] <= n_max]
                agg = s2.groupby("n", as_index=False).agg(mean_cost=("effective_cost", "mean"))
                fit = fit_power_law(agg["n"].to_numpy(dtype=float), agg["mean_cost"].to_numpy(dtype=float))
                stability_rows.append(
                    {
                        "epsilon": eps,
                        "k": k,
                        "n_max": n_max,
                        "b": fit["b"],
                    }
                )

    divergence_stability_df = pd.DataFrame(stability_rows)
    divergence_stability_df.to_csv(RAW_DIR / "divergence_exponent_stability.csv", index=False)
    print("Saved fits + stability CSVs.")


In [None]:
# 4.4 Divergence figures
if divergence_raw_df.empty:
    print("No divergence data yet.")
else:
    agg = (
        divergence_raw_df.groupby(["epsilon", "k", "n"], as_index=False)
        .agg(mean_cost=("effective_cost", "mean"))
        .sort_values(["epsilon", "k", "n"])
    )

    # Figure 1: log-log divergence plot (for k=3, all epsilon)
    fig, ax = plt.subplots(figsize=(8.5, 5.5))
    base_k = 3
    sub = agg[agg["k"] == base_k]
    for eps, g in sub.groupby("epsilon"):
        x = g["n"].to_numpy(dtype=float)
        y = g["mean_cost"].to_numpy(dtype=float)
        fit = fit_power_law(x, y)
        ax.loglog(x, y, "o-", label=f"eps={eps}, b={fit['b']:.3f}")

    x_ref = np.array(NS, dtype=float)
    y_ref = x_ref / x_ref[0]
    y_ref_105 = (x_ref / x_ref[0]) ** 1.05
    y_ref_11 = (x_ref / x_ref[0]) ** 1.10
    # Normalize reference curves for visual comparability.
    scale = sub[sub["n"] == NS[0]]["mean_cost"].mean()
    ax.loglog(x_ref, y_ref * scale, "--", color="black", alpha=0.6, label="b=1.00")
    ax.loglog(x_ref, y_ref_105 * scale, ":", color="black", alpha=0.8, label="b=1.05")
    ax.loglog(x_ref, y_ref_11 * scale, "-.", color="black", alpha=0.8, label="b=1.10")

    ax.set_xlabel("n")
    ax.set_ylabel("mean effective cost")
    ax.set_title("Log-log divergence scaling (k=3)")
    ax.legend()
    fig.tight_layout()
    path1 = FIG_DIR / "divergence_loglog_k3.png"
    fig.savefig(path1)
    plt.show()
    print("Saved:", path1)

    # Figure 2: exponent stability plot
    if not divergence_stability_df.empty:
        fig, ax = plt.subplots(figsize=(9, 5.5))
        for (eps, k), g in divergence_stability_df.groupby(["epsilon", "k"]):
            ax.plot(g["n_max"], g["b"], marker="o", label=f"eps={eps}, k={k}")
        ax.set_xscale("log")
        ax.set_xlabel("n_max used for fitting")
        ax.set_ylabel("Estimated b")
        ax.set_title("Exponent stability vs fit range")
        ax.legend(ncol=2, fontsize=8)
        fig.tight_layout()
        path2 = FIG_DIR / "divergence_exponent_stability.png"
        fig.savefig(path2)
        plt.show()
        print("Saved:", path2)

    # Figure 3: heatmap of b by (epsilon, k)
    if not divergence_fit_df.empty:
        pivot = divergence_fit_df.pivot(index="k", columns="epsilon", values="b")
        fig, ax = plt.subplots(figsize=(6, 4.5))
        sns.heatmap(pivot, annot=True, fmt=".3f", cmap="viridis", ax=ax)
        ax.set_title("Power-law exponent b by (epsilon, k)")
        fig.tight_layout()
        path3 = FIG_DIR / "divergence_b_heatmap.png"
        fig.savefig(path3)
        plt.show()
        print("Saved:", path3)


# 5. MirageBench Packaging

This section exports:

- `miragebench_v0_1_tasks.json`
- model results CSVs
- leaderboard table
- iconic mirage plots
- zipped result bundle


In [None]:
# 5.1 Export MirageBench tasks JSON (release-ready schema)

def task_release_dict(task: MirageBenchTask) -> Dict[str, Any]:
    return {
        "task_id": task.task_id,
        "category": task.category,
        "full_context": task.full_context,
        "compressed_context": task.compressed_context,
        "question": task.question,
        "pivot_ground_truth": task.pivot_ground_truth,
        "answer_ground_truth": task.answer_ground_truth,
        "decoy_pivot": task.decoy_pivot,
        "decoy_answer": task.decoy_answer,
        "k": task.k,
    }

release_tasks = [task_release_dict(t) for t in miragebench_tasks]

tasks_json_path = RAW_DIR / "miragebench_v0_1_tasks.json"
with open(tasks_json_path, "w", encoding="utf-8") as f:
    json.dump(release_tasks, f, indent=2, ensure_ascii=False)

print("Saved:", tasks_json_path)
print("Task count:", len(release_tasks))


In [None]:
# 5.2 Build leaderboard + summary table (if black-box results exist)
if 'blackbox_results_df' in globals() and (not blackbox_results_df.empty):
    tmp = blackbox_results_df.copy()
    if "error" in tmp.columns:
        tmp = tmp[tmp["error"].isna()] if tmp["error"].notna().any() else tmp

    for c in ["raw_validity", "semantic_regret", "pivot_preserved", "mirage_flag", "pivot_matches_ground_truth"]:
        if c in tmp.columns:
            tmp[c] = pd.to_numeric(tmp[c], errors="coerce")

    leaderboard = (
        tmp.groupby("model_name", as_index=False)
        .agg(
            mean_raw_validity=("raw_validity", "mean"),
            mean_pivot_preservation=("pivot_preserved", "mean"),
            mean_semantic_regret=("semantic_regret", "mean"),
            mirage_rate=("mirage_flag", "mean"),
            pivot_gt_rate=("pivot_matches_ground_truth", "mean"),
            n=("task_id", "count"),
        )
        .sort_values(["mean_pivot_preservation", "mean_raw_validity"], ascending=[False, False])
    )

    display(leaderboard)

    leaderboard_path = RAW_DIR / "miragebench_leaderboard.csv"
    leaderboard.to_csv(leaderboard_path, index=False)
    print("Saved:", leaderboard_path)

    # Paper-style summary table per model x category
    summary_table = (
        tmp.groupby(["model_name", "category"], as_index=False)
        .agg(
            raw_validity=("raw_validity", "mean"),
            pivot_preservation=("pivot_preserved", "mean"),
            semantic_regret=("semantic_regret", "mean"),
            mirage_rate=("mirage_flag", "mean"),
            n=("task_id", "count"),
        )
    )
    summary_path = RAW_DIR / "miragebench_summary_table.csv"
    summary_table.to_csv(summary_path, index=False)
    print("Saved:", summary_path)
else:
    print("No black-box results available yet.")


In [None]:
# 5.3 Save an iconic mirage plot per model (if results exist)
if 'blackbox_results_df' in globals() and (not blackbox_results_df.empty):
    tmp = blackbox_results_df.copy()
    if "error" in tmp.columns:
        tmp = tmp[tmp["error"].isna()] if tmp["error"].notna().any() else tmp

    tmp["raw_validity"] = pd.to_numeric(tmp.get("raw_validity"), errors="coerce")
    tmp["pivot_preserved"] = pd.to_numeric(tmp.get("pivot_preserved"), errors="coerce")

    for model_name, sub in tmp.groupby("model_name"):
        agg = sub.groupby("compression_level", as_index=False).agg(
            raw_validity=("raw_validity", "mean"),
            pivot_preservation=("pivot_preserved", "mean"),
        )

        fig, ax = plt.subplots(figsize=(6.4, 4.8))
        ax.plot(agg["raw_validity"], agg["pivot_preservation"], marker="o", linewidth=2)
        for _, row in agg.iterrows():
            ax.text(row["raw_validity"] + 0.003, row["pivot_preservation"] + 0.003, f"c={row['compression_level']:.1f}")
        ax.set_xlim(0, 1.02)
        ax.set_ylim(0, 1.02)
        ax.set_xlabel("Raw validity")
        ax.set_ylabel("Pivot preservation")
        ax.set_title(f"Mirage plot: {model_name}")
        fig.tight_layout()

        path = FIG_DIR / f"mirage_plot_{model_name.replace('/', '_')}.png"
        fig.savefig(path)
        plt.close(fig)
        print("Saved:", path)
else:
    print("No black-box results available yet.")


In [None]:
# 5.4 Zip all outputs for one-click download
zip_path = RESULTS_ROOT / "miragebench_colab_results.zip"

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    for path in RESULTS_ROOT.rglob("*"):
        if path.is_file() and path != zip_path:
            zf.write(path, arcname=path.relative_to(RESULTS_ROOT.parent))

print("Created:", zip_path)

# Optional Colab download helper
try:
    from google.colab import files
    print("You can download the zip with:")
    print(f"files.download('{zip_path}')")
except Exception:
    pass


# 6. Summary & Key Findings

Use this section after running the experiments:

1. Which models are most susceptible to the Validity Mirage?
2. Does KV-cache surgery show attention shift from true pivot/setup to decoy pivot?
3. Does contract-guarded eviction outperform H2O-style eviction on pivot preservation?
4. Is divergence exponent `b` stable at scale, or does it drift upward?


In [None]:
# 6.1 Success criteria checks
criterion_1 = False
criterion_2 = False
criterion_3 = False

# (1) High validity + low pivot preservation in black-box results
if 'blackbox_results_df' in globals() and not blackbox_results_df.empty:
    tmp = blackbox_results_df.copy()
    if "error" in tmp.columns:
        tmp = tmp[tmp["error"].isna()] if tmp["error"].notna().any() else tmp
    if {"raw_validity", "pivot_preserved"}.issubset(tmp.columns):
        tmp["raw_validity"] = pd.to_numeric(tmp["raw_validity"], errors="coerce")
        tmp["pivot_preserved"] = pd.to_numeric(tmp["pivot_preserved"], errors="coerce")
        criterion_1 = bool(((tmp["raw_validity"] >= 0.75) & (tmp["pivot_preserved"] < 0.5)).any())

# (2) Attention shift / pivot switch in KV sweep
if 'kv_sweep_df' in globals() and not kv_sweep_df.empty:
    ok = kv_sweep_df[kv_sweep_df.get("error").isna()] if "error" in kv_sweep_df.columns else kv_sweep_df
    if {"raw_validity", "pivot_preserved"}.issubset(ok.columns):
        criterion_2 = bool(((ok["raw_validity"] >= 0.65) & (ok["pivot_preserved"] < 0.5)).any())

# (3) Contract-guarded better than H2O
if 'kv_sweep_df' in globals() and not kv_sweep_df.empty:
    ok = kv_sweep_df[kv_sweep_df.get("error").isna()] if "error" in kv_sweep_df.columns else kv_sweep_df
    if {"strategy", "pivot_preserved"}.issubset(ok.columns):
        means = ok.groupby("strategy", as_index=False)["pivot_preserved"].mean()
        if {"contract_guarded", "attention_h2o"}.issubset(set(means["strategy"])):
            c = float(means[means["strategy"] == "contract_guarded"]["pivot_preserved"].iloc[0])
            h = float(means[means["strategy"] == "attention_h2o"]["pivot_preserved"].iloc[0])
            criterion_3 = c > h

print("Success criteria status:")
print(f"1) Real-model validity mirage observed: {criterion_1}")
print(f"2) Neural-level shift under KV surgery observed: {criterion_2}")
print(f"3) Contract-guarded beats H2O on pivot preservation: {criterion_3}")

if criterion_1 or criterion_2 or criterion_3:
    print("\nAt least one publishable bridge result is present.")
else:
    print("\nNo criterion has fired yet. Run more models/tasks or increase sweep depth.")
