# MirageBench KV Cache Eviction (Colab)

This notebook measures how right-truncating KV cache affects MirageBench outputs for `meta-llama/Llama-3.1-8B-Instruct`.

Experiment design:
- Load MirageBench runtime via `_load_notebook_runtime` + `_patch_runtime_with_methodology_fixes` + `_validate_investment_ground_truth`.
- Build the standard 12-task set with `build_miragebench_v01`.
- Compute a full-cache baseline answer/pivot for each task.
- Re-run each task with KV retentions `[0.7, 0.5, 0.3, 0.1]` by keeping the rightmost cache tokens per layer.
- Save per-retention checkpoint CSVs to Google Drive immediately and one final merged CSV.


In [None]:
# 1) Install dependencies
!pip -q install transformers==4.46.3 accelerate sentence-transformers scikit-learn pandas tqdm
print('Dependencies installed.')


In [None]:
# 2) Mount Google Drive + prepare repo path
from google.colab import drive
import subprocess
import sys
from pathlib import Path

drive.mount('/content/drive')

REPO_DIR = Path('/content/mirage')
if not REPO_DIR.exists():
    print('Cloning mirage repo...')
    subprocess.run([
        'git', 'clone', 'https://github.com/jack-chaudier/mirage.git', str(REPO_DIR)
    ], check=True)
else:
    print(f'Repo already present at {REPO_DIR}')

if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))
if str(REPO_DIR / 'endogenous_context_theory') not in sys.path:
    sys.path.insert(0, str(REPO_DIR / 'endogenous_context_theory'))

print('Python paths configured.')


In [None]:
# 3) Imports + reproducibility
import gc
import random
import re
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed

from endogenous_context_theory.scripts.run_miragebench_ollama import (
    _load_notebook_runtime,
    _patch_runtime_with_methodology_fixes,
    _validate_investment_ground_truth,
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
set_seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
if hasattr(torch.backends, 'cudnn'):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f'Torch: {torch.__version__}')
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU:', torch.cuda.get_device_name(0))


In [None]:
# 4) Load MirageBench runtime exactly as used in blackbox pipeline
ROOT = REPO_DIR / 'endogenous_context_theory'
NB_PATH = ROOT / 'notebooks' / 'legacy' / 'miragebench_experiments_colab.ipynb'

runtime = _load_notebook_runtime(NB_PATH)
_patch_runtime_with_methodology_fixes(runtime)

build_miragebench_v01 = runtime['build_miragebench_v01']
make_prompt = runtime['make_prompt']
raw_validity_score = runtime['raw_validity_score']
semantic_regret = runtime['semantic_regret']

tasks = build_miragebench_v01()
_validate_investment_ground_truth(tasks)
tasks = tasks[:12]

print(f'Loaded {len(tasks)} MirageBench tasks.')
print('Task IDs:', [t.task_id for t in tasks])


In [None]:
# 5) Model + output settings
MODEL_ID = 'meta-llama/Llama-3.1-8B-Instruct'
RETENTIONS = [1.0, 0.7, 0.5, 0.3, 0.1]
MAX_NEW_TOKENS = 220
SKIP_IF_RETENTION_CSV_EXISTS = True

DRIVE_OUT_DIR = Path('/content/drive/MyDrive/miragebench_kv_cache_eviction_mirage')
LOCAL_OUT_DIR = Path('/content/miragebench_kv_cache_eviction_mirage')
DRIVE_OUT_DIR.mkdir(parents=True, exist_ok=True)
LOCAL_OUT_DIR.mkdir(parents=True, exist_ok=True)

FINAL_CSV_NAME = 'kv_cache_eviction_mirage_results.csv'
SUMMARY_CSV_NAME = 'kv_cache_eviction_mirage_summary_by_retention.csv'

print('Model:', MODEL_ID)
print('Retentions:', RETENTIONS)
print('Drive output:', DRIVE_OUT_DIR)
print('Local output:', LOCAL_OUT_DIR)


In [None]:
# 6) Helpers: pivot extraction, feasibility, cache truncation, greedy decode
PIVOT_PRIMARY_RE = re.compile(r'PIVOT_ID\s*=\s*([A-Z]\d{1,4}-E\d{3})')
PIVOT_FALLBACK_RE = re.compile(r'([A-Z]\d{1,4}-E\d{3})')

def extract_pivot_id(text: str, fallback_candidates: List[str] | None = None) -> str:
    if not text:
        return ''
    m = PIVOT_PRIMARY_RE.search(text)
    if m:
        return m.group(1)
    markers = PIVOT_FALLBACK_RE.findall(text)
    if markers and fallback_candidates:
        for c in fallback_candidates:
            if c in markers:
                return c
    return markers[0] if markers else ''

