# Mirage-Guarded KV Cache Demo (RoPE-Preserving Attention-Mask Ablation)

**Model target:** Llama 3.1 8B on H100 (Colab Pro+)

This notebook demonstrates a validity-mirage failure mode under compression:
standard structure-blind eviction keeps responses well-formed but can silently
substitute the wrong causal hypothesis.

## Mathematical Framing (L2 Frontier)

Each chunk carries an L1 state `(weight, d_total, d_pre)`:

- `weight`: pivot severity if the chunk is a pivot, else `-inf`
- `d_total`: predecessor capacity contribution (1 for warning chunks, else 0)
- `d_pre`: predecessor count tracked through provenance

The L2 streaming scan maintains `W[0..k]` (here `k=3`) with composition:

`W_new[j] = max(W_prev[j], W_incoming[max(0, j - d_total_prev)])`

This scan is `O(n)` in chunks (vector size `k+1` is constant). We track provenance
for `W[k]` so the final protected set is the pivot plus its structural predecessors.
The eviction policy enforces this provenance set as a hard budget constraint.

The contract `d_total(mu(B)) >= min(d_total(B), k)` is checked after each full scan.

## Important Terminology

Eviction is applied via **attention-mask ablation** over the full token stream.
This preserves RoPE geometry and simulates inaccessible information, but it is
not literal `past_key_values` pruning. The limitations section states this
explicitly.


In [None]:
# Optional one-time compatibility fix (only if transformers import complains about tokenizers):
# !python -m pip install --no-deps --force-reinstall tokenizers==0.20.3

In [1]:
import os, socket, torch, shutil, subprocess
print("HOST:", socket.gethostname())
print("PID:", os.getpid())
print("TORCH:", torch.__version__, "CUDA:", torch.version.cuda, "AVAILABLE:", torch.cuda.is_available())
print("NVIDIA_SMI:", shutil.which("nvidia-smi"))
if shutil.which("nvidia-smi"):
    print(subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True).stdout)


HOST: f5935bf15b00
PID: 11162
TORCH: 2.10.0+cu128 CUDA: 12.8 AVAILABLE: True
NVIDIA_SMI: /opt/bin/nvidia-smi
GPU 0: NVIDIA A100-SXM4-80GB (UUID: GPU-5e10842d-d072-6d92-cc1f-e21f668ebffd)



## 0. Setup & Installation

In [2]:
# Colab/Local runtime bootstrap
import os
from pathlib import Path

IS_COLAB = False
try:
    import google.colab  # type: ignore
    IS_COLAB = True
except Exception:
    IS_COLAB = Path('/content').exists()

if IS_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    ARTIFACT_DIR = Path('/content/drive/MyDrive/mirage_outputs/mirage_rca_demo')
else:
    ARTIFACT_DIR = Path('artifacts/mirage_rca_demo')

print(f'Running in Colab: {IS_COLAB}')
print(f'Artifact directory: {ARTIFACT_DIR}')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Running in Colab: True
Artifact directory: /content/drive/MyDrive/mirage_outputs/mirage_rca_demo


In [3]:
# ====== Kernel sanity check ======
import os, platform, time
print("KERNEL: alive")
print(f"python={platform.python_version()}")
print(f"pid={os.getpid()}")
print(f"cwd={os.getcwd()}")
print(f"ts={time.time()}")
print("Run this first cell whenever execution appears to do nothing")


KERNEL: alive
python=3.12.12
pid=11162
cwd=/content
ts=1772061562.2624938
Run this first cell whenever execution appears to do nothing


In [4]:
# Reproducible environment setup: pin deps, set seeds, and log environment.
import os
import sys
import json
import random
import platform
import subprocess
import shutil
import importlib.metadata as importlib_metadata
from datetime import datetime, timezone

# Keep torch out of the pinned install list.
# Pinning torch in Colab can accidentally replace the runtime CUDA wheel with CPU-only torch.
PINNED_PACKAGES = [
    "transformers==4.46.3",
    "tokenizers==0.20.3",
    "accelerate==1.1.1",
    "bitsandbytes==0.45.0",
    "huggingface_hub==0.26.2",
    "sentencepiece==0.2.0",
    "protobuf==5.28.3",
    "matplotlib==3.9.2",
    "numpy==2.1.2",
]

def _version_ok(spec: str) -> bool:
    name, want = spec.split("==")
    try:
        return importlib_metadata.version(name) == want
    except Exception:
        return False


def ensure_pinned_packages(specs):
    to_install = [s for s in specs if not _version_ok(s)]
    if not to_install:
        print("Dependency check: all pinned packages already present.")
        return

    print(f"Dependency check: installing {len(to_install)} package(s)...")
    print("  " + ", ".join(to_install))
    # `--no-deps` avoids mutating torch/runtime packages in managed notebooks.
    subprocess.run([sys.executable, "-m", "pip", "install", "--no-deps", *to_install], check=True)
    print("Dependency check: pin install complete.")


ensure_pinned_packages(PINNED_PACKAGES)

import numpy as np
import torch

SEED = 20260225
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def _detect_runtime_gpus():
    if not shutil.which("nvidia-smi"):
        return []

    try:
        out = subprocess.run(
            ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
            capture_output=True,
            text=True,
            check=False,
        )
    except Exception:
        return []

    if out.returncode != 0:
        return []

    return [line.strip() for line in out.stdout.splitlines() if line.strip()]


RUNTIME_GPUS = _detect_runtime_gpus()
RUNTIME_GPU_VISIBLE = bool(RUNTIME_GPUS)
TORCH_CUDA_AVAILABLE = torch.cuda.is_available()
HAS_CUDA_TORCH = torch.version.cuda is not None
HAS_CUDA_READY = RUNTIME_GPU_VISIBLE and TORCH_CUDA_AVAILABLE and HAS_CUDA_TORCH

runtime_info = {
    "runtime_gpu_visible": RUNTIME_GPU_VISIBLE,
    "runtime_gpus": RUNTIME_GPUS,
    "torch_cuda_available": TORCH_CUDA_AVAILABLE,
}

if TORCH_CUDA_AVAILABLE:
    runtime_info.update({
        "gpu_name": torch.cuda.get_device_name(0),
        "compute_capability": f"{torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}",
        "gpu_mem_gb": round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2),
    })
elif RUNTIME_GPU_VISIBLE and not HAS_CUDA_TORCH:
    runtime_info["message"] = "GPU runtime is present but torch is CPU-only"
elif not RUNTIME_GPU_VISIBLE:
    runtime_info["message"] = "CUDA device not visible at runtime"
else:
    runtime_info["message"] = "GPU runtime present but torch CUDA is unavailable"

env_info = {
    "timestamp_utc": datetime.now(timezone.utc).isoformat(),
    "python": platform.python_version(),
    "platform": platform.platform(),
    "seed": SEED,
    "is_colab": bool(globals().get("IS_COLAB", False)),
    "torch_version": torch.__version__,
    "torch_cuda_version": torch.version.cuda,
    "has_cuda_runtime": RUNTIME_GPU_VISIBLE,
    "has_cuda_torch": HAS_CUDA_TORCH,
    "cuda_ready": HAS_CUDA_READY,
    "pinned_packages": PINNED_PACKAGES,
    **runtime_info,
}

print(json.dumps(env_info, indent=2))

