
# Mirage-Aware Fine-Tuning: Gemma 2B LoRA (H100/A100)


## 1. Setup

In [None]:
!pip install -q transformers==4.46.3 peft==0.13.2 trl==0.12.2 accelerate==1.1.1 bitsandbytes>=0.46.1 datasets matplotlib
print("Dependencies installed.")


In [None]:

# Verify GPU + set deterministic seeds
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

assert torch.cuda.is_available(), "No GPU detected! Go to Runtime > Change runtime type > GPU"
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
print(f"Seed fixed at {SEED}")


## 2. Upload Training Data

Upload your `train.jsonl` and `valid.jsonl` files when the file picker appears.


In [None]:
import os
import json
from google.colab import files

DATA_DIR = "/content/training_data"
os.makedirs(DATA_DIR, exist_ok=True)

train_path = os.path.join(DATA_DIR, "train.jsonl")
valid_path = os.path.join(DATA_DIR, "valid.jsonl")

if not os.path.exists(train_path) or not os.path.exists(valid_path):
    print("Upload train.jsonl and valid.jsonl:")
    uploaded = files.upload()
    for fname, content in uploaded.items():
        dest = os.path.join(DATA_DIR, fname)
        with open(dest, "wb") as f:
            f.write(content)
        print(f"  Saved {fname} -> {dest}")
else:
    print(f"Data already present at {DATA_DIR}")


def count_and_check(path):
    count = 0
    strong = 0
    degraded = 0
    for line in open(path):
        obj = json.loads(line)
        assert "messages" in obj, f"Missing 'messages' key in {path}"
        assert len(obj["messages"]) >= 2, f"Need at least 2 messages in {path}"
        content = obj["messages"][1]["content"]
        if "STRONG" in content:
            strong += 1
        if "DEGRADED" in content:
            degraded += 1
        count += 1
    return count, strong, degraded


for p, name in [(train_path, "Train"), (valid_path, "Valid")]:
    n, s, d = count_and_check(p)
    pct = d / n * 100 if n > 0 else 0
    print(f"{name}: {n} examples ({s} STRONG, {d} DEGRADED, {pct:.0f}% degraded)")

print("\nData verified.")


## 3. Load Model

In [None]:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_ID = "google/gemma-2-2b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {MODEL_ID} in 4-bit...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")


## 4. Configure LoRA

In [None]:

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

# Verify target modules exist before applying LoRA.
all_module_names = [name for name, _ in model.named_modules()]
matched_modules = sorted({
    name.split(".")[-1]
    for name in all_module_names
    if name.split(".")[-1] in TARGET_MODULES
})
matched_full_names = [name for name in all_module_names if name.split(".")[-1] in TARGET_MODULES]

print(f"Requested target modules: {TARGET_MODULES}")
print(f"Matched target module types: {matched_modules}")
print(f"Matched module instances: {len(matched_full_names)}")
if matched_full_names:
    print("Sample matched module paths:")
    for name in matched_full_names[:12]:
        print(f"  - {name}")

assert len(matched_full_names) > 0, (
    "No LoRA target modules matched in Gemma. "
    "Inspect module names and update TARGET_MODULES before training."
)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGET_MODULES,
)

model = get_peft_model(model, lora_config)
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")


## 5. Load Dataset

In [None]:

from datasets import load_dataset

dataset = load_dataset("json", data_files={
    "train": train_path,
    "validation": valid_path,
})

print(f"Train: {len(dataset['train'])} examples")
print(f"Valid: {len(dataset['validation'])} examples")


def is_degraded(example):
    if "evidence_degraded" in example:
        return bool(example["evidence_degraded"])
    target = example["messages"][-1]["content"]
    return "EVIDENCE ASSESSMENT: DEGRADED" in target.upper() or "DEGRADED" in target.upper()


train_split = dataset["train"]
valid_split = dataset["validation"]

strong_indices = []
degraded_indices = []
for idx, ex in enumerate(train_split):
    if is_degraded(ex):
        degraded_indices.append(idx)
    else:
        strong_indices.append(idx)

assert strong_indices, "No STRONG examples found in train split."
assert degraded_indices, "No DEGRADED examples found in train split."

rng = random.Random(SEED)
if len(strong_indices) < len(degraded_indices):
    extra = [rng.choice(strong_indices) for _ in range(len(degraded_indices) - len(strong_indices))]
    strong_balanced = strong_indices + extra