def compute_fixed_pivot_feasible(task: Any, full_pivot: str, context_text: str) -> bool:
    if not full_pivot:
        return False
    req_map = task.metadata.get('candidate_requirements', {}) if isinstance(task.metadata, dict) else {}
    reqs = req_map.get(full_pivot, [])
    if full_pivot not in context_text:
        return False
    return all(marker in context_text for marker in reqs)

def format_chat_prompt(tokenizer, prompt: str) -> str:
    messages = [{'role': 'user', 'content': prompt}]
    if hasattr(tokenizer, 'apply_chat_template') and tokenizer.chat_template:
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

def to_legacy_past(past_key_values):
    if hasattr(past_key_values, 'to_legacy_cache'):
        return past_key_values.to_legacy_cache()
    return past_key_values

def get_model_device(model) -> torch.device:
    if hasattr(model, 'device'):
        return model.device
    return next(model.parameters()).device

def past_seq_len(past_key_values) -> int:
    return int(past_key_values[0][0].shape[-2])

def load_bf16_model(model_id: str):
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map='auto',
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    model.eval()
    return model, tokenizer

def build_full_prompt_cache(model, tokenizer, prompt: str) -> Dict[str, Any]:
    input_text = format_chat_prompt(tokenizer, prompt)
    encoded = tokenizer(input_text, return_tensors='pt')
    device = get_model_device(model)
    input_ids = encoded['input_ids'].to(device)
    attention_mask = encoded.get('attention_mask', torch.ones_like(input_ids)).to(device)
    if input_ids.shape[1] < 1:
        raise ValueError('Prompt tokenization is empty.')

    with torch.no_grad():
        out = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
        )

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'last_token': input_ids[:, -1:],
        'past_key_values': to_legacy_past(out.past_key_values),
    }

def truncate_past_middle(past_key_values, retention: float, anchor_len: int = 256):
    if retention <= 0 or retention > 1:
        raise ValueError(f'retention must be in (0,1], got {retention}')
    full_len = past_seq_len(past_key_values)
    keep_total = max(anchor_len + 1, int(retention * full_len))
    tail_len = keep_total - anchor_len

    new_past = []
    for layer in past_key_values:
        key, value = layer[0], layer[1]
        head = key.narrow(-2, 0, anchor_len)
        tail = key.narrow(-2, full_len - tail_len, tail_len)
        new_key = torch.cat([head, tail], dim=-2).contiguous()

        head_v = value.narrow(-2, 0, anchor_len)
        tail_v = value.narrow(-2, full_len - tail_len, tail_len)
        new_value = torch.cat([head_v, tail_v], dim=-2).contiguous()

        if len(layer) == 2:
            new_past.append((new_key, new_value))
        else:
            new_past.append((new_key, new_value, *layer[2:]))
    return tuple(new_past), keep_total, full_len