if not HAS_CUDA_READY:
    if RUNTIME_GPU_VISIBLE and not HAS_CUDA_TORCH:
        hint = (
            "GPU runtime is attached, but torch is CPU-only. Reinstall CUDA torch and restart runtime:\n"
            "python -m pip install --index-url https://download.pytorch.org/whl/cu124 --force-reinstall --no-cache-dir torch"
        )
    elif not RUNTIME_GPU_VISIBLE:
        hint = (
            "Colab attached a CPU VM (GPU is not visible via nvidia-smi). "
            "Choose Runtime -> Change runtime type -> GPU, then Runtime -> Restart session. "
            "If it still comes back CPU, quota/availability is likely exhausted right now."
        )
    elif HAS_CUDA_TORCH and not TORCH_CUDA_AVAILABLE:
        hint = "GPU is present and torch has CUDA build, but CUDA failed to initialize. Restart runtime and rerun from the first cell."
    else:
        hint = "unknown CUDA mismatch. Restart runtime and retry."

    runtime_hint = "in Colab" if globals().get("IS_COLAB", False) else "in a local runtime"
    print(f"No GPU found ({runtime_hint}). {hint}")
    raise RuntimeError(f"No GPU found ({runtime_hint}). {hint}")



Dependency check: all pinned packages already present.
{
  "timestamp_utc": "2026-02-25T23:19:24.086736+00:00",
  "python": "3.12.12",
  "platform": "Linux-6.6.113+-x86_64-with-glibc2.35",
  "seed": 20260225,
  "is_colab": true,
  "torch_version": "2.10.0+cu128",
  "torch_cuda_version": "12.8",
  "has_cuda_runtime": true,
  "has_cuda_torch": true,
  "cuda_ready": true,
  "pinned_packages": [
    "transformers==4.46.3",
    "tokenizers==0.20.3",
    "accelerate==1.1.1",
    "bitsandbytes==0.45.0",
    "huggingface_hub==0.26.2",
    "sentencepiece==0.2.0",
    "protobuf==5.28.3",
    "matplotlib==3.9.2",
    "numpy==2.1.2"
  ],
  "runtime_gpu_visible": true,
  "runtime_gpus": [
    "NVIDIA A100-SXM4-80GB"
  ],
  "torch_cuda_available": true,
  "gpu_name": "NVIDIA A100-SXM4-80GB",
  "compute_capability": "8.0",
  "gpu_mem_gb": 85.09
}


In [5]:
# GPU sanity probe (run anytime execution stalls or after env setup)
import shutil
import subprocess
import platform
import torch

print("--- Runtime diagnostics ---")
print(f"Python: {platform.python_version()}")
print(f"Torch: {torch.__version__} | torch.cuda: {torch.version.cuda}")
print(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
print(f"Detected CUDA devices: {torch.cuda.device_count()}")

if torch.cuda.is_available():
    idx = torch.cuda.current_device()
    props = torch.cuda.get_device_properties(idx)
    print(f"Current device: {idx}")
    print(f"GPU name: {torch.cuda.get_device_name(idx)}")
    print(f"Compute capability: {props.major}.{props.minor}")
    print(f"Total memory: {props.total_memory / 1e9:.2f} GB")
else:
    print("CUDA runtime not exposed to Torch.")

if shutil.which("nvidia-smi"):
    try:
        out = subprocess.run(["nvidia-smi", "--query-gpu=name,memory.total,memory.used", "--format=csv,noheader,nounits"],
                             capture_output=True, text=True, check=False)
        print("nvidia-smi:\n" + (out.stdout.strip() or "<no output>"))
    except Exception as exc:
        print(f"nvidia-smi failed: {exc}")
else:
    print("nvidia-smi not available in PATH.")



--- Runtime diagnostics ---
Python: 3.12.12
Torch: 2.10.0+cu128 | torch.cuda: 12.8
torch.cuda.is_available(): True
Detected CUDA devices: 1
Current device: 0
GPU name: NVIDIA A100-SXM4-80GB
Compute capability: 8.0
Total memory: 85.09 GB
nvidia-smi:
NVIDIA A100-SXM4-80GB, 81920, 6


In [6]:
import os
from huggingface_hub import login

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
    try:
        from google.colab import userdata  # type: ignore
        HF_TOKEN = userdata.get("HF_TOKEN")
    except Exception:
        HF_TOKEN = None

if HF_TOKEN:
    try:
        login(token=HF_TOKEN, add_to_git_credential=False)
        print("HF login: authenticated from env var or Colab secret")
    except Exception as e:
        print(f"HF login skipped: token present but login failed ({type(e).__name__}: {e})")
else:
    print("HF_TOKEN not found.")
    print("Notebook will continue; add token later only if model loading prompts for auth.")
    print("To authenticate: set HF_TOKEN in environment, Colab secret, or run huggingface_hub.login manually when prompted.")


HF_TOKEN not found.
Notebook will continue; add token later only if model loading prompts for auth.
To authenticate: set HF_TOKEN in environment, Colab secret, or run huggingface_hub.login manually when prompted.


In [7]:
import os

try:
    from google.colab import userdata
    if not os.getenv("HF_TOKEN"):
        os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
except Exception:
    pass

print("HF_TOKEN loaded:", bool(os.getenv("HF_TOKEN")))


HF_TOKEN loaded: False


In [8]:
import os, getpass
from huggingface_hub import login

if not os.getenv("HF_TOKEN"):
    os.environ["HF_TOKEN"] = getpass.getpass("HF_TOKEN: ").strip()  # hidden input

login(token=os.environ["HF_TOKEN"], add_to_git_credential=False)
print("HF_TOKEN set:", bool(os.getenv("HF_TOKEN")))


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


HF_TOKEN set: True


## 1. Load Model

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging as hf_logging
import torch

hf_logging.set_verbosity_error()  # silence sampling-config warnings

MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=MODEL_DTYPE,
    device_map="auto",
    attn_implementation="eager",  # needed for output_attentions=True baselines
)
model.eval()

# Ensure deterministic decoding settings.
model.generation_config.do_sample = False
model.generation_config.temperature = 1.0
model.generation_config.top_p = 1.0

if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

assert tokenizer.is_fast, "Need fast tokenizer for offset_mapping support"

print(f"Model loaded: {sum(p.numel() for p in model.parameters()) / 1e9:.1f}B params")
print(f"Pad token: {tokenizer.pad_token} (id={tokenizer.pad_token_id})")


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: 8.0B params
Pad token: <|eot_id|> (id=128009)


## 2. Core Functions

In [12]:
import numpy as np
import re
from typing import List, Optional
from dataclasses import dataclass


@torch.no_grad()
def generate(model, tokenizer, text: str, max_new_tokens: int = 80) -> str:
    """Deterministic generation from raw text."""
    inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True).to(model.device)
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )
    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)


@torch.no_grad()
def build_chunk_masked_inputs(
    model,
    tokenizer,
    full_text: str,
    query: str,
    chunk_ids: List[str],
    chunk_char_spans: List[tuple],
    evicted_ids: set,
) -> dict:
    """Tokenize once and return masked inputs that preserve original position IDs."""
    combined = full_text + "\n\n" + query
    enc = tokenizer(
        combined,
        return_tensors="pt",
        add_special_tokens=True,
        return_offsets_mapping=True,
    )

    raw_offsets = enc.pop("offset_mapping")
    if hasattr(raw_offsets, "tolist"):
        raw_offsets = raw_offsets.tolist()
    if raw_offsets and isinstance(raw_offsets[0], (list, tuple)) and raw_offsets[0] and isinstance(raw_offsets[0][0], (list, tuple)):
        raw_offsets = raw_offsets[0]
    offsets = [tuple(pair) for pair in raw_offsets]
    input_ids = enc["input_ids"].to(model.device)
    attention_mask = enc["attention_mask"].clone().to(model.device)
    position_ids = torch.arange(input_ids.shape[1], device=model.device).unsqueeze(0)

    evicted_spans = [span for cid, span in zip(chunk_ids, chunk_char_spans) if cid in evicted_ids]
    context_end = len(full_text)
    context_tokens = 0
    masked_tokens = 0

    for tok_idx, (tok_start, tok_end) in enumerate(offsets):
        if tok_end <= tok_start:
            continue
        if tok_start >= context_end:
            continue

        context_tokens += 1
        for ch_start, ch_end in evicted_spans:
            if tok_end > ch_start and tok_start < ch_end:
                attention_mask[0, tok_idx] = 0
                masked_tokens += 1
                break

    if attention_mask.shape[1] > 0:
        attention_mask[0, 0] = 1  # never mask BOS

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "position_ids": position_ids,
        "masked_tokens": int(masked_tokens),
        "context_tokens": int(context_tokens),
    }