else:
    # Rare case: if STRONG already exceeds DEGRADED, downsample for a balanced slice.
    strong_balanced = rng.sample(strong_indices, len(degraded_indices))

balanced_indices = degraded_indices + strong_balanced
rng.shuffle(balanced_indices)
balanced_train_split = train_split.select(balanced_indices)

print("\nBalanced training slice:")
print(f"  Original train size: {len(train_split)}")
print(f"  Original STRONG: {len(strong_indices)}")
print(f"  Original DEGRADED: {len(degraded_indices)}")
print(f"  Balanced train size: {len(balanced_train_split)}")
print(f"  Balanced STRONG target: {len(strong_balanced)}")
print(f"  Balanced DEGRADED: {len(degraded_indices)}")

# Keep only messages for SFT training.
train_cols = [c for c in balanced_train_split.column_names if c != "messages"]
valid_cols = [c for c in valid_split.column_names if c != "messages"]

train_data = balanced_train_split.remove_columns(train_cols)
valid_data = valid_split.remove_columns(valid_cols)
print(f"Stripped extra columns: {train_cols}")


def format_chat(example):
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return {"text": text}


train_data = train_data.map(format_chat, remove_columns=["messages"])
valid_data = valid_data.map(format_chat, remove_columns=["messages"])

print(f"Train columns: {train_data.column_names}")
print(f"Valid columns: {valid_data.column_names}")
print("\nSample formatted text (first 500 chars):")
print(train_data[0]["text"][:500] + "...")


In [None]:

# Derive Gemma response template from chat template + sequence length diagnostic
MAX_SEQ = 2560  # must match training config

probe_messages = [{"role": "user", "content": "Template probe"}]
chat_without_gen = tokenizer.apply_chat_template(
    probe_messages, tokenize=False, add_generation_prompt=False
)
chat_with_gen = tokenizer.apply_chat_template(
    probe_messages, tokenize=False, add_generation_prompt=True
)

if chat_with_gen.startswith(chat_without_gen):
    response_template = chat_with_gen[len(chat_without_gen):]
else:
    # Fallback: compute suffix after common prefix.
    common = 0
    for a, b in zip(chat_without_gen, chat_with_gen):
        if a != b:
            break
        common += 1
    response_template = chat_with_gen[common:]

assert response_template, (
    "Failed to derive response_template from Gemma chat template. "
    "Inspect tokenizer.apply_chat_template output."
)

response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)
assert len(response_template_ids) > 0, "Derived response_template token IDs are empty."

print("Derived response template (repr):")
print(repr(response_template))
print(f"Response template token count: {len(response_template_ids)}")

lengths = []
assistant_visible = {"full": 0, "partial": 0, "none": 0, "marker_missing": 0}

for ex in train_data:
    text = ex["text"]
    tokens = tokenizer(text, truncation=False)["input_ids"]
    total_len = len(tokens)
    lengths.append(total_len)

    asst_start_char = text.find(response_template)
    if asst_start_char < 0:
        assistant_visible["marker_missing"] += 1
        continue

    prefix = text[: asst_start_char + len(response_template)]
    prefix_tokens = len(tokenizer(prefix, truncation=False)["input_ids"])

    if prefix_tokens >= MAX_SEQ:
        assistant_visible["none"] += 1
    elif total_len <= MAX_SEQ:
        assistant_visible["full"] += 1
    else:
        assistant_visible["partial"] += 1

n = len(lengths)
print(f"\nSequence length stats (n={n}):")
print(f"  Mean: {sum(lengths)/n:.0f}, Max: {max(lengths)}, Min: {min(lengths)}")
over = sum(1 for l in lengths if l > MAX_SEQ)
print(f"  Over {MAX_SEQ}: {over} ({over/n:.1%})")

print(f"\nAssistant token visibility at max_seq_length={MAX_SEQ}:")
for k, v in assistant_visible.items():
    print(f"  {k}: {v} ({v/n:.1%})")

if assistant_visible["none"] > n * 0.1:
    print(f"\nWARNING: {assistant_visible['none']/n:.0%} of examples have NO assistant tokens at max_seq_length={MAX_SEQ}!")
    print("Consider increasing max_seq_length or shortening context in data generation.")


In [None]:

# Completion-only loss masking (train only on assistant response)
from trl import DataCollatorForCompletionOnlyLM

