# Mirage-Aware Fine-Tuning: Qwen 2.5 7B LoRA (A100)

This notebook fine-tunes **Qwen/Qwen2.5-7B-Instruct** with QLoRA to detect and flag
when context compression degrades evidence, instead of silently switching hypotheses.

**Requirements:** Colab with A100 GPU.

**Runtime:** ~45-60 minutes for training + full validation eval.

**Inputs:** `train.jsonl` and `valid.jsonl` (upload when prompted).

**Outputs:** LoRA adapter + eval CSV + summary JSON + raw response backups.


## 1. Setup

In [None]:
!pip install -q transformers==4.46.3 peft==0.13.2 trl==0.12.2 accelerate==1.1.1 bitsandbytes>=0.46.1 datasets matplotlib
print("Dependencies installed.")


In [None]:
# Verify GPU
import torch
assert torch.cuda.is_available(), "No GPU detected! Go to Runtime > Change runtime type > GPU"
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")


## 2. Upload Training Data

Upload your `train.jsonl` and `valid.jsonl` files when the file picker appears.


In [None]:
import os
import json
from google.colab import files

DATA_DIR = "/content/training_data"
os.makedirs(DATA_DIR, exist_ok=True)

train_path = os.path.join(DATA_DIR, "train.jsonl")
valid_path = os.path.join(DATA_DIR, "valid.jsonl")

if not os.path.exists(train_path) or not os.path.exists(valid_path):
    print("Upload train.jsonl and valid.jsonl:")
    uploaded = files.upload()
    for fname, content in uploaded.items():
        dest = os.path.join(DATA_DIR, fname)
        with open(dest, "wb") as f:
            f.write(content)
        print(f"  Saved {fname} -> {dest}")
else:
    print(f"Data already present at {DATA_DIR}")


def count_and_check(path):
    count = 0
    strong = 0
    degraded = 0
    for line in open(path):
        obj = json.loads(line)
        assert "messages" in obj, f"Missing 'messages' key in {path}"
        assert len(obj["messages"]) >= 2, f"Need at least 2 messages in {path}"
        content = obj["messages"][1]["content"]
        if "STRONG" in content:
            strong += 1
        if "DEGRADED" in content:
            degraded += 1
        count += 1
    return count, strong, degraded


for p, name in [(train_path, "Train"), (valid_path, "Valid")]:
    n, s, d = count_and_check(p)
    pct = d / n * 100 if n > 0 else 0
    print(f"{name}: {n} examples ({s} STRONG, {d} DEGRADED, {pct:.0f}% degraded)")

print("\nData verified.")


## 3. Load Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {MODEL_ID} in 4-bit...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")


## 4. Configure LoRA

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

model = get_peft_model(model, lora_config)
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")


## 5. Load Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("json", data_files={
    "train": train_path,
    "validation": valid_path,
})

print(f"Train: {len(dataset['train'])} examples")
print(f"Valid: {len(dataset['validation'])} examples")

# Keep only messages for SFT training.
train_cols = [c for c in dataset["train"].column_names if c != "messages"]
valid_cols = [c for c in dataset["validation"].column_names if c != "messages"]

train_data = dataset["train"].remove_columns(train_cols)
valid_data = dataset["validation"].remove_columns(valid_cols)
print(f"Stripped extra columns: {train_cols}")


def format_chat(example):
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return {"text": text}


train_data = train_data.map(format_chat, remove_columns=["messages"])
valid_data = valid_data.map(format_chat, remove_columns=["messages"])

print(f"Train columns: {train_data.column_names}")
print(f"Valid columns: {valid_data.column_names}")
print("\nSample formatted text (first 500 chars):")
print(train_data[0]["text"][:500] + "...")


In [None]:
# Sequence length diagnostic
MAX_SEQ = 2560  # must match training config

lengths = []
assistant_visible = {"full": 0, "partial": 0, "none": 0}