@torch.no_grad()
def generate_with_chunk_mask(
    model,
    tokenizer,
    full_text: str,
    query: str,
    chunk_ids: List[str],
    chunk_char_spans: List[tuple],
    evicted_ids: set,
    max_new_tokens: int = 80,
) -> tuple:
    """
    Approximate KV-style eviction without retokenizing surviving text:
    keep full token stream, but mask evicted-chunk tokens in attention_mask.
    Explicit position_ids are provided so RoPE positions remain contiguous.
    """
    mask_state = build_chunk_masked_inputs(
        model=model,
        tokenizer=tokenizer,
        full_text=full_text,
        query=query,
        chunk_ids=chunk_ids,
        chunk_char_spans=chunk_char_spans,
        evicted_ids=evicted_ids,
    )

    input_ids = mask_state["input_ids"]
    attention_mask = mask_state["attention_mask"]
    position_ids = mask_state["position_ids"]

    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )
    new_tokens = outputs[0][input_ids.shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True), int(mask_state["masked_tokens"]), int(mask_state["context_tokens"])
@torch.no_grad()

def get_chunk_attention(
    model,
    tokenizer,
    full_text: str,
    query: Optional[str],
    chunk_char_spans: List[tuple],
    pool: str = "max",
    probe_tokens: int = 64,
) -> List[float]:
    """
    Compute per-chunk salience from attention.

    If query is provided: score using query-token attention (oracle-ish baseline).
    If query is None: score using tail-prefix probe tokens (streaming-style proxy).

    pool options: "max", "mean", "sum".
    """
    combined = full_text if query is None else (full_text + "\n\n" + query)
    enc = tokenizer(
        combined,
        return_tensors="pt",
        add_special_tokens=True,
        return_offsets_mapping=True,
    )

    raw_offsets = enc.pop("offset_mapping")
    if hasattr(raw_offsets, "tolist"):
        raw_offsets = raw_offsets.tolist()
    if raw_offsets and isinstance(raw_offsets[0], (list, tuple)) and raw_offsets[0] and isinstance(raw_offsets[0][0], (list, tuple)):
        raw_offsets = raw_offsets[0]
    offsets = [tuple(pair) for pair in raw_offsets]
    inputs = {k: v.to(model.device) for k, v in enc.items()}

    outputs = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        output_attentions=True,
    )

    seq_len = outputs.attentions[0].shape[-1]
    n_layers = len(outputs.attentions)
    context_end = len(full_text)

    valid_tok_idxs = [
        i for i, (s, e) in enumerate(offsets)
        if e > s
    ]
    context_tok_idxs = [
        i for i, (s, e) in enumerate(offsets)
        if e > s and s < context_end
    ]

    if query is None:
        probe_idxs = context_tok_idxs[-min(probe_tokens, len(context_tok_idxs)):] if context_tok_idxs else valid_tok_idxs[-1:]
    else:
        probe_idxs = [
            i for i, (s, e) in enumerate(offsets)
            if e > s and s >= context_end
        ]
        if not probe_idxs:
            probe_idxs = context_tok_idxs[-min(probe_tokens, len(context_tok_idxs)):] if context_tok_idxs else valid_tok_idxs[-1:]

    attn_sum = torch.zeros(seq_len, device=model.device)
    for layer_attn in outputs.attentions:
        # layer_attn: [batch=1, heads, q_len, k_len]
        probe_attn = layer_attn[0, :, probe_idxs, :].mean(dim=0)  # [probe, k_len]
        attn_sum += probe_attn.sum(dim=0)

    denom = max(1, n_layers * len(probe_idxs))
    attn_avg = (attn_sum / denom).detach().cpu().numpy()

    chunk_scores = []
    for char_start, char_end in chunk_char_spans:
        tok_vals = []
        for tok_idx, (tok_start, tok_end) in enumerate(offsets):
            if tok_end <= tok_start:
                continue
            if tok_end > char_start and tok_start < char_end:
                tok_vals.append(float(attn_avg[tok_idx]))

        if not tok_vals:
            chunk_scores.append(0.0)
        elif pool == "sum":
            chunk_scores.append(float(np.sum(tok_vals)))
        elif pool == "mean":
            chunk_scores.append(float(np.mean(tok_vals)))
        else:
            chunk_scores.append(float(np.max(tok_vals)))  # default: preserve peak signal

    return chunk_scores


print("Core functions defined ✓")


Core functions defined ✓


## 3. Structured Prompts with Nonce Codes

In [13]:
import hashlib
import random as stdlib_random


TOTAL_LINES = 200
K = 3

QUERY = """Identify the highest-severity valid incident. An incident is valid only if exactly 3 prior warnings for that incident appear in the log. Return the incident and its three warnings with their verification codes.

Output format requirements:
- Output exactly 4 lines and no extra prose.
- Use these exact labels: ROOT_CAUSE, WARNING_1, WARNING_2, WARNING_3.
- Each value must be a 6-character alphanumeric verification code from the log.

ROOT_CAUSE: <6-char code>
WARNING_1: <6-char code>
WARNING_2: <6-char code>
WARNING_3: <6-char code>"""

VARIANT_LAYOUTS = [
    {
        "variant": "base_30_110",
        "description": "Primary arc centered early (approx lines 30-110).",
        "primary_lines": {"W1": 32, "W2": 55, "W3": 78, "P": 110},
    },
    {
        "variant": "shift_15_95",
        "description": "Primary arc shifted earlier (approx lines 15-95).",
        "primary_lines": {"W1": 15, "W2": 40, "W3": 68, "P": 95},
    },
    {
        "variant": "shift_50_130",
        "description": "Primary arc shifted later (approx lines 50-130).",
        "primary_lines": {"W1": 50, "W2": 75, "W3": 102, "P": 130},
    },
]

DECOY_LINES = {"W1": 156, "W2": 162, "W3": 168, "P": 175}


def make_nonce(seed_str: str) -> str:
    return hashlib.sha256(seed_str.encode()).hexdigest()[:6].upper()


def noise_text(line_no: int) -> str:
    routine = [
        "heartbeat check passed",
        "directory sync completed",
        "endpoint telemetry batch uploaded",
        "nightly policy bundle validated",
        "service mesh latency within baseline",
    ]
    mid = [
        "interactive login observed",
        "backup verification completed",
        "scheduled patch window opened",
        "IAM token refresh succeeded",
        "routine vulnerability scan queued",
    ]
    evening = [
        "change-control ticket archived",
        "regional mirror lag normalized",
        "SIEM digest exported",
        "off-peak replication checkpointed",
        "maintenance daemon heartbeat stable",
    ]
    generic = [
        "audit log rotation completed",
        "SLA probe response nominal",
        "queue depth returned to baseline",
        "diagnostic ping timeout recovered",
        "routine service health reported",
    ]

    if line_no <= 30:
        msg = routine[line_no % len(routine)]
    elif 111 <= line_no <= 154:
        msg = mid[line_no % len(mid)]
    elif 176 <= line_no <= 200:
        msg = evening[line_no % len(evening)]
    else:
        msg = generic[line_no % len(generic)]

    return f"L{line_no:03d} INFO {msg}."