completion_collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template_ids,
    tokenizer=tokenizer,
)

# Verify masking on a sample batch.
sample_features = [
    tokenizer(train_data[i]["text"], truncation=True, max_length=MAX_SEQ)
    for i in range(min(4, len(train_data)))
]
sample_batch = completion_collator(sample_features)
labels = sample_batch["labels"]
input_ids = sample_batch["input_ids"]

masked = int((labels == -100).sum().item())
active = int((labels != -100).sum().item())
print(f"Collator check: masked={masked}, active={active}")
assert masked > 0, "Expected masked prompt tokens (-100)."
assert active > 0, "Expected some assistant tokens to contribute to loss."

# Strong sanity check: each sampled row should contain the response marker and masked prefix.
for i in range(min(2, input_ids.shape[0])):
    ids = input_ids[i].tolist()
    start = None
    for j in range(len(ids) - len(response_template_ids) + 1):
        if ids[j : j + len(response_template_ids)] == response_template_ids:
            start = j
            break
    assert start is not None, "Response template IDs not found in sample input_ids."

    # Prefix up to the response marker should be fully masked.
    prefix_masked = bool((labels[i, :start] == -100).all().item())
    assert prefix_masked, "Non-assistant prefix tokens are not fully masked."

print("Completion-only collator ready.")


## 6. Train

In [None]:

from trl import SFTTrainer, SFTConfig

OUTPUT_DIR = "/content/mirage_aware_gemma2b_adapter"

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_seq_length=2560,
    fp16=False,
    bf16=True,
    logging_steps=10,
    save_steps=100,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=250,
    per_device_eval_batch_size=1,
    seed=SEED,
    report_to="none",
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    packing=False,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=valid_data,
    processing_class=tokenizer,
    data_collator=completion_collator,
)

print(f"Starting training ({len(train_data)} balanced examples, {training_args.num_train_epochs} epochs)...")
steps_per_epoch = max(1, len(train_data) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps))
print("Estimated steps:", steps_per_epoch * training_args.num_train_epochs)
print()

torch.cuda.empty_cache()
trainer.train()

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"\nAdapter saved to {OUTPUT_DIR}")


In [None]:

# Plot training loss
import matplotlib.pyplot as plt

logs = trainer.state.log_history
train_steps = [l["step"] for l in logs if "loss" in l and "eval_loss" not in l]
train_loss = [l["loss"] for l in logs if "loss" in l and "eval_loss" not in l]
eval_steps = [l["step"] for l in logs if "eval_loss" in l]
eval_loss = [l["eval_loss"] for l in logs if "eval_loss" in l]

plt.figure(figsize=(10, 4))
plt.plot(train_steps, train_loss, label="Train", alpha=0.7)
if eval_steps:
    plt.plot(eval_steps, eval_loss, "ro-", label="Eval", markersize=6)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Mirage-Aware Gemma 2B Training Loss")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("/content/training_loss.png", dpi=150)
plt.show()

if train_loss:
    print(f"Final train loss: {train_loss[-1]:.4f}")
if eval_loss:
    print(f"Final eval loss: {eval_loss[-1]:.4f}")


## 7. Evaluate: Base vs Fine-Tuned

In [None]:
import gc
import re
import json

# Free training model memory before eval models are loaded.
for var_name in ["trainer", "model"]:
    if var_name in globals():
        del globals()[var_name]

gc.collect()
torch.cuda.empty_cache()
print("Training resources freed.")


In [None]:

from peft import PeftModel


def load_base_model():
    # Load the base model in 4-bit.
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    return model


def load_finetuned_model(adapter_path):
    # Load base + LoRA adapter.
    base = load_base_model()
    return PeftModel.from_pretrained(base, adapter_path)


def generate_response(model, tokenizer, prompt, max_new_tokens=512):
    # Generate a response from a prompt string with deterministic greedy decoding.
    messages = [{"role": "user", "content": prompt}]
    input_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
        )

    new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)


def extract_pivot(text):
    # Extract pivot ID. Supports 1-4 digit actor IDs.
    match = re.search(r"PIVOT_ID\s*=\s*([A-Z]\d{1,4}-E\d{3})", text)
    if match:
        return match.group(1)
    # Fallback: find any pivot-shaped ID in text
    match = re.search(r"([A-Z]\d{1,4}-E\d{3})", text)
    return match.group(1) if match else "UNKNOWN"