response_marker = "<|im_start|>assistant\n"
for ex in train_data:
    tokens = tokenizer(ex["text"], truncation=False)["input_ids"]
    total_len = len(tokens)
    lengths.append(total_len)

    text = ex["text"]
    asst_start_char = text.find(response_marker)
    if asst_start_char >= 0:
        prefix = text[: asst_start_char + len(response_marker)]
        prefix_tokens = len(tokenizer(prefix, truncation=False)["input_ids"])

        if prefix_tokens >= MAX_SEQ:
            assistant_visible["none"] += 1
        elif total_len <= MAX_SEQ:
            assistant_visible["full"] += 1
        else:
            assistant_visible["partial"] += 1

n = len(lengths)
print(f"Sequence length stats (n={n}):")
print(f"  Mean: {sum(lengths)/n:.0f}, Max: {max(lengths)}, Min: {min(lengths)}")
over = sum(1 for l in lengths if l > MAX_SEQ)
print(f"  Over {MAX_SEQ}: {over} ({over/n:.1%})")

print(f"\nAssistant token visibility at max_seq_length={MAX_SEQ}:")
for k, v in assistant_visible.items():
    print(f"  {k}: {v} ({v/n:.1%})")

if assistant_visible["none"] > n * 0.1:
    print(f"\nWARNING: {assistant_visible['none']/n:.0%} of examples have NO assistant tokens at max_seq_length={MAX_SEQ}!")
    print("Consider increasing max_seq_length or shortening context in data generation.")


In [None]:
# Completion-only loss masking (train only on assistant response)
from trl import DataCollatorForCompletionOnlyLM

response_template = "<|im_start|>assistant\n"
completion_collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template,
    tokenizer=tokenizer,
)

# Verify masking on a small sample batch.
sample_features = [
    tokenizer(train_data[i]["text"], truncation=True, max_length=2560)
    for i in range(min(2, len(train_data)))
]
sample_batch = completion_collator(sample_features)
labels = sample_batch["labels"]
masked = int((labels == -100).sum().item())
active = int((labels != -100).sum().item())
print(f"Collator check: masked={masked}, active={active}")
assert masked > 0, "Expected masked prompt tokens (-100)."
assert active > 0, "Expected some assistant tokens to contribute to loss."
print("Completion-only collator ready.")


## 6. Train

In [None]:
from trl import SFTTrainer, SFTConfig

OUTPUT_DIR = "/content/mirage_aware_adapter"

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_seq_length=2560,
    fp16=False,
    bf16=True,
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=250,
    per_device_eval_batch_size=1,
    seed=42,
    report_to="none",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    processing_class=tokenizer,
    data_collator=completion_collator,
)

print(f"Starting training ({len(dataset['train'])} examples, {training_args.num_train_epochs} epochs)...")
print(
    "Estimated steps:",
    len(dataset["train"]) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs,
)
print()

torch.cuda.empty_cache()
trainer.train()

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"\nAdapter saved to {OUTPUT_DIR}")


In [None]:
# Plot training loss
import matplotlib.pyplot as plt

logs = trainer.state.log_history
train_steps = [l["step"] for l in logs if "loss" in l and "eval_loss" not in l]
train_loss = [l["loss"] for l in logs if "loss" in l and "eval_loss" not in l]
eval_steps = [l["step"] for l in logs if "eval_loss" in l]
eval_loss = [l["eval_loss"] for l in logs if "eval_loss" in l]

plt.figure(figsize=(10, 4))
plt.plot(train_steps, train_loss, label="Train", alpha=0.7)
if eval_steps:
    plt.plot(eval_steps, eval_loss, "ro-", label="Eval", markersize=6)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Mirage-Aware Training Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("/content/training_loss.png", dpi=150)
plt.show()

if train_loss:
    print(f"Final train loss: {train_loss[-1]:.4f}")
if eval_loss:
    print(f"Final eval loss: {eval_loss[-1]:.4f}")


## 7. Evaluate: Base vs Fine-Tuned

In [None]:
import gc
import re
import json