def build_variant_prompt(layout: dict):
    variant = layout["variant"]
    primary = layout["primary_lines"]

    primary_nonces = {
        "ROOT_CAUSE": make_nonce(f"{variant}_primary_root"),
        "WARNING_1": make_nonce(f"{variant}_primary_w1"),
        "WARNING_2": make_nonce(f"{variant}_primary_w2"),
        "WARNING_3": make_nonce(f"{variant}_primary_w3"),
    }
    decoy_nonces = {
        "ROOT_CAUSE": make_nonce(f"{variant}_decoy_root"),
        "WARNING_1": make_nonce(f"{variant}_decoy_w1"),
        "WARNING_2": make_nonce(f"{variant}_decoy_w2"),
        "WARNING_3": make_nonce(f"{variant}_decoy_w3"),
    }

    by_line = {}

    # Primary arc (severity 95)
    by_line[primary["W1"]] = {
        "role": "predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{primary['W1']:03d} WARN phishing payload executed on HR-LAPTOP (verification code: {primary_nonces['WARNING_1']}).",
    }
    by_line[primary["W2"]] = {
        "role": "predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{primary['W2']:03d} WARN lateral movement detected to DB-SERVER (verification code: {primary_nonces['WARNING_2']}).",
    }
    by_line[primary["W3"]] = {
        "role": "predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{primary['W3']:03d} WARN admin privilege escalation confirmed on DB-SERVER (verification code: {primary_nonces['WARNING_3']}).",
    }
    by_line[primary["P"]] = {
        "role": "pivot",
        "event_type": "pivot",
        "severity": 95.0,
        "text": f"L{primary['P']:03d} CRITICAL APT data breach declared; severity score 95; 2.1TB exfiltrated (verification code: {primary_nonces['ROOT_CAUSE']}).",
    }

    # Decoy arc (severity 38)
    by_line[DECOY_LINES["W1"]] = {
        "role": "decoy_predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{DECOY_LINES['W1']:03d} WARN CPU temperature spike on cache node (verification code: {decoy_nonces['WARNING_1']}).",
    }
    by_line[DECOY_LINES["W2"]] = {
        "role": "decoy_predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{DECOY_LINES['W2']:03d} WARN fan RPM dropping, thermal alert triggered (verification code: {decoy_nonces['WARNING_2']}).",
    }
    by_line[DECOY_LINES["W3"]] = {
        "role": "decoy_predecessor",
        "event_type": "warning",
        "severity": None,
        "text": f"L{DECOY_LINES['W3']:03d} WARN thermal throttling engaged on cache node (verification code: {decoy_nonces['WARNING_3']}).",
    }
    by_line[DECOY_LINES["P"]] = {
        "role": "decoy_pivot",
        "event_type": "pivot",
        "severity": 38.0,
        "text": f"L{DECOY_LINES['P']:03d} CRITICAL cache node overheated; severity score 38; service degraded (verification code: {decoy_nonces['ROOT_CAUSE']}).",
    }

    sentences = []
    for line_no in range(1, TOTAL_LINES + 1):
        if line_no in by_line:
            row = by_line[line_no]
            sentences.append({
                "id": f"{variant}_L{line_no:03d}",
                "line_no": line_no,
                "text": row["text"],
                "role": row["role"],
                "event_type": row["event_type"],
                "severity": row["severity"],
            })
        else:
            sentences.append({
                "id": f"{variant}_L{line_no:03d}",
                "line_no": line_no,
                "text": noise_text(line_no),
                "role": "noise",
                "event_type": "noise",
                "severity": None,
            })

    return {
        "name": f"CyberOps Log ({variant})",
        "variant": variant,
        "variant_description": layout["description"],
        "sentences": sentences,
        "query": QUERY,
        "k": K,
        "primary_nonces": primary_nonces,
        "decoy_nonces": decoy_nonces,
        "primary_lines": primary,
        "decoy_lines": dict(DECOY_LINES),
    }


def build_text_and_spans(sentences):
    parts, spans, pos = [], [], 0
    sep = "\n"
    for i, sent in enumerate(sentences):
        text = sent["text"]
        if i > 0:
            parts.append(sep)
            pos += len(sep)
        start = pos
        parts.append(text)
        pos += len(text)
        spans.append((start, pos))
    return "".join(parts), spans


PROMPTS = [build_variant_prompt(layout) for layout in VARIANT_LAYOUTS]

for p in PROMPTS:
    sents = p["sentences"]
    full_text, spans = build_text_and_spans(sents)
    ids = [s["id"] for s in sents]
    assert len(ids) == len(set(ids)), f"Duplicate IDs in {p['variant']}"

    for sent, (start, end) in zip(sents, spans):
        assert full_text[start:end] == sent["text"], f"Span mismatch at {sent['id']}"

    for lbl, code in p["primary_nonces"].items():
        assert code in full_text, f"Missing primary nonce {lbl} in {p['variant']}"
    for lbl, code in p["decoy_nonces"].items():
        assert code in full_text, f"Missing decoy nonce {lbl} in {p['variant']}"

    n_tok = len(tokenizer.encode(full_text, add_special_tokens=False))
    print(f"{p['variant']}: {len(sents)} lines, {n_tok} context tokens, primary@{p['primary_lines']} decoy@{p['decoy_lines']} ✓")

print("\nPrompt variants verified ✓")


base_30_110: 200 lines, 1737 context tokens, primary@{'W1': 32, 'W2': 55, 'W3': 78, 'P': 110} decoy@{'W1': 156, 'W2': 162, 'W3': 168, 'P': 175} ✓
shift_15_95: 200 lines, 1735 context tokens, primary@{'W1': 15, 'W2': 40, 'W3': 68, 'P': 95} decoy@{'W1': 156, 'W2': 162, 'W3': 168, 'P': 175} ✓
shift_50_130: 200 lines, 1736 context tokens, primary@{'W1': 50, 'W2': 75, 'W3': 102, 'P': 130} decoy@{'W1': 156, 'W2': 162, 'W3': 168, 'P': 175} ✓

Prompt variants verified ✓


## 4. Eviction Policies & Label-Sensitive Scorer

In [15]:
@dataclass
class Chunk:
    chunk_id: str
    text: str
    role: str
    event_type: str
    l1_weight: float
    l1_d_total: int
    l1_d_pre: int
    line_no: int
    char_start: int
    char_end: int
    token_count: int = 0
    prefix_attention_score: float = 0.0
    query_attention_score: float = 0.0


@dataclass
class L2Summary:
    d_total: int
    W: List[float]
    provenance: List[Optional[dict]]
    predecessor_tail: List[str]


def identity_summary(k: int) -> L2Summary:
    return L2Summary(
        d_total=0,
        W=[float('-inf')] * (k + 1),
        provenance=[None] * (k + 1),
        predecessor_tail=[],
    )


def chunk_summary(chunk: Chunk, k: int) -> L2Summary:
    W = [float('-inf')] * (k + 1)
    provenance = [None] * (k + 1)

    if np.isfinite(chunk.l1_weight):
        W[0] = float(chunk.l1_weight)
        provenance[0] = {
            "pivot_id": chunk.chunk_id,
            "pred_ids": [],
            "weight": float(chunk.l1_weight),
        }

    predecessor_tail = [chunk.chunk_id] if chunk.l1_d_total > 0 else []

    return L2Summary(
        d_total=min(k, chunk.l1_d_total),
        W=W,
        provenance=provenance,
        predecessor_tail=predecessor_tail,
    )