def greedy_decode_from_cache(
    model,
    tokenizer,
    past_key_values,
    continuation_input_ids,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> str:
    past = to_legacy_past(past_key_values)
    input_token = continuation_input_ids
    generated_ids: List[int] = []
    eos_id = tokenizer.eos_token_id

    with torch.no_grad():
        for _ in range(max_new_tokens):
            cache_len = past_seq_len(past)
            attention_mask = torch.ones(
                (input_token.shape[0], cache_len + input_token.shape[1]),
                dtype=torch.long,
                device=input_token.device,
            )
            out = model(
                input_ids=input_token,
                past_key_values=past,
                attention_mask=attention_mask,
                use_cache=True,
            )
            past = to_legacy_past(out.past_key_values)
            next_token = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
            token_id = int(next_token.item())
            generated_ids.append(token_id)
            input_token = next_token
            if eos_id is not None and token_id == eos_id:
                break

    return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()


In [None]:
# 7) Load model and compute full-cache baselines for all 12 tasks
model, tokenizer = load_bf16_model(MODEL_ID)

baseline_by_task: Dict[str, Dict[str, Any]] = {}

for task in tqdm(tasks, desc='Full-cache baseline'):
    full_prompt = make_prompt(task.full_context, task.question)
    cache = build_full_prompt_cache(model, tokenizer, full_prompt)
    full_answer = greedy_decode_from_cache(
        model=model,
        tokenizer=tokenizer,
        past_key_values=cache['past_key_values'],
        continuation_input_ids=cache['last_token'],
        max_new_tokens=MAX_NEW_TOKENS,
    )
    full_pivot = extract_pivot_id(full_answer, [task.pivot_ground_truth, task.decoy_pivot])
    baseline_by_task[task.task_id] = {
        'full_prompt': full_prompt,
        'full_answer': full_answer,
        'full_pivot': full_pivot,
    }

    del cache
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print('Computed baselines:', len(baseline_by_task))


In [None]:
# 8) Retention sweep with per-retention checkpoint CSVs
all_rows: List[Dict[str, Any]] = []

for retention in RETENTIONS:
    retention_slug = str(retention).replace('.', 'p')
    local_ckpt = LOCAL_OUT_DIR / f'kv_cache_eviction_retention_{retention_slug}.csv'
    drive_ckpt = DRIVE_OUT_DIR / f'kv_cache_eviction_retention_{retention_slug}.csv'

    if SKIP_IF_RETENTION_CSV_EXISTS and drive_ckpt.exists():
        print(f'Loading existing checkpoint for retention={retention}: {drive_ckpt}')
        cached_df = pd.read_csv(drive_ckpt)
        all_rows.extend(cached_df.to_dict(orient='records'))
        continue

    retention_rows: List[Dict[str, Any]] = []
    for task in tqdm(tasks, desc=f'Retention {retention:.1f}'):
        baseline = baseline_by_task[task.task_id]
        cache = build_full_prompt_cache(model, tokenizer, baseline['full_prompt'])
        truncated_past, kept_len, full_len = truncate_past_middle(cache['past_key_values'], retention=retention)
        truncated_answer = greedy_decode_from_cache(
            model=model,
            tokenizer=tokenizer,
            past_key_values=truncated_past,
            continuation_input_ids=cache['last_token'],
            max_new_tokens=MAX_NEW_TOKENS,
        )

        compressed_pivot = extract_pivot_id(
            truncated_answer,
            [task.pivot_ground_truth, task.decoy_pivot],
        )
        row = {
            'task_id': task.task_id,
            'category': task.category,
            'retention': float(retention),
            'full_pivot': baseline['full_pivot'],
            'compressed_pivot': compressed_pivot,
            'has_pivot_header': int(bool(PIVOT_PRIMARY_RE.search(truncated_answer))),
            'pivot_preserved': int(bool(baseline['full_pivot'] and compressed_pivot and baseline['full_pivot'] == compressed_pivot)),
            'fixed_pivot_feasible': int(compute_fixed_pivot_feasible(task, baseline['full_pivot'], task.full_context)),
            'raw_validity': float(raw_validity_score(truncated_answer, task)),
            'semantic_regret': float(semantic_regret(baseline['full_answer'], truncated_answer)),
            'full_answer': baseline['full_answer'],
            'truncated_answer': truncated_answer,
            'cache_tokens_full': int(full_len),
            'cache_tokens_kept': int(kept_len),
        }
        retention_rows.append(row)

        del cache
        del truncated_past
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    retention_df = pd.DataFrame(retention_rows)
    retention_df.to_csv(local_ckpt, index=False)
    retention_df.to_csv(drive_ckpt, index=False)
    print(f'Saved retention checkpoint for {retention:.1f}:')
    print('  Local:', local_ckpt)
    print('  Drive:', drive_ckpt)

    all_rows.extend(retention_rows)

results_df = pd.DataFrame(all_rows).sort_values(['retention', 'task_id']).reset_index(drop=True)

final_local = LOCAL_OUT_DIR / FINAL_CSV_NAME
final_drive = DRIVE_OUT_DIR / FINAL_CSV_NAME
results_df.to_csv(final_local, index=False)
results_df.to_csv(final_drive, index=False)

print('\nSaved final merged CSV:')
print('  Local:', final_local)
print('  Drive:', final_drive)
print('Rows:', len(results_df))

results_df.head()


In [None]:
# 9) Summary by retention
# Backfill header-presence metric for legacy checkpoints that predate this column.
if 'has_pivot_header' not in results_df.columns:
    results_df['has_pivot_header'] = results_df['truncated_answer'].fillna('').str.contains(r'PIVOT_ID\s*=', regex=True).astype(int)
else:
    backfill = results_df['truncated_answer'].fillna('').str.contains(r'PIVOT_ID\s*=', regex=True)
    results_df['has_pivot_header'] = results_df['has_pivot_header'].fillna(backfill).astype(int)

summary = (
    results_df.groupby('retention', as_index=False)
    .agg(
        has_pivot_header=('has_pivot_header', 'mean'),
        pivot_preserved=('pivot_preserved', 'mean'),
        fixed_pivot_feasible=('fixed_pivot_feasible', 'mean'),
        raw_validity=('raw_validity', 'mean'),
        semantic_regret=('semantic_regret', 'mean'),
    )
    .sort_values('retention', ascending=False)
)

summary_local = LOCAL_OUT_DIR / SUMMARY_CSV_NAME
summary_drive = DRIVE_OUT_DIR / SUMMARY_CSV_NAME
summary.to_csv(summary_local, index=False)
summary.to_csv(summary_drive, index=False)

print('Saved summary CSV:')
print('  Local:', summary_local)
print('  Drive:', summary_drive)

summary


## KV API Compatibility Notes

- This notebook converts cache objects via `to_legacy_cache()` when available so truncation can operate on per-layer `(key, value)` tensors.
- If a future `transformers` release changes cache container types/shapes, update `to_legacy_past` and `truncate_past_middle` accordingly.
- The decode path intentionally uses minimal continuation input (`last prompt token`) with an attention mask sized as `kept_cache_len + current_input_len` on each step.