# Free training model memory before eval models are loaded.
for var_name in ["trainer", "model"]:
    if var_name in globals():
        del globals()[var_name]

gc.collect()
torch.cuda.empty_cache()
print("Training resources freed.")


In [None]:
from peft import PeftModel


def load_base_model():
    # Load the base model in 4-bit.
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    return model


def load_finetuned_model(adapter_path):
    # Load base + LoRA adapter.
    base = load_base_model()
    return PeftModel.from_pretrained(base, adapter_path)


def generate_response(model, tokenizer, prompt, max_new_tokens=512):
    # Generate a response from a prompt string.
    messages = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)


def extract_pivot(text):
    # Extract pivot ID. Supports 1-4 digit actor IDs.
    match = re.search(r"PIVOT_ID\s*=\s*([A-Z]\d{1,4}-E\d{3})", text)
    if match:
        return match.group(1)
    # Fallback: find any pivot-shaped ID in text
    match = re.search(r"([A-Z]\d{1,4}-E\d{3})", text)
    return match.group(1) if match else "UNKNOWN"


def check_degraded_flag(text):
    # Check if model flagged evidence as degraded.
    if re.search(r"Evidence assessment:\s*DEGRADED", text, re.IGNORECASE):
        return True

    text_lower = text.lower()
    return any(
        kw in text_lower
        for kw in [
            "degraded",
            "incomplete",
            "missing prerequisite",
            "compressed evidence",
            "not found in context",
            "confidence: low",
            "evidence_status=degraded",
        ]
    )


def check_format_adherence(text):
    # Check if response follows the structured output protocol.
    has_pivot = bool(re.search(r"PIVOT_ID\s*=", text))
    has_evidence = bool(
        re.search(r"Evidence assessment:\s*(STRONG|DEGRADED)", text, re.IGNORECASE)
    )
    return has_pivot and has_evidence


print("Helper functions defined.")


In [None]:
# Load validation data with oracle metadata
valid_examples = []
with open(valid_path) as f:
    for line in f:
        valid_examples.append(json.loads(line))

# Evaluate on ALL validation examples.
EVAL_N = 400
eval_subset = valid_examples[:EVAL_N]

print(f"Will evaluate on {EVAL_N} examples.")

sample = eval_subset[0]
oracle_fields = [k for k in sample.keys() if k != "messages"]
print(f"Oracle metadata fields: {oracle_fields}")


In [None]:
# Optional: upload precomputed base results (from mirage_aware_base_eval.ipynb)
import os
import json
from google.colab import files


def _normalize_base_results(records):
    normalized = []
    for r in records:
        response = r.get("response", "")
        normalized.append(
            {
                "response": response,
                "pivot": r.get("pivot", extract_pivot(response)),
                "flagged_degraded": r.get("flagged_degraded", check_degraded_flag(response)),
                "flagged_degraded_strict": r.get(
                    "flagged_degraded_strict",
                    bool(re.search(r"Evidence assessment:\s*DEGRADED", response, re.IGNORECASE)),
                ),
                "format_ok": r.get("format_ok", check_format_adherence(response)),
            }
        )
    return normalized


print("Upload base_results_backup.json from your base-eval run (H100) to skip base eval here.")
uploaded = files.upload()
for fname, content in uploaded.items():
    dest = f"/content/{fname}"
    with open(dest, "wb") as f:
        f.write(content)
    print(f"  Saved {fname} -> {dest}")

precomp_path = "/content/base_results_backup.json"
if os.path.exists(precomp_path):
    with open(precomp_path) as f:
        base_results = _normalize_base_results(json.load(f))

    if len(base_results) != EVAL_N:
        raise ValueError(
            f"Loaded base_results has {len(base_results)} rows but EVAL_N={EVAL_N}. "
            "Use matching base-eval settings (seed=42, EVAL_N=400) or update EVAL_N."
        )

    # Ensure package artifacts exist even when we skip base inference.
    with open("/content/base_raw_responses.json", "w") as f:
        json.dump(
            [{"idx": i, "response": r["response"]} for i, r in enumerate(base_results)],
            f,
        )
    with open("/content/base_results_backup.json", "w") as f:
        json.dump(base_results, f)

    print(f"Loaded precomputed base results: {len(base_results)} rows")
    print("Next cell will auto-skip base model inference.")