def compose_summaries(prev: L2Summary, incoming: L2Summary, k: int) -> L2Summary:
    """
    Theorem-style composition:
      W_new[j] = max(W_prev[j], W_incoming[max(0, j - d_total_prev)])
    with provenance tracking for W[k].
    """
    new_d_total = min(k, prev.d_total + incoming.d_total)
    new_W = [float('-inf')] * (k + 1)
    new_prov = [None] * (k + 1)

    for j in range(k + 1):
        cand_prev = prev.W[j]
        prov_prev = prev.provenance[j]

        idx = max(0, j - prev.d_total)
        cand_incoming = incoming.W[idx]
        prov_incoming = None

        if np.isfinite(cand_incoming) and incoming.provenance[idx] is not None:
            needed_from_prev = j - idx
            left_preds = prev.predecessor_tail[-needed_from_prev:] if needed_from_prev > 0 else []
            pred_ids = left_preds + incoming.provenance[idx]["pred_ids"]
            if j > 0:
                pred_ids = pred_ids[-j:]
            else:
                pred_ids = []

            prov_incoming = {
                "pivot_id": incoming.provenance[idx]["pivot_id"],
                "pred_ids": pred_ids,
                "weight": incoming.provenance[idx]["weight"],
            }

        if cand_incoming > cand_prev:
            new_W[j] = cand_incoming
            new_prov[j] = prov_incoming
        else:
            new_W[j] = cand_prev
            new_prov[j] = prov_prev

    new_tail = (prev.predecessor_tail + incoming.predecessor_tail)[-k:]

    return L2Summary(
        d_total=new_d_total,
        W=new_W,
        provenance=new_prov,
        predecessor_tail=new_tail,
    )


def l2_frontier_scan(chunks: List[Chunk], k: int) -> L2Summary:
    s = identity_summary(k)
    for c in chunks:
        s = compose_summaries(s, chunk_summary(c, k), k)
    return s


def l2_protected_set(chunks: List[Chunk], k: int):
    summary = l2_frontier_scan(chunks, k)
    prov = summary.provenance[k]

    protected = set()
    if prov is not None:
        protected.add(prov["pivot_id"])
        protected.update(prov["pred_ids"])

    raw_d_total = sum(c.l1_d_total for c in chunks)
    contract_rhs = min(raw_d_total, k)
    contract_ok = summary.d_total >= contract_rhs
    return protected, summary, contract_ok


# ── Eviction policies (token-budgeted) ─────────────────────────────────────

def _take_until_token_budget(ranked_chunks: List[Chunk], target_evict_tokens: int):
    evicted, total_tokens = [], 0
    for c in ranked_chunks:
        if total_tokens >= target_evict_tokens:
            break
        evicted.append(c)
        total_tokens += c.token_count
    return {c.chunk_id for c in evicted}, total_tokens


def evict_recency(chunks: List[Chunk], target_evict_tokens: int):
    """StreamingLLM-style: keep tail, evict earliest chunks first."""
    earliest_first = sorted(chunks, key=lambda c: c.line_no)
    return _take_until_token_budget(earliest_first, target_evict_tokens)


def evict_by_score(chunks: List[Chunk], target_evict_tokens: int, score_attr: str):
    ranked = sorted(chunks, key=lambda c: getattr(c, score_attr))
    return _take_until_token_budget(ranked, target_evict_tokens)


def evict_structure_aware(chunks: List[Chunk], target_evict_tokens: int, protected_ids: set):
    """
    L2-guarded policy: evict by lowest prefix score from non-protected pool first.
    Breach protected set only if budget exceeds all non-protected tokens.
    """
    non_protected = sorted(
        [c for c in chunks if c.chunk_id not in protected_ids],
        key=lambda c: c.prefix_attention_score,
    )
    protected = sorted(
        [c for c in chunks if c.chunk_id in protected_ids],
        key=lambda c: c.prefix_attention_score,
    )

    evicted_ids, evicted_tokens = _take_until_token_budget(non_protected, target_evict_tokens)
    breached = []

    if evicted_tokens < target_evict_tokens:
        overflow_target = target_evict_tokens - evicted_tokens
        overflow_ids, overflow_tokens = _take_until_token_budget(protected, overflow_target)
        evicted_ids |= overflow_ids
        evicted_tokens += overflow_tokens
        breached = [c.chunk_id for c in protected if c.chunk_id in overflow_ids]

    return evicted_ids, evicted_tokens, breached


def reconstruct(chunks: List[Chunk], evicted_ids: set) -> str:
    return "\\n".join(c.text for c in chunks if c.chunk_id not in evicted_ids)


# ── Deterministic nonce scoring ─────────────────────────────────────────────

def parse_rca_lines(response: str) -> dict:
    """Parse model output into canonical labels, tolerating minor format drift."""
    pat = r'(?im)^\s*(?:[-*]\s*|\d+[\.)]\s*)?([A-Z][A-Z0-9 _-]*?)\s*[:\-]\s*\[?\s*([A-Z0-9]{6})\s*\]?\s*$'
    pairs = re.findall(pat, response)
    out = {}
    for raw_label, raw_code in pairs:
        norm_label = re.sub(r"[\s\-]+", "_", raw_label.strip().upper())
        norm_label = re.sub(r"_+", "_", norm_label)

        if norm_label in {"ROOTCAUSE", "ROOT_CAUSE"}:
            key = "ROOT_CAUSE"
        else:
            m = re.fullmatch(r"WARNING_?([123])", norm_label)
            if not m:
                continue
            key = f"WARNING_{m.group(1)}"

        if key not in out:
            out[key] = raw_code.upper()
    return out


def score_rca_nonce(response: str, primary_nonces: dict, decoy_nonces: dict) -> dict:
    labels = ["ROOT_CAUSE", "WARNING_1", "WARNING_2", "WARNING_3"]
    got = parse_rca_lines(response)

    raw_valid = all(lbl in got and re.fullmatch(r"[A-Z0-9]{6}", got[lbl]) for lbl in labels)
    pivot_preservation = int(raw_valid and got["ROOT_CAUSE"] == primary_nonces["ROOT_CAUSE"])

    primary_full = int(raw_valid and all(got[lbl] == primary_nonces[lbl] for lbl in labels))
    decoy_full = int(raw_valid and all(got[lbl] == decoy_nonces[lbl] for lbl in labels))
    semantic_valid = int(primary_full or decoy_full)

    mirage_detected = int(decoy_full)

    if primary_full:
        mode = "primary_correct"
    elif decoy_full:
        mode = "mirage_substitution"
    elif raw_valid and pivot_preservation == 0:
        mode = "valid_wrong"
    elif raw_valid:
        mode = "valid_partial"
    else:
        mode = "invalid_format"

    return {
        "raw_validity": int(raw_valid),
        "pivot_preservation": pivot_preservation,
        "mirage_detected": mirage_detected,
        "primary_full": primary_full,
        "decoy_full": decoy_full,
        "semantic_valid": semantic_valid,
        "got": got,
        "mode": mode,
    }


print("L2 frontier + policies + deterministic RCA scorer defined ✓")


# Scorer sanity checks
p = {"ROOT_CAUSE": "AAAAAA", "WARNING_1": "BBBBBB", "WARNING_2": "CCCCCC", "WARNING_3": "DDDDDD"}
d = {"ROOT_CAUSE": "111111", "WARNING_1": "222222", "WARNING_2": "333333", "WARNING_3": "444444"}

s1 = score_rca_nonce("ROOT_CAUSE: AAAAAA\nWARNING_1: BBBBBB\nWARNING_2: CCCCCC\nWARNING_3: DDDDDD", p, d)
assert s1["raw_validity"] == 1 and s1["primary_full"] == 1 and s1["mirage_detected"] == 0

s2 = score_rca_nonce("ROOT_CAUSE: 111111\nWARNING_1: 222222\nWARNING_2: 333333\nWARNING_3: 444444", p, d)
assert s2["raw_validity"] == 1 and s2["decoy_full"] == 1 and s2["primary_full"] == 0 and s2["mirage_detected"] == 1

s3 = score_rca_nonce("root_cause: aaaaaa\nwarning_1: bbbbbb\nwarning_2: cccccc\nwarning_3: dddddd", p, d)
assert s3["raw_validity"] == 1 and s3["primary_full"] == 1 and s3["mirage_detected"] == 0