def check_degraded_flag(text):
    # Strict protocol check used for headline metrics.
    return bool(re.search(r"Evidence assessment:\s*DEGRADED", text, re.IGNORECASE))


def check_degraded_flag_loose(text):
    # Loose heuristic retained for diagnostics only.
    if check_degraded_flag(text):
        return True

    text_lower = text.lower()
    return any(
        kw in text_lower
        for kw in [
            "degraded",
            "incomplete",
            "missing prerequisite",
            "compressed evidence",
            "not found in context",
            "confidence: low",
            "evidence_status=degraded",
        ]
    )


def check_format_adherence(text):
    # Check if response follows the structured output protocol.
    has_pivot = bool(re.search(r"PIVOT_ID\s*=", text))
    has_evidence = bool(
        re.search(r"Evidence assessment:\s*(STRONG|DEGRADED)", text, re.IGNORECASE)
    )
    return has_pivot and has_evidence


print("Helper functions defined.")


In [None]:

# Load validation data and build deterministic eval subsets
valid_examples = []
with open(valid_path) as f:
    for line in f:
        valid_examples.append(json.loads(line))


def oracle_is_degraded(ex):
    if "evidence_degraded" in ex:
        return bool(ex["evidence_degraded"])
    target = ex["messages"][1]["content"]
    return "EVIDENCE ASSESSMENT: DEGRADED" in target.upper() or "DEGRADED" in target.upper()


rng_eval = random.Random(SEED)
shuffled_valid = list(valid_examples)
rng_eval.shuffle(shuffled_valid)

# Natural-distribution subset
natural_n = min(400, len(shuffled_valid))
eval_subset_natural = shuffled_valid[:natural_n]

# Balanced subset: deterministic 200 strong + 200 degraded
degraded_pool = [ex for ex in shuffled_valid if oracle_is_degraded(ex)]
strong_pool = [ex for ex in shuffled_valid if not oracle_is_degraded(ex)]

assert len(degraded_pool) >= 200, f"Need >=200 degraded examples, found {len(degraded_pool)}"
assert len(strong_pool) >= 200, f"Need >=200 strong examples, found {len(strong_pool)}"

rng_eval.shuffle(degraded_pool)
rng_eval.shuffle(strong_pool)

balanced_each = 200
eval_subset_balanced = degraded_pool[:balanced_each] + strong_pool[:balanced_each]
rng_eval.shuffle(eval_subset_balanced)

eval_slices = {
    "natural_400": eval_subset_natural,
    "balanced_400": eval_subset_balanced,
}

print(f"Built eval subsets with seed={SEED}.")
for slice_name, subset in eval_slices.items():
    from collections import Counter

    diff_dist = Counter(ex.get("difficulty", "unknown") for ex in subset)
    deg_dist = Counter("degraded" if oracle_is_degraded(ex) else "strong" for ex in subset)
    print(f"\n{slice_name}: n={len(subset)}")
    print(f"  Difficulty: {dict(diff_dist)}")
    print(f"  Labels: {dict(deg_dist)}")

sample = eval_subset_natural[0]
oracle_fields = [k for k in sample.keys() if k != "messages"]
print(f"\nOracle metadata fields: {oracle_fields}")


In [None]:

# Precomputed-base merge is intentionally skipped here.
# This notebook evaluates both natural and balanced slices in one run.
print("Proceeding with in-notebook base eval for all slices.")


In [None]:

# Run BASE model eval (all slices)
print("Loading BASE model...")
base_model = load_base_model()
base_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if base_tokenizer.pad_token is None:
    base_tokenizer.pad_token = base_tokenizer.eos_token

base_results_by_slice = {}

for slice_name, subset in eval_slices.items():
    results = []
    print(f"\nGenerating BASE responses for {slice_name} (n={len(subset)})...")
    for i, ex in enumerate(subset):
        prompt = ex["messages"][0]["content"]
        response = generate_response(base_model, base_tokenizer, prompt)
        strict_flag = check_degraded_flag(response)
        results.append(
            {
                "response": response,
                "pivot": extract_pivot(response),
                "flagged_degraded_strict": strict_flag,
                "flagged_degraded_loose": check_degraded_flag_loose(response),
                "format_ok": check_format_adherence(response),
            }
        )
        if (i + 1) % 20 == 0 or (i + 1) == len(subset):
            print(f"  [{i+1}/{len(subset)}] done")

    base_results_by_slice[slice_name] = results
    with open(f"/content/base_raw_responses_{slice_name}.json", "w") as f:
        json.dump([{"idx": i, "response": r["response"]} for i, r in enumerate(results)], f)