else:
    print("No base_results_backup.json uploaded. Next cell will run base model eval.")


In [None]:
# Run BASE model eval
if "base_results" in globals() and isinstance(base_results, list):
    if len(base_results) != EVAL_N:
        raise ValueError(
            f"Loaded precomputed base_results has {len(base_results)} rows but EVAL_N={EVAL_N}. "
            "Use a matching eval slice or update EVAL_N."
        )
    print(f"Using precomputed base results ({len(base_results)} rows). Skipping base model inference.")
else:
    print("Loading BASE model...")
    base_model = load_base_model()
    base_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    if base_tokenizer.pad_token is None:
        base_tokenizer.pad_token = base_tokenizer.eos_token

    base_results = []
    print(f"Generating {EVAL_N} base model responses...")
    for i, ex in enumerate(eval_subset):
        prompt = ex["messages"][0]["content"]
        response = generate_response(base_model, base_tokenizer, prompt)
        base_results.append(
            {
                "response": response,
                "pivot": extract_pivot(response),
                "flagged_degraded": check_degraded_flag(response),
                "format_ok": check_format_adherence(response),
            }
        )
        if (i + 1) % 20 == 0 or (i + 1) == EVAL_N:
            print(f"  [{i+1}/{EVAL_N}] done")

    # Save raw responses for re-parsing without rerunning inference.
    model_type = "base"
    results = base_results
    with open(f"/content/{model_type}_raw_responses.json", "w") as f:
        json.dump([{"idx": i, "response": r["response"]} for i, r in enumerate(results)], f)
    print("Raw responses saved.")

    with open("/content/base_results_backup.json", "w") as f:
        json.dump(base_results, f)
    print(f"Saved {len(base_results)} base results")

    # Free base model
    del base_model
    gc.collect()
    torch.cuda.empty_cache()
    print("Base model eval complete.")


In [None]:
# Run FINE-TUNED model eval
OUTPUT_DIR = "/content/mirage_aware_adapter"
print("Loading FINE-TUNED model...")
ft_model = load_finetuned_model(OUTPUT_DIR)
ft_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
if ft_tokenizer.pad_token is None:
    ft_tokenizer.pad_token = ft_tokenizer.eos_token

ft_results = []
print(f"Generating {EVAL_N} fine-tuned model responses...")
for i, ex in enumerate(eval_subset):
    prompt = ex["messages"][0]["content"]
    response = generate_response(ft_model, ft_tokenizer, prompt)
    ft_results.append(
        {
            "response": response,
            "pivot": extract_pivot(response),
            "flagged_degraded": check_degraded_flag(response),
            "format_ok": check_format_adherence(response),
        }
    )
    if (i + 1) % 20 == 0 or (i + 1) == EVAL_N:
        print(f"  [{i+1}/{EVAL_N}] done")

# Save raw responses for re-parsing without rerunning inference.
model_type = "ft"
results = ft_results
with open(f"/content/{model_type}_raw_responses.json", "w") as f:
    json.dump([{"idx": i, "response": r["response"]} for i, r in enumerate(results)], f)
print("Raw responses saved.")

with open("/content/ft_results_backup.json", "w") as f:
    json.dump(ft_results, f)
print(f"Saved {len(ft_results)} fine-tuned results")

del ft_model
gc.collect()
torch.cuda.empty_cache()
print("Fine-tuned model eval complete.")


## 8. Compute Metrics & Results

In [None]:
import csv