s4 = score_rca_nonce("1) ROOT-CAUSE: AAAAAA\n2) WARNING 1: BBBBBB\n3) WARNING-2: CCCCCC\n4) WARNING_3: DDDDDD", p, d)
assert s4["raw_validity"] == 1 and s4["primary_full"] == 1 and s4["mirage_detected"] == 0

print("Scorer checks ✓")


L2 frontier + policies + deterministic RCA scorer defined ✓
Scorer checks ✓


## 5. Run the Experiment

- Eviction fractions: `0, 20, 35, 50, 60, 70, 75, 80, 85, 90%`
- Policies: `Recency`, `H2O-proxy`, `Query-aware Attention`, `Structure-aware (L2-guarded)`
- Application channel: **attention-mask ablation with explicit position IDs (RoPE-aware simulation)**
- Variants: as many as `PROMPTS` entries (current configuration in code)


In [None]:
import math
import csv
import json
from pathlib import Path
from collections import defaultdict


def wilson_ci(successes: int, n: int, z: float = 1.96):
    if n == 0:
        return (0.0, 0.0)
    p = successes / n
    denom = 1 + z * z / n
    center = (p + z * z / (2 * n)) / denom
    margin = z * math.sqrt((p * (1 - p) + z * z / (4 * n)) / n) / denom
    return max(0.0, center - margin), min(1.0, center + margin)


def annotate_chunk_token_counts(tokenizer, full_text: str, query: str, chunks: List[Chunk]):
    """
    Exact per-chunk token attribution from a shared tokenization view used by masking.
    Token budgeting and masking both use the same add_special_tokens + prompt+query encoding.
    """
    combined = full_text + "\n\n" + query
    enc = tokenizer(
        combined,
        add_special_tokens=True,
        return_offsets_mapping=True,
    )
    raw_offsets = enc["offset_mapping"]
    if hasattr(raw_offsets, "tolist"):
        raw_offsets = raw_offsets.tolist()
    if raw_offsets and isinstance(raw_offsets[0], (list, tuple)) and raw_offsets[0] and isinstance(raw_offsets[0][0], (list, tuple)):
        raw_offsets = raw_offsets[0]
    offsets = [tuple(pair) for pair in raw_offsets]

    spans = [(c.char_start, c.char_end) for c in chunks]
    counts = [0] * len(chunks)
    unassigned = 0
    context_tokens = 0
    context_end = len(full_text)

    for tok_start, tok_end in offsets:
        if tok_end <= tok_start:
            continue
        if tok_start >= context_end:
            continue

        context_tokens += 1
        best_idx = None
        best_overlap = 0
        for idx, (ch_start, ch_end) in enumerate(spans):
            overlap = min(tok_end, ch_end) - max(tok_start, ch_start)
            if overlap > best_overlap:
                best_overlap = overlap
                best_idx = idx
        if best_idx is None or best_overlap <= 0:
            unassigned += 1
        else:
            counts[best_idx] += 1

    for c, cnt in zip(chunks, counts):
        c.token_count = cnt

    return context_tokens, sum(counts), unassigned


EVICTION_FRACS = [0.00, 0.20, 0.35, 0.50, 0.60, 0.70, 0.75, 0.80, 0.85, 0.90]

POLICY_NAMES = [
    "Recency",
    "H2O-proxy",
    "Query-aware Attention",
    "Structure-aware (L2-guarded)",
]

artifact_dir = ARTIFACT_DIR
artifact_dir.mkdir(parents=True, exist_ok=True)

results = []
summary_rows = []
retention_floors = {}
num_prompt_variants = len(PROMPTS)
if num_prompt_variants < 10:
    print(f"[WARNING] Prompt variants={num_prompt_variants} is small for reliable Wilson interval estimates; publishable studies typically need many more prompts.")


for prompt in PROMPTS:
    print(f"\n{'=' * 96}")
    print(f"{prompt['name']} :: {prompt['variant']} :: {prompt['variant_description']}")
    print(f"{'=' * 96}")

    full_text, char_spans = build_text_and_spans(prompt["sentences"])

    chunks = []
    for sent, (cs, ce) in zip(prompt["sentences"], char_spans):
        severity = sent["severity"] if sent["severity"] is not None else float('-inf')
        l1_d_total = 1 if sent["event_type"] == "warning" else 0
        l1_d_pre = l1_d_total

        chunks.append(Chunk(
            chunk_id=sent["id"],
            text=sent["text"],
            role=sent["role"],
            event_type=sent["event_type"],
            l1_weight=float(severity),
            l1_d_total=l1_d_total,
            l1_d_pre=l1_d_pre,
            line_no=sent["line_no"],
            char_start=cs,
            char_end=ce,
        ))

    context_tokens, evictable_tokens, unassigned = annotate_chunk_token_counts(tokenizer, full_text, prompt["query"], chunks)
    print(f"Context tokens={context_tokens}, attributed tokens={evictable_tokens}, unassigned={unassigned}")
    if evictable_tokens <= 0:
        print("  No evictable context tokens found after attribution; skipping this prompt.")
        continue

    # L2 protected set from numeric L1 states only
    protected_ids, l2_summary, contract_ok = l2_protected_set(chunks, prompt["k"])
    protected_lines = sorted([c.line_no for c in chunks if c.chunk_id in protected_ids])
    print(f"L2 contract check: {'OK' if contract_ok else 'FAIL'}")
    print(f"Protected lines (k={prompt['k']}): {protected_lines}")

    # Theoretical retention floor for primary structural chunks
    primary_struct = [c for c in chunks if c.role in {"predecessor", "pivot"}]
    primary_struct_tokens = sum(c.token_count for c in primary_struct)
    floor = 1.0 - (primary_struct_tokens / max(1, evictable_tokens))
    retention_floors[prompt["variant"]] = floor
    print(f"Primary structural floor (safe eviction <=): {floor:.1%}")

    # One prefix-only and one query-aware attention pass per prompt variant
    prefix_scores = get_chunk_attention(
        model,
        tokenizer,
        full_text,
        None,
        [(c.char_start, c.char_end) for c in chunks],
        pool="max",
        probe_tokens=64,
    )
    query_scores = get_chunk_attention(
        model,
        tokenizer,
        full_text,
        prompt["query"],
        [(c.char_start, c.char_end) for c in chunks],
        pool="max",
        probe_tokens=64,
    )
    for c, ps, qs in zip(chunks, prefix_scores, query_scores):
        c.prefix_attention_score = ps
        c.query_attention_score = qs

    shown_invalid_example = False
    for frac in EVICTION_FRACS:
        budgeted_evict_tokens = int(round(evictable_tokens * frac))
        budgeted_evict_tokens = max(0, min(budgeted_evict_tokens, evictable_tokens))
        if frac > 0 and budgeted_evict_tokens == 0:
            print(f"\n  Fraction {frac:.0%}: target evicts {budgeted_evict_tokens}/{evictable_tokens} tokens (rounding no-op).")
        else:
            print(f"\n  Fraction {frac:.0%}: target evict {budgeted_evict_tokens}/{evictable_tokens} tokens")

        for policy in POLICY_NAMES:
            if policy == "Recency":
                evicted_ids, actual_evict_tokens = evict_recency(chunks, budgeted_evict_tokens)
                breach_ids = []
            elif policy == "H2O-proxy":
                evicted_ids, actual_evict_tokens = evict_by_score(chunks, budgeted_evict_tokens, "prefix_attention_score")
                breach_ids = []
            elif policy == "Query-aware Attention":
                evicted_ids, actual_evict_tokens = evict_by_score(chunks, budgeted_evict_tokens, "query_attention_score")
                breach_ids = []
            else:
                evicted_ids, actual_evict_tokens, breach_ids = evict_structure_aware(chunks, budgeted_evict_tokens, protected_ids)

            response, realized_evict_tokens, _context_tokens = generate_with_chunk_mask(
                model,
                tokenizer,
                full_text,
                prompt["query"],
                [c.chunk_id for c in chunks],
                [(c.char_start, c.char_end) for c in chunks],
                evicted_ids,
                max_new_tokens=100,
            )
            score = score_rca_nonce(response, prompt["primary_nonces"], prompt["decoy_nonces"])

            realized_ret = (evictable_tokens - realized_evict_tokens) / max(1, evictable_tokens)

            row = {
                "variant": prompt["variant"],
                "variant_description": prompt["variant_description"],
                "policy": policy,
                "fraction": frac,
                "budgeted_evict_tokens": budgeted_evict_tokens,
                "target_evict_tokens": budgeted_evict_tokens,
                "actual_evict_tokens": realized_evict_tokens,
                "realized_token_retention": realized_ret,
                "evicted_chunk_ids": sorted(evicted_ids),
                "l2_breach_chunk_ids": breach_ids,
                "response": response,
                "raw_validity": score["raw_validity"],
                "pivot_preservation": score["pivot_preservation"],
                "primary_full": score["primary_full"],
                "decoy_full": score["decoy_full"],
                "mirage_detected": score["mirage_detected"],
                "mode": score["mode"],
                "got_nonces": score["got"],
                "primary_nonces": prompt["primary_nonces"],
                "decoy_nonces": prompt["decoy_nonces"],
            }
            results.append(row)

            print(
                f"    {policy:24s} | raw={score['raw_validity']} | pivot={score['pivot_preservation']} "
                f"| primary={score['primary_full']} | decoy={score['decoy_full']} | mirage={score['mirage_detected']} | ret={realized_ret:.1%} "
                f"| mode={score['mode']}"
            )
            if score["raw_validity"] == 0 and not shown_invalid_example:
                preview = " | ".join(line.strip() for line in response.strip().splitlines()[:6])
                print(f"      invalid-format example: {preview[:240]}")
                print(f"      parsed labels: {sorted(score['got'].keys())}")
                shown_invalid_example = True
            if policy == "Structure-aware (L2-guarded)" and breach_ids:
                print(f"      ⚠ L2 retention floor breached: {breach_ids}")