with open("/content/base_raw_responses.json", "w") as f:
    json.dump(
        {
            slice_name: [{"idx": i, "response": r["response"]} for i, r in enumerate(results)]
            for slice_name, results in base_results_by_slice.items()
        },
        f,
    )
print("Raw BASE responses saved.")

with open("/content/base_results_backup.json", "w") as f:
    json.dump(base_results_by_slice, f)
print("Parsed BASE results saved.")

del base_model
gc.collect()
torch.cuda.empty_cache()
print("Base model eval complete.")


In [None]:

# Run FINE-TUNED model eval (all slices)
print("Loading FINE-TUNED model...")
ft_model = load_finetuned_model(OUTPUT_DIR)
ft_tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
if ft_tokenizer.pad_token is None:
    ft_tokenizer.pad_token = ft_tokenizer.eos_token

ft_results_by_slice = {}

for slice_name, subset in eval_slices.items():
    results = []
    print(f"\nGenerating FT responses for {slice_name} (n={len(subset)})...")
    for i, ex in enumerate(subset):
        prompt = ex["messages"][0]["content"]
        response = generate_response(ft_model, ft_tokenizer, prompt)
        strict_flag = check_degraded_flag(response)
        results.append(
            {
                "response": response,
                "pivot": extract_pivot(response),
                "flagged_degraded_strict": strict_flag,
                "flagged_degraded_loose": check_degraded_flag_loose(response),
                "format_ok": check_format_adherence(response),
            }
        )
        if (i + 1) % 20 == 0 or (i + 1) == len(subset):
            print(f"  [{i+1}/{len(subset)}] done")

    ft_results_by_slice[slice_name] = results
    with open(f"/content/ft_raw_responses_{slice_name}.json", "w") as f:
        json.dump([{"idx": i, "response": r["response"]} for i, r in enumerate(results)], f)

with open("/content/ft_raw_responses.json", "w") as f:
    json.dump(
        {
            slice_name: [{"idx": i, "response": r["response"]} for i, r in enumerate(results)]
            for slice_name, results in ft_results_by_slice.items()
        },
        f,
    )
print("Raw FT responses saved.")

with open("/content/ft_results_backup.json", "w") as f:
    json.dump(ft_results_by_slice, f)
print("Parsed FT results saved.")

del ft_model
gc.collect()
torch.cuda.empty_cache()
print("Fine-tuned model eval complete.")



## 8. Compute Metrics & Results


In [None]:

import csv

# Build per-example results for each eval slice.
rows_by_slice = {}
all_rows = []

for slice_name, subset in eval_slices.items():
    rows = []
    base_results = base_results_by_slice[slice_name]
    ft_results = ft_results_by_slice[slice_name]

    for i, ex in enumerate(subset):
        target_content = ex["messages"][1]["content"]
        target_pivot = extract_pivot(target_content)
        is_degraded = "DEGRADED" in target_content

        oracle_pivot = ex.get("oracle_pivot", target_pivot)
        evidence_degraded = bool(ex.get("evidence_degraded", is_degraded))
        prereq_ratio = ex.get("prereq_ratio", None)
        difficulty = ex.get("difficulty", "unknown")
        category = ex.get("category", "unknown")
        compression_level = ex.get("compression_level", None)
        is_full_context = ex.get("is_full_context", None)

        br = base_results[i]
        fr = ft_results[i]

        row = {
            "eval_slice": slice_name,
            "example_idx": i,
            "category": category,
            "difficulty": difficulty,
            "compression_level": compression_level,
            "is_full_context": is_full_context,
            "oracle_degraded": evidence_degraded,
            "prereq_ratio": prereq_ratio,
            "oracle_pivot": oracle_pivot,
            "target_pivot": target_pivot,
            # Base model
            "base_pivot": br["pivot"],
            "base_pivot_correct": br["pivot"] == oracle_pivot,
            "base_flagged_degraded": br["flagged_degraded_strict"],
            "base_flagged_degraded_loose": br["flagged_degraded_loose"],
            "base_format_ok": br["format_ok"],
            "base_silent_mirage": evidence_degraded and (br["pivot"] != oracle_pivot) and not br["flagged_degraded_strict"],
            # Fine-tuned model
            "ft_pivot": fr["pivot"],
            "ft_pivot_correct": fr["pivot"] == oracle_pivot,
            "ft_flagged_degraded": fr["flagged_degraded_strict"],
            "ft_flagged_degraded_loose": fr["flagged_degraded_loose"],
            "ft_format_ok": fr["format_ok"],
            "ft_silent_mirage": evidence_degraded and (fr["pivot"] != oracle_pivot) and not fr["flagged_degraded_strict"],
        }
        rows.append(row)
        all_rows.append(row)

    rows_by_slice[slice_name] = rows

