In [None]:
#!/usr/bin/env python3
"""
stress_test_audit.py

Purpose:
  Perform prompt / context stress test on canaries,
  compare privacy signal stability between Base vs SFT models under different perturbations.

Input:
  - canary_output.txt
  - base model (Qwen2.5-0.5B-Instruct)
  - sft model (your training output directory)
  - (optional) wiki_trimmed_with_canary.jsonl

Output:
  - stress_test_results.csv
"""

import json
import random
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from google.colab import drive

drive.mount('/content/drive')

# =========================
# Configuration
# =========================
BASE_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
SFT_MODEL_DIR = "/models/content/drive/MyDrive/PrivacyAudit/models/stage1_sft"

CANARY_FILE = "/data/canary_output.txt"
WIKI_FILE = "wiki_trimmed_with_canary.jsonl"  # optional
OUTPUT_CSV = "/content/drive/MyDrive/PrivacyAudit/stress_test_results.csv"

MAX_CONTEXT_SAMPLES = 50
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

# =========================
# Utility Functions
# =========================
def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map=None
    ).to(DEVICE)
    model.eval()
    return tokenizer, model


@torch.no_grad()
def last_token_logprob(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model(**inputs)
    logits = outputs.logits[0, -2]
    probs = torch.log_softmax(logits, dim=-1)
    last_id = inputs["input_ids"][0, -1]
    return probs[last_id].item()


@torch.no_grad()
def last_token_rank(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
    outputs = model(**inputs)
    logits = outputs.logits[0, -2]
    sorted_ids = torch.argsort(logits, descending=True)
    last_id = inputs["input_ids"][0, -1]
    rank = (sorted_ids == last_id).nonzero(as_tuple=True)[0].item()
    return rank


def load_canaries():
    with open(CANARY_FILE) as f:
        return [l.strip() for l in f if l.strip()]


def load_contexts():
    try:
        with open(WIKI_FILE) as f:
            texts = []
            for line in f:
                t = json.loads(line)["text"]
                if "CANARY" not in t and len(t.split()) > 30:
                    texts.append(t)
            random.shuffle(texts)
            return texts[:MAX_CONTEXT_SAMPLES]
    except FileNotFoundError:
        return []


# =========================
# Prompt Templates (Stress Test Core)
# =========================
def build_prompts(canary, contexts):
    prompts = []

    # 1) Plain
    prompts.append(("plain", canary))

    # 2) Prefix context
    if contexts:
        ctx = random.choice(contexts)
        prompts.append(("with_context", f"{ctx}\n\n{canary}"))

    # 3) Instruction wrap
    prompts.append(("instruction_wrap", f"Please answer the following:\n{canary}"))

    # 4) Minor format perturbation
    prompts.append(("suffix_noise", f"{canary}\nAnswer:"))

    return prompts


# =========================
# Main Process
# =========================
def main():
    print("[INFO] Loading models...")
    base_tok, base_model = load_model(BASE_MODEL_NAME)
    sft_tok, sft_model = load_model(SFT_MODEL_DIR)

    print("[INFO] Loading canaries...")
    canaries = load_canaries()

    print("[INFO] Loading contexts...")
    contexts = load_contexts()

    results = []

    print("[INFO] Running stress tests...")
    for cid, canary in enumerate(canaries):
        prompt_variants = build_prompts(canary, contexts)

        for variant_name, prompt in prompt_variants:
            base_lp = last_token_logprob(base_model, base_tok, prompt)
            sft_lp = last_token_logprob(sft_model, sft_tok, prompt)

            base_rank = last_token_rank(base_model, base_tok, prompt)
            sft_rank = last_token_rank(sft_model, sft_tok, prompt)

            results.append({
                "canary_id": cid,
                "variant": variant_name,
                "base_logprob": base_lp,
                "sft_logprob": sft_lp,
                "delta_logprob": sft_lp - base_lp,
                "base_rank": base_rank,
                "sft_rank": sft_rank,
                "delta_rank": base_rank - sft_rank
            })

    df = pd.DataFrame(results)
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"[DONE] Stress test results saved to {OUTPUT_CSV}")


if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[INFO] Loading models...
[INFO] Loading canaries...
[INFO] Loading contexts...
[INFO] Running stress tests...
[DONE] Stress test results saved to /content/drive/MyDrive/PrivacyAudit/stress_test_results.csv


In [None]:
from google.colab import drive
drive.mount('/content/drive')