# Aggregate summary: policy x fraction across prompt variants
for policy in POLICY_NAMES:
    for frac in EVICTION_FRACS:
        subset = [r for r in results if r["policy"] == policy and abs(r["fraction"] - frac) < 1e-12]
        if not subset:
            continue

        n = len(subset)
        raw_s = int(sum(r["raw_validity"] for r in subset))
        piv_s = int(sum(r["pivot_preservation"] for r in subset))
        prim_s = int(sum(r["primary_full"] for r in subset))
        dec_s = int(sum(r["decoy_full"] for r in subset))
        mir_s = int(sum(r["mirage_detected"] for r in subset))

        raw_lo, raw_hi = wilson_ci(raw_s, n)
        piv_lo, piv_hi = wilson_ci(piv_s, n)
        prim_lo, prim_hi = wilson_ci(prim_s, n)
        dec_lo, dec_hi = wilson_ci(dec_s, n)
        mir_lo, mir_hi = wilson_ci(mir_s, n)

        summary_rows.append({
            "policy": policy,
            "fraction": frac,
            "n": n,
            "raw_validity_mean": raw_s / n,
            "raw_validity_ci_lo": raw_lo,
            "raw_validity_ci_hi": raw_hi,
            "pivot_preservation_mean": piv_s / n,
            "pivot_preservation_ci_lo": piv_lo,
            "pivot_preservation_ci_hi": piv_hi,
            "primary_full_mean": prim_s / n,
            "primary_full_ci_lo": prim_lo,
            "primary_full_ci_hi": prim_hi,
            "decoy_full_mean": dec_s / n,
            "decoy_full_ci_lo": dec_lo,
            "decoy_full_ci_hi": dec_hi,
            "mirage_rate_mean": mir_s / n,
            "mirage_rate_ci_lo": mir_lo,
            "mirage_rate_ci_hi": mir_hi,
            "realized_token_retention_mean": float(np.mean([r["realized_token_retention"] for r in subset])),
        })

# Model revision hash best effort
model_revision = (
    getattr(model.config, "_commit_hash", None)
    or tokenizer.init_kwargs.get("_commit_hash")
    or tokenizer.init_kwargs.get("revision")
    or "unknown"
)

config = {
    "model_id": MODEL_ID,
    "model_revision": model_revision,
    "seed": SEED,
    "eviction_fractions": EVICTION_FRACS,
    "policies": POLICY_NAMES,
    "k": K,
    "application": "attention-mask ablation with explicit position IDs (RoPE-aware simulation)",
    "prompt_variants": [
        {
            "variant": p["variant"],
            "description": p["variant_description"],
            "primary_lines": p["primary_lines"],
            "decoy_lines": p["decoy_lines"],
        }
        for p in PROMPTS
    ],
    "env": env_info,
}

(artifact_dir / "config.json").write_text(json.dumps(config, indent=2))
with (artifact_dir / "results.jsonl").open("w") as f:
    for row in results:
        f.write(json.dumps(row) + "\n")
with (artifact_dir / "summary.csv").open("w", newline="") as f:
    writer = csv.DictWriter(
        f,
        fieldnames=[
            "policy", "fraction", "n",
            "raw_validity_mean", "raw_validity_ci_lo", "raw_validity_ci_hi",
            "pivot_preservation_mean", "pivot_preservation_ci_lo", "pivot_preservation_ci_hi",
            "primary_full_mean", "primary_full_ci_lo", "primary_full_ci_hi",
            "decoy_full_mean", "decoy_full_ci_lo", "decoy_full_ci_hi",
            "mirage_rate_mean", "mirage_rate_ci_lo", "mirage_rate_ci_hi",
            "realized_token_retention_mean",
        ],
    )
    writer.writeheader()
    for row in summary_rows:
        writer.writerow(row)

print("\n" + "=" * 96)
print("RUN COMPLETE")
print(f"Wrote: {(artifact_dir / 'config.json').as_posix()}")
print(f"Wrote: {(artifact_dir / 'results.jsonl').as_posix()}")
print(f"Wrote: {(artifact_dir / 'summary.csv').as_posix()}")
print("=" * 96)



CyberOps Log (base_30_110) :: base_30_110 :: Primary arc centered early (approx lines 30-110).
Context tokens=1737, attributed tokens=1737, unassigned=0
L2 contract check: OK
Protected lines (k=3): [32, 55, 78, 110]
Primary structural floor (safe eviction <=): 94.8%

  Fraction 0%: target evict 0/1737 tokens
    Recency                  | raw=0 | pivot=0 | primary=0 | decoy=0 | mirage=0 | ret=100.0% | mode=invalid_format
      parsed labels: []
    H2O-proxy                | raw=0 | pivot=0 | primary=0 | decoy=0 | mirage=0 | ret=100.0% | mode=invalid_format
    Query-aware Attention    | raw=0 | pivot=0 | primary=0 | decoy=0 | mirage=0 | ret=100.0% | mode=invalid_format
    Structure-aware (L2-guarded) | raw=0 | pivot=0 | primary=0 | decoy=0 | mirage=0 | ret=100.0% | mode=invalid_format

  Fraction 20%: target evict 347/1737 tokens
    Recency                  | raw=0 | pivot=0 | primary=0 | decoy=0 | mirage=0 | ret=80.0% | mode=invalid_format
    H2O-proxy                | raw=0 | pi

## 6. Visualization

In [None]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 11})

COLORS = {
    "Recency": "#c0392b",
    "H2O-proxy": "#d35400",
    "Query-aware Attention": "#8e44ad",
    "Structure-aware (L2-guarded)": "#1e8449",
}
MARKERS = {
    "Recency": "o",
    "H2O-proxy": "s",
    "Query-aware Attention": "^",
    "Structure-aware (L2-guarded)": "D",
}