results_csv = "/content/mirage_aware_gemma2b_eval_results.csv"
with open(results_csv, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=all_rows[0].keys())
    writer.writeheader()
    writer.writerows(all_rows)

print(f"Detailed results saved to {results_csv}")


In [None]:

# Compute summary metrics per eval slice

def safe_rate(subset, key):
    if not subset:
        return 0.0
    return sum(1 for r in subset if r[key]) / len(subset)


def safe_flag_rate(subset, flag_key):
    if not subset:
        return "N/A (0 wrong)"
    return f"{safe_rate(subset, flag_key):.1%} ({sum(1 for r in subset if r[flag_key])}/{len(subset)})"


metrics_by_slice = {}

for slice_name, rows in rows_by_slice.items():
    n = len(rows)
    degraded_rows = [r for r in rows if r["oracle_degraded"]]
    strong_rows = [r for r in rows if not r["oracle_degraded"]]
    n_deg = len(degraded_rows)
    n_str = len(strong_rows)

    # Key honesty metric: does the model flag DEGRADED when it's actually wrong?
    degraded_and_wrong_base = [r for r in degraded_rows if not r["base_pivot_correct"]]
    degraded_and_wrong_ft = [r for r in degraded_rows if not r["ft_pivot_correct"]]

    metrics = {
        "Total examples": n,
        "Degraded examples": n_deg,
        "Strong examples": n_str,
        "": "",
        "--- PIVOT ACCURACY ---": "",
        "Base pivot accuracy (all)": f"{safe_rate(rows, 'base_pivot_correct'):.1%}",
        "FT pivot accuracy (all)": f"{safe_rate(rows, 'ft_pivot_correct'):.1%}",
        "Base pivot accuracy (degraded)": f"{safe_rate(degraded_rows, 'base_pivot_correct'):.1%}",
        "FT pivot accuracy (degraded)": f"{safe_rate(degraded_rows, 'ft_pivot_correct'):.1%}",
        " ": "",
        "--- MIRAGE AWARENESS (strict protocol flag) ---": "",
        "Base silent mirage rate": f"{safe_rate(degraded_rows, 'base_silent_mirage'):.1%}",
        "FT silent mirage rate": f"{safe_rate(degraded_rows, 'ft_silent_mirage'):.1%}",
        "Base degradation flag (when degraded)": f"{safe_rate(degraded_rows, 'base_flagged_degraded'):.1%}",
        "FT degradation flag (when degraded)": f"{safe_rate(degraded_rows, 'ft_flagged_degraded'):.1%}",
        "  ": "",
        "--- HONESTY (flag when wrong) ---": "",
        "Base flag_given_wrong": safe_flag_rate(degraded_and_wrong_base, "base_flagged_degraded"),
        "FT flag_given_wrong": safe_flag_rate(degraded_and_wrong_ft, "ft_flagged_degraded"),
        "   ": "",
        "--- FALSE ALARM RATE ---": "",
        "Base false alarm (flagged when strong)": f"{safe_rate(strong_rows, 'base_flagged_degraded'):.1%}",
        "FT false alarm (flagged when strong)": f"{safe_rate(strong_rows, 'ft_flagged_degraded'):.1%}",
        "    ": "",
        "--- FORMAT ADHERENCE ---": "",
        "Base format adherence": f"{safe_rate(rows, 'base_format_ok'):.1%}",
        "FT format adherence": f"{safe_rate(rows, 'ft_format_ok'):.1%}",
        "     ": "",
        "--- DIAGNOSTIC (loose flag rates) ---": "",
        "Base loose degradation flag (degraded)": f"{safe_rate(degraded_rows, 'base_flagged_degraded_loose'):.1%}",
        "FT loose degradation flag (degraded)": f"{safe_rate(degraded_rows, 'ft_flagged_degraded_loose'):.1%}",
    }

    metrics_by_slice[slice_name] = metrics

    print("\n" + "=" * 70)
    print(f"   MIRAGE-AWARE GEMMA 2B: EVALUATION RESULTS [{slice_name}]")
    print("=" * 70)
    for k, v in metrics.items():
        if v == "":
            print(k)
        else:
            print(f"  {k:50s} {v}")
    print("=" * 70)

