In [5]:
#!/usr/bin/env python3
"""
stress_test_audit.py

Purpose:
  Perform prompt/context stress test on 50 canaries,
  and compare stability across multiple stage pairs:
  - Base vs SFT
  - SFT vs DPO_NoCanary
  - DPO_NoCanary vs DPO_WithCanary

Prerequisites:
  - data/canary_output.txt (50 canaries)
  - data/wiki_trimmed_with_canary.jsonl (~10,050 samples)
  - models/stage1_sft/
  - models/stage2_dpo_no_canary/
  - models/stage2_dpo_with_canary/
  - src/run_metadata.py (for metadata recording)

Output:
  - reports/stress_test_results.csv
"""

import json
import os
import random
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# =========================
# Configuration
# =========================
BASE_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
SEED = 42
random.seed(SEED)

# Auto-detect model directories across local / colab layouts
PATH_CANDIDATES = {
    "sft": ["models/qwen2_0p5b_sft_50", "./qwen2_0p5b_sft_50", "/qwen2_0p5b_sft_50"],
    "dpo_no_canary": [
        "models/stage2_dpo_no_canary_50",
        "./stage2_dpo_no_canary_50",
        "/stage2_dpo_no_canary_50",
    ],
    "dpo_with_canary": [
        "models/stage2_dpo_with_canary_50",
        "./stage2_dpo_with_canary_50",
        "/stage2_dpo_with_canary_50",
    ],
}

def first_existing(paths):
    return next((x for x in paths if os.path.exists(x)), None)

SFT_MODEL_DIR = first_existing(PATH_CANDIDATES["sft"])
DPO_NC_MODEL_DIR = first_existing(PATH_CANDIDATES["dpo_no_canary"])
DPO_WC_MODEL_DIR = first_existing(PATH_CANDIDATES["dpo_with_canary"])

missing = []
if not SFT_MODEL_DIR:
    missing.append(f"sft candidates: {PATH_CANDIDATES['sft']}")
if not DPO_NC_MODEL_DIR:
    missing.append(f"dpo_no_canary candidates: {PATH_CANDIDATES['dpo_no_canary']}")
if not DPO_WC_MODEL_DIR:
    missing.append(f"dpo_with_canary candidates: {PATH_CANDIDATES['dpo_with_canary']}")
if missing:
    raise FileNotFoundError("Missing model dirs:\n" + "\n".join(missing))

# Auto-detect data files
canary_candidates = ["data/canary_output.txt", "./data/canary_output.txt", "/data/canary_output.txt"]
wiki_candidates = ["data/wiki_trimmed_with_canary.jsonl", "./data/wiki_trimmed_with_canary.jsonl", "/data/wiki_trimmed_with_canary.jsonl"]
CANARY_FILE = first_existing(canary_candidates)
WIKI_FILE = first_existing(wiki_candidates)

if not CANARY_FILE:
    raise FileNotFoundError(f"Canary file not found: {canary_candidates}")

os.makedirs("reports", exist_ok=True)
OUTPUT_CSV = "reports/stress_test_results.csv"
MAX_CONTEXT_SAMPLES = 50

if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using device: {DEVICE}")

# =========================
# Utility Functions
# =========================
def load_base_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None).to(DEVICE)
    model.eval()
    return tokenizer, model


def load_peft_model(base_name, adapter_dir):
    tokenizer = AutoTokenizer.from_pretrained(adapter_dir)
    model = AutoModelForCausalLM.from_pretrained(base_name, device_map=None).to(DEVICE)
    model = PeftModel.from_pretrained(model, adapter_dir)
    model.eval()
    return tokenizer, model