def get_series(policy: str, metric: str, lo: str = None, hi: str = None):
    xs, ys, los, his = [], [], [], []
    for frac in EVICTION_FRACS:
        rows = [r for r in summary_rows if r["policy"] == policy and abs(r["fraction"] - frac) < 1e-12]
        if not rows:
            continue
        row = rows[0]
        xs.append(frac)
        ys.append(row[metric])
        if lo and hi:
            los.append(row[lo])
            his.append(row[hi])
    return xs, ys, los, his


# Plot 1: main 2-panel (raw validity + pivot preservation)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
for policy in POLICY_NAMES:
    x, y, lo, hi = get_series(policy, "raw_validity_mean", "raw_validity_ci_lo", "raw_validity_ci_hi")
    ax.plot(x, y, f"{MARKERS[policy]}-", color=COLORS[policy], label=policy, linewidth=2.3, markersize=7)
    ax.fill_between(x, lo, hi, color=COLORS[policy], alpha=0.12)
ax.set_title("Raw Validity vs Eviction")
ax.set_xlabel("Eviction Fraction")
ax.set_ylabel("Raw Validity")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)

ax = axes[1]
for policy in POLICY_NAMES:
    x, y, lo, hi = get_series(policy, "pivot_preservation_mean", "pivot_preservation_ci_lo", "pivot_preservation_ci_hi")
    ax.plot(x, y, f"{MARKERS[policy]}-", color=COLORS[policy], label=policy, linewidth=2.3, markersize=7)
    ax.fill_between(x, lo, hi, color=COLORS[policy], alpha=0.12)
ax.set_title("Pivot Preservation vs Eviction")
ax.set_xlabel("Eviction Fraction")
ax.set_ylabel("Pivot Preservation")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)

plt.tight_layout()
plot1_path = artifact_dir / "plot1_validity_vs_pivot.png"
plt.savefig(plot1_path, dpi=170, bbox_inches="tight")
plt.show()
print(f"Saved: {plot1_path.as_posix()}")

# Plot 2: mirage rate
fig, ax = plt.subplots(1, 1, figsize=(7, 5))
for policy in POLICY_NAMES:
    x, y, lo, hi = get_series(policy, "mirage_rate_mean", "mirage_rate_ci_lo", "mirage_rate_ci_hi")
    ax.plot(x, y, f"{MARKERS[policy]}-", color=COLORS[policy], label=policy, linewidth=2.3, markersize=7)
    ax.fill_between(x, lo, hi, color=COLORS[policy], alpha=0.12)
ax.set_title("Decoy Substitution Rate (strict coherent substitution)")
ax.set_xlabel("Eviction Fraction")
ax.set_ylabel("Mirage Rate")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
plt.tight_layout()
plot2_path = artifact_dir / "plot2_mirage_rate.png"
plt.savefig(plot2_path, dpi=170, bbox_inches="tight")
plt.show()
print(f"Saved: {plot2_path.as_posix()}")

# Plot 3: pivot preservation + retention floor annotation
floor_vals = list(retention_floors.values())
floor_mean = float(np.mean(floor_vals)) if floor_vals else 0.0

fig, ax = plt.subplots(1, 1, figsize=(8, 5))
for policy in POLICY_NAMES:
    x, y, lo, hi = get_series(policy, "pivot_preservation_mean", "pivot_preservation_ci_lo", "pivot_preservation_ci_hi")
    ax.plot(x, y, f"{MARKERS[policy]}-", color=COLORS[policy], label=policy, linewidth=2.3, markersize=7)
    ax.fill_between(x, lo, hi, color=COLORS[policy], alpha=0.10)

ax.axvline(floor_mean, color="black", linestyle="--", linewidth=1.4, label=f"primary retention floor ~ {floor_mean:.2f}")
ax.set_title("Pivot Preservation with Theoretical Retention Floor")
ax.set_xlabel("Eviction Fraction")
ax.set_ylabel("Pivot Preservation")
ax.set_ylim(-0.05, 1.05)
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)

plt.tight_layout()
plot3_path = artifact_dir / "plot3_pivot_with_floor.png"
plt.savefig(plot3_path, dpi=170, bbox_inches="tight")
plt.show()
print(f"Saved: {plot3_path.as_posix()}")


## 7. Summary Tables


In [None]:
from collections import Counter, defaultdict

print("=" * 108)
print(f"TABLE 1: Aggregated Metrics by Policy x Eviction Fraction (n={len(PROMPTS)} variants, Wilson 95% CI)")
print("=" * 108)
print(f"{'Policy':28s} {'Frac':>6s} {'Raw':>16s} {'Primary':>16s} {'Decoy':>16s} {'Mirage':>16s} {'Ret':>8s}")
print("-" * 108)

for policy in POLICY_NAMES:
    for frac in EVICTION_FRACS:
        rows = [r for r in summary_rows if r["policy"] == policy and abs(r["fraction"] - frac) < 1e-12]
        if not rows:
            continue
        r = rows[0]
        raw = f"{r['raw_validity_mean']*100:5.1f}%[{r['raw_validity_ci_lo']*100:4.1f},{r['raw_validity_ci_hi']*100:4.1f}]"
        # pivot formatting retained only when/if needed
        prim = f"{r['primary_full_mean']*100:5.1f}%[{r['primary_full_ci_lo']*100:4.1f},{r['primary_full_ci_hi']*100:4.1f}]"
        dec = f"{r['decoy_full_mean']*100:5.1f}%[{r['decoy_full_ci_lo']*100:4.1f},{r['decoy_full_ci_hi']*100:4.1f}]"
        mir = f"{r['mirage_rate_mean']*100:5.1f}%[{r['mirage_rate_ci_lo']*100:4.1f},{r['mirage_rate_ci_hi']*100:4.1f}]"
        print(f"{policy:28s} {frac:6.0%} {raw:>16s} {prim:>16s} {dec:>16s} {mir:>16s} {r['realized_token_retention_mean']:8.1%}")

print()
print("=" * 108)
print("TABLE 2: Mode Counts (trial-level)")
print("=" * 108)
mode_counts = defaultdict(Counter)
for row in results:
    mode_counts[row["policy"]][row["mode"]] += 1
for policy in POLICY_NAMES:
    print(f"{policy}:")
    for mode, cnt in mode_counts[policy].most_common():
        print(f"  {mode:22s}: {cnt}")

print()
print("Retention floors by variant:")
for v, f in sorted(retention_floors.items()):
    print(f"  {v:18s}: {f:.2f}")

print()
print("Artifacts:")
print(f"  {(artifact_dir / 'config.json').as_posix()}")
print(f"  {(artifact_dir / 'results.jsonl').as_posix()}")
print(f"  {(artifact_dir / 'summary.csv').as_posix()}")
print(f"  {(artifact_dir / 'plot1_validity_vs_pivot.png').as_posix()}")
print(f"  {(artifact_dir / 'plot2_mirage_rate.png').as_posix()}")
print(f"  {(artifact_dir / 'plot3_pivot_with_floor.png').as_posix()}")


## 8. Interpretation & Limits

### What this demo establishes

- Structure-blind eviction can keep **raw validity** high while **pivot preservation** drops.
- Mirage detection is deterministic (`raw_validity=1`, strict decoy substitution of all 4 fields).
- L2-guarded protection is derived from the frontier scan/provenance, not label string matching.

### Limitations (explicit)

1. This is **attention-mask ablation**, not literal `past_key_values` pruning.
2. The prompt set is intentionally small for runtime-constrained demo behavior.
3. Query-aware attention is oracle-ish and included as an upper bound, not a streaming deployment policy.

### Recommended scale-up path

- Expand to many generated incident logs for tighter error bars.
- Add real KV-cache pruning implementation for production benchmarks.
- Add runtime/cost profiling side-by-side with semantic metrics.