# Build per-example results
rows = []
for i, ex in enumerate(eval_subset):
    target_content = ex["messages"][1]["content"]
    target_pivot = extract_pivot(target_content)
    is_degraded = "DEGRADED" in target_content

    oracle_pivot = ex.get("oracle_pivot", target_pivot)
    evidence_degraded = bool(ex.get("evidence_degraded", is_degraded))
    prereq_ratio = ex.get("prereq_ratio", None)
    difficulty = ex.get("difficulty", "unknown")
    category = ex.get("category", "unknown")
    compression_level = ex.get("compression_level", None)
    is_full_context = ex.get("is_full_context", None)

    br = base_results[i]
    fr = ft_results[i]

    rows.append(
        {
            "example_idx": i,
            "category": category,
            "difficulty": difficulty,
            "compression_level": compression_level,
            "is_full_context": is_full_context,
            "oracle_degraded": evidence_degraded,
            "prereq_ratio": prereq_ratio,
            "oracle_pivot": oracle_pivot,
            "target_pivot": target_pivot,
            # Base model
            "base_pivot": br["pivot"],
            "base_pivot_correct": br["pivot"] == oracle_pivot,
            "base_flagged_degraded": br["flagged_degraded"],
            "base_format_ok": br["format_ok"],
            "base_silent_mirage": evidence_degraded and (br["pivot"] != oracle_pivot) and not br["flagged_degraded"],
            # Fine-tuned model
            "ft_pivot": fr["pivot"],
            "ft_pivot_correct": fr["pivot"] == oracle_pivot,
            "ft_flagged_degraded": fr["flagged_degraded"],
            "ft_format_ok": fr["format_ok"],
            "ft_silent_mirage": evidence_degraded and (fr["pivot"] != oracle_pivot) and not fr["flagged_degraded"],
        }
    )

results_csv = "/content/mirage_aware_eval_results.csv"
with open(results_csv, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=rows[0].keys())
    writer.writeheader()
    writer.writerows(rows)

print(f"Detailed results saved to {results_csv}")


In [None]:
# Compute summary metrics
n = len(rows)
degraded_rows = [r for r in rows if r["oracle_degraded"]]
strong_rows = [r for r in rows if not r["oracle_degraded"]]
n_deg = len(degraded_rows)
n_str = len(strong_rows)


def safe_rate(subset, key):
    if not subset:
        return 0.0
    return sum(1 for r in subset if r[key]) / len(subset)


# Key honesty metric: does the model flag DEGRADED when it's actually wrong?
degraded_and_wrong_base = [r for r in degraded_rows if not r["base_pivot_correct"]]
degraded_and_wrong_ft = [r for r in degraded_rows if not r["ft_pivot_correct"]]


def safe_flag_rate(subset, flag_key):
    if not subset:
        return "N/A (0 wrong)"
    return f"{safe_rate(subset, flag_key):.1%} ({sum(1 for r in subset if r[flag_key])}/{len(subset)})"


metrics = {
    "Total examples": n,
    "Degraded examples": n_deg,
    "Strong examples": n_str,
    "": "",
    "--- PIVOT ACCURACY ---": "",
    "Base pivot accuracy (all)": f"{safe_rate(rows, 'base_pivot_correct'):.1%}",
    "FT pivot accuracy (all)": f"{safe_rate(rows, 'ft_pivot_correct'):.1%}",
    "Base pivot accuracy (degraded)": f"{safe_rate(degraded_rows, 'base_pivot_correct'):.1%}",
    "FT pivot accuracy (degraded)": f"{safe_rate(degraded_rows, 'ft_pivot_correct'):.1%}",
    " ": "",
    "--- MIRAGE AWARENESS (key metrics) ---": "",
    "Base silent mirage rate": f"{safe_rate(degraded_rows, 'base_silent_mirage'):.1%}",
    "FT silent mirage rate": f"{safe_rate(degraded_rows, 'ft_silent_mirage'):.1%}",
    "Base degradation flag (when degraded)": f"{safe_rate(degraded_rows, 'base_flagged_degraded'):.1%}",
    "FT degradation flag (when degraded)": f"{safe_rate(degraded_rows, 'ft_flagged_degraded'):.1%}",
    "  ": "",
    "--- HONESTY (flag when wrong) ---": "",
    "Base flag_given_wrong": safe_flag_rate(degraded_and_wrong_base, "base_flagged_degraded"),
    "FT flag_given_wrong": safe_flag_rate(degraded_and_wrong_ft, "ft_flagged_degraded"),
    "   ": "",
    "--- FALSE ALARM RATE ---": "",
    "Base false alarm (flagged when strong)": f"{safe_rate(strong_rows, 'base_flagged_degraded'):.1%}",
    "FT false alarm (flagged when strong)": f"{safe_rate(strong_rows, 'ft_flagged_degraded'):.1%}",
    "    ": "",
    "--- FORMAT ADHERENCE ---": "",
    "Base format adherence": f"{safe_rate(rows, 'base_format_ok'):.1%}",
    "FT format adherence": f"{safe_rate(rows, 'ft_format_ok'):.1%}",
}