@torch.no_grad()
def last_token_logprob(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model(**inputs)
    logits = outputs.logits[0, -2]
    probs = torch.log_softmax(logits, dim=-1)
    last_id = inputs["input_ids"][0, -1]
    return probs[last_id].item()


@torch.no_grad()
def last_token_rank(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model(**inputs)
    logits = outputs.logits[0, -2]
    sorted_ids = torch.argsort(logits, descending=True)
    last_id = inputs["input_ids"][0, -1]
    rank = (sorted_ids == last_id).nonzero(as_tuple=True)[0].item()
    return rank


def load_canaries():
    with open(CANARY_FILE) as f:
        return [l.strip() for l in f if l.strip()]


def load_contexts():
    if not WIKI_FILE:
        return []
    try:
        with open(WIKI_FILE) as f:
            texts = []
            for line in f:
                t = json.loads(line)["text"]
                if "CANARY" not in t and len(t.split()) > 30:
                    texts.append(t)
            random.shuffle(texts)
            return texts[:MAX_CONTEXT_SAMPLES]
    except FileNotFoundError:
        return []


# =========================
# Prompt Templates
# =========================
def build_prompts(canary, contexts):
    prompts = []
    prompts.append(("plain", canary))
    if contexts:
        ctx = random.choice(contexts)
        prompts.append(("with_context", f"{ctx}\n\n{canary}"))
    prompts.append(("instruction_wrap", f"Please answer the following:\n{canary}"))
    prompts.append(("suffix_noise", f"{canary}\nAnswer:"))
    return prompts


# =========================
# Main Process
# =========================
def main():
    print("[INFO] Loading models...")
    stage_models = {
        "Stage0_Base": load_base_model(BASE_MODEL_NAME),
        "Stage1_SFT": load_peft_model(BASE_MODEL_NAME, SFT_MODEL_DIR),
        "Stage2a_DPO_NoCanary": load_peft_model(BASE_MODEL_NAME, DPO_NC_MODEL_DIR),
        "Stage2b_DPO_WithCanary": load_peft_model(BASE_MODEL_NAME, DPO_WC_MODEL_DIR),
    }

    print("[INFO] Stage paths:")
    print(f"  Stage1_SFT: {SFT_MODEL_DIR}")
    print(f"  Stage2a_DPO_NoCanary: {DPO_NC_MODEL_DIR}")
    print(f"  Stage2b_DPO_WithCanary: {DPO_WC_MODEL_DIR}")

    compare_pairs = [
        ("Base_vs_SFT", "Stage0_Base", "Stage1_SFT"),
        ("SFT_vs_DPO_NoCanary", "Stage1_SFT", "Stage2a_DPO_NoCanary"),
        ("DPO_NoCanary_vs_WithCanary", "Stage2a_DPO_NoCanary", "Stage2b_DPO_WithCanary"),
    ]

    print("[INFO] Loading canaries...")
    canaries = load_canaries()

    print("[INFO] Loading contexts...")
    contexts = load_contexts()

    results = []

    print("[INFO] Running stress tests...")
    for cid, canary in enumerate(canaries):
        prompt_variants = build_prompts(canary, contexts)

        for variant_name, prompt in prompt_variants:
            per_stage = {}
            for stage_name, (tok, model) in stage_models.items():
                per_stage[stage_name] = {
                    "logprob": last_token_logprob(model, tok, prompt),
                    "rank": last_token_rank(model, tok, prompt),
                }

            for tag, stage_a, stage_b in compare_pairs:
                a = per_stage[stage_a]
                b = per_stage[stage_b]
                results.append({
                    "canary_id": cid,
                    "variant": variant_name,
                    "compare_tag": tag,
                    "stage_a": stage_a,
                    "stage_b": stage_b,
                    "stage_a_logprob": a["logprob"],
                    "stage_b_logprob": b["logprob"],
                    "delta_logprob": b["logprob"] - a["logprob"],
                    "stage_a_rank": a["rank"],
                    "stage_b_rank": b["rank"],
                    "delta_rank": a["rank"] - b["rank"],
                })

    df = pd.DataFrame(results)
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"[DONE] Stress test results saved to {OUTPUT_CSV}")

    summary = (
        df.groupby(["compare_tag", "variant"]).agg(
            mean_delta_logprob=("delta_logprob", "mean"),
            mean_delta_rank=("delta_rank", "mean"),
            n=("delta_rank", "size"),
        ).reset_index()
    )
    print("\n[INFO] Quick summary:")
    print(summary.to_string(index=False))


if __name__ == "__main__":
    main()



Using device: cuda
[INFO] Loading models...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

Loading weights:   0%|          | 0/290 [00:00<?, ?it/s]

[INFO] Stage paths:
  Stage1_SFT: ./qwen2_0p5b_sft_50
  Stage2a_DPO_NoCanary: ./stage2_dpo_no_canary_50
  Stage2b_DPO_WithCanary: ./stage2_dpo_with_canary_50
[INFO] Loading canaries...
[INFO] Loading contexts...
[INFO] Running stress tests...
[DONE] Stress test results saved to reports/stress_test_results.csv

[INFO] Quick summary:
               compare_tag          variant  mean_delta_logprob  mean_delta_rank  n
               Base_vs_SFT instruction_wrap           -0.064219             6.90 50
               Base_vs_SFT            plain            0.198125             7.38 50
               Base_vs_SFT     suffix_noise            1.607031             3.72 50
               Base_vs_SFT     with_context           -0.166406           -27.92 50
DPO_NoCanary_vs_WithCanary instruction_wrap            0.041875             5.98 50
DPO_NoCanary_vs_WithCanary            plain            0.025000             3.46 50
DPO_NoCanary_vs_WithCanary     suffix_noise            0.222031             0.

## Direction Consistency Analysis

Report direction consistency rates broken down by `compare_tag` and `variant`.
- `delta_logprob > 0` means stage_b has higher logprob (stronger memorization signal)
- `delta_rank > 0` means stage_b has lower rank (stronger memorization signal)

In [6]:
import pandas as pd

df = pd.read_csv("reports/stress_test_results.csv")

print(f"Total samples: {len(df)}")
print(f"Unique canaries: {df['canary_id'].nunique()}")
print()

# --- Per compare_tag direction consistency ---
print("=" * 70)
print("Direction Consistency by compare_tag")
print("=" * 70)
for tag in df["compare_tag"].unique():
    sub = df[df["compare_tag"] == tag]
    n = len(sub)
    lp_pos = (sub["delta_logprob"] > 0).sum()
    lp_neg = (sub["delta_logprob"] < 0).sum()
    lp_zero = (sub["delta_logprob"] == 0).sum()
    rk_pos = (sub["delta_rank"] > 0).sum()
    rk_neg = (sub["delta_rank"] < 0).sum()
    rk_zero = (sub["delta_rank"] == 0).sum()
    print(f"\n  {tag} (n={n}):")
    print(f"    logprob: +{lp_pos} / -{lp_neg} / ={lp_zero}  (consistency: {max(lp_pos, lp_neg)/n*100:.1f}%)")
    print(f"    rank:    +{rk_pos} / -{rk_neg} / ={rk_zero}  (consistency: {max(rk_pos, rk_neg)/n*100:.1f}%)")

# --- Per variant direction consistency ---
print()
print("=" * 70)
print("Direction Consistency by variant")
print("=" * 70)
for variant in df["variant"].unique():
    sub = df[df["variant"] == variant]
    n = len(sub)
    lp_pos = (sub["delta_logprob"] > 0).sum()
    lp_neg = (sub["delta_logprob"] < 0).sum()
    lp_zero = (sub["delta_logprob"] == 0).sum()
    rk_pos = (sub["delta_rank"] > 0).sum()
    rk_neg = (sub["delta_rank"] < 0).sum()
    rk_zero = (sub["delta_rank"] == 0).sum()
    print(f"\n  {variant} (n={n}):")
    print(f"    logprob: +{lp_pos} / -{lp_neg} / ={lp_zero}  (consistency: {max(lp_pos, lp_neg)/n*100:.1f}%)")
    print(f"    rank:    +{rk_pos} / -{rk_neg} / ={rk_zero}  (consistency: {max(rk_pos, rk_neg)/n*100:.1f}%)")

# --- Cross-tabulation: compare_tag x variant ---
print()
print("=" * 70)
print("Direction Consistency by compare_tag x variant")
print("=" * 70)
for tag in df["compare_tag"].unique():
    print(f"\n  [{tag}]")
    for variant in df["variant"].unique():
        sub = df[(df["compare_tag"] == tag) & (df["variant"] == variant)]
        if len(sub) == 0:
            continue
        n = len(sub)
        lp_pos = (sub["delta_logprob"] > 0).sum()
        rk_pos = (sub["delta_rank"] > 0).sum()
        print(f"    {variant:20s}  n={n:3d}  logprob_consistency={max(lp_pos, n-lp_pos)/n*100:5.1f}%  rank_consistency={max(rk_pos, n-rk_pos)/n*100:5.1f}%")


Total samples: 600
Unique canaries: 50

Direction Consistency by compare_tag

  Base_vs_SFT (n=200):
    logprob: +99 / -94 / =7  (consistency: 49.5%)
    rank:    +100 / -43 / =57  (consistency: 50.0%)

  SFT_vs_DPO_NoCanary (n=200):
    logprob: +42 / -151 / =7  (consistency: 75.5%)
    rank:    +28 / -79 / =93  (consistency: 39.5%)

  DPO_NoCanary_vs_WithCanary (n=200):
    logprob: +127 / -52 / =21  (consistency: 63.5%)
    rank:    +55 / -29 / =116  (consistency: 27.5%)

Direction Consistency by variant

  plain (n=150):
    logprob: +72 / -68 / =10  (consistency: 48.0%)
    rank:    +50 / -26 / =74  (consistency: 33.3%)

  with_context (n=150):
    logprob: +47 / -93 / =10  (consistency: 62.0%)
    rank:    +32 / -63 / =55  (consistency: 42.0%)

  instruction_wrap (n=150):
    logprob: +42 / -93 / =15  (consistency: 62.0%)
    rank:    +30 / -36 / =84  (consistency: 24.0%)

  suffix_noise (n=150):
    logprob: +107 / -43 / =0  (consistency: 71.3%)
    rank:    +71 / -26 / =53  (c