headline_slice = "balanced_400" if "balanced_400" in rows_by_slice else next(iter(rows_by_slice))
headline_rows = rows_by_slice[headline_slice]
headline_degraded = [r for r in headline_rows if r["oracle_degraded"]]
headline_strong = [r for r in headline_rows if not r["oracle_degraded"]]
degraded_and_wrong_ft = [r for r in headline_degraded if not r["ft_pivot_correct"]]

metrics = metrics_by_slice[headline_slice]
base_sr = safe_rate(headline_degraded, "base_silent_mirage")
ft_sr = safe_rate(headline_degraded, "ft_silent_mirage")
ft_flag = safe_rate(headline_degraded, "ft_flagged_degraded")
ft_fa = safe_rate(headline_strong, "ft_flagged_degraded")
ft_flag_given_wrong = safe_rate(degraded_and_wrong_ft, "ft_flagged_degraded") if degraded_and_wrong_ft else float("nan")

print(f"\nHeadline slice for verdicts: {headline_slice}")


In [None]:

# Stratified metrics by difficulty and category (for each eval slice)
for slice_name, rows in rows_by_slice.items():
    print("\n" + "=" * 70)
    print(f"   STRATIFIED BY DIFFICULTY [{slice_name}]")
    print("=" * 70)

    for diff in ["easy", "medium", "hard", "extreme"]:
        subset = [r for r in rows if r["difficulty"] == diff]
        if not subset:
            continue
        deg_sub = [r for r in subset if r["oracle_degraded"]]
        str_sub = [r for r in subset if not r["oracle_degraded"]]

        base_acc = safe_rate(subset, "base_pivot_correct")
        ft_acc = safe_rate(subset, "ft_pivot_correct")

        ft_flag_deg = safe_rate(deg_sub, "ft_flagged_degraded") if deg_sub else 0
        ft_silent = safe_rate(deg_sub, "ft_silent_mirage") if deg_sub else 0

        deg_wrong = [r for r in deg_sub if not r["ft_pivot_correct"]]
        fgw = safe_rate(deg_wrong, "ft_flagged_degraded") if deg_wrong else float("nan")
        fgw_str = f"{fgw:.1%}" if deg_wrong else "N/A"

        print(f"\n  {diff.upper()} (n={len(subset)}, {len(deg_sub)} degraded, {len(str_sub)} strong)")
        print(f"    Pivot acc:        base={base_acc:.1%}  FT={ft_acc:.1%}")
        print(f"    FT flag (deg):    {ft_flag_deg:.1%}")
        print(f"    FT silent mirage: {ft_silent:.1%}")
        print(f"    FT flag|wrong:    {fgw_str} ({len(deg_wrong)} wrong cases)")

    print("\n" + "=" * 70)
    print(f"   STRATIFIED BY CATEGORY [{slice_name}]")
    print("=" * 70)

    for cat in ["investment", "incident", "narrative"]:
        subset = [r for r in rows if r["category"] == cat]
        if not subset:
            continue
        deg_sub = [r for r in subset if r["oracle_degraded"]]
        str_sub = [r for r in subset if not r["oracle_degraded"]]

        base_acc = safe_rate(subset, "base_pivot_correct")
        ft_acc = safe_rate(subset, "ft_pivot_correct")
        ft_flag_deg = safe_rate(deg_sub, "ft_flagged_degraded") if deg_sub else 0
        ft_silent = safe_rate(deg_sub, "ft_silent_mirage") if deg_sub else 0

        deg_wrong = [r for r in deg_sub if not r["ft_pivot_correct"]]
        fgw = safe_rate(deg_wrong, "ft_flagged_degraded") if deg_wrong else float("nan")
        fgw_str = f"{fgw:.1%}" if deg_wrong else "N/A"

        print(f"\n  {cat.upper()} (n={len(subset)}, {len(deg_sub)} degraded, {len(str_sub)} strong)")
        print(f"    Pivot acc:        base={base_acc:.1%}  FT={ft_acc:.1%}")
        print(f"    FT flag (deg):    {ft_flag_deg:.1%}")
        print(f"    FT silent mirage: {ft_silent:.1%}")
        print(f"    FT flag|wrong:    {fgw_str} ({len(deg_wrong)} wrong cases)")