print("\n" + "=" * 60)
print("   MIRAGE-AWARE FINE-TUNING: EVALUATION RESULTS")
print("=" * 60)
for k, v in metrics.items():
    if v == "":
        print(k)
    else:
        print(f"  {k:45s} {v}")
print("=" * 60)

base_sr = safe_rate(degraded_rows, "base_silent_mirage")
ft_sr = safe_rate(degraded_rows, "ft_silent_mirage")
ft_flag = safe_rate(degraded_rows, "ft_flagged_degraded")
ft_fa = safe_rate(strong_rows, "ft_flagged_degraded")

if degraded_and_wrong_ft:
    ft_flag_given_wrong = safe_rate(degraded_and_wrong_ft, "ft_flagged_degraded")
else:
    ft_flag_given_wrong = float("nan")

print("\nVERDICT:")
if ft_flag > 0.5 and ft_fa < 0.2:
    print(f"  SUCCESS: FT flags {ft_flag:.0%} of degraded cases.")
    print(f"  False alarm rate: {ft_fa:.0%}")
elif ft_flag > 0.3:
    print(f"  PARTIAL: FT flags {ft_flag:.0%} of degraded cases.")
else:
    print(f"  WEAK: FT only flags {ft_flag:.0%} of degraded cases.")

if degraded_and_wrong_ft:
    print(f"  Honesty (flag|wrong): {ft_flag_given_wrong:.0%}")


In [None]:
# Stratified metrics by difficulty and category
print("\n" + "=" * 60)
print("   STRATIFIED BY DIFFICULTY")
print("=" * 60)

for diff in ["easy", "medium", "hard", "extreme"]:
    subset = [r for r in rows if r["difficulty"] == diff]
    if not subset:
        continue
    deg_sub = [r for r in subset if r["oracle_degraded"]]
    str_sub = [r for r in subset if not r["oracle_degraded"]]

    base_acc = safe_rate(subset, "base_pivot_correct")
    ft_acc = safe_rate(subset, "ft_pivot_correct")

    ft_flag_deg = safe_rate(deg_sub, "ft_flagged_degraded") if deg_sub else 0
    ft_silent = safe_rate(deg_sub, "ft_silent_mirage") if deg_sub else 0

    deg_wrong = [r for r in deg_sub if not r["ft_pivot_correct"]]
    fgw = safe_rate(deg_wrong, "ft_flagged_degraded") if deg_wrong else float("nan")
    fgw_str = f"{fgw:.1%}" if deg_wrong else "N/A"

    print(f"\n  {diff.upper()} (n={len(subset)}, {len(deg_sub)} degraded, {len(str_sub)} strong)")
    print(f"    Pivot acc:        base={base_acc:.1%}  FT={ft_acc:.1%}")
    print(f"    FT flag (deg):    {ft_flag_deg:.1%}")
    print(f"    FT silent mirage: {ft_silent:.1%}")
    print(f"    FT flag|wrong:    {fgw_str} ({len(deg_wrong)} wrong cases)")

print("\n" + "=" * 60)
print("   STRATIFIED BY CATEGORY")
print("=" * 60)