In [None]:

# Print 5 side-by-side examples from headline slice
print("\n" + "=" * 60)
print(f"SAMPLE OUTPUTS (5 examples) [{headline_slice}]")
print("=" * 60)

rows = rows_by_slice[headline_slice]
base_results = base_results_by_slice[headline_slice]
ft_results = ft_results_by_slice[headline_slice]

deg_idx = [i for i, r in enumerate(rows) if r["oracle_degraded"]][:3]
str_idx = [i for i, r in enumerate(rows) if not r["oracle_degraded"]][:2]
sample_indices = deg_idx + str_idx

for idx in sample_indices:
    r = rows[idx]
    status = "DEGRADED" if r["oracle_degraded"] else "STRONG"
    print(f"\n--- Example {idx} ({status}, difficulty={r['difficulty']}, category={r['category']}) ---")
    print(f"Oracle pivot: {r['oracle_pivot']}")

    print("\nBASE response (first 400 chars):")
    print(base_results[idx]["response"][:400])

    print("\nFINE-TUNED response (first 400 chars):")
    print(ft_results[idx]["response"][:400])

    print(f"\nBase: pivot={r['base_pivot']}, strict_flagged={r['base_flagged_degraded']}, loose_flagged={r['base_flagged_degraded_loose']}")
    print(f"FT:   pivot={r['ft_pivot']}, strict_flagged={r['ft_flagged_degraded']}, loose_flagged={r['ft_flagged_degraded_loose']}")



## 9. Download Results


In [None]:

# Save summary
summary_path = "/content/mirage_aware_gemma2b_eval_summary.json"
summary_payload = {
    "headline_slice": headline_slice,
    "metrics_by_slice": metrics_by_slice,
}
with open(summary_path, "w") as f:
    json.dump(summary_payload, f, indent=2)

# Save per-example rows as JSON for easy reuse.
rows_path = "/content/mirage_aware_gemma2b_eval_rows.json"
with open(rows_path, "w") as f:
    json.dump(all_rows, f, indent=2)

# Package adapter + results for download
!cd /content && tar czf mirage_aware_gemma2b_package.tar.gz \
    mirage_aware_gemma2b_adapter/ \
    mirage_aware_gemma2b_eval_results.csv \
    mirage_aware_gemma2b_eval_summary.json \
    mirage_aware_gemma2b_eval_rows.json \
    base_raw_responses.json \
    ft_raw_responses.json \
    base_raw_responses_natural_400.json \
    base_raw_responses_balanced_400.json \
    ft_raw_responses_natural_400.json \
    ft_raw_responses_balanced_400.json \
    base_results_backup.json \
    ft_results_backup.json \
    training_loss.png

print("\nPackage created. Downloading...")
print("Contents:")
print("  - mirage_aware_gemma2b_adapter/ (LoRA weights)")
print("  - mirage_aware_gemma2b_eval_results.csv (per-example results)")
print("  - mirage_aware_gemma2b_eval_summary.json (summary metrics by slice)")
print("  - mirage_aware_gemma2b_eval_rows.json (full rows)")
print("  - base_raw_responses*.json / ft_raw_responses*.json (raw outputs)")
print("  - base_results_backup.json / ft_results_backup.json (parsed outputs)")
print("  - training_loss.png (loss curve)")

files.download("/content/mirage_aware_gemma2b_package.tar.gz")


In [None]:

print("Done. Check the summary tables above for results.")
print(f"\nHeadline slice: {headline_slice}")
print("Key question: Did the fine-tuned model learn to flag DEGRADED when it is wrong?")
print(f"  Base silent mirage rate: {base_sr:.1%}")
print(f"  FT silent mirage rate: {ft_sr:.1%}")
print(f"  FT degradation detection (strict): {ft_flag:.1%}")
print(f"  FT false alarm rate (strict): {ft_fa:.1%}")
if degraded_and_wrong_ft:
    print(f"  FT flag_given_wrong (strict): {ft_flag_given_wrong:.1%}")