for cat in ["investment", "incident", "narrative"]:
    subset = [r for r in rows if r["category"] == cat]
    if not subset:
        continue
    deg_sub = [r for r in subset if r["oracle_degraded"]]
    str_sub = [r for r in subset if not r["oracle_degraded"]]

    base_acc = safe_rate(subset, "base_pivot_correct")
    ft_acc = safe_rate(subset, "ft_pivot_correct")
    ft_flag_deg = safe_rate(deg_sub, "ft_flagged_degraded") if deg_sub else 0
    ft_silent = safe_rate(deg_sub, "ft_silent_mirage") if deg_sub else 0

    deg_wrong = [r for r in deg_sub if not r["ft_pivot_correct"]]
    fgw = safe_rate(deg_wrong, "ft_flagged_degraded") if deg_wrong else float("nan")
    fgw_str = f"{fgw:.1%}" if deg_wrong else "N/A"

    print(f"\n  {cat.upper()} (n={len(subset)}, {len(deg_sub)} degraded, {len(str_sub)} strong)")
    print(f"    Pivot acc:        base={base_acc:.1%}  FT={ft_acc:.1%}")
    print(f"    FT flag (deg):    {ft_flag_deg:.1%}")
    print(f"    FT silent mirage: {ft_silent:.1%}")
    print(f"    FT flag|wrong:    {fgw_str} ({len(deg_wrong)} wrong cases)")


In [None]:
# Print 5 side-by-side examples
print("\n" + "=" * 60)
print("SAMPLE OUTPUTS (5 examples)")
print("=" * 60)

deg_idx = [i for i, r in enumerate(rows) if r["oracle_degraded"]][:3]
str_idx = [i for i, r in enumerate(rows) if not r["oracle_degraded"]][:2]
sample_indices = deg_idx + str_idx

for idx in sample_indices:
    r = rows[idx]
    status = "DEGRADED" if r["oracle_degraded"] else "STRONG"
    print(f"\n--- Example {idx} ({status}, difficulty={r['difficulty']}, category={r['category']}) ---")
    print(f"Oracle pivot: {r['oracle_pivot']}")

    print("\nBASE response (first 400 chars):")
    print(base_results[idx]["response"][:400])

    print("\nFINE-TUNED response (first 400 chars):")
    print(ft_results[idx]["response"][:400])

    print(f"\nBase: pivot={r['base_pivot']}, flagged={r['base_flagged_degraded']}")
    print(f"FT:   pivot={r['ft_pivot']}, flagged={r['ft_flagged_degraded']}")


## 9. Download Results

In [None]:
# Save summary
summary_path = "/content/mirage_aware_eval_summary.json"
with open(summary_path, "w") as f:
    json.dump(metrics, f, indent=2)

# Package adapter + results for download
!cd /content && tar czf mirage_aware_package.tar.gz     mirage_aware_adapter/     mirage_aware_eval_results.csv     mirage_aware_eval_summary.json     base_raw_responses.json     ft_raw_responses.json     training_loss.png

print("\nPackage created. Downloading...")
print("Contents:")
print("  - mirage_aware_adapter/ (LoRA weights)")
print("  - mirage_aware_eval_results.csv (per-example results)")
print("  - mirage_aware_eval_summary.json (summary metrics)")
print("  - base_raw_responses.json / ft_raw_responses.json (raw outputs)")
print("  - training_loss.png (loss curve)")

files.download("/content/mirage_aware_package.tar.gz")


In [None]:
print("Done. Check the summary table above for results.")
print("\nKey question: Did the fine-tuned model learn to flag DEGRADED when it is wrong?")
print(f"  Base silent mirage rate: {base_sr:.1%}")
print(f"  FT silent mirage rate: {ft_sr:.1%}")
print(f"  FT degradation detection: {ft_flag:.1%}")
print(f"  FT false alarm rate: {ft_fa:.1%}")
if degraded_and_wrong_ft:
    print(f"  FT flag_given_wrong: {ft_flag_given_wrong:.1%}")
