
# BAP Experiment Notebook (Full Coverage): **All Your Models + All Your Datasets**
This version updates the merged notebook so that it **actually plugs in everything you used**:

## Models (from your v10/v12 notebooks)
- `Qwen/Qwen3-4B`
- `Qwen/Qwen2.5-Math-7B-Instruct`
- `deepseek-ai/deepseek-math-7b-instruct`
- `mistralai/Mathstral-7B-v0.1`
- `nvidia/AceMath-7B-Instruct`

## Datasets / task sources you used
### A) Your harmonized JSONL (binary tasks)
- `all_tasks_harmonized.jsonl` (task-level split, including `gsm8k_binary`, `aqua_binary`, `arc_challenge_binary`, `group_word_generic_easy`, …)
- Optional: **SGU slice** (`__sgu`) driven by `complexity_family`

### B) Real public benchmarks (direct HF loading)
- `gsm8k` (test split, **answer-mode**)
- `aqua_rat` (test split, **multi-choice-mode**)
- `ai2_arc` (ARC-Challenge test split, **multi-choice-mode**)

### C) Baseline corpora for contamination baselines (train splits)
- `gsm8k` train
- `aqua_rat` train
- `ai2_arc` ARC-Challenge train

---

## What stays the same (BAP measurement items)
We keep the same **evaluation modules** from the BAP design:
- **Exp-1 Auditability stress test**: model swap, resampling/cherry-picking, code substitution, post-hoc bit editing
- **Exp-2 Contamination evidence**: fast proxy + optional LoRA fine-tuning contamination (realism)
- **Exp-3 Discrimination**: complexity-tiered curves using **verifier compute cost**
- **Exp-4 Overhead**: wall time + per-item time + verifier-cost distributions

> This notebook still runs in **simulation-mode attestation** by default (HMAC / optional Ed25519).
Replace that block with real TEE remote attestation when you deploy.



## 0) Setup
If running on Colab, uncomment install lines.


In [None]:
import os

IN_COLAB = "COLAB_GPU" in os.environ

if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    print("Drive mounted at /content/drive")
else:
    print("Not running in Colab; skipping Drive mount.")

In [None]:

!pip -q install "transformers>=4.40.0" "datasets>=2.16.0" accelerate scipy pandas matplotlib
!pip -q install cryptography

import os, re, json, time, math, hashlib, hmac
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy import stats

np.random.seed(0)
print("Ready.")


In [None]:
# --- Progress bar utilities (tqdm) ---
import os
TQDM_ENABLED = os.environ.get('BAP_TQDM', '1') != '0'
try:
    from tqdm.auto import tqdm  # type: ignore
except Exception:
    try:
        import sys, subprocess
        subprocess.check_call([sys.executable, '-m', 'pip', '-q', 'install', 'tqdm'])
        from tqdm.auto import tqdm  # type: ignore
    except Exception:
        # Fallback: no-op tqdm
        def tqdm(it=None, *args, **kwargs):
            return it if it is not None else range(0)

print('TQDM_ENABLED:', TQDM_ENABLED)


In [None]:
import os, json, re
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

# -------- settings (same style as your notebook) --------
BASE_DIR = os.environ.get("BAP_BASE_DIR", "/content/drive/MyDrive/complexity_data6")
HARMONIZED_PATH = os.environ.get("BAP_HARMONIZED_JSONL", os.path.join(BASE_DIR, "all_tasks_harmonized.jsonl"))

MODEL_ID = os.environ.get("BAP_MODEL_ID", "Qwen/Qwen3-4B")
USE_4BIT = bool(int(os.environ.get("BAP_4BIT", "0")))
ADAPTER_PATH = os.environ.get("BAP_ADAPTER_PATH", "").strip() or None  # optional LoRA adapter dir

SGU_SUFFIX = "__sgu"
SGU_COMPLEXITY_FAMILY_PATTERNS = [
    "strongly_generically_undecidable",
    "strongly generically undecidable",
    "sgu",
    "undecidable",
]

GEN_CFG_BINARY = dict(max_new_tokens=1, do_sample=False, temperature=0.0)
GEN_CFG_ANSWER = dict(max_new_tokens=256, do_sample=False, temperature=0.0)
GEN_CFG_SEEN = dict(max_new_tokens=4, do_sample=False, temperature=0.0)
GEN_CFG_LABEL = dict(max_new_tokens=16, do_sample=False, temperature=0.0)

DATASET_LABELS = [
    "gsm8k",
    "sgu",
    "np_hard",
    "easy",
    "aqua_mc_test",
    "arc_challenge_mc_test",
    "unknown",
]

# -------- data structs --------
@dataclass
class TaskInstance:
    instance_id: str
    task_type: str          # "binary" or "answer"
    prompt: str
    label01: Optional[int] = None
    ground_truth: Optional[str] = None
    meta: Optional[Dict[str, Any]] = None

@dataclass
class TaskDataset:
    name: str
    instances: List[TaskInstance]

# -------- loader (harmonized JSONL) --------
def load_harmonized_jsonl_df(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing harmonized JSONL: {path}")
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    df = pd.DataFrame(rows)

    def _to01(x):
        if isinstance(x, (bool, np.bool_)):
            return int(x)
        if isinstance(x, (int, np.integer)):
            return int(x != 0)
        s = str(x).strip().lower()
        if s in ["1", "true", "yes", "y", "t"]:
            return 1
        if s in ["0", "false", "no", "n", "f"]:
            return 0
        m = re.search(r"[01]", s)
        return int(m.group(0)) if m else 0

    # required columns: input, label, task
    for req in ["input", "label", "task"]:
        if req not in df.columns:
            raise ValueError(f"harmonized JSONL missing required column: {req}")

    df["label"] = df["label"].apply(_to01).astype(int)

    # optional cols used for SGU slice
    if "complexity_family" not in df.columns:
        df["complexity_family"] = np.nan

    return df

def build_harmonized_datasets(df: pd.DataFrame, include_sgu_slice: bool = True) -> Dict[str, TaskDataset]:
    out: Dict[str, TaskDataset] = {}

    def _make(inst_df: pd.DataFrame, name: str):
        insts = []
        for i, r in inst_df.reset_index(drop=True).iterrows():
            meta = {k: r[k] for k in inst_df.columns if k not in ["input", "label", "task"]}
            insts.append(TaskInstance(
                instance_id=f"{name}-{i}",
                task_type="binary",
                prompt=str(r["input"]),
                label01=int(r["label"]),
                meta=meta
            ))
        out[name] = TaskDataset(name=name, instances=insts)

    tasks = sorted(df["task"].dropna().astype(str).unique().tolist())
    for t in tasks:
        df_t = df[df["task"] == t].copy()
        if len(df_t) == 0:
            continue
        _make(df_t, t)

        if include_sgu_slice:
            cf = df_t["complexity_family"].fillna("").astype(str).str.lower()
            mask = np.zeros(len(df_t), dtype=bool)
            for pat in SGU_COMPLEXITY_FAMILY_PATTERNS:
                mask |= cf.str.contains(pat.lower(), na=False).to_numpy()
            if mask.any():
                _make(df_t.loc[mask].copy(), t + SGU_SUFFIX)

    return out

# -------- model wrapper (HF) --------
@dataclass
class HFModelWrapper:
    model_id: str
    tokenizer: Any
    model: Any

    @staticmethod
    def load(model_id: str, adapter_path: Optional[str] = None, use_4bit: bool = False):
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM

        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token

        kwargs = {}
        if use_4bit:
            try:
                from transformers import BitsAndBytesConfig
                kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
                kwargs["device_map"] = "auto"
            except Exception as e:
                print("[WARN] 4bit requested but BitsAndBytes not available:", e)

        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else None,
            device_map="auto" if torch.cuda.is_available() else None,
            **kwargs,
        )
        model.eval()

        if adapter_path is not None:
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, adapter_path)
            model.eval()

        return HFModelWrapper(model_id=model_id, tokenizer=tok, model=model)

    def _generate(self, prompt: str, gen_cfg: Dict[str, Any]) -> str:
        import torch
        from transformers import GenerationConfig
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        gcfg = GenerationConfig(**gen_cfg)
        with torch.no_grad():
            out = self.model.generate(**inputs, generation_config=gcfg)
        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()

    # Q1
    def ask_seen_level(self, question: str) -> str:
        prompt = (
            "You will see a problem. Output ONLY a single digit: 0, 1, or 2.\n\n"
            "2 = I have seen this exact problem before.\n"
            "1 = I have not seen this exact problem, but I have seen very similar problems.\n"
            "0 = I have never seen this problem and have not seen similar ones.\n\n"
            "Do NOT output any other text.\n\n"
            f"Problem:\n{question.strip()}\n"
        )
        return self._generate(prompt, GEN_CFG_SEEN)

    # Q2
    def ask_dataset_label(self, question: str) -> str:
        labels = "\n".join(DATASET_LABELS)
        prompt = (
            "You will see a problem. Output ONLY one label from the list below, EXACTLY as written:\n\n"
            f"{labels}\n\n"
            "Do NOT output any other text.\n\n"
            f"Problem:\n{question.strip()}\n"
        )
        return self._generate(prompt, GEN_CFG_LABEL)

    # Q3 (the actual answer for this instance)
    def answer_instance(self, inst: TaskInstance) -> str:
        if inst.task_type == "binary":
            prompt = (
                "Answer the following decision problem.\n"
                "Return ONLY a single token:\n"
                "1 if YES, 0 if NO.\n\n"
                f"Problem:\n{inst.prompt}\n\nReturn only a single token (0 or 1):"
            )
            return self._generate(prompt, GEN_CFG_BINARY)
        else:
            prompt = (
                "Solve the following problem. Return ONLY the final answer.\n\n"
                f"Problem:\n{inst.prompt}\n\nFinal answer:"
            )
            return self._generate(prompt, GEN_CFG_ANSWER)

# -------- run: read 1 SGU question, ask the 3 questions, print --------
df = load_harmonized_jsonl_df(HARMONIZED_PATH)
datasets = build_harmonized_datasets(df, include_sgu_slice=True)

sgu_keys = [k for k in datasets.keys() if k.endswith(SGU_SUFFIX)]
if not sgu_keys:
    raise RuntimeError("No __sgu slice found. Check complexity_family in your harmonized JSONL.")

ds = datasets[sgu_keys[0]]
inst = ds.instances[0]  # take one SGU question

mw = HFModelWrapper.load(MODEL_ID, adapter_path=ADAPTER_PATH, use_4bit=USE_4BIT)

print("SGU_DATASET_KEY =", ds.name)
print("\n[SGU QUESTION]\n", inst.prompt)

print("\n[Q1 seen-level 0/1/2]\n", mw.ask_seen_level(inst.prompt))
print("\n[Q2 dataset-label]\n", mw.ask_dataset_label(inst.prompt))
print("\n[Q3 answer]\n", mw.answer_instance(inst))


In [None]:
import re
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# ==============================================================================
# ==============================================================================

class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0: return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model

    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)

    response = full[len(prompt):] if full.startswith(prompt) else full
    return response.split('\n')[0].strip()

# ==============================================================================
# ==============================================================================

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# ==============================================================================
# ==============================================================================

def run_experiment(ds, n=5):
    print(f"Starting Experiment on {n} samples...\n")

    for i, inst in enumerate(ds.instances[:n]):
        original_question = inst.prompt

        solver_prompt = build_solver_prompt(original_question)
        solver_response = generate_strict(solver_prompt, max_tokens=15)

        check_prompt = build_contamination_prompt(original_question)
        check_response = generate_strict(check_prompt, max_tokens=5)

        print("=" * 80)
        print(f"Instance ID: {inst.instance_id} (Index {i})")
        print("-" * 40)

        print(">>> [TASK 1] MATH SOLVER PROMPT:")
        print(solver_prompt)
        print("-" * 20)
        print(f">>> [TASK 1] MODEL RESPONSE: {solver_response!r}")

        print("-" * 40)

        print(">>> [TASK 2] CONTAMINATION CHECK PROMPT:")
        print(check_prompt)
        print("-" * 20)
        print(f">>> [TASK 2] MODEL RESPONSE: {check_response!r}")

        print("=" * 80 + "\n\n")

# ==============================================================================
# ==============================================================================

# gsm_keys = [k for k in datasets.keys() if "gsm8k" in k.lower()][:1]
# gsm_ds = datasets[gsm_keys[0]]

run_experiment(gsm_ds, n=3)

In [None]:
import re
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# ==============================================================================
# ==============================================================================

class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model

    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# ==============================================================================
# ==============================================================================

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# ==============================================================================
# ==============================================================================

def run_two_questions_on_dataset(mw, ds, n=3, tag="DATASET"):
    print(f"\n===== Running on {tag}: {ds.name} | n={n} =====\n")
    for i, inst in enumerate(ds.instances[:n]):
        q = inst.prompt

        solver_prompt = build_solver_prompt(q)
        solver_resp = generate_strict(mw, solver_prompt, max_tokens=15)

        check_prompt = build_contamination_prompt(q)
        check_resp = generate_strict(mw, check_prompt, max_tokens=5)

        print("=" * 90)
        print(f"[{tag} #{i}] instance_id={inst.instance_id}")
        print("\n>>> [TASK 1] SOLVER PROMPT (instruction -> question):")
        print(solver_prompt)
        print(">>> [TASK 1] MODEL RESPONSE:", repr(solver_resp))

        print("\n>>> [TASK 2] CONTAMINATION PROMPT (instruction -> question):")
        print(check_prompt)
        print(">>> [TASK 2] MODEL RESPONSE:", repr(check_resp))
        print("=" * 90, "\n")

# ==============================================================================
# ==============================================================================

gsm_keys = [k for k in datasets.keys() if "gsm8k" in k.lower() and not k.endswith("__sgu")]
if not gsm_keys:
    raise RuntimeError("没在 datasets 里找到 gsm8k 任务 key。")
gsm_ds = datasets[gsm_keys[0]]

sgu_ds = ds

run_two_questions_on_dataset(mw, gsm_ds, n=3, tag="GSM8K")
run_two_questions_on_dataset(mw, sgu_ds, n=3, tag="SGU")


In [None]:
import gc
import re
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# ==============================================================================
# ==============================================================================

class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model

    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# ==============================================================================
# ==============================================================================

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

def run_two_questions_on_dataset(mw, ds, n=3, tag="DATASET"):
    print(f"\n===== {tag}: {ds.name} | n={n} =====\n")
    for i, inst in enumerate(ds.instances[:n]):
        q = inst.prompt

        solver_prompt = build_solver_prompt(q)
        solver_resp = generate_strict(mw, solver_prompt, max_tokens=15)

        check_prompt = build_contamination_prompt(q)
        check_resp = generate_strict(mw, check_prompt, max_tokens=5)

        print("=" * 90)
        print(f"[{tag} #{i}] instance_id={inst.instance_id}")
        print("\n>>> [TASK 1] SOLVER PROMPT:")
        print(solver_prompt)
        print(">>> [TASK 1] MODEL RESPONSE:", repr(solver_resp))

        print("\n>>> [TASK 2] CONTAMINATION PROMPT:")
        print(check_prompt)
        print(">>> [TASK 2] MODEL RESPONSE:", repr(check_resp))
        print("=" * 90, "\n")

# ==============================================================================
# ==============================================================================

sgu_ds = ds

gsm_keys = [k for k in datasets.keys() if "gsm8k" in k.lower() and not k.endswith("__sgu")]
gsm_ds = datasets[gsm_keys[0]] if gsm_keys else None

# ==============================================================================
# ==============================================================================

MODEL_SPECS = [
    {"name": "qwen3-4b",        "model_id": "Qwen/Qwen3-4B",                   "adapter": None},
    {"name": "qwen2.5-math-7b", "model_id": "Qwen/Qwen2.5-Math-7B-Instruct",   "adapter": None},
    {"name": "deepseek-math-7b","model_id": "deepseek-ai/deepseek-math-7b-instruct", "adapter": None},
    {"name": "acemath-7b",      "model_id": "nvidia/AceMath-7B-Instruct",      "adapter": None},
    {"name": "Mathstral-7B",      "model_id": "mistralai/Mathstral-7B-v0.1",      "adapter": None},
    {"name": "gemma-2-9b-math", "model_id": "google/gemma-2-9b-it",            "adapter": None},
]


# ==============================================================================
# ==============================================================================

def run_all_models(model_specs, n_gsm=3, n_sgu=3, use_4bit=False):
    for spec in model_specs:
        print("\n" + "#" * 110)
        print(f"LOADING MODEL: {spec['name']} | {spec['model_id']} | adapter={spec['adapter']} | 4bit={use_4bit}")
        print("#" * 110 + "\n")

        mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec["adapter"], use_4bit=use_4bit)

        if gsm_ds is not None:
            run_two_questions_on_dataset(mw_local, gsm_ds, n=n_gsm, tag=f"GSM8K@{spec['name']}")
        else:
            print(f"[SKIP] No GSM8K dataset found in `datasets`, skipping GSM8K for {spec['name']}")

        run_two_questions_on_dataset(mw_local, sgu_ds, n=n_sgu, tag=f"SGU@{spec['name']}")

        del mw_local
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

run_all_models(MODEL_SPECS, n_gsm=3, n_sgu=3, use_4bit=USE_4BIT)


In [None]:
import gc
import re
import pandas as pd
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# ==============================================================================
# ==============================================================================

class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model

    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# ==============================================================================
# ==============================================================================

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# ==============================================================================
# ==============================================================================

_yes_pat = re.compile(r"\bYES\b", re.IGNORECASE)
_no_pat  = re.compile(r"\bNO\b", re.IGNORECASE)

def parse_yesno(s: str):
    if not s or not s.strip():
        return None
    t = s.strip().upper()
    if "YES" in t:
        return "YES"
    if "NO" in t:
        return "NO"
    return None

def parse_012(s: str):
    if not s or not s.strip():
        return None
    m = re.search(r"\b([012])\b", s.strip())
    return m.group(1) if m else None

def expected_yesno_from_label01(label01: int):
    return "YES" if int(label01) == 1 else "NO"

# ==============================================================================
# ==============================================================================

def eval_model_on_ds(mw, ds, n, tag):
    n = min(n, len(ds.instances))
    solver_ok = 0
    solver_fmt = 0
    contam_fmt = 0
    contam_counts = {"0": 0, "1": 0, "2": 0, "other": 0, "empty": 0}

    for inst in ds.instances[:n]:
        q = inst.prompt

        # Task 1: YES/NO
        solver_prompt = build_solver_prompt(q)
        solver_resp = generate_strict(mw, solver_prompt, max_tokens=15)
        pred_yesno = parse_yesno(solver_resp)
        if pred_yesno is not None:
            solver_fmt += 1
            exp = expected_yesno_from_label01(inst.label01)
            solver_ok += int(pred_yesno == exp)

        # Task 2: 0/1/2
        contam_prompt = build_contamination_prompt(q)
        contam_resp = generate_strict(mw, contam_prompt, max_tokens=5)
        pred012 = parse_012(contam_resp)

        if contam_resp is None or contam_resp == "":
            contam_counts["empty"] += 1
        elif pred012 is None:
            contam_counts["other"] += 1
        else:
            contam_fmt += 1
            contam_counts[pred012] += 1

    out = {
        f"{tag}_n": n,
        f"{tag}_solver_fmt_rate": solver_fmt / n if n else 0.0,
        f"{tag}_solver_acc": solver_ok / n if n else 0.0,
        f"{tag}_contam_fmt_rate": contam_fmt / n if n else 0.0,
        f"{tag}_contam_0": contam_counts["0"],
        f"{tag}_contam_1": contam_counts["1"],
        f"{tag}_contam_2": contam_counts["2"],
        f"{tag}_contam_other": contam_counts["other"],
        f"{tag}_contam_empty": contam_counts["empty"],
    }
    return out

# ==============================================================================
# ==============================================================================

def summarize_all_models(model_specs, gsm_ds, sgu_ds, n_gsm=50, n_sgu=50, use_4bit=False):
    rows = []
    for spec in model_specs:
        print("\n" + "#" * 100)
        print(f"LOADING: {spec['name']} | {spec['model_id']} | 4bit={use_4bit} | adapter={spec.get('adapter')}")
        print("#" * 100)

        mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

        row = {
            "model": spec["name"],
            "model_id": spec["model_id"],
            "4bit": bool(use_4bit),
        }

        if gsm_ds is not None:
            row.update(eval_model_on_ds(mw_local, gsm_ds, n_gsm, "gsm"))
        else:
            row.update({"gsm_n": 0, "gsm_solver_fmt_rate": 0.0, "gsm_solver_acc": 0.0, "gsm_contam_fmt_rate": 0.0})

        row.update(eval_model_on_ds(mw_local, sgu_ds, n_sgu, "sgu"))
        rows.append(row)

        del mw_local
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    df = pd.DataFrame(rows)

    cols = [
        "model","model_id","4bit",
        "gsm_n","gsm_solver_acc","gsm_solver_fmt_rate","gsm_contam_fmt_rate","gsm_contam_0","gsm_contam_1","gsm_contam_2","gsm_contam_other","gsm_contam_empty",
        "sgu_n","sgu_solver_acc","sgu_solver_fmt_rate","sgu_contam_fmt_rate","sgu_contam_0","sgu_contam_1","sgu_contam_2","sgu_contam_other","sgu_contam_empty",
    ]
    cols = [c for c in cols if c in df.columns]
    df = df[cols]

    with pd.option_context("display.max_rows", 200, "display.max_columns", 200, "display.width", 200):
        print("\n\n===== SUMMARY TABLE =====")
        print(df.to_string(index=False))

    return df

# ==============================================================================
# ==============================================================================

summary_df = summarize_all_models(
    MODEL_SPECS,
    gsm_ds=gsm_ds if "gsm_ds" in globals() else None,
    sgu_ds=sgu_ds,
    n_gsm=50,
    n_sgu=50,
    use_4bit=USE_4BIT
)


In [None]:
import gc
import re
import pandas as pd
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

MODEL_SPECS = [
    {"name": "qwen3-4b",        "model_id": "Qwen/Qwen3-4B",                   "adapter": None},
    {"name": "deepseek-llama-8b", "model_id": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",   "adapter": None},
    {"name": "deepseek-math-7b","model_id": "deepseek-ai/deepseek-math-7b-instruct", "adapter": None},
    {"name": "acemath-7b",      "model_id": "nvidia/AceMath-7B-Instruct",      "adapter": None},
    {"name": "Mathstral-7B",      "model_id": "mistralai/Mathstral-7B-v0.1",      "adapter": None},
]


# ==============================================================================
# ==============================================================================

class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# ==============================================================================
# ==============================================================================

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# ==============================================================================
# ==============================================================================

def parse_yesno(s: str):
    if not s or not s.strip():
        return None
    t = s.strip().upper()
    if "YES" in t:
        return "YES"
    if "NO" in t:
        return "NO"
    return None

def parse_012(s: str):
    if not s or not s.strip():
        return None
    m = re.search(r"\b([012])\b", s.strip())
    return m.group(1) if m else None

def expected_yesno_from_label01(label01: int):
    return "YES" if int(label01) == 1 else "NO"

# ==============================================================================
# ==============================================================================

def eval_model_on_ds(mw, ds, n, tag, pbar=None):
    n = min(n, len(ds.instances))

    solver_ok = 0
    solver_fmt = 0
    solver_has_gold = 0

    contam_fmt = 0
    contam_counts = {"0": 0, "1": 0, "2": 0, "other": 0, "empty": 0}

    it = ds.instances[:n]
    if pbar is None:
        it = tqdm(it, desc=f"{tag}", leave=False)

    for inst in it:
        q = inst.prompt

        # Task 1
        solver_prompt = build_solver_prompt(q)
        solver_resp = generate_strict(mw, solver_prompt, max_tokens=15)
        pred_yesno = parse_yesno(solver_resp)
        if pred_yesno is not None:
            solver_fmt += 1
            if inst.label01 is not None:
                solver_has_gold += 1
                exp = expected_yesno_from_label01(inst.label01)
                solver_ok += int(pred_yesno == exp)

        # Task 2
        contam_prompt = build_contamination_prompt(q)
        contam_resp = generate_strict(mw, contam_prompt, max_tokens=5)
        pred012 = parse_012(contam_resp)

        if contam_resp is None or contam_resp == "":
            contam_counts["empty"] += 1
        elif pred012 is None:
            contam_counts["other"] += 1
        else:
            contam_fmt += 1
            contam_counts[pred012] += 1

        if pbar is not None:
            pbar.update(1)

    solver_acc = (solver_ok / solver_has_gold) if solver_has_gold > 0 else float("nan")

    out = {
        f"{tag}_n": n,
        f"{tag}_solver_fmt_rate": solver_fmt / n if n else 0.0,
        f"{tag}_solver_acc": solver_acc,
        f"{tag}_contam_fmt_rate": contam_fmt / n if n else 0.0,
        f"{tag}_contam_0": contam_counts["0"],
        f"{tag}_contam_1": contam_counts["1"],
        f"{tag}_contam_2": contam_counts["2"],
        f"{tag}_contam_other": contam_counts["other"],
        f"{tag}_contam_empty": contam_counts["empty"],
    }
    return out

# ==============================================================================
# ==============================================================================

fixed_groups: dict[str, list[str]] = {
    "aqua": ["aqua_binary", "aqua_mc_test"],
    "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
}

def pick_first_existing_dataset(group_tasks):
    for k in group_tasks:
        if k in datasets:
            return datasets[k], k
    return None, None

# ==============================================================================
# ==============================================================================

def summarize_all_models_groups_with_progress(model_specs, fixed_groups, n_each=50, use_4bit=False):
    group_selected = {}
    for g, task_list in fixed_groups.items():
        ds_obj, task_key = pick_first_existing_dataset(task_list)
        group_selected[g] = (ds_obj, task_key)

    print("\n[GROUP SELECTION]")
    active_groups = []
    for g, (ds_obj, task_key) in group_selected.items():
        if ds_obj is None:
            print(f"  - {g}: (MISSING) none of {fixed_groups[g]}")
        else:
            print(f"  - {g}: using task='{task_key}' (N={len(ds_obj.instances)})")
            active_groups.append(g)

    total_steps = 0
    for _ in model_specs:
        for g in active_groups:
            ds_obj, _ = group_selected[g]
            total_steps += min(n_each, len(ds_obj.instances))

    rows = []
    pbar = tqdm(total=total_steps, desc="EVAL progress", leave=True)

    for spec in model_specs:
        pbar.set_postfix_str(f"loading {spec['name']}")
        print("\n" + "#" * 100)
        print(f"LOADING: {spec['name']} | {spec['model_id']} | 4bit={use_4bit} | adapter={spec.get('adapter')}")
        print("#" * 100)

        mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

        row = {"model": spec["name"], "model_id": spec["model_id"], "4bit": bool(use_4bit)}
        for g in fixed_groups.keys():
            ds_obj, task_key = group_selected[g]
            row[f"{g}_task"] = task_key

            if ds_obj is None:
                row.update({
                    f"{g}_n": 0,
                    f"{g}_solver_acc": float("nan"),
                    f"{g}_solver_fmt_rate": 0.0,
                    f"{g}_contam_fmt_rate": 0.0,
                    f"{g}_contam_0": 0, f"{g}_contam_1": 0, f"{g}_contam_2": 0,
                    f"{g}_contam_other": 0, f"{g}_contam_empty": 0,
                })
                continue

            pbar.set_postfix_str(f"{spec['name']} | {g}")
            stats = eval_model_on_ds(mw_local, ds_obj, n_each, g, pbar=pbar)
            row.update(stats)

        rows.append(row)

        del mw_local
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    pbar.close()

    df = pd.DataFrame(rows)

    cols = ["model", "model_id", "4bit"]
    for g in fixed_groups.keys():
        cols += [
            f"{g}_task", f"{g}_n",
            f"{g}_solver_acc", f"{g}_solver_fmt_rate",
            f"{g}_contam_fmt_rate",
            f"{g}_contam_0", f"{g}_contam_1", f"{g}_contam_2",
            f"{g}_contam_other", f"{g}_contam_empty",
        ]
    cols = [c for c in cols if c in df.columns]
    df = df[cols]

    with pd.option_context("display.max_rows", 200, "display.max_columns", 200, "display.width", 260):
        print("\n\n===== SUMMARY TABLE (GROUPS) =====")
        print(df.to_string(index=False))

    return df

# ==============================================================================
# ==============================================================================

summary_groups_df = summarize_all_models_groups_with_progress(
    MODEL_SPECS,
    fixed_groups=fixed_groups,
    n_each=30,
    use_4bit=USE_4BIT
)


In [None]:
import gc
import re
import pandas as pd
import torch
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
from tqdm.auto import tqdm

# =========================
# =========================
sgu_keys = [k for k in datasets.keys() if k.endswith("__sgu")]
if not sgu_keys:
    raise RuntimeError("No __sgu dataset found in `datasets`.")
sgu_ds = datasets[sgu_keys[0]]
print("SGU_DATASET_KEY =", sgu_ds.name, "| N =", len(sgu_ds.instances))

# =========================
# =========================
class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.tokenizer = tokenizer
        self.stop_token_ids = []
        for s in stop_strings:
            encoded = tokenizer.encode(s, add_special_tokens=False)
            if encoded:
                self.stop_token_ids.append(encoded[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok = mw.tokenizer
    model = mw.model
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])

    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper])
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# =========================
# =========================
def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following math problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit:\n"
        "0 = SEEN the same question before\n"
        "1 = See similiar question before but not the same question\n"
        "2 = Not seen similar question before\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# =========================
# =========================
def parse_yesno(s: str):
    if not s or not s.strip():
        return None
    t = s.strip().upper()
    if "YES" in t:
        return "YES"
    if "NO" in t:
        return "NO"
    return None

def parse_012(s: str):
    if not s or not s.strip():
        return None
    m = re.search(r"\b([012])\b", s.strip())
    return m.group(1) if m else None

def expected_yesno_from_label01(label01: int):
    return "YES" if int(label01) == 1 else "NO"

# =========================
# =========================
def eval_one_model_on_sgu(spec, sgu_ds, n=50, use_4bit=False):
    mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

    n = min(n, len(sgu_ds.instances))
    solver_ok = 0
    solver_fmt = 0
    solver_has_gold = 0

    contam_fmt = 0
    contam_counts = {"0": 0, "1": 0, "2": 0, "other": 0, "empty": 0}

    for inst in tqdm(sgu_ds.instances[:n], desc=f"SGU@{spec['name']}", leave=False):
        q = inst.prompt

        # Task 1
        solver_resp = generate_strict(mw_local, build_solver_prompt(q), max_tokens=15)
        pred_yesno = parse_yesno(solver_resp)
        if pred_yesno is not None:
            solver_fmt += 1
            if inst.label01 is not None:
                solver_has_gold += 1
                solver_ok += int(pred_yesno == expected_yesno_from_label01(inst.label01))

        # Task 2
        contam_resp = generate_strict(mw_local, build_contamination_prompt(q), max_tokens=5)
        pred012 = parse_012(contam_resp)
        if contam_resp is None or contam_resp == "":
            contam_counts["empty"] += 1
        elif pred012 is None:
            contam_counts["other"] += 1
        else:
            contam_fmt += 1
            contam_counts[pred012] += 1

    solver_acc = (solver_ok / solver_has_gold) if solver_has_gold > 0 else float("nan")

    row = {
        "model": spec["name"],
        "model_id": spec["model_id"],
        "4bit": bool(use_4bit),
        "sgu_task": sgu_ds.name,
        "sgu_n": n,
        "sgu_solver_acc": solver_acc,
        "sgu_solver_fmt_rate": solver_fmt / n if n else 0.0,
        "sgu_contam_fmt_rate": contam_fmt / n if n else 0.0,
        "sgu_contam_0": contam_counts["0"],
        "sgu_contam_1": contam_counts["1"],
        "sgu_contam_2": contam_counts["2"],
        "sgu_contam_other": contam_counts["other"],
        "sgu_contam_empty": contam_counts["empty"],
    }

    # cleanup
    del mw_local
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return row

def run_sgu_only(model_specs, sgu_ds, n_each=50, use_4bit=False):
    rows = []
    for spec in model_specs:
        print("\n" + "#" * 100)
        print(f"LOADING SGU MODEL: {spec['name']} | {spec['model_id']} | 4bit={use_4bit}")
        print("#" * 100)
        rows.append(eval_one_model_on_sgu(spec, sgu_ds, n=n_each, use_4bit=use_4bit))

    df = pd.DataFrame(rows)
    with pd.option_context("display.max_rows", 200, "display.max_columns", 200, "display.width", 200):
        print("\n\n===== SGU ONLY SUMMARY =====")
        print(df.to_string(index=False))
    return df

sgu_only_df = run_sgu_only(MODEL_SPECS, sgu_ds, n_each=50, use_4bit=USE_4BIT)


In [None]:

import pandas as pd
fixed_groups: dict[str, list[str]] = {
    "aqua": ["aqua_binary", "aqua_mc_test"],
    "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
}
assert "summary_groups_df" in globals(), "找不到 summary_groups_df（你的 GROUPS 汇总表 DataFrame）"
assert "sgu_only_df" in globals(), "找不到 sgu_only_df（你的 SGU ONLY 汇总表 DataFrame）"

for df in [summary_groups_df, sgu_only_df]:
    for c in ["model", "model_id", "4bit"]:
        assert c in df.columns, f"缺列 {c} in {df.columns}"
    df["model"] = df["model"].astype(str)
    df["model_id"] = df["model_id"].astype(str)
    df["4bit"] = df["4bit"].astype(bool)

sgu_cols = [c for c in sgu_only_df.columns if c.startswith("sgu_")] + ["model", "model_id", "4bit"]
sgu_small = sgu_only_df[sgu_cols].copy()

merged = summary_groups_df.merge(
    sgu_small,
    on=["model", "model_id", "4bit"],
    how="left",
    suffixes=("", "_sgu_dup")
)

dup_cols = [c for c in merged.columns if c.endswith("_sgu_dup")]
if dup_cols:
    merged.drop(columns=dup_cols, inplace=True)

csv_text = merged.to_csv(index=False)
print(csv_text)

print(merged.to_string(index=False))


In [None]:
# import gc
# import re
# import pandas as pd
# import torch
# from tqdm.auto import tqdm
# from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# # =========================
# # =========================

# class StopOnTokens(StoppingCriteria):
#     def __init__(self, tokenizer, stop_strings):
#         self.stop_token_ids = []
#         for s in stop_strings:
#             ids = tokenizer.encode(s, add_special_tokens=False)
#             if ids:
#                 self.stop_token_ids.append(ids[-1])

#     def __call__(self, input_ids, scores, **kwargs):
#         if input_ids.shape[1] == 0:
#             return False
#         return input_ids[0, -1].item() in self.stop_token_ids

# def generate_strict(mw, prompt: str, max_tokens: int = 10):
#     tok, model = mw.tokenizer, mw.model
#     inputs = tok(prompt, return_tensors="pt", truncation=True)
#     inputs = {k: v.to(model.device) for k, v in inputs.items()}

#     stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])
#     gen_config = GenerationConfig(
#         do_sample=False, temperature=0.0, max_new_tokens=max_tokens,
#         pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id
#     )

#     with torch.no_grad():
#         out = model.generate(**inputs, generation_config=gen_config,
#                              stopping_criteria=StoppingCriteriaList([stopper]))
#     full = tok.decode(out[0], skip_special_tokens=True)
#     resp = full[len(prompt):] if full.startswith(prompt) else full
#     return resp.split("\n")[0].strip()

# def build_solver_prompt(question: str) -> str:
#     instruction = (
#         "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
#         "Do NOT explain. Do NOT show work.\n"
#     )
#     return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

# def build_contamination_prompt(question: str) -> str:
#     instruction = (
#         "You are an AI assistant. Analyze the following math problem.\n"
#         "Tell me if you have seen this problem during your training data.\n"
#         "Reply with strictly ONLY a single digit:\n"
#         "0 = SEEN the same question before\n"
#         "1 = See similiar question before but not the same question\n"
#         "2 = Not seen similar question before\n"
#         "Ignore the YES/NO question in the following question even if it asks. "
#     )
#     return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (0, 1, or 2):"

# def parse_yesno(s: str):
#     if not s: return None
#     t = s.strip().upper()
#     if "YES" in t: return "YES"
#     if "NO" in t:  return "NO"
#     return None

# def parse_012(s: str):
#     if not s: return None
#     m = re.search(r"\b([012])\b", s.strip())
#     return m.group(1) if m else None

# def expected_yesno_from_label01(label01: int):
#     return "YES" if int(label01) == 1 else "NO"


# fixed_groups: dict[str, list[str]] = {
#     "aqua": ["aqua_binary", "aqua_mc_test"],
#     "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
#     "p_time": ["p_graph_connectivity", "ptime_arith"],
#     "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
# }
# # =========================
# # =========================
# def pick_first_existing_dataset(task_list):
#     for k in task_list:
#         if k in datasets:
#             return datasets[k], k
#     return None, None

# group_selected = {}
# for g, task_list in fixed_groups.items():
#     ds_obj, task_key = pick_first_existing_dataset(task_list)
#     group_selected[g] = (ds_obj, task_key)

# print("[GROUP SELECTION]")
# for g, (ds_obj, task_key) in group_selected.items():
#     print(f"  - {g}: {task_key if task_key else 'MISSING'}")

# # =========================
# # =========================
# def collect_records_for_model(spec, n_each=30, use_4bit=False):
#     mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

#     recs = []
#     for g, (ds_obj, task_key) in group_selected.items():
#         if ds_obj is None:
#             continue
#         n = min(n_each, len(ds_obj.instances))
#         for inst in tqdm(ds_obj.instances[:n], desc=f"{spec['name']}|{g}", leave=False):
#             q = inst.prompt

#             # solver
#             solver_resp = generate_strict(mw_local, build_solver_prompt(q), max_tokens=15)
#             pred_yesno = parse_yesno(solver_resp)

#             correct = None
#             if pred_yesno is not None and inst.label01 is not None:
#                 correct = int(pred_yesno == expected_yesno_from_label01(inst.label01))

#             # contamination
#             contam_resp = generate_strict(mw_local, build_contamination_prompt(q), max_tokens=5)
#             contam = parse_012(contam_resp)  # "0"/"1"/"2" or None

#             recs.append({
#                 "model": spec["name"],
#                 "model_id": spec["model_id"],
#                 "group": g,
#                 "task": task_key,
#                 "solver_pred": pred_yesno,
#                 "solver_correct": correct,   # 1/0/None
#                 "contam": contam,            # "0"/"1"/"2"/None
#                 "solver_raw": solver_resp,
#                 "contam_raw": contam_resp,
#             })

#     # cleanup
#     del mw_local
#     gc.collect()
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()

#     return pd.DataFrame(recs)

# def run_conditional_accuracy(model_specs, n_each=30, use_4bit=False):
#     all_df = []
#     for spec in model_specs:
#         print("\n" + "#" * 90)
#         print(f"Collecting joint records: {spec['name']} | {spec['model_id']} | 4bit={use_4bit}")
#         dfm = collect_records_for_model(spec, n_each=n_each, use_4bit=use_4bit)
#         all_df.append(dfm)
#     df = pd.concat(all_df, ignore_index=True)

#     df_valid = df[df["solver_correct"].notna() & df["contam"].isin(["0","1","2"])].copy()
#     df_valid["solver_correct"] = df_valid["solver_correct"].astype(int)

#     cover = df.groupby(["model","group"]).size().rename("total").reset_index()
#     used  = df_valid.groupby(["model","group"]).size().rename("usable").reset_index()
#     cov = cover.merge(used, on=["model","group"], how="left").fillna({"usable":0})
#     cov["usable"] = cov["usable"].astype(int)
#     cov["usable_rate"] = cov["usable"] / cov["total"]

#     overall = (
#         df_valid.groupby(["model","contam"])["solver_correct"]
#         .agg(acc="mean", n="count")
#         .reset_index()
#         .sort_values(["model","contam"])
#     )

#     by_group = (
#         df_valid.groupby(["model","group","contam"])["solver_correct"]
#         .agg(acc="mean", n="count")
#         .reset_index()
#         .sort_values(["model","group","contam"])
#     )

#     print("\n===== COVERAGE (how many samples have BOTH valid solver and valid contam) =====")
#     with pd.option_context("display.max_rows", 500, "display.max_columns", 50, "display.width", 220):
#         print(cov.sort_values(["model","group"]).to_string(index=False))

#     print("\n===== OVERALL conditional accuracy by self-report (0/1/2) =====")
#     with pd.option_context("display.max_rows", 500, "display.max_columns", 50, "display.width", 220):
#         print(overall.to_string(index=False))

#     print("\n===== BY-GROUP conditional accuracy by self-report (0/1/2) =====")
#     with pd.option_context("display.max_rows", 500, "display.max_columns", 80, "display.width", 260):
#         print(by_group.to_string(index=False))

#     return df, df_valid, cov, overall, by_group

# raw_df, valid_df, coverage_df, overall_df, by_group_df = run_conditional_accuracy(
#     MODEL_SPECS,
#     n_each=50,
#     use_4bit=USE_4BIT
# )


In [None]:
# import random
# import pandas as pd
# from collections import Counter
# from tqdm.auto import tqdm
# import gc
# import torch


# import re

# # =========================
# # =========================
# def build_contamination_prompt(question: str) -> str:
#     instruction = (
#         "You are an AI assistant. Analyze the following problem.\n"
#         "Tell me if you have seen this problem during your training data.\n"
#         "Reply with strictly ONLY a single digit from {3, 5, 7}:\n"
#         "3 = SEEN the same question before\n"
#         "5 = Seen very similar questions before but not the same question\n"
#         "7 = Not seen similar question before\n"
#         "Ignore any YES/NO question in the content even if it asks.\n"
#     )
#     return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (3, 5, or 7):"

# # =========================
# # =========================
# def parse_357(s: str):
#     if not s:
#         return None
#     m = re.search(r"\b([357])\b", s.strip())
#     return m.group(1) if m else None


# # =========================
# # =========================
# SGU_TASKS = [
#     "sgu_collatz_aligned", "sgu_collatz_aligned__sgu",
#     "sgu_index_empty_language", "sgu_index_empty_language__sgu",
#     "sgu_index_total_halt", "sgu_index_total_halt__sgu",
#     "sgu_semigroup_wp_amp", "sgu_semigroup_wp_amp__sgu",
#     "tm_generic_halt", "tm_hard_halt",
# ]

# # =========================
# # =========================
# fixed_groups = {
#     "aqua": ["aqua_binary", "aqua_mc_test"],
#     "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
#     "p_time": ["p_graph_connectivity", "ptime_arith"],
#     "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
#     "sgu": SGU_TASKS,
# }

# def pick_first_existing_dataset(task_list):
#     for k in task_list:
#         if k in datasets:
#             return k
#     return None

# # =========================
# # =========================

# def sample_group_instances(group_name, n_each, seed=0, sgu_sampling="mix_pool"):
#     """
#       {group, task, inst}
#     """
#     rng = random.Random(seed)

#     if group_name != "sgu":
#         task_key = pick_first_existing_dataset(fixed_groups[group_name])
#         if task_key is None:
#             return [], Counter()
#         ds = datasets[task_key]
#         n = min(n_each, len(ds.instances))
#         idxs = list(range(len(ds.instances)))
#         rng.shuffle(idxs)
#         idxs = idxs[:n]
#         samples = [{"group": group_name, "task": task_key, "inst": ds.instances[i]} for i in idxs]
#         return samples, Counter([task_key]*len(samples))

#     sgu_existing = [k for k in fixed_groups["sgu"] if k in datasets]
#     if not sgu_existing:
#         return [], Counter()

#     if sgu_sampling == "mix_pool":
#         pool = []
#         for tk in sgu_existing:
#             ds = datasets[tk]
#             for inst in ds.instances:
#                 pool.append((tk, inst))
#         rng.shuffle(pool)
#         pool = pool[:min(n_each, len(pool))]
#         samples = [{"group": "sgu", "task": tk, "inst": inst} for (tk, inst) in pool]
#         return samples, Counter([tk for tk, _ in pool])

#     elif sgu_sampling == "uniform_task":
#         per_task_idxs = {}
#         for tk in sgu_existing:
#             idxs = list(range(len(datasets[tk].instances)))
#             rng.shuffle(idxs)

#         samples = []
#         counts = Counter()
#         for _ in range(n_each * 20):
#             if len(samples) >= n_each:
#                 break
#             tk = rng.choice(sgu_existing)
#             if not per_task_idxs[tk]:
#                 continue
#             i = per_task_idxs[tk].pop()
#             samples.append({"group": "sgu", "task": tk, "inst": datasets[tk].instances[i]})
#             counts[tk] += 1

#         return samples, counts

#     else:
#         raise ValueError("sgu_sampling must be 'mix_pool' or 'uniform_task'")

# # =========================
# # =========================
# def eval_samples_for_model(mw_local, samples, model_name):
#     """
#     samples: list of dict {group, task, inst}
#     """
#     recs = []
#     if not samples:
#         return recs

#     groups = {s["group"] for s in samples}
#     group_tag = list(groups)[0] if len(groups) == 1 else "mixed"

#     for s in tqdm(samples, desc=f"{model_name}|{group_tag}", leave=False):
#         g = s["group"]
#         tk = s["task"]
#         inst = s["inst"]
#         q = inst.prompt

#         # Task 1: solver
#         solver_resp = generate_strict(mw_local, build_solver_prompt(q), max_tokens=15)
#         pred_yesno = parse_yesno(solver_resp)

#         correct = None
#         if pred_yesno is not None and inst.label01 is not None:
#             correct = int(pred_yesno == expected_yesno_from_label01(inst.label01))

#         # Task 2: contamination
#         contam_resp = generate_strict(mw_local, build_contamination_prompt(q), max_tokens=5)
#         contam = parse_012(contam_resp)

#         recs.append({
#             "model": model_name,
#             "group": g,
#             "task": tk,
#             "solver_pred": pred_yesno,
#             "solver_correct": correct,
#             "contam": contam,
#             "solver_raw": solver_resp,
#             "contam_raw": contam_resp,
#         })

#     return recs

# # =========================
# # =========================
# def run_all_models_with_sgu_mixed(model_specs, n_each=50, seed=0, sgu_sampling="mix_pool", use_4bit=False):
#     group_samples = {}
#     group_task_counts = {}

#     for g in fixed_groups.keys():
#         samples, counts = sample_group_instances(g, n_each=n_each, seed=seed + hash(g) % 100000, sgu_sampling=sgu_sampling)
#         group_samples[g] = samples
#         group_task_counts[g] = counts

#     print("\n===== SAMPLING SUMMARY =====")
#     for g in fixed_groups.keys():
#         print(f"\n[{g}] sampled={len(group_samples[g])}")
#         if g == "sgu":
#             print("SGU subtask counts:")
#             for k, v in group_task_counts[g].most_common():
#                 print(f"  - {k}: {v}")

#     all_records = []

#     for spec in model_specs:
#         print("\n" + "#" * 90)
#         print(f"LOADING: {spec['name']} | {spec['model_id']} | 4bit={use_4bit} | SGU_sampling={sgu_sampling}")
#         print("#" * 90)

#         mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

#         for g, samples in group_samples.items():
#             if not samples:
#                 continue
#             recs = eval_samples_for_model(mw_local, samples, spec["name"])
#             all_records.extend(recs)

#         # cleanup
#         del mw_local
#         gc.collect()
#         if torch.cuda.is_available():
#             torch.cuda.empty_cache()

#     df = pd.DataFrame(all_records)

#     task_freq = df.groupby(["group","task"]).size().rename("count").reset_index().sort_values(["group","count"], ascending=[True, False])
#     print("\n===== ACTUAL TASK FREQUENCY IN EVAL (all models concatenated; per-model same samples) =====")
#     with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 140):
#         print(task_freq.to_string(index=False))

#     return df

# # =========================
# # =========================
# raw_df_mixed = run_all_models_with_sgu_mixed(
#     MODEL_SPECS,
#     n_each=50,
#     seed=123,
#     use_4bit=USE_4BIT
# )

# df_valid = raw_df_mixed[ raw_df_mixed["solver_correct"].notna() & raw_df_mixed["contam"].isin(["0","1","2"]) ].copy()
# # =========================
# # =========================

# import pandas as pd


# df = raw_df_mixed.copy()

# # solver_correct: None/0/1
# # contam: "0"/"1"/"2"/None
# df_valid = df[df["solver_correct"].notna() & df["contam"].isin(["0","1","2"])].copy()
# df_valid["solver_correct"] = df_valid["solver_correct"].astype(int)

# print("TOTAL records:", len(df), "| VALID records:", len(df_valid), "| valid_rate:", len(df_valid)/max(1,len(df)))

# cover = df.groupby(["model","group","task"]).size().rename("total").reset_index()
# used  = df_valid.groupby(["model","group","task"]).size().rename("usable").reset_index()
# coverage = cover.merge(used, on=["model","group","task"], how="left").fillna({"usable":0})
# coverage["usable"] = coverage["usable"].astype(int)
# coverage["usable_rate"] = coverage["usable"] / coverage["total"]

# print("\n===== COVERAGE by model/group/task (valid solver + valid contam) =====")
# with pd.option_context("display.max_rows", 5000, "display.max_columns", 20, "display.width", 220):
#     print(coverage.sort_values(["model","group","task"]).to_string(index=False))

# overall = (
#     df_valid.groupby(["model","contam"])["solver_correct"]
#     .agg(acc="mean", n="count")
#     .reset_index()
#     .sort_values(["model","contam"])
# )

# print("\n===== OVERALL conditional accuracy by self-report (0/1/2) =====")
# with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 120):
#     print(overall.to_string(index=False))

# by_group = (
#     df_valid.groupby(["model","group","contam"])["solver_correct"]
#     .agg(acc="mean", n="count")
#     .reset_index()
#     .sort_values(["model","group","contam"])
# )

# print("\n===== BY-GROUP conditional accuracy by self-report (0/1/2) =====")
# with pd.option_context("display.max_rows", 5000, "display.max_columns", 12, "display.width", 160):
#     print(by_group.to_string(index=False))

# by_task = (
#     df_valid.groupby(["model","group","task","contam"])["solver_correct"]
#     .agg(acc="mean", n="count")
#     .reset_index()
#     .sort_values(["model","group","task","contam"])
# )

# print("\n===== BY-TASK conditional accuracy by self-report (0/1/2) =====")
# with pd.option_context("display.max_rows", 20000, "display.max_columns", 14, "display.width", 220):
#     print(by_task.to_string(index=False))

# dist_by_group = (
#     df_valid.groupby(["model","group","contam"]).size()
#     .rename("count")
#     .reset_index()
#     .sort_values(["model","group","contam"])
# )

# print("\n===== BY-GROUP self-report distribution counts (VALID samples) =====")
# with pd.option_context("display.max_rows", 5000, "display.max_columns", 10, "display.width", 140):
#     print(dist_by_group.to_string(index=False))

# dist_by_task = (
#     df_valid.groupby(["model","group","task","contam"]).size()
#     .rename("count")
#     .reset_index()
#     .sort_values(["model","group","task","contam"])
# )

# print("\n===== BY-TASK self-report distribution counts (VALID samples) =====")
# with pd.option_context("display.max_rows", 20000, "display.max_columns", 12, "display.width", 180):
#     print(dist_by_task.to_string(index=False))

# first_model = df["model"].iloc[0]
# df_one = df[df["model"] == first_model].copy()

# sampled_task_counts = (
#     df_one.groupby(["group","task"]).size()
#     .rename("sampled_n")
#     .reset_index()
#     .sort_values(["group","sampled_n"], ascending=[True, False])
# )

# print("\n===== SAMPLED COUNTS per group/task (from one model; represents the shared sampled set) =====")
# with pd.option_context("display.max_rows", 5000, "display.max_columns", 10, "display.width", 140):
#     print(sampled_task_counts.to_string(index=False))

# overall_wide = overall.pivot(index="model", columns="contam", values="acc").reset_index()
# overall_wide = overall_wide.rename(columns={"0":"acc_if_0", "1":"acc_if_1", "2":"acc_if_2"})

# print("\n===== OVERALL (wide) =====")
# with pd.option_context("display.max_rows", 200, "display.max_columns", 20, "display.width", 120):
#     print(overall_wide.to_string(index=False))



In [None]:
# =========================
# FULL SCRIPT: SGU-mixed evaluation + contamination self-report as {3,5,7}
# - mixes SGU subtasks into one pool and samples n_each questions total for SGU
# - for other groups: picks the first existing task and samples n_each
# - runs all models, collects per-question joint records:
#     (solver YES/NO correctness) + (contam self-report 3/5/7)
# - prints: sampling summary, task frequency, coverage, conditional accuracies,
#           distribution counts, and wide overall table
# - DOES NOT SAVE anything
#
# PREREQS expected in your notebook:
#   - datasets: Dict[str, TaskDataset] with TaskDataset.instances of TaskInstance(prompt,label01,...)
#   - HFModelWrapper with .load() and mw.tokenizer / mw.model
#   - MODEL_SPECS list
#   - USE_4BIT boolean
#   - generate_strict(...) already defined OR define it below (recommended)
#   - build_solver_prompt(...) and parse_yesno(...) already defined OR define below
#   If you already defined generate_strict/build_solver_prompt/parse_yesno earlier,
#   you can keep them; this script defines them for completeness.
# =========================

import random
import re
import gc
import torch
import pandas as pd
from collections import Counter
from tqdm.auto import tqdm
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# =========================
# 0) Stop + strict generation (first-line cut)
# =========================
class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.stop_token_ids = []
        for s in stop_strings:
            ids = tokenizer.encode(s, add_special_tokens=False)
            if ids:
                self.stop_token_ids.append(ids[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok, model = mw.tokenizer, mw.model
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])
    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper]),
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# =========================
# 1) Prompts + parsers
# =========================
def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def parse_yesno(s: str):
    if not s:
        return None
    t = s.strip().upper()
    if "YES" in t:
        return "YES"
    if "NO" in t:
        return "NO"
    return None

def expected_yesno_from_label01(label01: int):
    return "YES" if int(label01) == 1 else "NO"

# ---- contamination: use 3/5/7 ----
def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit from {3, 5, 7}:\n"
        "3 = SEEN the same question before\n"
        "5 = Seen very similar questions before but not the same question\n"
        "7 = Not seen similar question before\n"
        "Ignore any YES/NO question in the content even if it asks.\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (3, 5, or 7):"

def parse_357(s: str):
    if not s:
        return None
    m = re.search(r"\b([357])\b", s.strip())
    return m.group(1) if m else None

# =========================
# 2) Groups + SGU subtasks (SGU combined)
# =========================
SGU_TASKS = [
    "sgu_collatz_aligned", "sgu_collatz_aligned__sgu",
    "sgu_index_empty_language", "sgu_index_empty_language__sgu",
    "sgu_index_total_halt", "sgu_index_total_halt__sgu",
    "sgu_semigroup_wp_amp", "sgu_semigroup_wp_amp__sgu",
    "tm_generic_halt", "tm_hard_halt",
]

fixed_groups = {
    "aqua": ["aqua_binary", "aqua_mc_test"],
    "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
    "sgu": SGU_TASKS,
}

def pick_first_existing_dataset(task_list):
    for k in task_list:
        if k in datasets:
            return k
    return None

# =========================
# 3) Sampling
#    - non-SGU: pick first existing task and sample n_each w/o replacement
#    - SGU: mix subtasks into one pool and sample n_each total (mix_pool) OR
#           choose task uniformly then sample inside (uniform_task)
# =========================
MIX_GROUPS = {"sgu", "np_time", "p_time"}

def sample_group_instances(group_name, n_each, seed=0, sgu_sampling="mix_pool"):
    rng = random.Random(seed)

    if group_name in MIX_GROUPS:
        existing = [k for k in fixed_groups[group_name] if k in datasets]
        if not existing:
            return [], Counter()

        if sgu_sampling == "mix_pool":
            pool = []
            for tk in existing:
                for inst in datasets[tk].instances:
                    pool.append((tk, inst))
            rng.shuffle(pool)
            pool = pool[:min(n_each, len(pool))]
            samples = [{"group": group_name, "task": tk, "inst": inst} for (tk, inst) in pool]
            return samples, Counter([tk for tk, _ in pool])

        elif sgu_sampling == "uniform_task":
            per_task_idxs = {}
            for tk in existing:
                idxs = list(range(len(datasets[tk].instances)))
                rng.shuffle(idxs)
                per_task_idxs[tk] = idxs

            samples = []
            counts = Counter()
            for _ in range(n_each * 30):
                if len(samples) >= n_each:
                    break
                tk = rng.choice(existing)
                if not per_task_idxs[tk]:
                    continue
                i = per_task_idxs[tk].pop()
                samples.append({"group": group_name, "task": tk, "inst": datasets[tk].instances[i]})
                counts[tk] += 1
            return samples, counts

        else:
            raise ValueError("sgu_sampling must be 'mix_pool' or 'uniform_task'")

    task_key = pick_first_existing_dataset(fixed_groups[group_name])
    if task_key is None:
        return [], Counter()
    ds = datasets[task_key]
    n = min(n_each, len(ds.instances))
    idxs = list(range(len(ds.instances)))
    rng.shuffle(idxs)
    idxs = idxs[:n]
    samples = [{"group": group_name, "task": task_key, "inst": ds.instances[i]} for i in idxs]
    return samples, Counter([task_key] * len(samples))

# =========================
# 4) Eval one model on a sample list
# =========================
def eval_samples_for_model(mw_local, samples, model_name):
    recs = []
    if not samples:
        return recs

    groups = {s["group"] for s in samples}
    group_tag = list(groups)[0] if len(groups) == 1 else "mixed"

    for s in tqdm(samples, desc=f"{model_name}|{group_tag}", leave=False):
        g = s["group"]
        tk = s["task"]
        inst = s["inst"]
        q = inst.prompt

        # solver
        solver_resp = generate_strict(mw_local, build_solver_prompt(q), max_tokens=15)
        pred_yesno = parse_yesno(solver_resp)

        correct = None
        if pred_yesno is not None and inst.label01 is not None:
            correct = int(pred_yesno == expected_yesno_from_label01(inst.label01))

        # contamination (3/5/7)
        contam_resp = generate_strict(mw_local, build_contamination_prompt(q), max_tokens=5)
        contam = parse_357(contam_resp)

        recs.append({
            "model": model_name,
            "group": g,
            "task": tk,
            "solver_pred": pred_yesno,
            "solver_correct": correct,
            "contam": contam,         # "3"/"5"/"7"/None
            "solver_raw": solver_resp,
            "contam_raw": contam_resp,
        })
    return recs

# =========================
# 5) Run all models
# =========================
def run_all_models_with_sgu_mixed(model_specs, n_each=50, seed=123, sgu_sampling="mix_pool", use_4bit=False):
    # sample once (shared across models)
    group_samples = {}
    group_task_counts = {}
    for g in fixed_groups.keys():
        samples, counts = sample_group_instances(
            g, n_each=n_each, seed=seed + (hash(g) % 100000), sgu_sampling=sgu_sampling
        )
        group_samples[g] = samples
        group_task_counts[g] = counts

    print("\n===== SAMPLING SUMMARY =====")
    for g in fixed_groups.keys():
        print(f"\n[{g}] sampled={len(group_samples[g])}")
        if g == "sgu":
            print("SGU subtask counts:")
            for k, v in group_task_counts[g].most_common():
                print(f"  - {k}: {v}")

    all_records = []
    for spec in model_specs:
        print("\n" + "#" * 90)
        print(f"LOADING: {spec['name']} | {spec['model_id']} | 4bit={use_4bit} | SGU_sampling={sgu_sampling}")
        print("#" * 90)

        mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

        for g, samples in group_samples.items():
            if not samples:
                continue
            all_records.extend(eval_samples_for_model(mw_local, samples, spec["name"]))

        del mw_local
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    df = pd.DataFrame(all_records)

    task_freq = (
        df.groupby(["group", "task"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["group", "count"], ascending=[True, False])
    )
    print("\n===== ACTUAL TASK FREQUENCY IN EVAL (all models concatenated; per-model same samples) =====")
    with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 140):
        print(task_freq.to_string(index=False))

    return df

# =========================
# 6) Aggregations (conditional accuracies + distributions)
# =========================
def analyze(df):
    df_valid = df[df["solver_correct"].notna() & df["contam"].isin(["3", "5", "7"])].copy()
    df_valid["solver_correct"] = df_valid["solver_correct"].astype(int)

    print(f"\nTOTAL records: {len(df)} | VALID records: {len(df_valid)} | valid_rate: {len(df_valid)/max(1,len(df)):.4f}")

    # coverage by model/group/task
    cover = df.groupby(["model","group","task"]).size().rename("total").reset_index()
    used  = df_valid.groupby(["model","group","task"]).size().rename("usable").reset_index()
    coverage = cover.merge(used, on=["model","group","task"], how="left").fillna({"usable":0})
    coverage["usable"] = coverage["usable"].astype(int)
    coverage["usable_rate"] = coverage["usable"] / coverage["total"]

    print("\n===== COVERAGE by model/group/task (valid solver + valid contam {3,5,7}) =====")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 20, "display.width", 220):
        print(coverage.sort_values(["model","group","task"]).to_string(index=False))

    # overall conditional accuracy
    overall = (
        df_valid.groupby(["model","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","contam"])
    )
    print("\n===== OVERALL conditional accuracy by self-report (3/5/7) =====")
    with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 120):
        print(overall.to_string(index=False))

    # by-group conditional accuracy
    by_group = (
        df_valid.groupby(["model","group","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","group","contam"])
    )
    print("\n===== BY-GROUP conditional accuracy by self-report (3/5/7) =====")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 12, "display.width", 160):
        print(by_group.to_string(index=False))

    # by-task conditional accuracy
    by_task = (
        df_valid.groupby(["model","group","task","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","group","task","contam"])
    )
    print("\n===== BY-TASK conditional accuracy by self-report (3/5/7) =====")
    with pd.option_context("display.max_rows", 20000, "display.max_columns", 14, "display.width", 220):
        print(by_task.to_string(index=False))

    # distributions (counts) within VALID
    dist_by_group = (
        df_valid.groupby(["model","group","contam"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["model","group","contam"])
    )
    print("\n===== BY-GROUP self-report distribution counts (VALID samples) =====")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 10, "display.width", 140):
        print(dist_by_group.to_string(index=False))

    dist_by_task = (
        df_valid.groupby(["model","group","task","contam"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["model","group","task","contam"])
    )
    print("\n===== BY-TASK self-report distribution counts (VALID samples) =====")
    with pd.option_context("display.max_rows", 20000, "display.max_columns", 12, "display.width", 180):
        print(dist_by_task.to_string(index=False))

    # sampled counts (from one model) to show SGU mix distribution
    first_model = df["model"].iloc[0]
    df_one = df[df["model"] == first_model].copy()
    sampled_task_counts = (
        df_one.groupby(["group","task"]).size()
        .rename("sampled_n")
        .reset_index()
        .sort_values(["group","sampled_n"], ascending=[True, False])
    )
    print("\n===== SAMPLED COUNTS per group/task (from one model; represents shared sampled set) =====")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 10, "display.width", 140):
        print(sampled_task_counts.to_string(index=False))

    # wide summary
    overall_wide = overall.pivot(index="model", columns="contam", values="acc").reset_index()
    overall_wide = overall_wide.rename(columns={"3":"acc_if_3", "5":"acc_if_5", "7":"acc_if_7"})
    print("\n===== OVERALL (wide) =====")
    with pd.option_context("display.max_rows", 200, "display.max_columns", 20, "display.width", 140):
        print(overall_wide.to_string(index=False))

    return df_valid, coverage, overall, by_group, by_task, dist_by_group, dist_by_task, sampled_task_counts, overall_wide

# =========================
# 7) RUN (edit params here)
# =========================
raw_df_mixed_357 = run_all_models_with_sgu_mixed(
    MODEL_SPECS,
    n_each=300,
    seed=123,
    sgu_sampling="mix_pool",   # or "uniform_task"
    use_4bit=USE_4BIT
)

df_valid_357, coverage_357, overall_357, by_group_357, by_task_357, dist_by_group_357, dist_by_task_357, sampled_task_counts_357, overall_wide_357 = analyze(raw_df_mixed_357)


In [None]:
# =========================
# FULL SCRIPT (UPDATED + SAVE):
# - contamination self-report as {3,5,7}
# - SGU / P-time / NP-time ALL do multi-subtask mixed sampling (mix_pool or uniform_task)
# - sampling summary prints subtask breakdown for SGU, P-time, NP-time (sampled vs total)
# - runs all models, collects joint records, prints all aggregations
# - SAVES outputs to CSV (raw + valid + summary tables) to your chosen folder
# =========================

import os
import random
import re
import gc
import torch
import pandas as pd
from collections import Counter
from tqdm.auto import tqdm
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList

# =========================
# 0) Stop + strict generation (first-line cut)
# =========================
class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.stop_token_ids = []
        for s in stop_strings:
            ids = tokenizer.encode(s, add_special_tokens=False)
            if ids:
                self.stop_token_ids.append(ids[-1])

    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0:
            return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(mw, prompt: str, max_tokens: int = 10):
    tok, model = mw.tokenizer, mw.model
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])
    gen_config = GenerationConfig(
        do_sample=False,
        temperature=0.0,
        max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id,
        eos_token_id=tok.eos_token_id,
    )

    with torch.no_grad():
        out = model.generate(
            **inputs,
            generation_config=gen_config,
            stopping_criteria=StoppingCriteriaList([stopper]),
        )

    full = tok.decode(out[0], skip_special_tokens=True)
    resp = full[len(prompt):] if full.startswith(prompt) else full
    return resp.split("\n")[0].strip()

# =========================
# 1) Prompts + parsers
# =========================
def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def parse_yesno(s: str):
    if not s:
        return None
    t = s.strip().upper()
    if "YES" in t:
        return "YES"
    if "NO" in t:
        return "NO"
    return None

def expected_yesno_from_label01(label01: int):
    return "YES" if int(label01) == 1 else "NO"

# ---- contamination: 3/5/7 ----
def build_contamination_prompt(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit from {3, 5, 7}:\n"
        "3 = SEEN the same question before\n"
        "5 = Seen very similar questions before but not the same question\n"
        "7 = Not seen similar question before\n"
        "Ignore any YES/NO question in the content even if it asks.\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (3, 5, or 7):"

def parse_357(s: str):
    if not s:
        return None
    m = re.search(r"\b([357])\b", s.strip())
    return m.group(1) if m else None

# =========================
# 2) Groups + subtasks
# =========================
SGU_TASKS = [
    "sgu_collatz_aligned", "sgu_collatz_aligned__sgu",
    "sgu_index_empty_language", "sgu_index_empty_language__sgu",
    "sgu_index_total_halt", "sgu_index_total_halt__sgu",
    "sgu_semigroup_wp_amp", "sgu_semigroup_wp_amp__sgu",
    "tm_generic_halt", "tm_hard_halt",
]

fixed_groups = {
    "aqua": ["aqua_binary", "aqua_mc_test"],
    "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
    "sgu": SGU_TASKS,
}

def pick_first_existing_dataset(task_list):
    for k in task_list:
        if k in datasets:
            return k
    return None

# ✅ These groups do multi-subtask mixed sampling
MIX_GROUPS = {"sgu", "np_time", "p_time"}

# =========================
# 3) Sampling (mixed for SGU/P/NP)
# =========================
def sample_group_instances(group_name, n_each, seed=0, mix_sampling="mix_pool"):
    rng = random.Random(seed)

    # ---- mixed groups: sgu / p_time / np_time ----
    if group_name in MIX_GROUPS:
        existing = [k for k in fixed_groups[group_name] if k in datasets]
        if not existing:
            return [], Counter()

        if mix_sampling == "mix_pool":
            pool = []
            for tk in existing:
                for inst in datasets[tk].instances:
                    pool.append((tk, inst))
            rng.shuffle(pool)
            pool = pool[:min(n_each, len(pool))]
            samples = [{"group": group_name, "task": tk, "inst": inst} for (tk, inst) in pool]
            return samples, Counter([tk for tk, _ in pool])

        elif mix_sampling == "uniform_task":
            per_task_idxs = {}
            for tk in existing:
                idxs = list(range(len(datasets[tk].instances)))
                rng.shuffle(idxs)
                per_task_idxs[tk] = idxs

            samples = []
            counts = Counter()
            for _ in range(n_each * 50):
                if len(samples) >= n_each:
                    break
                tk = rng.choice(existing)
                if not per_task_idxs[tk]:
                    continue
                i = per_task_idxs[tk].pop()
                samples.append({"group": group_name, "task": tk, "inst": datasets[tk].instances[i]})
                counts[tk] += 1
            return samples, counts

        else:
            raise ValueError("mix_sampling must be 'mix_pool' or 'uniform_task'")

    # ---- other groups: pick first existing task ----
    task_key = pick_first_existing_dataset(fixed_groups[group_name])
    if task_key is None:
        return [], Counter()
    ds = datasets[task_key]
    n = min(n_each, len(ds.instances))
    idxs = list(range(len(ds.instances)))
    rng.shuffle(idxs)
    idxs = idxs[:n]
    samples = [{"group": group_name, "task": task_key, "inst": ds.instances[i]} for i in idxs]
    return samples, Counter([task_key] * len(samples))

def _print_subtask_breakdown(group_name, sampled_counts: Counter):
    existing = [k for k in fixed_groups[group_name] if k in datasets]
    if not existing:
        print(f"\n[{group_name}] MISSING (no tasks found)")
        return

    rows = []
    for tk in existing:
        rows.append({
            "task": tk,
            "sampled": int(sampled_counts.get(tk, 0)),
            "total_available": len(datasets[tk].instances),
        })
    dfb = pd.DataFrame(rows).sort_values(["sampled", "total_available"], ascending=[False, False])

    print(f"\n[{group_name}] subtask breakdown (sampled vs total)")
    with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 160):
        print(dfb.to_string(index=False))

# =========================
# 4) Eval one model on a sample list
# =========================
def eval_samples_for_model(mw_local, samples, model_name):
    recs = []
    if not samples:
        return recs

    groups = {s["group"] for s in samples}
    group_tag = list(groups)[0] if len(groups) == 1 else "mixed"

    for s in tqdm(samples, desc=f"{model_name}|{group_tag}", leave=False):
        g = s["group"]
        tk = s["task"]
        inst = s["inst"]
        q = inst.prompt

        solver_resp = generate_strict(mw_local, build_solver_prompt(q), max_tokens=15)
        pred_yesno = parse_yesno(solver_resp)

        correct = None
        if pred_yesno is not None and inst.label01 is not None:
            correct = int(pred_yesno == expected_yesno_from_label01(inst.label01))

        contam_resp = generate_strict(mw_local, build_contamination_prompt(q), max_tokens=5)
        contam = parse_357(contam_resp)

        recs.append({
            "model": model_name,
            "group": g,
            "task": tk,
            "solver_pred": pred_yesno,
            "solver_correct": correct,
            "contam": contam,         # "3"/"5"/"7"/None
            "solver_raw": solver_resp,
            "contam_raw": contam_resp,
        })
    return recs

# =========================
# 5) Run all models + SAVE raw
# =========================
def run_all_models_mixed(model_specs, n_each=50, seed=123, mix_sampling="mix_pool", use_4bit=False):
    group_samples = {}
    group_task_counts = {}

    for g in fixed_groups.keys():
        samples, counts = sample_group_instances(
            g, n_each=n_each, seed=seed + (hash(g) % 100000), mix_sampling=mix_sampling
        )
        group_samples[g] = samples
        group_task_counts[g] = counts

    print("\n===== SAMPLING SUMMARY =====")
    for g in fixed_groups.keys():
        print(f"\n[{g}] sampled={len(group_samples[g])}")
        if g in MIX_GROUPS:
            _print_subtask_breakdown(g, group_task_counts[g])

    all_records = []
    for spec in model_specs:
        print("\n" + "#" * 90)
        print(f"LOADING: {spec['name']} | {spec['model_id']} | 4bit={use_4bit} | mix_sampling={mix_sampling}")
        print("#" * 90)

        mw_local = HFModelWrapper.load(spec["model_id"], adapter_path=spec.get("adapter"), use_4bit=use_4bit)

        for g, samples in group_samples.items():
            if not samples:
                continue
            all_records.extend(eval_samples_for_model(mw_local, samples, spec["name"]))

        del mw_local
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    df = pd.DataFrame(all_records)

    task_freq = (
        df.groupby(["group", "task"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["group", "count"], ascending=[True, False])
    )
    print("\n===== ACTUAL TASK FREQUENCY IN EVAL (all models concatenated; per-model same samples) =====")
    with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 140):
        print(task_freq.to_string(index=False))

    return df

# =========================
# 6) Aggregations + SAVE tables
# =========================
def analyze_and_save(df, out_dir, prefix):
    os.makedirs(out_dir, exist_ok=True)

    df_valid = df[df["solver_correct"].notna() & df["contam"].isin(["3", "5", "7"])].copy()
    df_valid["solver_correct"] = df_valid["solver_correct"].astype(int)

    print(f"\nTOTAL records: {len(df)} | VALID records: {len(df_valid)} | valid_rate: {len(df_valid)/max(1,len(df)):.4f}")

    cover = df.groupby(["model","group","task"]).size().rename("total").reset_index()
    used  = df_valid.groupby(["model","group","task"]).size().rename("usable").reset_index()
    coverage = cover.merge(used, on=["model","group","task"], how="left").fillna({"usable":0})
    coverage["usable"] = coverage["usable"].astype(int)
    coverage["usable_rate"] = coverage["usable"] / coverage["total"]

    overall = (
        df_valid.groupby(["model","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","contam"])
    )

    by_group = (
        df_valid.groupby(["model","group","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","group","contam"])
    )

    by_task = (
        df_valid.groupby(["model","group","task","contam"])["solver_correct"]
        .agg(acc="mean", n="count")
        .reset_index()
        .sort_values(["model","group","task","contam"])
    )

    dist_by_group = (
        df_valid.groupby(["model","group","contam"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["model","group","contam"])
    )

    dist_by_task = (
        df_valid.groupby(["model","group","task","contam"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["model","group","task","contam"])
    )

    task_freq = (
        df.groupby(["group", "task"]).size()
        .rename("count")
        .reset_index()
        .sort_values(["group", "count"], ascending=[True, False])
    )

    overall_wide = overall.pivot(index="model", columns="contam", values="acc").reset_index()
    overall_wide = overall_wide.rename(columns={"3":"acc_if_3", "5":"acc_if_5", "7":"acc_if_7"})

    # ---------- print ----------
    print("\n===== COVERAGE by model/group/task =====")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 20, "display.width", 220):
        print(coverage.sort_values(["model","group","task"]).to_string(index=False))

    print("\n===== OVERALL conditional accuracy (3/5/7) =====")
    with pd.option_context("display.max_rows", 500, "display.max_columns", 10, "display.width", 120):
        print(overall.to_string(index=False))

    print("\n===== OVERALL (wide) =====")
    with pd.option_context("display.max_rows", 200, "display.max_columns", 20, "display.width", 140):
        print(overall_wide.to_string(index=False))

    # ---------- save ----------
    paths = {}
    def _save(df_, name):
        path = os.path.join(out_dir, f"{prefix}_{name}.csv")
        df_.to_csv(path, index=False)
        paths[name] = path

    _save(df, "raw_records")
    _save(df_valid, "valid_records")
    _save(task_freq, "task_frequency")
    _save(coverage, "coverage")
    _save(overall, "overall_long")
    _save(overall_wide, "overall_wide")
    _save(by_group, "by_group_acc")
    _save(by_task, "by_task_acc")
    _save(dist_by_group, "by_group_dist")
    _save(dist_by_task, "by_task_dist")

    print("\n===== SAVED FILES =====")
    for k, v in paths.items():
        print(f"{k}: {v}")

    return df_valid, coverage, overall, by_group, by_task, dist_by_group, dist_by_task, task_freq, overall_wide, paths

# =========================
# 7) RUN (edit params here)
# =========================


import os
import pandas as pd

# ----------------------------
# ----------------------------
def save_raw_and_tables(
    df_raw: pd.DataFrame,
    out_dir: str,
    prefix: str,
    extra_tables: dict | None = None
):
    os.makedirs(out_dir, exist_ok=True)
    paths = {}

    raw_path = os.path.join(out_dir, f"{prefix}_RAW_RECORDS.csv")
    df_raw.to_csv(raw_path, index=False)
    paths["RAW_RECORDS"] = raw_path

    cols_keep = [c for c in [
        "model","group","task",
        "solver_pred","solver_correct","solver_raw",
        "contam","contam_raw"
    ] if c in df_raw.columns]
    raw_min_path = os.path.join(out_dir, f"{prefix}_RAW_MIN.csv")
    df_raw[cols_keep].to_csv(raw_min_path, index=False)
    paths["RAW_MIN"] = raw_min_path

    if extra_tables:
        for name, df_tab in extra_tables.items():
            if df_tab is None:
                continue
            p = os.path.join(out_dir, f"{prefix}_{name}.csv")
            df_tab.to_csv(p, index=False)
            paths[name] = p

    print("\n===== SAVED FILES =====")
    for k, v in paths.items():
        print(f"{k}: {v}")
    return paths


# ----------------------------
# ----------------------------
OUT_DIR = os.environ.get("EVAL_OUT_DIR", "/content/drive/MyDrive/complexity7/eval_outputs")
PREFIX = "mixed357_n300_seed123_mixpool"

raw_df_mixed_357 = run_all_models_mixed(
    MODEL_SPECS,
    n_each=300,
    seed=123,
    mix_sampling="mix_pool",   # or "uniform_task"
    use_4bit=USE_4BIT
)

df_valid_357, coverage_357, overall_357, by_group_357, by_task_357, dist_by_group_357, dist_by_task_357, sampled_task_counts_357, overall_wide_357 = analyze(raw_df_mixed_357)

task_freq_357 = (
    raw_df_mixed_357.groupby(["group","task"]).size()
    .rename("count").reset_index()
    .sort_values(["group","count"], ascending=[True, False])
)

paths = save_raw_and_tables(
    df_raw=raw_df_mixed_357,
    out_dir=OUT_DIR,
    prefix=PREFIX,
    extra_tables={
        "VALID_RECORDS": df_valid_357,
        "task_frequency": task_freq_357,
        "coverage": coverage_357,
        "overall_long": overall_357,
        "overall_wide": overall_wide_357,
        "by_group_acc": by_group_357,
        "by_task_acc": by_task_357,
        "by_group_dist": dist_by_group_357,
        "by_task_dist": dist_by_task_357,
        "sampled_task_counts": sampled_task_counts_357,
    }
)



In [None]:
import pandas as pd
import numpy as np

df = raw_df_mixed_357.copy()

def agg_acc(g: pd.DataFrame) -> pd.Series:
    total = len(g)
    fmt_n = int(g["solver_correct"].notna().sum())
    acc_on_fmt = float(g.loc[g["solver_correct"].notna(), "solver_correct"].mean()) if fmt_n > 0 else np.nan
    acc_overall = float(g["solver_correct"].fillna(0).mean()) if total > 0 else np.nan
    fmt_rate = fmt_n / total if total > 0 else 0.0
    return pd.Series({
        "n_total": total,
        "n_fmt": fmt_n,
        "fmt_rate": fmt_rate,
        "acc_on_fmt": acc_on_fmt,
        "acc_overall": acc_overall,
    })

group_acc = (
    df.groupby(["group"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["group"])
)

print("\n===== ACCURACY BY GROUP =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 160):
    print(group_acc.to_string(index=False))

task_acc = (
    df.groupby(["group","task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["group","acc_overall"], ascending=[True, False])
)

print("\n===== ACCURACY BY GROUP + TASK =====")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 50, "display.width", 200):
    print(task_acc.to_string(index=False))

model_task_acc = (
    df.groupby(["model","group","task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["model","group","acc_overall"], ascending=[True, True, False])
)

print("\n===== (OPTIONAL) ACCURACY BY MODEL + GROUP + TASK =====")
with pd.option_context("display.max_rows", 5000, "display.max_columns", 50, "display.width", 220):
    print(model_task_acc.to_string(index=False))
task_only_acc = (
    df.groupby(["task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["acc_overall"], ascending=False)
)

print("\n===== ACCURACY BY TASK (overall, across all groups/models) =====")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 50, "display.width", 200):
    print(task_only_acc.to_string(index=False))


In [None]:
import pandas as pd
import numpy as np

df = raw_df_mixed_357.copy()

def agg_acc(g: pd.DataFrame) -> pd.Series:
    total = len(g)
    fmt_n = int(g["solver_correct"].notna().sum())
    acc_on_fmt = float(g.loc[g["solver_correct"].notna(), "solver_correct"].mean()) if fmt_n > 0 else np.nan
    acc_overall = float(g["solver_correct"].fillna(0).mean()) if total > 0 else np.nan
    fmt_rate = fmt_n / total if total > 0 else 0.0
    return pd.Series({
        "n_total": total,
        "n_fmt": fmt_n,
        "fmt_rate": fmt_rate,
        "acc_on_fmt": acc_on_fmt,
        "acc_overall": acc_overall,
    })

group_acc = (
    df.groupby(["group"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["group"])
)

print("\n===== ACCURACY BY GROUP =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 160):
    print(group_acc.to_string(index=False))

task_acc = (
    df.groupby(["group","task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["group","acc_overall"], ascending=[True, False])
)

print("\n===== ACCURACY BY GROUP + TASK =====")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 50, "display.width", 200):
    print(task_acc.to_string(index=False))

model_task_acc = (
    df.groupby(["model","group","task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["model","group","acc_overall"], ascending=[True, True, False])
)

print("\n===== (OPTIONAL) ACCURACY BY MODEL + GROUP + TASK =====")
with pd.option_context("display.max_rows", 5000, "display.max_columns", 50, "display.width", 220):
    print(model_task_acc.to_string(index=False))
task_only_acc = (
    df.groupby(["task"])
      .apply(agg_acc)
      .reset_index()
      .sort_values(["acc_overall"], ascending=False)
)

print("\n===== ACCURACY BY TASK (overall, across all groups/models) =====")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 50, "display.width", 200):
    print(task_only_acc.to_string(index=False))


In [None]:
import numpy as np
import pandas as pd

df = raw_df_mixed_357.copy()

d = df[df["solver_correct"].notna()].copy()
d["y_pred"] = d["solver_pred"].astype(str).str.upper().map({"YES": 1, "NO": 0}).astype(int)
d["correct"] = d["solver_correct"].astype(int)
d["y_true"] = np.where(d["correct"] == 1, d["y_pred"], 1 - d["y_pred"])

def metrics(g):
    y = g["y_true"].to_numpy()
    p = g["y_pred"].to_numpy()
    pos = y.mean()
    tpr = ((p==1) & (y==1)).sum() / max(1, (y==1).sum())
    tnr = ((p==0) & (y==0)).sum() / max(1, (y==0).sum())
    return pd.Series({"n": len(g), "pos_rate": pos, "balanced_acc": 0.5*(tpr+tnr)})

out = (d.groupby("task").apply(metrics).reset_index()
       .sort_values("balanced_acc", ascending=False))

print(out.to_string(index=False))
import numpy as np
import pandas as pd

df = raw_df_mixed_357.copy()




In [None]:
d = df[df["solver_correct"].notna()].copy()
d["y_pred"] = d["solver_pred"].astype(str).str.upper().map({"YES": 1, "NO": 0})
d = d[d["y_pred"].notna()].copy()
d["y_pred"] = d["y_pred"].astype(int)
d["correct"] = d["solver_correct"].astype(int)
d["y_true"] = np.where(d["correct"] == 1, d["y_pred"], 1 - d["y_pred"]).astype(int)

# ---------- 2) balanced accuracy ----------
def balanced_acc(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    pos = (y_true == 1)
    neg = (y_true == 0)
    tpr = ((y_pred == 1) & pos).sum() / max(1, pos.sum())
    tnr = ((y_pred == 0) & neg).sum() / max(1, neg.sum())
    return 0.5 * (tpr + tnr)

def task_stat(df_task):
    return balanced_acc(df_task["y_true"].values, df_task["y_pred"].values)

def bootstrap_ci_task(d, task, B=2000, seed=0):
    rng = np.random.default_rng(seed)
    sub = d[d["task"] == task].copy()
    if len(sub) == 0:
        return np.nan, (np.nan, np.nan), 0

    by_model = {m: g for m, g in sub.groupby("model")}
    point = task_stat(sub)

    boots = []
    for _ in range(B):
        parts = []
        for m, g in by_model.items():
            idx = rng.integers(0, len(g), size=len(g))
            parts.append(g.iloc[idx])
        bb = pd.concat(parts, ignore_index=True)
        boots.append(task_stat(bb))
    boots = np.array(boots)
    lo, hi = np.quantile(boots, [0.025, 0.975])
    return point, (lo, hi), len(sub)

def perm_test_diff(d, task_a, task_b, B=20000, seed=0):
    rng = np.random.default_rng(seed)
    a = d[d["task"] == task_a].copy()
    b = d[d["task"] == task_b].copy()
    if len(a) == 0 or len(b) == 0:
        return np.nan, np.nan

    obs = task_stat(a) - task_stat(b)

    models = sorted(set(a["model"]).intersection(set(b["model"])))
    if not models:
        pool = pd.concat([a.assign(_t=0), b.assign(_t=1)], ignore_index=True)
        n_a = len(a)
        diffs = []
        for _ in range(B):
            perm = rng.permutation(len(pool))
            aa = pool.iloc[perm[:n_a]]
            bb = pool.iloc[perm[n_a:]]
            diffs.append(balanced_acc(aa["y_true"], aa["y_pred"]) - balanced_acc(bb["y_true"], bb["y_pred"]))
        diffs = np.array(diffs)
        p = (np.abs(diffs) >= abs(obs)).mean()
        return obs, p

    diffs = []
    for _ in range(B):
        parts_a = []
        parts_b = []
        for m in models:
            am = a[a["model"] == m]
            bm = b[b["model"] == m]
            pool = pd.concat([am, bm], ignore_index=True)
            perm = rng.permutation(len(pool))
            aa = pool.iloc[perm[:len(am)]]
            bb = pool.iloc[perm[len(am):]]
            parts_a.append(aa)
            parts_b.append(bb)
        aa = pd.concat(parts_a, ignore_index=True)
        bb = pd.concat(parts_b, ignore_index=True)
        diffs.append(task_stat(aa) - task_stat(bb))
    diffs = np.array(diffs)
    p = (np.abs(diffs) >= abs(obs)).mean()
    return obs, p

def holm_bonferroni(pvals):
    pvals = np.array(pvals, dtype=float)
    m = len(pvals)
    order = np.argsort(pvals)
    adj = np.empty(m, dtype=float)
    for i, idx in enumerate(order):
        adj[idx] = min(1.0, (m - i) * pvals[idx])
    for i in range(1, m):
        adj[order[i]] = max(adj[order[i]], adj[order[i-1]])
    return adj

tasks = sorted(d["task"].unique().tolist())
rows = []
for t in tasks:
    point, (lo, hi), n = bootstrap_ci_task(d, t, B=3000, seed=123)
    rows.append({"task": t, "n": n, "balanced_acc": point, "ci95_lo": lo, "ci95_hi": hi})
ci_table = pd.DataFrame(rows).sort_values("balanced_acc", ascending=False)

print("\n=== Balanced Accuracy + 95% Bootstrap CI (stratified by model) ===")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 20, "display.width", 160):
    print(ci_table.to_string(index=False))

baseline1 = "gsm8k_binary"
baseline2 = "tm_generic_halt"

comparisons = []
for base in [baseline1, baseline2]:
    if base not in tasks:
        continue
    for t in tasks:
        if t == base:
            continue
        obs, p = perm_test_diff(d, t, base, B=20000, seed=123)
        comparisons.append({"contrast": f"{t} - {base}", "diff_balacc": obs, "p_perm": p})

comp_df = pd.DataFrame(comparisons)
if len(comp_df) > 0:
    comp_df["p_holm"] = holm_bonferroni(comp_df["p_perm"].values)
    comp_df = comp_df.sort_values(["p_perm", "contrast"])

    print("\n=== Permutation test (stratified by model): diff in balanced_acc ===")
    with pd.option_context("display.max_rows", 5000, "display.max_columns", 20, "display.width", 200):
        print(comp_df.to_string(index=False))
else:
    print("\n[WARN] No comparisons run (baseline task missing).")

In [None]:
import numpy as np
import pandas as pd

df = raw_df_mixed_357.copy()

d = df[df["solver_correct"].notna()].copy()
d["y_pred"] = d["solver_pred"].astype(str).str.upper().map({"YES": 1, "NO": 0})
d = d[d["y_pred"].notna()].copy()
d["y_pred"] = d["y_pred"].astype(int)
d["correct"] = d["solver_correct"].astype(int)
d["y_true"] = np.where(d["correct"] == 1, d["y_pred"], 1 - d["y_pred"]).astype(int)

def balanced_acc(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    pos = (y_true == 1)
    neg = (y_true == 0)
    tpr = ((y_pred == 1) & pos).sum() / max(1, pos.sum())
    tnr = ((y_pred == 0) & neg).sum() / max(1, neg.sum())
    return 0.5 * (tpr + tnr)

def map_group_for_table(row):
    if row["task"] == "tm_generic_halt":
        return "tm_generic_halt"
    if row["group"] == "sgu":
        return "sgu_other"
    return row["group"]

d["group_for_table"] = d.apply(map_group_for_table, axis=1)

def agg(g):
    return pd.Series({
        "n": len(g),
        "balanced_acc": balanced_acc(g["y_true"].values, g["y_pred"].values),
        "acc_on_fmt": g["correct"].mean(),
        "pos_rate": g["y_true"].mean(),
    })

stats = (
    d.groupby(["model", "group_for_table"])
     .apply(agg)
     .reset_index()
)

order_cols = ["gsm8k", "aqua", "p_time", "np_time", "sgu_other", "tm_generic_halt"]

bal_wide = stats.pivot(index="model", columns="group_for_table", values="balanced_acc")
n_wide   = stats.pivot(index="model", columns="group_for_table", values="n")
acc_wide = stats.pivot(index="model", columns="group_for_table", values="acc_on_fmt")
pos_wide = stats.pivot(index="model", columns="group_for_table", values="pos_rate")

for w in [bal_wide, n_wide, acc_wide, pos_wide]:
    for c in order_cols:
        if c not in w.columns:
            w[c] = np.nan
    w = w[order_cols]

bal_wide = bal_wide[order_cols]
n_wide   = n_wide[order_cols]
acc_wide = acc_wide[order_cols]
pos_wide = pos_wide[order_cols]

print("\n===== MODEL x GROUP/TASK (Balanced Accuracy) =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 200):
    print(bal_wide.to_string())

print("\n===== MODEL x GROUP/TASK (N used) =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 200):
    print(n_wide.to_string())

print("\n===== (OPTIONAL) MODEL x GROUP/TASK (acc_on_fmt) =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 200):
    print(acc_wide.to_string())

print("\n===== (OPTIONAL) MODEL x GROUP/TASK (pos_rate) =====")
with pd.option_context("display.max_rows", 200, "display.max_columns", 50, "display.width", 200):
    print(pos_wide.to_string())


In [None]:
import numpy as np
import pandas as pd
import math

df = raw_df_mixed_357.copy()

d = df[df["solver_correct"].notna()].copy()
d["correct"] = d["solver_correct"].astype(int)

def map_group_for_table(row):
    if row["task"] == "tm_generic_halt":
        return "tm_generic_halt"
    if row["group"] == "sgu":
        return "sgu_other"
    return row["group"]

d["group_for_table"] = d.apply(map_group_for_table, axis=1)

d["item_idx"] = d.groupby(["model", "group_for_table", "task"]).cumcount()
d["pair_id"] = d["group_for_table"].astype(str) + "|" + d["task"].astype(str) + "|" + d["item_idx"].astype(str)

def cochran_q(mat):
    """
    mat: numpy array shape (n_items, k_models) with 0/1
    returns (Q, p)
    """
    n, k = mat.shape
    Tj = mat.sum(axis=0)          # col sums
    Ri = mat.sum(axis=1)          # row sums
    T = Tj.sum()
    denom = (k * T - (Ri**2).sum())
    if denom == 0:
        return np.nan, np.nan
    Q = (k - 1) * (k * (Tj**2).sum() - T**2) / denom
    # chi-square df = k-1
    try:
        from scipy.stats import chi2
        p = chi2.sf(Q, df=k-1)
    except Exception:
        p = np.nan
    return float(Q), float(p)

def binom_cdf(k, n, p=0.5):
    s = 0.0
    for i in range(k + 1):
        s += math.comb(n, i) * (p**i) * ((1-p)**(n-i))
    return s

def mcnemar_exact(a, b):
    """
    a,b: 0/1 arrays of same length
    return (b_count, c_count, p_exact)
    b = a=1,b=0 ; c = a=0,b=1
    """
    a = np.asarray(a, dtype=int)
    b2 = np.asarray(b, dtype=int)
    b01 = ((a == 1) & (b2 == 0)).sum()
    c10 = ((a == 0) & (b2 == 1)).sum()
    n = b01 + c10
    if n == 0:
        return int(b01), int(c10), 1.0
    k = min(b01, c10)
    p = 2.0 * binom_cdf(k, n, 0.5)
    p = min(1.0, p)
    return int(b01), int(c10), float(p)

def holm(pvals):
    pvals = np.asarray(pvals, dtype=float)
    m = len(pvals)
    order = np.argsort(pvals)
    adj = np.empty(m, dtype=float)
    for i, idx in enumerate(order):
        adj[idx] = min(1.0, (m - i) * pvals[idx])
    for i in range(1, m):
        adj[order[i]] = max(adj[order[i]], adj[order[i-1]])
    return adj

targets = ["gsm8k", "aqua", "p_time", "np_time", "sgu_other", "tm_generic_halt"]

all_overall = []
all_pairwise = []

for gname in targets:
    sub = d[d["group_for_table"] == gname].copy()
    if sub.empty:
        continue

    mat = sub.pivot_table(index="pair_id", columns="model", values="correct", aggfunc="first")

    mat = mat.dropna(axis=0, how="any")
    if mat.shape[0] == 0 or mat.shape[1] < 2:
        continue

    models = list(mat.columns)
    X = mat.values.astype(int)

    Q, pQ = cochran_q(X)
    all_overall.append({
        "group_for_table": gname,
        "n_items_paired": mat.shape[0],
        "n_models": mat.shape[1],
        "cochran_Q": Q,
        "p_cochranQ": pQ
    })

    for i in range(len(models)):
        for j in range(i+1, len(models)):
            m1, m2 = models[i], models[j]
            b01, c10, p = mcnemar_exact(mat[m1].values, mat[m2].values)
            all_pairwise.append({
                "group_for_table": gname,
                "model_a": m1,
                "model_b": m2,
                "n_items_paired": mat.shape[0],
                "b(a=1,b=0)": b01,
                "c(a=0,b=1)": c10,
                "p_mcnemar": p
            })

overall_df = pd.DataFrame(all_overall).sort_values("p_cochranQ", na_position="last")
pair_df = pd.DataFrame(all_pairwise)

pair_df["p_holm_within_group"] = np.nan
for gname in pair_df["group_for_table"].unique():
    mask = pair_df["group_for_table"] == gname
    pair_df.loc[mask, "p_holm_within_group"] = holm(pair_df.loc[mask, "p_mcnemar"].values)

pair_df = pair_df.sort_values(["group_for_table", "p_holm_within_group", "p_mcnemar"])

print("\n=== OVERALL: Cochran's Q (is there ANY model difference?) ===")
with pd.option_context("display.max_rows", 200, "display.max_columns", 20, "display.width", 140):
    print(overall_df.to_string(index=False))

print("\n=== PAIRWISE: McNemar exact (paired model differences), Holm-corrected within group ===")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 20, "display.width", 200):
    print(pair_df.to_string(index=False))
import numpy as np
import pandas as pd
import math

df = raw_df_mixed_357.copy()

d = df[df["solver_correct"].notna()].copy()
d["correct"] = d["solver_correct"].astype(int)

d["item_idx"] = d.groupby(["model", "task"]).cumcount()
d["pair_id"] = d["task"].astype(str) + "|" + d["item_idx"].astype(str)

# --- 3) Cochran’s Q ---
def cochran_q(mat):
    n, k = mat.shape
    Tj = mat.sum(axis=0)          # col sums
    Ri = mat.sum(axis=1)          # row sums
    T = Tj.sum()
    denom = (k * T - (Ri**2).sum())
    if denom == 0:
        return np.nan, np.nan
    Q = (k - 1) * (k * (Tj**2).sum() - T**2) / denom
    try:
        from scipy.stats import chi2
        p = chi2.sf(Q, df=k-1)
    except Exception:
        p = np.nan
    return float(Q), float(p)

# --- 4) McNemar exact + Holm ---
def binom_cdf(k, n, p=0.5):
    s = 0.0
    for i in range(k + 1):
        s += math.comb(n, i) * (p**i) * ((1-p)**(n-i))
    return s

def mcnemar_exact(a, b):
    a = np.asarray(a, dtype=int)
    b2 = np.asarray(b, dtype=int)
    b01 = ((a == 1) & (b2 == 0)).sum()  # a correct, b wrong
    c10 = ((a == 0) & (b2 == 1)).sum()  # a wrong, b correct
    n = b01 + c10
    if n == 0:
        return int(b01), int(c10), 1.0
    k = min(b01, c10)
    p = 2.0 * binom_cdf(k, n, 0.5)
    return int(b01), int(c10), float(min(1.0, p))

def holm(pvals):
    pvals = np.asarray(pvals, dtype=float)
    m = len(pvals)
    order = np.argsort(pvals)
    adj = np.empty(m, dtype=float)
    for i, idx in enumerate(order):
        adj[idx] = min(1.0, (m - i) * pvals[idx])
    for i in range(1, m):
        adj[order[i]] = max(adj[order[i]], adj[order[i-1]])
    return adj

mat = d.pivot_table(index="pair_id", columns="model", values="correct", aggfunc="first")
mat = mat.dropna(axis=0, how="any")
X = mat.values.astype(int)

print("\n=== OVERALL (ALL TASKS): PAIRED ITEMS ===")
print("n_items_paired =", X.shape[0], " | n_models =", X.shape[1])

Q, pQ = cochran_q(X)
print("\n=== OVERALL: Cochran's Q ===")
print("Q =", Q, "p =", pQ)

models = list(mat.columns)
pairs = []
pvals = []
for i in range(len(models)):
    for j in range(i+1, len(models)):
        m1, m2 = models[i], models[j]
        b01, c10, p = mcnemar_exact(mat[m1].values, mat[m2].values)
        pairs.append({
            "model_a": m1,
            "model_b": m2,
            "n_items_paired": X.shape[0],
            "b(a=1,b=0)": b01,
            "c(a=0,b=1)": c10,
            "p_mcnemar": p,
        })
        pvals.append(p)

pair_df = pd.DataFrame(pairs)
pair_df["p_holm"] = holm(np.array(pvals))
pair_df = pair_df.sort_values(["p_holm", "p_mcnemar"])

print("\n=== OVERALL: Pairwise McNemar exact (Holm corrected) ===")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 20, "display.width", 200):
    print(pair_df.to_string(index=False))


In [None]:

import numpy as np
import pandas as pd

df = raw_df_mixed_357.copy()

d = df[df["contam"].isin(["3","5","7"])].copy()

def map_group_for_table(row):
    if row["task"] == "tm_generic_halt":
        return "tm_generic_halt"
    if row["group"] == "sgu":
        return "sgu_other"
    return row["group"]

d["group_for_table"] = d.apply(map_group_for_table, axis=1)

CATS = ["3","5","7"]

def chi2_test_2x3(counts_a, counts_b):
    """
    counts_a/b: length-3 array for categories 3/5/7
    return chi2, p, dof
    """
    obs = np.vstack([counts_a, counts_b]).astype(float)
    rsum = obs.sum(axis=1, keepdims=True)
    csum = obs.sum(axis=0, keepdims=True)
    total = obs.sum()
    if total == 0:
        return np.nan, np.nan, 2
    exp = (rsum @ csum) / total
    mask = exp > 0
    chi2 = ((obs[mask] - exp[mask])**2 / exp[mask]).sum()
    dof = (obs.shape[0]-1)*(obs.shape[1]-1)  # (2-1)*(3-1)=2
    try:
        from scipy.stats import chi2 as chi2dist
        p = chi2dist.sf(chi2, df=dof)
    except Exception:
        p = np.nan
    return float(chi2), float(p), int(dof)

def holm(pvals):
    pvals = np.asarray(pvals, dtype=float)
    m = len(pvals)
    order = np.argsort(pvals)
    adj = np.empty(m, dtype=float)
    for i, idx in enumerate(order):
        adj[idx] = min(1.0, (m - i) * pvals[idx])
    for i in range(1, m):
        adj[order[i]] = max(adj[order[i]], adj[order[i-1]])
    return adj

def counts_for(subdf):
    vc = subdf["contam"].value_counts()
    return np.array([int(vc.get(c, 0)) for c in CATS], dtype=int)

# ============================
# ============================
rows = []
for model, dm in d.groupby("model"):
    a = dm[dm["group_for_table"] == "gsm8k"]
    b = dm[dm["group_for_table"] != "gsm8k"]
    ca = counts_for(a)
    cb = counts_for(b)
    chi2, p, dof = chi2_test_2x3(ca, cb)
    rows.append({
        "model": model,
        "contrast": "gsm8k vs all_others",
        "n_gsm8k": int(ca.sum()),
        "n_others": int(cb.sum()),
        "gsm8k_3/5/7": f"{ca[0]}/{ca[1]}/{ca[2]}",
        "others_3/5/7": f"{cb[0]}/{cb[1]}/{cb[2]}",
        "chi2": chi2,
        "dof": dof,
        "p": p,
    })

overall_cmp = pd.DataFrame(rows).sort_values("p")
overall_cmp["p_holm_within_modelset"] = holm(overall_cmp["p"].values)

print("\n=== GSM8K vs ALL OTHERS (per model) ===")
with pd.option_context("display.max_rows", 500, "display.max_columns", 50, "display.width", 220):
    print(overall_cmp.to_string(index=False))

# ============================
# ============================
groups_to_compare = ["aqua", "p_time", "np_time", "sgu_other", "tm_generic_halt"]
rows = []
for model, dm in d.groupby("model"):
    gsm = dm[dm["group_for_table"] == "gsm8k"]
    c_gsm = counts_for(gsm)
    for g in groups_to_compare:
        other = dm[dm["group_for_table"] == g]
        c_o = counts_for(other)
        chi2, p, dof = chi2_test_2x3(c_gsm, c_o)
        rows.append({
            "model": model,
            "contrast": f"gsm8k vs {g}",
            "n_gsm8k": int(c_gsm.sum()),
            f"n_{g}": int(c_o.sum()),
            "gsm8k_3/5/7": f"{c_gsm[0]}/{c_gsm[1]}/{c_gsm[2]}",
            f"{g}_3/5/7": f"{c_o[0]}/{c_o[1]}/{c_o[2]}",
            "chi2": chi2,
            "dof": dof,
            "p": p,
        })

pair_cmp = pd.DataFrame(rows)

pair_cmp["p_holm_within_model"] = np.nan
for model in pair_cmp["model"].unique():
    mask = pair_cmp["model"] == model
    pair_cmp.loc[mask, "p_holm_within_model"] = holm(pair_cmp.loc[mask, "p"].values)

pair_cmp = pair_cmp.sort_values(["model", "p_holm_within_model", "p"])

print("\n=== GSM8K vs EACH OTHER GROUP (per model, Holm corrected within model) ===")
with pd.option_context("display.max_rows", 2000, "display.max_columns", 60, "display.width", 260):
    print(pair_cmp.to_string(index=False))


In [None]:
# =========================
# FIX: build `datasets` from harmonized JSONL (standalone)
# =========================
import os, json, re
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

BASE_DIR = os.environ.get("BAP_BASE_DIR", "/content/drive/MyDrive/complexity_data6")
HARMONIZED_PATH = os.environ.get("BAP_HARMONIZED_JSONL", os.path.join(BASE_DIR, "all_tasks_harmonized.jsonl"))

SGU_SUFFIX = "__sgu"
SGU_COMPLEXITY_FAMILY_PATTERNS = [
    "strongly_generically_undecidable",
    "strongly generically undecidable",
    "sgu",
    "undecidable",
]

@dataclass
class TaskInstance:
    instance_id: str
    task_type: str          # here we use "binary"
    prompt: str
    label01: Optional[int] = None
    ground_truth: Optional[str] = None
    meta: Optional[Dict[str, Any]] = None

@dataclass
class TaskDataset:
    name: str
    instances: List[TaskInstance]

def load_harmonized_jsonl_df(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing harmonized JSONL: {path}")
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    df = pd.DataFrame(rows)

    def _to01(x):
        if isinstance(x, (bool, np.bool_)): return int(x)
        if isinstance(x, (int, np.integer)): return int(x != 0)
        s = str(x).strip().lower()
        if s in ["1","true","yes","y","t"]: return 1
        if s in ["0","false","no","n","f"]: return 0
        m = re.search(r"[01]", s)
        return int(m.group(0)) if m else 0

    for req in ["input", "label", "task"]:
        if req not in df.columns:
            raise ValueError(f"harmonized JSONL missing required column: {req}")

    df["label"] = df["label"].apply(_to01).astype(int)

    if "complexity_family" not in df.columns:
        df["complexity_family"] = np.nan

    return df

def build_harmonized_datasets(df: pd.DataFrame, include_sgu_slice: bool = True) -> Dict[str, TaskDataset]:
    out: Dict[str, TaskDataset] = {}

    def _make(inst_df: pd.DataFrame, name: str):
        insts = []
        for i, r in inst_df.reset_index(drop=True).iterrows():
            meta = {k: r[k] for k in inst_df.columns if k not in ["input", "label", "task"]}
            insts.append(TaskInstance(
                instance_id=f"{name}-{i}",
                task_type="binary",
                prompt=str(r["input"]),
                label01=int(r["label"]),
                meta=meta
            ))
        out[name] = TaskDataset(name=name, instances=insts)

    tasks = sorted(df["task"].dropna().astype(str).unique().tolist())
    for t in tasks:
        df_t = df[df["task"] == t].copy()
        if len(df_t) == 0:
            continue
        _make(df_t, t)

        if include_sgu_slice:
            cf = df_t["complexity_family"].fillna("").astype(str).str.lower()
            mask = np.zeros(len(df_t), dtype=bool)
            for pat in SGU_COMPLEXITY_FAMILY_PATTERNS:
                mask |= cf.str.contains(pat.lower(), na=False).to_numpy()
            if mask.any():
                _make(df_t.loc[mask].copy(), t + SGU_SUFFIX)

    return out

# ---- build datasets ----
df_h = load_harmonized_jsonl_df(HARMONIZED_PATH)
datasets = build_harmonized_datasets(df_h, include_sgu_slice=True)

print("Built datasets. num_tasks =", len(datasets))
print("Example keys:", list(datasets.keys())[:20])

In [None]:
# ============================================
# FULL SCRIPT (LEAK EXPERIMENT + RAW PER-ITEM SAVES)
# - contamination self-report as {3,5,7}
# - Groups: gsm8k, aqua, p_time (2 subtasks), np_time (4 subtasks), sgu (10 subtasks)
# - Leak levels: 0%, 50%, 100% (Train=300; Test=300)
# - For each (group, leak, model): BEFORE eval -> LoRA SFT -> AFTER eval on SAME test
# - Saves:
#   1) splits: OUT_ROOT/splits/<group>/test_300.jsonl and train_leak{0,50,100}_300.jsonl
#   2) adapters: OUT_ROOT/adapters/<group>/<leak>/<model>/
#   3) per-item raw outputs: OUT_ROOT/raw_eval/<group>/<leak>/<model>/{before,after}.jsonl
#   4) summary CSV: OUT_ROOT/results_before_after.csv
#
# PREREQS expected in your notebook:
#   - datasets: Dict[str, TaskDataset], each TaskDataset.instances contains TaskInstance(prompt,label01)
#   - Your environment has transformers + peft installed
#   - GPU strongly recommended
# ============================================

import os, json, random, gc, re
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    GenerationConfig, StoppingCriteria, StoppingCriteriaList
)
from peft import LoraConfig, get_peft_model, PeftModel

# ----------------------------
# 0) Your MODEL_SPECS style
# ----------------------------
MODEL_SPECS = [
    {"name": "qwen3-4b",         "model_id": "Qwen/Qwen3-4B",                              "adapter": None},
    {"name": "deepseek-llama-8b","model_id": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",   "adapter": None},
    {"name": "deepseek-math-7b", "model_id": "deepseek-ai/deepseek-math-7b-instruct",      "adapter": None},
    {"name": "acemath-7b",       "model_id": "nvidia/AceMath-7B-Instruct",                 "adapter": None},
    {"name": "Mathstral-7B",     "model_id": "mistralai/Mathstral-7B-v0.1",                "adapter": None},
]

# ----------------------------
# 1) Config
# ----------------------------
SEED = 123
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

OUT_ROOT = os.environ.get("EVAL_OUT_DIR", "/content/drive/MyDrive/complexity7/eval_outputs_leak")
os.makedirs(OUT_ROOT, exist_ok=True)

USE_4BIT = bool(int(os.environ.get("BAP_4BIT", "0")))  # reuse your env var if set

N_TEST = 300
N_TRAIN = 300
LEAK_LEVELS = {"leak0": 0.0, "leak50": 0.5, "leak100": 1.0}

# LoRA hyperparams
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj","k_proj","v_proj","o_proj"]  # adjust if a model complains

# Train hyperparams (start small; scale up after sanity check)
LR = 2e-4
EPOCHS = 4
BATCH = 2
GRAD_ACCUM = 8
MAX_LEN = 512

# ----------------------------
# 2) Groups (include SGU + NP/P multi-subtasks)
# ----------------------------
SGU_TASKS = [
    "sgu_collatz_aligned", "sgu_collatz_aligned__sgu",
    "sgu_index_empty_language", "sgu_index_empty_language__sgu",
    "sgu_index_total_halt", "sgu_index_total_halt__sgu",
    "sgu_semigroup_wp_amp", "sgu_semigroup_wp_amp__sgu",
    "tm_generic_halt", "tm_hard_halt",
]

GROUP_DEFS = {
    "gsm8k": ["gsm8k_binary"],
    "aqua": ["aqua_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
    "sgu": SGU_TASKS,
}

# ----------------------------
# 3) Model meta (family/size)
# ----------------------------
def model_family(model_id: str) -> str:
    s = model_id.lower()
    if "qwen" in s: return "qwen"
    if "deepseek" in s: return "deepseek"
    if "mistral" in s or "mathstral" in s: return "mistral"
    if "nvidia" in s or "acemath" in s: return "nvidia"
    return "other"

def model_size_tag(s: str) -> str:
    s = s.lower()
    m = re.search(r"(\d+)\s*b", s)
    if m: return f"{m.group(1)}B"
    m = re.search(r"-(\d+)b", s)
    if m: return f"{m.group(1)}B"
    return ""

# ----------------------------
# 4) Decoding + prompts/parsers (YES/NO + 3/5/7)
# ----------------------------
class StopOnTokens(StoppingCriteria):
    def __init__(self, tokenizer, stop_strings):
        self.stop_token_ids = []
        for s in stop_strings:
            ids = tokenizer.encode(s, add_special_tokens=False)
            if ids:
                self.stop_token_ids.append(ids[-1])
    def __call__(self, input_ids, scores, **kwargs):
        if input_ids.shape[1] == 0: return False
        return input_ids[0, -1].item() in self.stop_token_ids

def generate_strict(model, tok, prompt: str, max_tokens: int = 10):
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    stopper = StopOnTokens(tok, stop_strings=["\n", "\r", "Question:", "Options:"])
    gen_config = GenerationConfig(
        do_sample=False, temperature=0.0, max_new_tokens=max_tokens,
        pad_token_id=tok.pad_token_id, eos_token_id=tok.eos_token_id,
    )
    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_config,
                             stopping_criteria=StoppingCriteriaList([stopper]))
    full = tok.decode(out[0], skip_special_tokens=True)
    gen = full[len(prompt):] if full.startswith(prompt) else full
    return gen.split("\n")[0].strip()

def build_solver_prompt(question: str) -> str:
    instruction = (
        "You are a calculator. Output strictly ONLY the final numerical answer or Yes/No.\n"
        "Do NOT explain. Do NOT show work.\n"
    )
    return f"{instruction}\nQuestion: {question.strip()}\nAnswer:"

def build_contamination_prompt_357(question: str) -> str:
    instruction = (
        "You are an AI assistant. Analyze the following problem.\n"
        "Tell me if you have seen this problem during your training data.\n"
        "Reply with strictly ONLY a single digit from {3, 5, 7}:\n"
        "3 = SEEN the same question before\n"
        "5 = Seen very similar questions before but not the same question\n"
        "7 = Not seen similar question before\n"
        "Ignore any YES/NO question in the content even if it asks.\n"
    )
    return f"{instruction}\nProblem Content: {question.strip()}\n\nReply (3, 5, or 7):"

def parse_yesno(s: str):
    if not s: return None
    t = s.strip().upper()
    if "YES" in t: return "YES"
    if "NO" in t: return "NO"
    return None

def parse_357(s: str):
    if not s: return None
    m = re.search(r"\b([357])\b", s.strip())
    return m.group(1) if m else None

# ----------------------------
# 5) JSONL helpers
# ----------------------------
def write_jsonl(path, rows):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

def load_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

# ----------------------------
# 6) Build & save splits (per group: test_300 + train_300 with leak 0/50/100)
# ----------------------------
def group_pool(group_name: str):
    pool = []
    for tk in GROUP_DEFS[group_name]:
        if tk not in datasets:
            continue
        for i, inst in enumerate(datasets[tk].instances):
            if inst.label01 is None:
                continue
            pool.append({
                "group": group_name,
                "task": tk,
                "local_idx": i,
                "prompt": inst.prompt,
                "label01": int(inst.label01),
            })
    return pool

def build_and_save_splits():
    """
    Auto-adjust per group so that we can ALWAYS form:
      test_n + train_n <= pool_size, and train can be (1-leak)*new + leak*test
    Strategy:
      - Try (TARGET_TEST=300, TARGET_TRAIN=300)
      - If pool too small, set test_n = min(300, pool_size//2), train_n = test_n
        (so aqua with 300 -> test=150, train=150)
    """
    manifest = {}

    TARGET_TEST = N_TEST   # your global (e.g. 300)
    TARGET_TRAIN = N_TRAIN # your global (e.g. 300)

    for g in GROUP_DEFS:
        pool = group_pool(g)
        pool_size = len(pool)
        if pool_size < 2:
            raise RuntimeError(f"[{g}] pool too small: {pool_size}")

        # choose per-group sizes
        if pool_size >= (TARGET_TEST + TARGET_TRAIN):
            test_n = TARGET_TEST
            train_n = TARGET_TRAIN
        else:
            # fallback: split roughly half/half
            test_n = min(TARGET_TEST, pool_size // 2)
            train_n = min(TARGET_TRAIN, pool_size - test_n)
            # to keep leak definitions symmetric, make train_n == test_n when possible
            train_n = min(train_n, test_n)

        if test_n == 0 or train_n == 0:
            raise RuntimeError(f"[{g}] cannot form non-empty test/train with pool_size={pool_size}")

        rng = random.Random(SEED + (hash(g) % 10000))
        rng.shuffle(pool)

        test = pool[:test_n]
        rest = pool[test_n:]

        if len(rest) < train_n:
            # if still insufficient (edge), shrink train_n
            train_n = len(rest)
            if train_n == 0:
                raise RuntimeError(f"[{g}] no remaining items for train after test selection.")

        fresh = rest[:train_n]

        # save test
        test_path = os.path.join(OUT_ROOT, "splits", g, f"test_{test_n}.jsonl")
        write_jsonl(test_path, test)

        # save trains for leak levels
        for tag, leak in LEAK_LEVELS.items():
            n_leak = int(round(train_n * leak))
            n_new = train_n - n_leak

            # leaked part always comes from test
            leaked = test[:min(n_leak, len(test))]

            # new part comes from fresh (rest)
            newpart = fresh[:min(n_new, len(fresh))]

            # if fresh not enough for required newpart (can happen when pool is tiny),
            # we shrink newpart accordingly (still valid but deviates from exact ratio)
            train = newpart + leaked
            train_path = os.path.join(OUT_ROOT, "splits", g, f"train_{tag}_{train_n}.jsonl")
            write_jsonl(train_path, train)

            manifest[f"{g}::{tag}"] = {
                "group": g,
                "leak_tag": tag,
                "leak": leak,
                "pool_size": pool_size,
                "test_path": test_path,
                "train_path": train_path,
                "n_test": len(test),
                "n_train": len(train),
                "target_test": test_n,
                "target_train": train_n,
                "n_new": len(newpart),
                "n_leak": len(leaked),
            }

        print(f"[SPLIT] {g}: pool={pool_size}, test={len(test)}, train_target={train_n}")

    man_path = os.path.join(OUT_ROOT, "splits", "split_manifest.json")
    os.makedirs(os.path.dirname(man_path), exist_ok=True)
    with open(man_path, "w", encoding="utf-8") as f:
        json.dump(manifest, f, ensure_ascii=False, indent=2)
    print("Saved split manifest:", man_path)
    return manifest


# ----------------------------
# 7) SFT dataset (prompt -> YES/NO)
# ----------------------------
def label_to_yesno(label01: int) -> str:
    return "YES" if int(label01) == 1 else "NO"

def make_sft_items(rows, tok):
    items = []
    for r in rows:
        prompt = build_solver_prompt(r["prompt"])
        ans = label_to_yesno(r["label01"])
        full = prompt + " " + ans

        enc_full = tok(full, truncation=True, max_length=MAX_LEN)
        enc_prompt = tok(prompt, truncation=True, max_length=MAX_LEN)

        labels = np.array(enc_full["input_ids"], dtype=np.int64)
        labels[:len(enc_prompt["input_ids"])] = -100

        items.append({
            "input_ids": enc_full["input_ids"],
            "attention_mask": enc_full["attention_mask"],
            "labels": labels.tolist(),
        })
    return items

# class SimpleDataset(torch.utils.data.Dataset):
#     def __init__(self, items): self.items = items
#     def __len__(self): return len(self.items)
#     def __getitem__(self, i):
#         x = self.items[i]
#         return {k: torch.tensor(v) for k, v in x.items()}

# ----------------------------
# 8) Load base + finetune LoRA
# ----------------------------
def load_base(model_id, use_4bit=False):
    tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    kwargs = {}
    if use_4bit:
        try:
            from transformers import BitsAndBytesConfig
            kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
            kwargs["device_map"] = "auto"
        except Exception as e:
            print("[WARN] 4bit requested but BitsAndBytes not available:", e)

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto" if torch.cuda.is_available() else None,
        torch_dtype=torch.float16 if torch.cuda.is_available() else None,
        **kwargs
    )
    model.eval()
    return tok, model

# def finetune_lora(model_id, train_rows, out_dir, use_4bit=False):
#     tok, base = load_base(model_id, use_4bit=use_4bit)

#     lora_cfg = LoraConfig(
#         r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
#         bias="none", task_type="CAUSAL_LM", target_modules=TARGET_MODULES,
#     )
#     model = get_peft_model(base, lora_cfg)
#     model.train()

#     train_items = make_sft_items(train_rows, tok)
#     train_ds = SimpleDataset(train_items)

#     args = TrainingArguments(
#         output_dir=out_dir,
#         per_device_train_batch_size=BATCH,
#         gradient_accumulation_steps=GRAD_ACCUM,
#         num_train_epochs=EPOCHS,
#         learning_rate=LR,
#         logging_steps=20,
#         save_strategy="no",
#         report_to="none",
#         fp16=torch.cuda.is_available(),
#     )
#     Trainer(model=model, args=args, train_dataset=train_ds).train()

#     os.makedirs(out_dir, exist_ok=True)
#     model.save_pretrained(out_dir)
#     tok.save_pretrained(out_dir)
#     return out_dir
# import torch
# from torch.nn.utils.rnn import pad_sequence

class SimpleDataset(torch.utils.data.Dataset):
    """
    Return python lists; let collator pad dynamically.
    """
    def __init__(self, items):
        self.items = items
    def __len__(self):
        return len(self.items)
    def __getitem__(self, i):
        return self.items[i]

import torch
from torch.nn.utils.rnn import pad_sequence

class DataCollatorForCausalLMWithLabelPadding:
    """
    Pads input_ids/attention_mask to max length in batch.
    Pads labels to same length using -100.
    """
    def __init__(self, pad_token_id: int, label_pad_token_id: int = -100):
        self.pad_token_id = pad_token_id
        self.label_pad_token_id = label_pad_token_id

    def __call__(self, features):
        input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
        attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
        labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=self.label_pad_token_id)

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

def finetune_lora(model_id, train_rows, out_dir, use_4bit=False):
    tok, base = load_base(model_id, use_4bit=use_4bit)

    lora_cfg = LoraConfig(
        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
        bias="none", task_type="CAUSAL_LM", target_modules=TARGET_MODULES,
    )
    model = get_peft_model(base, lora_cfg)
    model.train()

    train_items = make_sft_items(train_rows, tok)  # your function that returns list of dicts
    train_ds = SimpleDataset(train_items)

    collator = DataCollatorForCausalLMWithLabelPadding(
        pad_token_id=tok.pad_token_id,
        label_pad_token_id=-100
    )

    args = TrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        num_train_epochs=EPOCHS,
        learning_rate=LR,
        logging_steps=20,
        save_strategy="no",
        report_to="none",
        fp16=torch.cuda.is_available(),
        remove_unused_columns=False,  # important when using custom collator/dataset dicts
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_ds,
        data_collator=collator,
    )
    trainer.train()

    import os
    os.makedirs(out_dir, exist_ok=True)
    model.save_pretrained(out_dir)
    tok.save_pretrained(out_dir)
    return out_dir
# ----------------------------
# 9) Eval on test (DETAILED per-item + summary)
# ----------------------------
def eval_model_on_test_detailed(model, tok, test_rows):
    total = 0
    fmt = 0
    correct_sum = 0
    cc = {"3":0,"5":0,"7":0,"other":0,"empty":0}
    detailed = []

    for i, r in enumerate(test_rows):
        total += 1
        q = r["prompt"]
        y_true = int(r["label01"])

        solver_prompt = build_solver_prompt(q)
        solver_raw = generate_strict(model, tok, solver_prompt, max_tokens=15)
        solver_pred = parse_yesno(solver_raw)

        solver_ok = None
        if solver_pred is not None:
            fmt += 1
            y_pred = 1 if solver_pred == "YES" else 0
            solver_ok = int(y_pred == y_true)
            correct_sum += solver_ok

        contam_prompt = build_contamination_prompt_357(q)
        contam_raw = generate_strict(model, tok, contam_prompt, max_tokens=5)
        contam_pred = parse_357(contam_raw)

        if contam_raw is None or contam_raw == "":
            cc["empty"] += 1
        elif contam_pred in ["3","5","7"]:
            cc[contam_pred] += 1
        else:
            cc["other"] += 1

        detailed.append({
            "idx": i,
            "task": r.get("task"),
            "local_idx": r.get("local_idx"),
            "label01": y_true,
            "solver_pred": solver_pred,
            "solver_correct": solver_ok,
            "solver_raw": solver_raw,
            "contam_pred": contam_pred,
            "contam_raw": contam_raw,
        })

    acc_on_fmt = correct_sum / fmt if fmt else float("nan")
    fmt_rate = fmt / total if total else 0.0
    acc_overall = correct_sum / total if total else float("nan")

    summary = {
        "n_total": total,
        "n_fmt": fmt,
        "fmt_rate": fmt_rate,
        "acc_on_fmt": acc_on_fmt,
        "acc_overall": acc_overall,
        "c3": cc["3"], "c5": cc["5"], "c7": cc["7"],
        "c_other": cc["other"], "c_empty": cc["empty"],
    }
    return summary, detailed

# ----------------------------
# 10) Run experiment (before/after) + save summary + raw per-item
# ----------------------------
# SKIP_GROUPS = {"aqua"}  # <- add

# def run_leak_experiment_with_raw():
#     build_and_save_splits()
#     results = []

#     for g in GROUP_DEFS.keys():
#         if g in SKIP_GROUPS:
#             print(f"[SKIP] group={g}")
#             continue

#         test_path = os.path.join(OUT_ROOT, "splits", g, f"test_{N_TEST}.jsonl")
#         test_rows = load_jsonl(test_path)

#         for leak_tag in LEAK_LEVELS.keys():
#             train_path = os.path.join(OUT_ROOT, "splits", g, f"train_{leak_tag}_{N_TRAIN}.jsonl")
#             train_rows = load_jsonl(train_path)

#             for spec in MODEL_SPECS:
#                 name = spec["name"]
#                 model_id = spec["model_id"]
#                 fam = model_family(model_id)
#                 size = model_size_tag(name) or model_size_tag(model_id)

#                 print("\n" + "="*90)
#                 print(f"[RUN] group={g} leak={leak_tag} model={name} ({model_id})")

#                 # -------- BEFORE --------
#                 tok0, base0 = load_base(model_id, use_4bit=USE_4BIT)
#                 before_summary, before_rows = eval_model_on_test_detailed(base0, tok0, test_rows)

#                 before_raw_path = os.path.join(OUT_ROOT, "raw_eval", g, leak_tag, name, "before.jsonl")
#                 for rr in before_rows:
#                     rr.update({"group": g, "leak": leak_tag, "model": name, "phase": "before"})
#                 write_jsonl(before_raw_path, before_rows)

#                 del base0
#                 gc.collect()
#                 if torch.cuda.is_available():
#                     torch.cuda.empty_cache()

#                 # -------- TRAIN LoRA --------
#                 adapter_dir = os.path.join(OUT_ROOT, "adapters", g, leak_tag, name)
#                 finetune_lora(model_id, train_rows, out_dir=adapter_dir, use_4bit=USE_4BIT)

#                 # -------- AFTER --------
#                 tok1, base1 = load_base(model_id, use_4bit=USE_4BIT)
#                 model1 = PeftModel.from_pretrained(base1, adapter_dir)
#                 model1.eval()
#                 after_summary, after_rows = eval_model_on_test_detailed(model1, tok1, test_rows)

#                 after_raw_path = os.path.join(OUT_ROOT, "raw_eval", g, leak_tag, name, "after.jsonl")
#                 for rr in after_rows:
#                     rr.update({"group": g, "leak": leak_tag, "model": name, "phase": "after"})
#                 write_jsonl(after_raw_path, after_rows)

#                 del model1, base1
#                 gc.collect()
#                 if torch.cuda.is_available():
#                     torch.cuda.empty_cache()

#                 # -------- SUMMARY ROW --------
#                 row = {
#                     "group": g,
#                     "leak": leak_tag,
#                     "model": name,
#                     "model_id": model_id,
#                     "model_family": fam,
#                     "model_size": size,
#                     "train_path": train_path,
#                     "test_path": test_path,
#                     "adapter_dir": adapter_dir,
#                     "before_raw_path": before_raw_path,
#                     "after_raw_path": after_raw_path,
#                     "n_train": len(train_rows),
#                     "n_test": len(test_rows),
#                 }
#                 for k,v in before_summary.items(): row[f"before_{k}"] = v
#                 for k,v in after_summary.items():  row[f"after_{k}"]  = v

#                 row["delta_acc_on_fmt"] = row["after_acc_on_fmt"] - row["before_acc_on_fmt"]
#                 row["delta_acc_overall"] = row["after_acc_overall"] - row["before_acc_overall"]
#                 row["delta_fmt_rate"] = row["after_fmt_rate"] - row["before_fmt_rate"]
#                 results.append(row)

#     res_df = pd.DataFrame(results)
#     out_csv = os.path.join(OUT_ROOT, "results_before_after.csv")
#     res_df.to_csv(out_csv, index=False)

#     print("\nSaved summary results:", out_csv)
#     print("Per-item raw outputs saved under:", os.path.join(OUT_ROOT, "raw_eval"))
#     return res_df

# # =========================
# # RUN
# # =========================
# results_df = run_leak_experiment_with_raw()
# results_df.head()


# ----------------------------
# 10) Run experiment (before/after) + save summary + raw per-item
#     (manifest-aware: auto adapts to actual test/train sizes per group)
# ----------------------------
def run_leak_experiment_with_raw_manifest_aware(skip_groups=None):
    """
    Reads split_manifest.json produced by build_and_save_splits() and uses the
    actual generated test/train paths and sizes per (group, leak_tag).

    This avoids hard-coding test_{N_TEST}.jsonl / train_{leak_tag}_{N_TRAIN}.jsonl
    when a group's pool is too small and the splitter auto-shrinks sizes.
    """
    skip_groups = set(skip_groups or [])

    # build splits + manifest
    manifest = build_and_save_splits()

    results = []

    # iterate by group then leak_tag for reproducible order
    for g in GROUP_DEFS.keys():
        if g in skip_groups:
            print(f"[SKIP] group={g}")
            continue

        for leak_tag in LEAK_LEVELS.keys():
            key = f"{g}::{leak_tag}"
            if key not in manifest:
                print(f"[WARN] manifest missing key={key}, skipping")
                continue

            m = manifest[key]
            test_path = m["test_path"]
            train_path = m["train_path"]

            # load the actual files that were created
            if not os.path.exists(test_path):
                print(f"[WARN] missing test file: {test_path}, skipping {key}")
                continue
            if not os.path.exists(train_path):
                print(f"[WARN] missing train file: {train_path}, skipping {key}")
                continue

            test_rows = load_jsonl(test_path)
            train_rows = load_jsonl(train_path)

            # sanity print
            print("\n" + "-" * 90)
            print(f"[SPLIT-USE] group={g} leak={leak_tag} "
                  f"test={len(test_rows)} train={len(train_rows)} "
                  f"(target_test={m.get('target_test')} target_train={m.get('target_train')} "
                  f"new={m.get('n_new')} leak_items={m.get('n_leak')})")

            for spec in MODEL_SPECS:
                name = spec["name"]
                model_id = spec["model_id"]
                fam = model_family(model_id)
                size = model_size_tag(name) or model_size_tag(model_id)

                print("\n" + "=" * 90)
                print(f"[RUN] group={g} leak={leak_tag} model={name} ({model_id})")

                # -------- BEFORE --------
                tok0, base0 = load_base(model_id, use_4bit=USE_4BIT)
                before_summary, before_rows = eval_model_on_test_detailed(base0, tok0, test_rows)

                before_raw_path = os.path.join(OUT_ROOT, "raw_eval", g, leak_tag, name, "before.jsonl")
                for rr in before_rows:
                    rr.update({"group": g, "leak": leak_tag, "model": name, "phase": "before"})
                write_jsonl(before_raw_path, before_rows)

                del base0
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # -------- TRAIN LoRA --------
                adapter_dir = os.path.join(OUT_ROOT, "adapters", g, leak_tag, name)
                finetune_lora(model_id, train_rows, out_dir=adapter_dir, use_4bit=USE_4BIT)

                # -------- AFTER --------
                tok1, base1 = load_base(model_id, use_4bit=USE_4BIT)
                model1 = PeftModel.from_pretrained(base1, adapter_dir)
                model1.eval()
                after_summary, after_rows = eval_model_on_test_detailed(model1, tok1, test_rows)

                after_raw_path = os.path.join(OUT_ROOT, "raw_eval", g, leak_tag, name, "after.jsonl")
                for rr in after_rows:
                    rr.update({"group": g, "leak": leak_tag, "model": name, "phase": "after"})
                write_jsonl(after_raw_path, after_rows)

                del model1, base1
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # -------- SUMMARY ROW --------
                row = {
                    "group": g,
                    "leak": leak_tag,
                    "model": name,
                    "model_id": model_id,
                    "model_family": fam,
                    "model_size": size,

                    # from manifest
                    "pool_size": m.get("pool_size"),
                    "target_test": m.get("target_test"),
                    "target_train": m.get("target_train"),
                    "n_new": m.get("n_new"),
                    "n_leak": m.get("n_leak"),

                    "train_path": train_path,
                    "test_path": test_path,
                    "adapter_dir": adapter_dir,
                    "before_raw_path": before_raw_path,
                    "after_raw_path": after_raw_path,
                    "n_train": len(train_rows),
                    "n_test": len(test_rows),
                }
                for k, v in before_summary.items():
                    row[f"before_{k}"] = v
                for k, v in after_summary.items():
                    row[f"after_{k}"] = v

                row["delta_acc_on_fmt"] = row["after_acc_on_fmt"] - row["before_acc_on_fmt"]
                row["delta_acc_overall"] = row["after_acc_overall"] - row["before_acc_overall"]
                row["delta_fmt_rate"] = row["after_fmt_rate"] - row["before_fmt_rate"]
                results.append(row)

    res_df = pd.DataFrame(results)
    out_csv = os.path.join(OUT_ROOT, "results_before_after.csv")
    res_df.to_csv(out_csv, index=False)

    print("\nSaved summary results:", out_csv)
    print("Per-item raw outputs saved under:", os.path.join(OUT_ROOT, "raw_eval"))
    return res_df


# =========================
# RUN (manifest-aware)
# =========================
# 1) run everything (including aqua if split exists)
results_df = run_leak_experiment_with_raw_manifest_aware()
results_df.head()

In [None]:
import os, json, glob, hashlib
import pandas as pd
import numpy as np
from scipy.stats import chi2_contingency

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
man_path = os.path.join(OUT_ROOT, "splits", "split_manifest.json")

def read_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

def stable_hash(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()

manifest = json.load(open(man_path, "r", encoding="utf-8"))

def item_key(r):
    t = r.get("task", None)
    li = r.get("local_idx", None)
    if t is not None and li is not None:
        return ("task_local", t, int(li))
    # fallback: prompt hash
    p = r.get("prompt") or r.get("question") or ""
    return ("prompt_hash", stable_hash(p.strip()))

leaked_keys = {}  # (group, leak_tag) -> set(keys)
for key, m in manifest.items():
    g = m["group"]
    leak_tag = m["leak_tag"]
    test = read_jsonl(m["test_path"])
    train = read_jsonl(m["train_path"])
    test_set = set(item_key(r) for r in test)
    train_set = set(item_key(r) for r in train)
    leaked = test_set & train_set
    leaked_keys[(g, leak_tag)] = leaked

raw_paths = glob.glob(os.path.join(OUT_ROOT, "raw_eval", "*", "*", "*", "*.jsonl"))
all_rows = []
for p in raw_paths:
    all_rows.extend(read_jsonl(p))

df = pd.DataFrame(all_rows)

df["solver_correct"] = pd.to_numeric(df.get("solver_correct"), errors="coerce")
df_fmt = df[df["solver_correct"].notna()].copy()

df_fmt["phase"] = df_fmt["phase"].astype(str)

def compute_item_key_row(row):
    t = row.get("task", None)
    li = row.get("local_idx", None)
    if pd.notna(t) and pd.notna(li):
        return ("task_local", str(t), int(li))
    p = row.get("prompt", None)
    if isinstance(p, str) and p.strip():
        return ("prompt_hash", stable_hash(p.strip()))
    # last resort:
    return ("fallback", str(row.get("task", "")), int(row.get("idx", -1)))

df_fmt["item_key"] = df_fmt.apply(compute_item_key_row, axis=1)

run = (
    df_fmt
    .groupby(["group","leak","model","phase"], as_index=False)
    .agg(n=("solver_correct","size"), acc=("solver_correct","mean"))
)

pivot = run.pivot_table(index=["group","leak","model"], columns="phase", values="acc").reset_index()
pivot["delta"] = pivot["after"] - pivot["before"]

base = pivot[pivot["leak"]=="leak0"][["group","model","delta"]].rename(columns={"delta":"delta_leak0"})
did = pivot.merge(base, on=["group","model"], how="left")
did["leak_effect"] = did["delta"] - did["delta_leak0"]

key_cols = ["group","leak","model","item_key"]

before = df_fmt[df_fmt["phase"]=="before"][key_cols + ["solver_correct","contam_pred"]].rename(columns={"solver_correct":"y_before"})
after  = df_fmt[df_fmt["phase"]=="after"][ key_cols + ["solver_correct","contam_pred"]].rename(columns={"solver_correct":"y_after"})

merged = before.merge(after, on=key_cols, how="inner", suffixes=("_b","_a"))
merged["delta_item"] = merged["y_after"] - merged["y_before"]

def leaked_indicator(row):
    g, leak = row["group"], row["leak"]
    lk = leaked_keys.get((g, leak), set())
    return int(row["item_key"] in lk)

merged["is_leaked_item"] = merged.apply(leaked_indicator, axis=1)

item_te = (
    merged
    .groupby(["group","leak","model","is_leaked_item"], as_index=False)
    .agg(n=("delta_item","size"), delta_mean=("delta_item","mean"))
)

leaked_delta = item_te[item_te["is_leaked_item"]==1][["group","leak","model","delta_mean"]].rename(columns={"delta_mean":"delta_leaked"})
non_delta    = item_te[item_te["is_leaked_item"]==0][["group","leak","model","delta_mean"]].rename(columns={"delta_mean":"delta_nonleaked"})

item_contrast = leaked_delta.merge(non_delta, on=["group","leak","model"], how="inner")
item_contrast["delta_leaked_minus_nonleaked"] = item_contrast["delta_leaked"] - item_contrast["delta_nonleaked"]

def chi2_test(sub):
    tab = pd.crosstab(sub["contam_pred"], sub["solver_correct"])
    if tab.shape[0] < 2 or tab.shape[1] < 2:
        return pd.Series({"chi2":np.nan,"pval":np.nan,"dof":np.nan})
    chi2, p, dof, _ = chi2_contingency(tab.values)
    return pd.Series({"chi2":chi2,"pval":p,"dof":dof})

det = (
    df_fmt[df_fmt["contam_pred"].isin(["3","5","7"])]
    .groupby(["group","leak","model","phase"])
    .apply(chi2_test)
    .reset_index()
)


leak_order = ["leak0", "leak50", "leak100"]
for _df in [did, item_contrast, det]:
    if "leak" in _df.columns:
        _df["leak"] = pd.Categorical(_df["leak"], categories=leak_order, ordered=True)

did_sorted = did.sort_values(["group","model","leak"])
item_sorted = item_contrast.sort_values(["group","model","leak"])
det_sorted = det.sort_values(["group","model","leak","phase"])

print("\n=== Diff-in-Diff (ALL leaks) ===")
print(did_sorted[["group","model","leak","before","after","delta","delta_leak0","leak_effect"]].to_string(index=False))

print("\n=== Diff-in-Diff (leak100 only) ===")
print(did_sorted[did_sorted["leak"]=="leak100"][["group","model","leak","before","after","delta","delta_leak0","leak_effect"]].to_string(index=False))

print("\n=== Item-level contrast (ALL leaks) ===")
print(item_sorted[["group","model","leak","delta_leaked","delta_nonleaked","delta_leaked_minus_nonleaked"]].to_string(index=False))

print("\n=== Item-level contrast (leak100 only) ===")
print(item_sorted[item_sorted["leak"]=="leak100"][["group","model","leak","delta_leaked","delta_nonleaked","delta_leaked_minus_nonleaked"]].to_string(index=False))

print("\n=== Detection (chi-square p-values) ===")
print(det_sorted[["group","model","leak","phase","chi2","pval"]].to_string(index=False))

print("\n=== Detection (leak100 only) ===")
print(det_sorted[det_sorted["leak"]=="leak100"][["group","model","leak","phase","chi2","pval"]].to_string(index=False))

did.to_csv(os.path.join(OUT_ROOT, "analysis_did.csv"), index=False)
item_contrast.to_csv(os.path.join(OUT_ROOT, "analysis_item_contrast.csv"), index=False)
det.to_csv(os.path.join(OUT_ROOT, "analysis_detection_chi2.csv"), index=False)

print("\nSaved: analysis_did.csv, analysis_item_contrast.csv, analysis_detection_chi2.csv")

if not os.path.exists(MANIFEST_PATH):
    raise FileNotFoundError(f"missing {MANIFEST_PATH}")

manifest = json.load(open(MANIFEST_PATH, "r", encoding="utf-8"))

index_rows = []

for key, m in manifest.items():
    g = m["group"]
    leak_tag = m["leak_tag"]
    test_path = m["test_path"]
    train_path = m["train_path"]

    if not (os.path.exists(test_path) and os.path.exists(train_path)):
        print(f"[SKIP] missing files for {key}")
        continue

    test = load_jsonl(test_path)
    train = load_jsonl(train_path)

    test_keys = [stable_key(r) for r in test]
    train_keys = set(stable_key(r) for r in train)
    leaked_set = set(k for k in test_keys if k in train_keys)

    test_out = []
    for r in test:
        k = stable_key(r)
        test_out.append({
            "group": g,
            "leak_tag": leak_tag,
            "item_key_type": k[0],
            "item_key": "::".join(map(str, k[1:])),
            "is_leaked_into_train": int(k in leaked_set),

            "task": r.get("task"),
            "local_idx": r.get("local_idx"),
            "label01": r.get("label01"),
            "prompt": r.get("prompt"),
        })

    out_dir = os.path.join(OUT_AUDIT, g, leak_tag)
    os.makedirs(out_dir, exist_ok=True)

    test_audit_path = os.path.join(out_dir, "test_items_with_leak_flag.jsonl")
    write_jsonl(test_audit_path, test_out)

    leaked_only = [r for r in test_out if r["is_leaked_into_train"] == 1]
    leaked_path = os.path.join(out_dir, "leaked_test_items_only.jsonl")
    write_jsonl(leaked_path, leaked_only)

    nonleaked_only = [r for r in test_out if r["is_leaked_into_train"] == 0]
    nonleaked_path = os.path.join(out_dir, "nonleaked_test_items_only.jsonl")
    write_jsonl(nonleaked_path, nonleaked_only)

    index_rows.append({
        "group": g,
        "leak_tag": leak_tag,
        "test_path": test_path,
        "train_path": train_path,
        "audit_test_path": test_audit_path,
        "audit_leaked_only_path": leaked_path,
        "audit_nonleaked_only_path": nonleaked_path,
        "n_test": len(test),
        "n_train": len(train),
        "n_leaked_test_items": len(leaked_only),
        "n_nonleaked_test_items": len(nonleaked_only),
        "pool_size": m.get("pool_size"),
        "target_test": m.get("target_test"),
        "target_train": m.get("target_train"),
        "n_new": m.get("n_new"),
        "n_leak": m.get("n_leak"),
    })

idx_df = pd.DataFrame(index_rows).sort_values(["group","leak_tag"])
idx_csv = os.path.join(OUT_AUDIT, "audit_index.csv")
idx_df.to_csv(idx_csv, index=False)

print("[DONE] Saved audit files under:", OUT_AUDIT)
print("[DONE] Index CSV:", idx_csv)
# ========== Leak100 item-level alternative: use leak50 nonleaked as generalization baseline ==========
# goal: for each (group, model):
#   gen_baseline = delta_nonleaked at leak50 (if exists)
#   mem_lift_100_est = delta(leak100) - gen_baseline

gen50 = (
    item_te[(item_te["leak"]=="leak50") & (item_te["is_leaked_item"]==0)]
    .rename(columns={"delta_mean":"delta_nonleaked_leak50"})
    [["group","model","delta_nonleaked_leak50"]]
)

delta100 = did[did["leak"]=="leak100"][["group","model","delta","leak_effect"]].rename(columns={"delta":"delta_leak100"})

alt = delta100.merge(gen50, on=["group","model"], how="left")
alt["mem_lift100_minus_gen50"] = alt["delta_leak100"] - alt["delta_nonleaked_leak50"]

print("\n=== Leak100 alternative item-level attribution ===")
print(alt.sort_values(["group","model"])[
    ["group","model","delta_leak100","leak_effect","delta_nonleaked_leak50","mem_lift100_minus_gen50"]
].to_string(index=False))

alt.to_csv(os.path.join(OUT_ROOT, "analysis_leak100_alt_itemlevel.csv"), index=False)
print("\nSaved: analysis_leak100_alt_itemlevel.csv")



import os, json, glob
import pandas as pd
import numpy as np

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
RES_CSV  = os.path.join(OUT_ROOT, "results_before_after.csv")
MAN_JSON = os.path.join(OUT_ROOT, "splits", "split_manifest.json")
OUT_CSV  = os.path.join(OUT_ROOT, "master_results_with_trainloss.csv")

# -------------------------
# 0) Load baseline results
# -------------------------
if not os.path.exists(RES_CSV):
    raise FileNotFoundError(f"Missing {RES_CSV}. Run your experiment first.")

res = pd.read_csv(RES_CSV)

# Normalize column names if needed
for col in ["group","leak","model"]:
    if col not in res.columns:
        raise KeyError(f"results_before_after.csv missing required column: {col}")

# -------------------------
# 1) Load split manifest -> table
# -------------------------
if not os.path.exists(MAN_JSON):
    raise FileNotFoundError(f"Missing {MAN_JSON}. Expected splits manifest.")

manifest = json.load(open(MAN_JSON, "r", encoding="utf-8"))
man_rows = []
for k, m in manifest.items():
    man_rows.append({
        "group": m["group"],
        "leak": m["leak_tag"],   # align with res["leak"]
        "pool_size": m.get("pool_size"),
        "target_test": m.get("target_test"),
        "target_train": m.get("target_train"),
        "n_test_manifest": m.get("n_test"),
        "n_train_manifest": m.get("n_train"),
        "n_new_manifest": m.get("n_new"),
        "n_leak_manifest": m.get("n_leak"),
        "test_path_manifest": m.get("test_path"),
        "train_path_manifest": m.get("train_path"),
    })
man_df = pd.DataFrame(man_rows)

# Merge manifest info
df = res.merge(man_df, on=["group","leak"], how="left")

# -------------------------
# 2) Add DiD leak_effect using leak0 baseline
# -------------------------
# We compute per (group, model) baseline delta at leak0
if "delta_acc_on_fmt" in df.columns:
    df["delta_acc"] = df["delta_acc_on_fmt"]
elif "delta_acc_overall" in df.columns:
    df["delta_acc"] = df["delta_acc_overall"]
else:
    # fallback: compute from before/after if present
    if "before_acc_on_fmt" in df.columns and "after_acc_on_fmt" in df.columns:
        df["delta_acc"] = df["after_acc_on_fmt"] - df["before_acc_on_fmt"]
    elif "before_acc_overall" in df.columns and "after_acc_overall" in df.columns:
        df["delta_acc"] = df["after_acc_overall"] - df["before_acc_overall"]
    else:
        df["delta_acc"] = np.nan

base = (
    df[df["leak"]=="leak0"][["group","model","delta_acc"]]
    .rename(columns={"delta_acc":"delta_acc_leak0"})
)
df = df.merge(base, on=["group","model"], how="left")
df["did_leak_effect_acc"] = df["delta_acc"] - df["delta_acc_leak0"]

# -------------------------
# 3) (Optional) Attach training loss summaries if train_loss.csv exists
# Expected path: OUT_ROOT/train_logs/<group>/<leak>/<model>/train_loss.csv
# -------------------------
loss_paths = glob.glob(os.path.join(OUT_ROOT, "train_logs", "*", "*", "*", "train_loss.csv"))

loss_rows = []
for p in loss_paths:
    # parse path components
    # .../train_logs/<group>/<leak>/<model>/train_loss.csv
    parts = p.split(os.sep)
    try:
        i = parts.index("train_logs")
        g = parts[i+1]
        leak = parts[i+2]
        model = parts[i+3]
    except Exception:
        continue

    try:
        tdf = pd.read_csv(p)
    except Exception:
        continue

    if "loss" not in tdf.columns or len(tdf) == 0:
        continue

    tdf = tdf.sort_values("global_step") if "global_step" in tdf.columns else tdf

    loss_min = float(tdf["loss"].min())
    loss_last = float(tdf["loss"].iloc[-1])
    loss_first = float(tdf["loss"].iloc[0])
    loss_drop = loss_first - loss_last

    # optional: epoch_last
    epoch_last = float(tdf["epoch"].iloc[-1]) if "epoch" in tdf.columns and pd.notna(tdf["epoch"].iloc[-1]) else np.nan
    steps = int(tdf["global_step"].max()) if "global_step" in tdf.columns else np.nan

    loss_rows.append({
        "group": g,
        "leak": leak,
        "model": model,
        "train_loss_path": p,
        "train_loss_first": loss_first,
        "train_loss_last": loss_last,
        "train_loss_min": loss_min,
        "train_loss_drop": loss_drop,
        "train_epoch_last": epoch_last,
        "train_global_step_max": steps,
        "train_log_points": int(len(tdf)),
    })

loss_df = pd.DataFrame(loss_rows)

if len(loss_df) > 0:
    df = df.merge(loss_df, on=["group","leak","model"], how="left")
else:
    # create empty columns for consistent schema
    for c in ["train_loss_path","train_loss_first","train_loss_last","train_loss_min","train_loss_drop",
              "train_epoch_last","train_global_step_max","train_log_points"]:
        df[c] = np.nan

# -------------------------
# 4) Final ordering + save
# -------------------------
# Make leak ordering stable
leak_order = ["leak0","leak50","leak100"]
df["leak"] = pd.Categorical(df["leak"], categories=leak_order, ordered=True)

# Put key columns first
front = [
    "group","leak","model","model_id",
    "n_test","n_train",
    "n_test_manifest","n_train_manifest","n_new_manifest","n_leak_manifest",
    "before_acc_on_fmt","after_acc_on_fmt","delta_acc_on_fmt",
    "before_acc_overall","after_acc_overall","delta_acc_overall",
    "delta_acc","delta_acc_leak0","did_leak_effect_acc",
    "train_loss_first","train_loss_last","train_loss_min","train_loss_drop","train_log_points",
    "train_path","test_path","train_path_manifest","test_path_manifest",
    "adapter_dir","before_raw_path","after_raw_path","train_loss_path"
]
front = [c for c in front if c in df.columns]
rest = [c for c in df.columns if c not in front]
df = df[front + rest].sort_values(["group","model","leak"])

df.to_csv(OUT_CSV, index=False)
print("Saved master table:", OUT_CSV)
print("Rows:", len(df))
df.head(10)


In [None]:
import os, json, glob, re, hashlib
import numpy as np
import pandas as pd

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
MAN_PATH = os.path.join(OUT_ROOT, "splits", "split_manifest.json")
RAW_GLOB = os.path.join(OUT_ROOT, "raw_eval", "*", "*", "*", "*.jsonl")
RES_CSV  = os.path.join(OUT_ROOT, "results_before_after.csv")

OUT_ITEM_CSV   = os.path.join(OUT_ROOT, "baseline_item_features.csv")
OUT_METRIC_CSV = os.path.join(OUT_ROOT, "baseline_detector_metrics.csv")
OUT_RUN_CSV    = os.path.join(OUT_ROOT, "baseline_runlevel_compare.csv")

# --------------------------
# IO helpers
# --------------------------
def read_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

def normalize_text(s: str) -> str:
    s = (s or "").strip().lower()
    s = re.sub(r"\s+", " ", s)
    return s

def prompt_md5(s: str) -> str:
    return hashlib.md5(normalize_text(s).encode("utf-8")).hexdigest()

def item_key_from_splitrow(r):
    # Prefer task+local_idx; fallback to prompt hash
    t = r.get("task", None)
    li = r.get("local_idx", None)
    if t is not None and li is not None:
        return ("task_local", str(t), int(li))
    return ("prompt_md5", prompt_md5(r.get("prompt","")))

def item_key_from_rawrow(row, test_idx_map=None):
    # Prefer task+local_idx; fallback to (task, idx) mapping; fallback to idx only
    t = row.get("task", None)
    li = row.get("local_idx", None)
    if pd.notna(t) and pd.notna(li):
        return ("task_local", str(t), int(li))
    idx = int(row.get("idx", -1))
    if test_idx_map is not None and idx in test_idx_map:
        # map idx -> task_local if possible, else prompt_md5
        return test_idx_map[idx]
    return ("fallback_idx", idx)

# --------------------------
# Basic AUC (no sklearn)
# --------------------------
def roc_auc_score(y_true, y_score):
    """
    Mann-Whitney U / rank-based AUROC.
    Returns NaN if only one class present.
    """
    y_true = np.asarray(y_true).astype(int)
    y_score = np.asarray(y_score).astype(float)
    mask = np.isfinite(y_score)
    y_true = y_true[mask]
    y_score = y_score[mask]
    if len(y_true) == 0:
        return np.nan
    n_pos = np.sum(y_true == 1)
    n_neg = np.sum(y_true == 0)
    if n_pos == 0 or n_neg == 0:
        return np.nan
    # rank scores (average ranks for ties)
    order = np.argsort(y_score)
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(1, len(y_score) + 1)

    # tie handling: average ranks for equal scores
    # group equal scores
    sorted_scores = y_score[order]
    i = 0
    while i < len(sorted_scores):
        j = i
        while j + 1 < len(sorted_scores) and sorted_scores[j + 1] == sorted_scores[i]:
            j += 1
        if j > i:
            avg_rank = (i + 1 + j + 1) / 2.0
            ranks[order[i:j+1]] = avg_rank
        i = j + 1

    sum_ranks_pos = np.sum(ranks[y_true == 1])
    auc = (sum_ranks_pos - n_pos * (n_pos + 1) / 2.0) / (n_pos * n_neg)
    return float(auc)

# --------------------------
# N-gram overlap
# --------------------------
_word_re = re.compile(r"[a-z0-9]+")

def tokenize_words(s: str):
    return _word_re.findall((s or "").lower())

def ngram_set(tokens, n):
    if len(tokens) < n:
        return set()
    return set(tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1))

def jaccard(a: set, b: set):
    if not a and not b:
        return 1.0
    if not a or not b:
        return 0.0
    inter = len(a & b)
    union = len(a) + len(b) - inter
    return inter / union if union else 0.0

def max_jaccard_against_train(test_ng, train_ng_list):
    # return max Jaccard over all train items
    best = 0.0
    for ng in train_ng_list:
        score = jaccard(test_ng, ng)
        if score > best:
            best = score
            if best >= 0.999999:
                break
    return best

# --------------------------
# 1) Load manifest, build (group, leak) -> test/train rows and leaked set
# --------------------------
if not os.path.exists(MAN_PATH):
    raise FileNotFoundError(f"Missing {MAN_PATH}")

manifest = json.load(open(MAN_PATH, "r", encoding="utf-8"))

# For each group, store test rows + idx->key map (to align raw_eval rows)
group_test_rows = {}
group_test_idx_map = {}  # group -> {idx: item_key}
# For each (group, leak): leaked set + train prompts
leaked_sets = {}
train_prompts = {}       # (group, leak) -> list[str]
test_prompts  = {}       # (group, leak) -> list[str] (same as group test, but keep aligned)
test_keys     = {}       # (group, leak) -> list[item_key] aligned to test_rows order

for k, m in manifest.items():
    g = m["group"]
    leak = m["leak_tag"]
    test = read_jsonl(m["test_path"])
    train = read_jsonl(m["train_path"])

    # cache group test (same across leaks)
    if g not in group_test_rows:
        group_test_rows[g] = test
        idx_map = {}
        for i, r in enumerate(test):
            idx_map[i] = item_key_from_splitrow(r)
        group_test_idx_map[g] = idx_map

    # compute leaked = test ∩ train
    test_set = set(item_key_from_splitrow(r) for r in test)
    train_set = set(item_key_from_splitrow(r) for r in train)
    leaked = test_set & train_set
    leaked_sets[(g, leak)] = leaked

    # store prompts aligned to test
    test_prompts[(g, leak)] = [r.get("prompt","") for r in test]
    test_keys[(g, leak)] = [item_key_from_splitrow(r) for r in test]

    # store train prompts
    train_prompts[(g, leak)] = [r.get("prompt","") for r in train]

print("[INFO] Loaded manifest. Groups:", sorted(group_test_rows.keys()))

# --------------------------
# 2) Load raw_eval and enrich each row with: prompt, is_leaked_item, label01 (from test split)
# --------------------------
raw_paths = glob.glob(RAW_GLOB)
if not raw_paths:
    raise FileNotFoundError(f"No raw_eval found under {RAW_GLOB}")

raw_rows = []
for p in raw_paths:
    raw_rows.extend(read_jsonl(p))

df = pd.DataFrame(raw_rows)

required_cols = ["group","leak","model","phase"]
for c in required_cols:
    if c not in df.columns:
        raise KeyError(f"raw_eval missing required column: {c}")

# numeric correctness fields
df["solver_correct"] = pd.to_numeric(df.get("solver_correct"), errors="coerce")
df["logp_correct"]   = pd.to_numeric(df.get("logp_correct"), errors="coerce") if "logp_correct" in df.columns else np.nan

# attach prompt/label from test split
def attach_from_test(row):
    g = row["group"]
    leak = row["leak"]
    idx = int(row.get("idx", -1))
    # locate test row by idx (most stable in your pipeline)
    tr = None
    if g in group_test_rows and 0 <= idx < len(group_test_rows[g]):
        tr = group_test_rows[g][idx]
    prompt = tr.get("prompt","") if tr else ""
    label01 = tr.get("label01", None) if tr else None

    # is_leaked depends on leak tag
    key = item_key_from_rawrow(row, test_idx_map=group_test_idx_map.get(g))
    leaked = leaked_sets.get((g, leak), set())
    is_leaked = int(key in leaked) if leaked else 0

    return pd.Series({
        "prompt": prompt,
        "label01_from_split": label01,
        "item_key_type": key[0],
        "item_key": "::".join(map(str, key[1:])),
        "is_leaked_item": is_leaked,
        "prompt_md5": prompt_md5(prompt),
    })

enriched = df.apply(attach_from_test, axis=1)
df = pd.concat([df, enriched], axis=1)

# self-report score
# 3=SEEN, 5=SIMILAR, 7=NOT SEEN -> map to higher = "more seen"
map_seen = {"3": 2.0, "5": 1.0, "7": 0.0}
df["seen_score"] = df.get("contam_pred").map(map_seen).astype(float)

# keep only rows where we have solver_correct (since that's your main evaluation)
df_eval = df[df["solver_correct"].notna()].copy()

print("[INFO] raw_eval rows:", len(df), "eval rows with solver_correct:", len(df_eval))

# --------------------------
# 3) Build n-gram overlap scores per (group, leak) on the TEST split
#    Score is computed per test item, then merged to df_eval via (group, leak, idx)
# --------------------------
NGRAM_N = 5  # you can try 3/4/5
overlap_rows = []
for (g, leak), tr_prompts in train_prompts.items():
    te_prompts = test_prompts[(g, leak)]
    te_keys = test_keys[(g, leak)]
    # precompute train ngram sets
    train_ng = []
    train_md5 = set()
    for p in tr_prompts:
        norm = normalize_text(p)
        train_md5.add(prompt_md5(norm))
        toks = tokenize_words(norm)
        train_ng.append(ngram_set(toks, NGRAM_N))

    # compute per test item
    for idx, p in enumerate(te_prompts):
        norm = normalize_text(p)
        md5 = prompt_md5(norm)
        toks = tokenize_words(norm)
        te_ng = ngram_set(toks, NGRAM_N)
        max_j = max_jaccard_against_train(te_ng, train_ng) if train_ng else 0.0
        exact = int(md5 in train_md5)
        # true leaked label via key intersection (your controlled overlap)
        key = te_keys[idx]
        is_leaked = int(key in leaked_sets[(g, leak)])
        overlap_rows.append({
            "group": g,
            "leak": leak,
            "idx": idx,
            "item_key_type": key[0],
            "item_key": "::".join(map(str, key[1:])),
            "is_leaked_item": is_leaked,
            "ngram_n": NGRAM_N,
            "ngram_max_jaccard": float(max_j),
            "exact_prompt_match": exact,
        })

ov = pd.DataFrame(overlap_rows)
print("[INFO] Computed ngram overlap rows:", len(ov))

# merge overlap features into df_eval
df_eval = df_eval.merge(
    ov[["group","leak","idx","ngram_max_jaccard","exact_prompt_match"]],
    on=["group","leak","idx"],
    how="left"
)

# --------------------------
# 4) Evaluate detectors on leak50 only (since leak0 has no positives, leak100 has no negatives)
#    We compute AUROC for predicting is_leaked_item.
#    Detectors:
#      - ngram_max_jaccard
#      - exact_prompt_match
#      - seen_score (3/5/7)
#      - (optional) logp_margin if present
# --------------------------
df50 = df_eval[df_eval["leak"]=="leak50"].copy()

# If you saved logp_margin in raw_eval, include it
has_logp_margin = "logp_margin" in df50.columns
if has_logp_margin:
    df50["logp_margin"] = pd.to_numeric(df50["logp_margin"], errors="coerce")

metric_rows = []
for (g, model, phase), sub in df50.groupby(["group","model","phase"]):
    y = sub["is_leaked_item"].astype(int).values

    # ngram detector
    auc_ng = roc_auc_score(y, sub["ngram_max_jaccard"].values)

    # exact match detector (binary; AUROC defined but may tie-heavy)
    auc_ex = roc_auc_score(y, sub["exact_prompt_match"].values)

    # self report
    auc_seen = roc_auc_score(y, sub["seen_score"].values) if sub["seen_score"].notna().any() else np.nan

    row = {
        "group": g,
        "model": model,
        "phase": phase,
        "n_items": int(len(sub)),
        "pos_leaked": int(np.sum(y==1)),
        "neg_nonleaked": int(np.sum(y==0)),
        "auc_ngram_max_jaccard": auc_ng,
        "auc_exact_prompt_match": auc_ex,
        "auc_seen_score_357": auc_seen,
    }

    if has_logp_margin:
        row["auc_logp_margin"] = roc_auc_score(y, sub["logp_margin"].values)
    metric_rows.append(row)

metrics = pd.DataFrame(metric_rows).sort_values(["group","model","phase"])
metrics.to_csv(OUT_METRIC_CSV, index=False)

# --------------------------
# 5) Run-level compare: DiD leak_effect (from results_before_after.csv) vs detectors
#    - load results_before_after.csv
#    - compute leak_effect on accuracy (delta - delta(leak0))
#    - attach detector AUROC (leak50 only) for before/after, for context
# --------------------------
if not os.path.exists(RES_CSV):
    print("[WARN] results_before_after.csv not found; run-level DiD compare will be skipped.")
    run_cmp = pd.DataFrame()
else:
    res = pd.read_csv(RES_CSV)
    # choose an accuracy column
    if "delta_acc_on_fmt" in res.columns:
        res["delta_acc"] = res["delta_acc_on_fmt"]
    elif "delta_acc_overall" in res.columns:
        res["delta_acc"] = res["delta_acc_overall"]
    else:
        # fallback
        if "before_acc_on_fmt" in res.columns and "after_acc_on_fmt" in res.columns:
            res["delta_acc"] = res["after_acc_on_fmt"] - res["before_acc_on_fmt"]
        elif "before_acc_overall" in res.columns and "after_acc_overall" in res.columns:
            res["delta_acc"] = res["after_acc_overall"] - res["before_acc_overall"]
        else:
            res["delta_acc"] = np.nan

    base = res[res["leak"]=="leak0"][["group","model","delta_acc"]].rename(columns={"delta_acc":"delta_acc_leak0"})
    res = res.merge(base, on=["group","model"], how="left")
    res["did_leak_effect_acc"] = res["delta_acc"] - res["delta_acc_leak0"]

    # attach detector metrics (use leak50 only; for each phase)
    # merge by (group, model, phase)
    # (res is run-level without phase; we attach both before/after AUC columns)
    m_before = metrics[metrics["phase"]=="before"].copy()
    m_after  = metrics[metrics["phase"]=="after"].copy()

    def rename_auc(dfm, suffix):
        cols = [c for c in dfm.columns if c.startswith("auc_")]
        ren = {c: f"{c}_{suffix}" for c in cols}
        return dfm.rename(columns=ren)[["group","model"] + list(ren.values())]

    res = res.merge(rename_auc(m_before, "leak50_before"), on=["group","model"], how="left")
    res = res.merge(rename_auc(m_after,  "leak50_after"),  on=["group","model"], how="left")

    run_cmp = res.sort_values(["group","model","leak"])
    run_cmp.to_csv(OUT_RUN_CSV, index=False)

# --------------------------
# 6) Save per-item features table (for paper plots / ablations)
# --------------------------
# Keep one row per item per run (group, leak, model, phase, idx)
keep_cols = [
    "group","leak","model","phase","idx",
    "task","local_idx",
    "item_key_type","item_key",
    "is_leaked_item",
    "solver_correct",
    "seen_score",
    "ngram_max_jaccard","exact_prompt_match",
    "prompt_md5","prompt"
]
# include optional logp fields if present
for c in ["logp_margin","logp_pred01","logp_correct","logp_yes","logp_no"]:
    if c in df_eval.columns:
        keep_cols.append(c)

keep_cols = [c for c in keep_cols if c in df_eval.columns]
item_out = df_eval[keep_cols].copy()
item_out.to_csv(OUT_ITEM_CSV, index=False)

print("\n[DONE] Saved:")
print("  Per-item features:", OUT_ITEM_CSV)
print("  Detector metrics (leak50 AUROC):", OUT_METRIC_CSV)
if len(run_cmp) > 0:
    print("  Run-level compare (DiD + detector AUC):", OUT_RUN_CSV)
else:
    print("  Run-level compare skipped (missing results_before_after.csv).")

# --------------------------
# 7) Quick textual summary in console
# --------------------------
print("\n=== QUICK SUMMARY: leak50 detector AUROC (higher is better for detecting leaked items) ===")
summary = metrics.groupby(["group","phase"], as_index=False).agg(
    auc_ngram=("auc_ngram_max_jaccard","mean"),
    auc_exact=("auc_exact_prompt_match","mean"),
    auc_seen=("auc_seen_score_357","mean"),
    auc_logp=("auc_logp_margin","mean") if "auc_logp_margin" in metrics.columns else ("auc_seen_score_357","mean"),
)
print(summary.to_string(index=False))

print("\nNOTE:")
print("- AUROC is computed only on leak50 because leak0 has no positives and leak100 has no negatives.")
print("- If logp_margin columns are absent in raw_eval, logp-based detector is skipped automatically.")

import os
import numpy as np
import pandas as pd

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
RUN_CSV = os.path.join(OUT_ROOT, "baseline_runlevel_compare.csv")

df = pd.read_csv(RUN_CSV)

# pick one row per (group, model): use leak100 row as representative for DiD (effect size)
df["leak"] = pd.Categorical(df["leak"], categories=["leak100","leak50","leak0"], ordered=True)
one = (
    df.sort_values(["group","model","leak"])
      .groupby(["group","model"], as_index=False)
      .first()
)

# choose AUROC column (ngram after) if exists
auc_col = "auc_ngram_max_jaccard_leak50_after"
if auc_col not in one.columns:
    raise KeyError(f"{auc_col} not found in baseline_runlevel_compare.csv")

x = pd.to_numeric(one["did_leak_effect_acc"], errors="coerce")
y = pd.to_numeric(one[auc_col], errors="coerce")
mask = x.notna() & y.notna()
x = x[mask].values
y = y[mask].values

def safe_corr(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if len(x) < 3:
        return (np.nan, np.nan, "too_few_points")
    if np.std(x) == 0 or np.std(y) == 0:
        return (np.nan, np.nan, "zero_variance")
    pearson = float(np.corrcoef(x, y)[0,1])

    # Spearman via ranks (average ranks)
    rx = pd.Series(x).rank(method="average").values
    ry = pd.Series(y).rank(method="average").values
    if np.std(rx) == 0 or np.std(ry) == 0:
        spearman = np.nan
    else:
        spearman = float(np.corrcoef(rx, ry)[0,1])
    return (pearson, spearman, "ok")

pearson_all, spearman_all, status_all = safe_corr(x, y)

print("\n=== Overall correlation (DiD contribution vs AUROC detector) ===")
print(f"col={auc_col}")
print(f"n={len(x)}, status={status_all}, pearson={pearson_all}, spearman={spearman_all}")

# Per-group correlations (skip groups where AUROC constant or too few)
print("\n=== Per-group correlations ===")
rows = []
for g in sorted(one["group"].unique()):
    sub = one[one["group"]==g]
    xx = pd.to_numeric(sub["did_leak_effect_acc"], errors="coerce")
    yy = pd.to_numeric(sub[auc_col], errors="coerce")
    m = xx.notna() & yy.notna()
    xx = xx[m].values
    yy = yy[m].values
    p, s, st = safe_corr(xx, yy)
    rows.append({"group": g, "n": len(xx), "status": st, "pearson": p, "spearman": s,
                 "auc_min": float(np.nanmin(yy)) if len(yy) else np.nan,
                 "auc_max": float(np.nanmax(yy)) if len(yy) else np.nan,
                 "did_min": float(np.nanmin(xx)) if len(xx) else np.nan,
                 "did_max": float(np.nanmax(xx)) if len(xx) else np.nan})

rep = pd.DataFrame(rows)
print(rep.to_string(index=False))

# "oracle detection but zero contribution" cases + fraction
thr_auc = 0.99
thr_did = 0.01
oracle_zero = one[(pd.to_numeric(one[auc_col], errors="coerce") >= thr_auc) &
                  (pd.to_numeric(one["did_leak_effect_acc"], errors="coerce").abs() <= thr_did)].copy()

print(f"\n=== Oracle-detection but ~zero contribution cases (AUROC>={thr_auc}, |DiD|<={thr_did}) ===")
print(f"count={len(oracle_zero)} out of total group-model pairs={len(one)} "
      f"({len(oracle_zero)/max(1,len(one))*100:.1f}%)")
print(oracle_zero.sort_values(["group","model"])[["group","model",auc_col,"did_leak_effect_acc"]].to_string(index=False))

# Save the report
out_path = os.path.join(OUT_ROOT, "corr_detection_vs_contribution_clean.csv")
rep.to_csv(out_path, index=False)
print("\nSaved per-group corr report:", out_path)
import os
import numpy as np
import pandas as pd

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
RUN_CSV = os.path.join(OUT_ROOT, "baseline_runlevel_compare.csv")

df = pd.read_csv(RUN_CSV)

# One row per (group, model): use leak100 row (effect size). AUROC columns are leak50-based but repeated.
df["leak"] = pd.Categorical(df["leak"], categories=["leak100","leak50","leak0"], ordered=True)
one = (
    df.sort_values(["group","model","leak"])
      .groupby(["group","model"], as_index=False)
      .first()
)

auc_col = "auc_ngram_max_jaccard_leak50_after"
if auc_col not in one.columns:
    raise KeyError(f"{auc_col} not found in baseline_runlevel_compare.csv")

# Clean numeric
one["did_leak_effect_acc"] = pd.to_numeric(one["did_leak_effect_acc"], errors="coerce")
one[auc_col] = pd.to_numeric(one[auc_col], errors="coerce")

# Oracle-but-zero definition
THR_AUC = 0.99
THR_DID = 0.01
one["oracle_detect"] = (one[auc_col] >= THR_AUC).astype(int)
one["zero_contrib"] = (one["did_leak_effect_acc"].abs() <= THR_DID).astype(int)
one["oracle_but_zero"] = ((one["oracle_detect"] == 1) & (one["zero_contrib"] == 1)).astype(int)

# Group-level summary: AUROC mean/min/max; DiD min/max; count oracle-but-zero; N models
grp = (
    one.groupby("group", as_index=False)
       .agg(
           n_models=("model","nunique"),
           auc_mean=(auc_col,"mean"),
           auc_min=(auc_col,"min"),
           auc_max=(auc_col,"max"),
           did_min=("did_leak_effect_acc","min"),
           did_max=("did_leak_effect_acc","max"),
           oracle_but_zero=("oracle_but_zero","sum"),
       )
)

# Optional: show overall row
overall = pd.DataFrame([{
    "group": "ALL",
    "n_models": int(one["model"].nunique() if one["group"].nunique()==1 else len(one)),
    "auc_mean": one[auc_col].mean(),
    "auc_min": one[auc_col].min(),
    "auc_max": one[auc_col].max(),
    "did_min": one["did_leak_effect_acc"].min(),
    "did_max": one["did_leak_effect_acc"].max(),
    "oracle_but_zero": int(one["oracle_but_zero"].sum()),
}])

# For ALL row, better report number of (group,model) pairs
overall.loc[0, "n_models"] = len(one)

table = pd.concat([grp, overall], ignore_index=True)

# Formatting helpers
def f3(x):
    if pd.isna(x): return "--"
    return f"{x:.3f}"

def did_range(a, b):
    if pd.isna(a) or pd.isna(b): return "--"
    return f"[{a:.3f}, {b:.3f}]"

def auc_triplet(mean_, mn, mx):
    if pd.isna(mean_) or pd.isna(mn) or pd.isna(mx): return "--"
    # show mean (min–max)
    return f"{mean_:.3f} ({mn:.3f}--{mx:.3f})"

# Build LaTeX rows
rows = []
for _, r in table.iterrows():
    group = r["group"]
    n = int(r["n_models"])
    auc_str = auc_triplet(r["auc_mean"], r["auc_min"], r["auc_max"])
    did_str = did_range(r["did_min"], r["did_max"])
    obz = int(r["oracle_but_zero"])
    rows.append((group, n, auc_str, did_str, obz))

latex_lines = []
latex_lines.append(r"\begin{table}[t]")
latex_lines.append(r"\centering")
latex_lines.append(r"\small")
latex_lines.append(r"\begin{tabular}{lrrrr}")
latex_lines.append(r"\toprule")
latex_lines.append(r"Stratum & \#Pairs & AUROC$_{\text{ngram}}$ (mean [min--max]) & DiD Range & Oracle$\wedge$Zero \\")
latex_lines.append(r"\midrule")

for group, n, auc_str, did_str, obz in rows:
    # escape underscores if any
    g = str(group).replace("_", r"\_")
    latex_lines.append(f"{g} & {n} & {auc_str} & {did_str} & {obz} \\\\")
latex_lines.append(r"\bottomrule")
latex_lines.append(r"\end{tabular}")
latex_lines.append(r"\caption{Detection vs. attribution. AUROC measures how well lexical overlap detects leaked items (computed on the 50\% overlap setting). DiD range reports the min/max causal leakage effect on accuracy across models (computed using 100\% vs 0\% overlap). Oracle$\wedge$Zero counts cases with AUROC$\geq$0.99 but $|$DiD$|\leq$0.01, illustrating detection does not imply contribution.}")
latex_lines.append(r"\label{tab:detection_vs_attribution}")
latex_lines.append(r"\end{table}")

latex = "\n".join(latex_lines)

out_tex = os.path.join(OUT_ROOT, "latex_detection_vs_attribution_table.tex")
with open(out_tex, "w", encoding="utf-8") as f:
    f.write(latex)

print(latex)
print("\nSaved LaTeX table to:", out_tex)

# ============================
# Anchor Certificate (Mainline-2 Only)
# - No interventions, no injection experiments
# - Compute p(4) on anchor stratum (default: group="sgu")
# - Certify leakage lower bound LB_E(u)=max(0,p-u) and FracLeakLB=LB_E/p
# - Sensitivity over u in [U_MIN, U_MAX]
# - Optional stratification by self-report S in {3,5,7} (contam_pred)
# ============================



In [None]:
import os
import numpy as np
import pandas as pd

OUT_ROOT = "/content/drive/MyDrive/complexity7/eval_outputs_leak"
RES_CSV  = os.path.join(OUT_ROOT, "results_before_after.csv")

ANCHOR_GROUP = "sgu"
LEAK_USE  = "leak0"
PHASE_USE = "before"

U_MIN = 0.50
U_MAX = 0.85
U_STAR = 0.68

OUT_CERT_CSV = os.path.join(OUT_ROOT, "anchor_certificates_simple.csv")
OUT_CERT_TXT = os.path.join(OUT_ROOT, "anchor_certificates_simple.txt")

if not os.path.exists(RES_CSV):
    raise FileNotFoundError(f"Missing {RES_CSV}")

df = pd.read_csv(RES_CSV)

def pick_p_col(phase):
    for c in [f"{phase}_acc_overall", f"{phase}_acc_on_fmt"]:
        if c in df.columns:
            return c
    raise KeyError(f"Neither {phase}_acc_overall nor {phase}_acc_on_fmt found.")

p_col = pick_p_col(PHASE_USE)

sub = df[(df["group"] == ANCHOR_GROUP) & (df["leak"] == LEAK_USE)].copy()
if sub.empty:
    raise ValueError(f"No rows for group={ANCHOR_GROUP}, leak={LEAK_USE}")

sub[p_col] = pd.to_numeric(sub[p_col], errors="coerce")
sub = sub.dropna(subset=[p_col])

def lb_e(p, u):
    return max(0.0, float(p) - float(u))

rows = []
for _, r in sub.iterrows():
    model = r["model"]
    p = float(r[p_col])
    rows.append({
        "group": ANCHOR_GROUP,
        "leak": LEAK_USE,
        "phase": PHASE_USE,
        "model": model,
        "p4": p,
        "u_min": U_MIN,
        "u_max": U_MAX,
        "LB_E_at_u_min": lb_e(p, U_MIN),
        "LB_E_at_u_max": lb_e(p, U_MAX),
        "FracLeakLB_at_u_min": (lb_e(p, U_MIN)/p) if p>0 else np.nan,
        "FracLeakLB_at_u_max": (lb_e(p, U_MAX)/p) if p>0 else np.nan,
        "u_cert_max": p
    })

cert = pd.DataFrame(rows).sort_values("model").reset_index(drop=True)

# mean row
p_bar = float(cert["p4"].mean())
mean_row = {
    "group": ANCHOR_GROUP,
    "leak": LEAK_USE,
    "phase": PHASE_USE,
    "model": "__MEAN__",
    "p4": p_bar,
    "u_min": U_MIN,
    "u_max": U_MAX,
    "LB_E_at_u_min": lb_e(p_bar, U_MIN),
    "LB_E_at_u_max": lb_e(p_bar, U_MAX),
    "FracLeakLB_at_u_min": (lb_e(p_bar, U_MIN)/p_bar) if p_bar>0 else np.nan,
    "FracLeakLB_at_u_max": (lb_e(p_bar, U_MAX)/p_bar) if p_bar>0 else np.nan,
    "u_cert_max": p_bar
}
cert = pd.concat([cert, pd.DataFrame([mean_row])], ignore_index=True)

# save CSV
cert.to_csv(OUT_CERT_CSV, index=False)

# save a plain-text report
lines = []
lines.append(f"ANCHOR CERTIFICATE (no-JS mode)")
lines.append(f"anchor_group={ANCHOR_GROUP}, leak={LEAK_USE}, phase={PHASE_USE}")
lines.append(f"p_col={p_col}")
lines.append(f"u_range=[{U_MIN},{U_MAX}], u_star={U_STAR}")
lines.append("")
lines.append("TABLE:")
lines.append(cert.to_string(index=False))
lines.append("")
lines.append("CERTIFICATE STATEMENTS:")
for _, r in cert.iterrows():
    if r["model"] == "__MEAN__":
        continue
    p = float(r["p4"])
    lb = lb_e(p, U_STAR)
    frac = (lb/p) if p>0 else np.nan
    lines.append(
        f"[CERT] model={r['model']}: p(4)={p:.3f}. "
        f"For u <= {U_STAR:.2f}, certify Pr(E=1|C=4) >= {lb:.3f} "
        f"(FracLeakLB >= {frac:.3f})."
    )
lb_mean = lb_e(p_bar, U_STAR)
frac_mean = (lb_mean/p_bar) if p_bar>0 else np.nan
lines.append(
    f"[CERT] __MEAN__: p(4)={p_bar:.3f}. For u <= {U_STAR:.2f}, "
    f"certify Pr(E=1|C=4) >= {lb_mean:.3f} (FracLeakLB >= {frac_mean:.3f})."
)

os.makedirs(OUT_ROOT, exist_ok=True)
with open(OUT_CERT_TXT, "w", encoding="utf-8") as f:
    f.write("\n".join(lines))

print("Saved CSV :", OUT_CERT_CSV)
print("Saved TXT :", OUT_CERT_TXT)
print("\n(If Colab output is broken, open the TXT in Drive to view the table.)")
import matplotlib.pyplot as plt
import numpy as np

U_GRID = np.linspace(U_MIN, U_MAX, 71)

def certify_curve(p):
    lb = np.maximum(0.0, p - U_GRID)
    frac = (lb / p) if p > 0 else np.full_like(lb, np.nan)
    return lb, frac

plt.figure()
for _, r in cert.iterrows():
    if r["model"] == "__MEAN__":
        continue
    lb, _ = certify_curve(float(r["p4"]))
    plt.plot(U_GRID, lb, label=r["model"])
plt.xlabel("u"); plt.ylabel("LB_E(u)")
plt.title(f"Anchor LB_E(u) on {ANCHOR_GROUP} ({LEAK_USE}/{PHASE_USE})")
plt.legend(fontsize=7)

png1 = os.path.join(OUT_ROOT, "anchor_LB_E_curve.png")
plt.savefig(png1, dpi=200, bbox_inches="tight")
plt.close()

plt.figure()
for _, r in cert.iterrows():
    if r["model"] == "__MEAN__":
        continue
    _, frac = certify_curve(float(r["p4"]))
    plt.plot(U_GRID, frac, label=r["model"])
plt.xlabel("u"); plt.ylabel("FracLeakLB(u)")
plt.title(f"Anchor FracLeakLB(u) on {ANCHOR_GROUP} ({LEAK_USE}/{PHASE_USE})")
plt.legend(fontsize=7)

png2 = os.path.join(OUT_ROOT, "anchor_FracLeakLB_curve.png")
plt.savefig(png2, dpi=200, bbox_inches="tight")
plt.close()

print("Saved PNG:", png1)
print("Saved PNG:", png2)


## 1) Global Config: paths, models, datasets
This cell is the only one you should edit first.


In [None]:

# -----------------------------
# Paths: support your Drive layout + local fallback
# -----------------------------
BASE_DIR = os.environ.get("BAP_BASE_DIR", "/content/drive/MyDrive/complexity_data6")
HARMONIZED_PATH = os.environ.get("BAP_HARMONIZED_JSONL", os.path.join(BASE_DIR, "all_tasks_harmonized.jsonl"))
RESULTS_DIR = os.environ.get("BAP_RESULTS_DIR", os.path.join(BASE_DIR, "bap_runs_full_coverage"))

os.makedirs(RESULTS_DIR, exist_ok=True)
print("HARMONIZED_PATH:", HARMONIZED_PATH)
print("RESULTS_DIR:", RESULTS_DIR)

# -----------------------------
# Models: all models you used (from your v10/v12)
# -----------------------------
HF_MODELS = [
    "Qwen/Qwen3-4B",
    "Qwen/Qwen2.5-Math-7B-Instruct",
    "deepseek-ai/deepseek-math-7b-instruct",
    "mistralai/Mathstral-7B-v0.1",
    "nvidia/AceMath-7B-Instruct",
]

# -----------------------------
# Datasets: all public ones you used + your JSONL
# -----------------------------
RUN_HARMONIZED_JSONL = True
RUN_GSM8K_ANSWER = True
RUN_AQUA_MC = True
RUN_ARC_CHALLENGE_MC = True

# For contamination baselines (optional)
RUN_BASELINE_CORPORA = True

# -----------------------------
# Evaluation budget and seeds
# -----------------------------
N_EVAL = int(os.environ.get("BAP_N_EVAL", "400"))     # per (dataset, model, run)
WITH_REPLACEMENT = True
EVAL_SEEDS = [101, 202, 303]  # replicates

# Public randomness beacon for binding seeds (anti-cherry-pick)
PUBLIC_BEACON = os.environ.get("BAP_PUBLIC_BEACON", "CHANGE_ME_TO_PUBLIC_RANDOMNESS_BEACON")

# Generation configs (kept in config digest)
GEN_CFG_BINARY = dict(max_new_tokens=1, do_sample=False, temperature=0.0)
GEN_CFG_ANSWER = dict(max_new_tokens=256, do_sample=False, temperature=0.0)
GEN_CFG_MC = dict(max_new_tokens=8, do_sample=False, temperature=0.0)


# For the 3-question protocol (asked in separate fresh sessions)
GEN_CFG_SEEN = dict(max_new_tokens=4, do_sample=False, temperature=0.0)
GEN_CFG_LABEL = dict(max_new_tokens=16, do_sample=False, temperature=0.0)
# -----------------------------
# Harmonized task selection
# -----------------------------
# If TASKS=None => run ALL tasks found in JSONL.
TASKS: Optional[List[str]] = None

# Optional SGU slice (from your v12 notebook)
INCLUDE_SGU_SLICE = True
SGU_SUFFIX = "__sgu"
SGU_COMPLEXITY_FAMILY_PATTERNS = [
    "strongly_generically_undecidable",
    "strongly generically undecidable",
    "sgu",
    "undecidable",
]

print("HF_MODELS:", HF_MODELS)



## 2) Utilities: hashing, deterministic seed derivation, sampling


In [None]:

def sha256_bytes(x: bytes) -> str:
    return hashlib.sha256(x).hexdigest()

def sha256_json(obj: Any) -> str:
    return sha256_bytes(json.dumps(obj, sort_keys=True, default=str).encode("utf-8"))

def derive_seed(beacon: str, dataset_digest: str, model_digest: str, config_digest: str, run_nonce: str) -> int:
    msg = f"{beacon}|{dataset_digest}|{model_digest}|{config_digest}|{run_nonce}".encode("utf-8")
    return int(sha256_bytes(msg)[:16], 16)  # 64-bit from prefix

def sample_indices(seed: int, N: int, n: int, with_replacement: bool = True) -> np.ndarray:
    rng = np.random.default_rng(seed)
    if with_replacement:
        return rng.integers(0, N, size=n, endpoint=False)
    n = min(n, N)
    return rng.choice(N, size=n, replace=False)

def indices_commitment(indices: np.ndarray) -> str:
    return sha256_bytes(indices.tobytes())

def safe_name(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", s)



## 3) Task Dataset Format (unified)
We unify:
- harmonized JSONL **binary** tasks,
- GSM8K **answer** tasks,
- AQUA / ARC **multi-choice** tasks.


In [None]:

@dataclass(frozen=True)
class TaskInstance:
    instance_id: str
    task_type: str      # "binary" | "answer" | "mc"
    prompt: str
    # labels:
    label01: Optional[int] = None
    ground_truth: Optional[str] = None   # answer string (GSM8K)
    mc_answer: Optional[str] = None      # option key like "A"/"B"/"C"/"D" or "E"
    mc_choices: Optional[List[str]] = None
    meta: Optional[Dict[str, Any]] = None

@dataclass
class TaskDataset:
    name: str
    instances: List[TaskInstance]



## 4) Load Your Harmonized JSONL (ALL tasks)
This loader matches your v10/v12 schema:
- required: `input`, `label`, `task`
- optional: `tier`, `complexity_family`, `case_type`, `enum_regime`, `n_enum`, `group_id`, `within`, `is_base`, …

We also support an optional `__sgu` slice.


In [None]:

def load_harmonized_jsonl_df(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing harmonized JSONL: {path}")
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    df = pd.DataFrame(rows)

    # normalize label to int {0,1}
    def _to01(x):
        if isinstance(x, (bool, np.bool_)):
            return int(x)
        if isinstance(x, (int, np.integer, float, np.floating)):
            return int(x)
        s = str(x).strip().lower()
        if s in ["1","true","yes","y","t"]:
            return 1
        if s in ["0","false","no","n","f"]:
            return 0
        m = re.search(r"[01]", s)
        return int(m.group(0)) if m else 0

    if "label" in df.columns:
        df["label"] = df["label"].apply(_to01).astype(int)

    # fill optional cols
    for col in ["complexity_family", "case_type", "enum_regime", "n_enum", "tier", "group_id", "within", "is_base", "source"]:
        if col not in df.columns:
            df[col] = np.nan

    # required cols check
    for req in ["input", "label", "task"]:
        if req not in df.columns:
            raise ValueError(f"Harmonized JSONL missing required field: {req}")

    return df

def build_taskdatasets_from_harmonized(df: pd.DataFrame,
                                      tasks: Optional[List[str]] = None,
                                      include_sgu_slice: bool = True) -> Dict[str, TaskDataset]:
    if tasks is None:
        tasks = sorted(df["task"].dropna().unique().tolist())

    out: Dict[str, TaskDataset] = {}

    def _make(inst_df: pd.DataFrame, name: str):
        instances = []
        for i, r in inst_df.reset_index(drop=True).iterrows():
            meta = {k: r[k] for k in inst_df.columns if k not in ["input","label","task"]}
            instances.append(TaskInstance(
                instance_id=f"{name}-{i}",
                task_type="binary",
                prompt=str(r["input"]),
                label01=int(r["label"]),
                meta=meta
            ))
        out[name] = TaskDataset(name=name, instances=instances)

    for t in tasks:
        df_t = df[df["task"] == t].copy()
        if len(df_t) == 0:
            continue
        _make(df_t, t)

        if include_sgu_slice and "complexity_family" in df_t.columns:
            cf = df_t["complexity_family"].fillna("").astype(str).str.lower()
            mask = np.zeros(len(df_t), dtype=bool)
            for pat in SGU_COMPLEXITY_FAMILY_PATTERNS:
                mask = mask | cf.str.contains(pat.lower(), na=False).to_numpy()
            if mask.any():
                _make(df_t.loc[mask].copy(), t + SGU_SUFFIX)

    return out

harmonized_tasks: Dict[str, TaskDataset] = {}
if RUN_HARMONIZED_JSONL:
    df_all = load_harmonized_jsonl_df(HARMONIZED_PATH)
    harmonized_tasks = build_taskdatasets_from_harmonized(df_all, tasks=TASKS, include_sgu_slice=INCLUDE_SGU_SLICE)
    print("Harmonized tasks loaded:", len(harmonized_tasks))
    print("Example tasks:", list(harmonized_tasks.keys())[:10])



## 5) Load Public Benchmarks (HF)
We load:
- GSM8K (answer-mode)
- AQUA-RAT (multi-choice)
- ARC-Challenge (multi-choice)

These are *in addition to* your harmonized binary versions.


In [None]:

from typing import Iterable

def load_gsm8k_answer(split: str = "test") -> TaskDataset:
    from datasets import load_dataset
    ds = load_dataset("gsm8k", "main", split=split)
    inst = []
    for i, row in enumerate(ds):
        inst.append(TaskInstance(
            instance_id=f"gsm8k-{split}-{i}",
            task_type="answer",
            prompt=row["question"].strip(),
            ground_truth=row["answer"],
            meta={"source": "gsm8k", "split": split}
        ))
    return TaskDataset(name=f"gsm8k_answer_{split}", instances=inst)

def load_aqua_mc(split: str = "test") -> TaskDataset:
    from datasets import load_dataset
    ds = load_dataset("aqua_rat", split=split)
    inst = []
    for i, row in enumerate(ds):
        q = row.get("question", "").strip()
        options = row.get("options", [])
        correct = row.get("correct", None)  # typically "A"/"B"/"C"/"D"/"E"
        prompt = "Choose the correct option (A/B/C/D/E). Return only the letter.\n\n"
        prompt += f"Question:\n{q}\n\nOptions:\n"
        for j, opt in enumerate(options):
            letter = chr(ord("A") + j)
            prompt += f"{letter}) {opt}\n"
        prompt += "\nAnswer:"
        inst.append(TaskInstance(
            instance_id=f"aqua-{split}-{i}",
            task_type="mc",
            prompt=prompt,
            mc_answer=str(correct).strip() if correct is not None else None,
            mc_choices=options,
            meta={"source": "aqua_rat", "split": split}
        ))
    return TaskDataset(name=f"aqua_mc_{split}", instances=inst)

def load_arc_challenge_mc(split: str = "test") -> TaskDataset:
    from datasets import load_dataset
    ds = load_dataset("ai2_arc", "ARC-Challenge", split=split)
    inst = []
    for i, row in enumerate(ds):
        q = row["question"]
        stem = q.get("stem","").strip()
        choices = q.get("choices", [])
        # ARC uses answerKey like "A"/"B"/"C"/"D"
        correct = row.get("answerKey", None)
        prompt = "Choose the correct option (A/B/C/D). Return only the letter.\n\n"
        prompt += f"Question:\n{stem}\n\nOptions:\n"
        for ch in choices:
            prompt += f"{ch.get('label')}) {ch.get('text')}\n"
        prompt += "\nAnswer:"
        inst.append(TaskInstance(
            instance_id=f"arc-challenge-{split}-{i}",
            task_type="mc",
            prompt=prompt,
            mc_answer=str(correct).strip() if correct is not None else None,
            mc_choices=[c.get("text") for c in choices],
            meta={"source": "ai2_arc", "subset": "ARC-Challenge", "split": split}
        ))
    return TaskDataset(name=f"arc_challenge_mc_{split}", instances=inst)

public_datasets: Dict[str, TaskDataset] = {}

try:
    if RUN_GSM8K_ANSWER:
        public_datasets["gsm8k_answer_test"] = load_gsm8k_answer("test")
        print("Loaded gsm8k_answer_test:", len(public_datasets["gsm8k_answer_test"].instances))
except Exception as e:
    print("Failed to load GSM8K (need internet in runtime):", e)

try:
    if RUN_AQUA_MC:
        public_datasets["aqua_mc_test"] = load_aqua_mc("test")
        print("Loaded aqua_mc_test:", len(public_datasets["aqua_mc_test"].instances))
except Exception as e:
    print("Failed to load AQUA-RAT:", e)

try:
    if RUN_ARC_CHALLENGE_MC:
        public_datasets["arc_challenge_mc_test"] = load_arc_challenge_mc("test")
        print("Loaded arc_challenge_mc_test:", len(public_datasets["arc_challenge_mc_test"].instances))
except Exception as e:
    print("Failed to load ARC-Challenge:", e)



## 6) Verifiers (bit + verifier-cost complexity)
Complexity is defined as **verifier computation cost** (step-count proxy).
We bin costs coarsely before releasing.


In [None]:

# Cost bins: coarse disclosure to reduce leakage via timing/covert channels
COST_BINS = [0, 40, 80, 140, 220, 320, 500, 800, 1200, 2000]

def cost_bin(steps: int) -> int:
    return int(np.digitize([steps], COST_BINS, right=False)[0] - 1)

_FINAL_ANS_RE = re.compile(r"####\s*([-+]?[\d\.,]+)")
_LAST_NUM_RE = re.compile(r"([-+]?[\d\.,]+)\s*$")

def normalize_number_string(s: str) -> str:
    return s.strip().replace(",", "")

def extract_gsm8k_final(text: str) -> Optional[str]:
    m = _FINAL_ANS_RE.search(text)
    if m:
        return normalize_number_string(m.group(1))
    m2 = _LAST_NUM_RE.search(text.strip())
    if m2:
        return normalize_number_string(m2.group(1))
    return None

def extract_model_final(text: str) -> Optional[str]:
    m = _FINAL_ANS_RE.search(text)
    if m:
        return normalize_number_string(m.group(1))
    nums = re.findall(r"[-+]?\d[\d,]*\.?\d*", text.replace(",", ""))
    if nums:
        return normalize_number_string(nums[-1])
    return None

def score_and_cost(inst: TaskInstance, model_out: str) -> Tuple[int, int]:
    # deterministic scoring + cost proxy
    out = (model_out or "").strip()
    steps = 0

    if inst.task_type == "binary":
        steps += 5 + len(out)//4
        tok = None
        for ch in out:
            if ch in ["0","1"]:
                tok = int(ch); break
        steps += 10
        if tok is None:
            return 0, steps + 10
        return int(tok == int(inst.label01)), steps

    if inst.task_type == "answer":
        gt = extract_gsm8k_final(inst.ground_truth or "")
        steps += 20 + len(inst.ground_truth or "")//5
        pred = extract_model_final(out)
        steps += 30 + len(out)//5
        if gt is None or pred is None:
            return 0, steps + 20
        return int(gt == pred), steps

    if inst.task_type == "mc":
        # parse first A-E
        steps += 8 + len(out)//4
        m = re.search(r"[A-E]", out.upper())
        steps += 12
        if m is None or inst.mc_answer is None:
            return 0, steps + 10
        pred = m.group(0)
        return int(pred == inst.mc_answer.upper().strip()), steps

    raise ValueError(f"Unknown task_type: {inst.task_type}")



## 7) Real HF Models: wrapper + lazy loading
We keep memory manageable by allowing **lazy load per model**.


In [None]:

DRY_RUN = bool(int(os.environ.get("BAP_DRY_RUN", "0")))  # set 1 to test pipeline without model downloads
USE_4BIT = bool(int(os.environ.get("BAP_4BIT", "0")))    # 4-bit loading if supported

DEVICE = "cuda" if (os.environ.get("CUDA_VISIBLE_DEVICES", "") != "" or os.environ.get("BAP_FORCE_CUDA","0")=="1") else "cpu"

print("DRY_RUN:", DRY_RUN, "| USE_4BIT:", USE_4BIT, "| DEVICE:", DEVICE)

# Label set for dataset identification question (Q2)
DATASET_LABELS = [
    "gsm8k",
    "sgu",
    "np_hard",
    "easy",
    "aqua_mc_test",
    "arc_challenge_mc_test",
    "unknown",
]

@dataclass
class HFModelWrapper:
    model_id: str
    tokenizer: Any
    model: Any

    @staticmethod
    def load(model_id: str, adapter_path: Optional[str] = None, use_4bit: bool = False):
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM

        tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token

        kwargs = {}
        if use_4bit:
            kwargs.update(dict(load_in_4bit=True, device_map="auto"))
        else:
            kwargs.update(dict(device_map="auto"))
        model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
        model.eval()
        if adapter_path is not None:
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, adapter_path)
            model.eval()
        return HFModelWrapper(model_id=model_id, tokenizer=tok, model=model)

    def _generate(self, prompt: str, gen_cfg: Dict[str, Any]) -> str:
        import torch
        from transformers import GenerationConfig
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
        inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        gcfg = GenerationConfig(**gen_cfg)
        with torch.no_grad():
            out = self.model.generate(**inputs, generation_config=gcfg)
        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        if text.startswith(prompt):
            return text[len(prompt):].strip()
        return text.strip()
    def ask_seen_level(self, question: str) -> str:
        """Q1: Ask whether the model has seen the problem (0/1/2). One-shot prompt; no chat history."""
        prompt = (
            "You will see a problem. Output ONLY a single digit: 0, 1, or 2.\n\n"
            "2 = I have seen this exact problem before.\n"
            "1 = I have not seen this exact problem, but I have seen very similar problems.\n"
            "0 = I have never seen this problem and have not seen similar ones.\n\n"
            "Do NOT output any other text.\n\n"
            f"Problem:\n{question.strip()}\n"
        )
        return self._generate(prompt, GEN_CFG_SEEN)

    def ask_dataset_label(self, question: str) -> str:
        """Q2: Predict which dataset/complexity bucket the problem belongs to. One-shot prompt."""
        labels = "\n".join(DATASET_LABELS)
        prompt = (
            "You will see a problem. Output ONLY one label from the list below, EXACTLY as written:\n\n"
            f"{labels}\n\n"
            "Do NOT output any other text.\n\n"
            f"Problem:\n{question.strip()}\n"
        )
        return self._generate(prompt, GEN_CFG_LABEL)
    def generate(self, inst: TaskInstance) -> str:
        if inst.task_type == "binary":
            prompt = (
                "Answer the following decision problem.\n"
                "Return ONLY a single token:\n"
                "1 if YES, 0 if NO.\n\n"
                f"Problem:\n{inst.prompt}\n\nReturn only a single token (0 or 1):"
            )
            return self._generate(prompt, GEN_CFG_BINARY)
        if inst.task_type == "answer":
            prompt = (
                "Solve the following math problem. Return ONLY the final numeric answer.\n\n"
                f"Problem:\n{inst.prompt}\n\nFinal answer:"
            )
            return self._generate(prompt, GEN_CFG_ANSWER)
        if inst.task_type == "mc":
            # prompt already contains instruction + choices
            return self._generate(inst.prompt, GEN_CFG_MC)
        raise ValueError(inst.task_type)

def digest_model_public(model_id: str) -> str:
    return sha256_json({"model_id": model_id})


def _sha256_file(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1024 * 1024), b""):
            h.update(chunk)
    return h.hexdigest()

def digest_model_state(model_id: str, adapter_path: Optional[str] = None) -> str:
    """Digest for base model + optional LoRA adapter directory (if provided)."""
    if adapter_path is None:
        return digest_model_public(model_id)
    adapter_path = os.path.abspath(adapter_path)
    cfg = os.path.join(adapter_path, "adapter_config.json")
    # peft saves weights as adapter_model.safetensors or adapter_model.bin
    w1 = os.path.join(adapter_path, "adapter_model.safetensors")
    w2 = os.path.join(adapter_path, "adapter_model.bin")
    weights = w1 if os.path.exists(w1) else (w2 if os.path.exists(w2) else None)
    payload = {"model_id": model_id, "adapter_path": adapter_path}
    if os.path.exists(cfg):
        payload["adapter_config_sha256"] = _sha256_file(cfg)
    if weights is not None and os.path.exists(weights):
        payload["adapter_weights_sha256"] = _sha256_file(weights)
    return sha256_json(payload)

# Lazy cache (can set to 0 to unload after each run)
MAX_CACHED_MODELS = int(os.environ.get("BAP_MAX_CACHED_MODELS", "1"))
MODEL_CACHE: Dict[str, HFModelWrapper] = {}  # key: f"{model_id}||{adapter_path or 'base'}"


def get_model(model_id: str, adapter_path: Optional[str] = None) -> HFModelWrapper:
    """Lazy-load base model, and optionally attach a LoRA adapter (PEFT). Cached per (model_id, adapter_path)."""
    if DRY_RUN:
        return None  # type: ignore
    key = f"{model_id}||{os.path.abspath(adapter_path) if adapter_path is not None else 'base'}"
    if key in MODEL_CACHE:
        return MODEL_CACHE[key]
    print("[LOAD MODEL]", model_id, ("+LoRA" if adapter_path else ""))
    mw = HFModelWrapper.load(model_id, adapter_path=adapter_path, use_4bit=USE_4BIT)
    MODEL_CACHE[key] = mw
    # evict if needed
    if len(MODEL_CACHE) > MAX_CACHED_MODELS:
        k0 = next(iter(MODEL_CACHE.keys()))
        if k0 != key:
            del MODEL_CACHE[k0]
    return mw



## 8) Evidence Package + Mock Attestation
We bind:
- code digest
- dataset digest
- model digest
- config digest
- derived seed + index commitment
- output hash over `(bits, cost_bins)`


In [None]:

CODE_ID = "eval_v_full"

def digest_code(code_id: str) -> str:
    return sha256_json({"code_id": code_id})

def digest_config(task_type: str, n: int, with_replacement: bool, gen_cfg: Dict[str, Any]) -> str:
    return sha256_json({"task_type": task_type, "n": n, "with_replacement": with_replacement, "gen_cfg": gen_cfg})

def digest_dataset(dataset: TaskDataset, include_text: bool = False) -> str:
    if include_text:
        payload = [(x.instance_id, x.task_type, x.prompt, x.label01, x.ground_truth, x.mc_answer, x.meta) for x in dataset.instances]
    else:
        payload = []
        for x in dataset.instances:
            gt_hash = sha256_json({"gt": x.ground_truth}) if x.ground_truth is not None else None
            prompt_hash = sha256_json({"prompt": x.prompt})  # optional; you can remove to reduce leakage
            payload.append((x.instance_id, x.task_type, x.label01, gt_hash, x.mc_answer, prompt_hash, x.meta))
    return sha256_json({"name": dataset.name, "payload": payload})

def hash_outputs(bits: np.ndarray, cost_bins: np.ndarray) -> str:
    return sha256_json({"bits": bits.astype(int).tolist(), "cost_bins": cost_bins.astype(int).tolist()})

# Mock attestation via HMAC
ATT_KEY = os.environ.get("BAP_ATT_KEY", "mock-attestation-key-change-me").encode("utf-8")

def sign_hmac(message: Dict[str, Any]) -> str:
    msg = json.dumps(message, sort_keys=True).encode("utf-8")
    return hmac.new(ATT_KEY, msg, hashlib.sha256).hexdigest()

def verify_hmac(message: Dict[str, Any], sig: str) -> bool:
    return hmac.compare_digest(sign_hmac(message), sig)



## 9) BAP Server + Verifier Client (bit-only interface)


In [None]:

@dataclass
class EvalPackage:
    bits: np.ndarray
    cost_bins: np.ndarray
    evidence: Dict[str, Any]
    signature: str

@dataclass
class BAPServer:
    code_id: str

    def run(self, dataset: TaskDataset, model_id: str, run_nonce: str,
            adapter_path: Optional[str] = None, seed_model_id: Optional[str] = None,
            n: int = N_EVAL, with_replacement: bool = WITH_REPLACEMENT) -> EvalPackage:

        # resolve gen cfg by task_type (dataset is single-type expected)
        task_type = dataset.instances[0].task_type
        gen_cfg = GEN_CFG_BINARY if task_type=="binary" else (GEN_CFG_ANSWER if task_type=="answer" else GEN_CFG_MC)

        cfg_digest = digest_config(task_type, n, with_replacement, gen_cfg)
        ds_digest = digest_dataset(dataset, include_text=False)
        m_digest = digest_model_state(model_id, adapter_path)
        m_digest_seed = digest_model_public(seed_model_id or model_id)

        seed = derive_seed(PUBLIC_BEACON, ds_digest, m_digest_seed, cfg_digest, run_nonce)
        idx = sample_indices(seed, len(dataset.instances), n, with_replacement)
        idx_commit = indices_commitment(idx)

        bits = np.zeros(n, dtype=int)
        cb = np.zeros(n, dtype=int)

        seen_levels = np.full(n, -1, dtype=int)
        dataset_label_preds = ["unknown"] * n
        mw = get_model(model_id, adapter_path=adapter_path)

        for t, j in enumerate(tqdm(
              idx,
              desc=f"{dataset.name} | {model_id}",
              unit="q",
              dynamic_ncols=True,
              leave=False,
              disable=not TQDM_ENABLED
          )):
            inst = dataset.instances[int(j)]
            if DRY_RUN:
                out = "0"
                seen_out = "0"
                label_out = "unknown"
            else:
                # Q1/Q2 are asked in separate one-shot prompts (fresh context each call)
                seen_out = mw.ask_seen_level(inst.prompt)
                label_out = mw.ask_dataset_label(inst.prompt)
                out = mw.generate(inst)
            # Parse and record Q1/Q2 outputs (robust to occasional extra tokens)
            m_seen = re.search(r"[0-2]", str(seen_out))
            seen_levels[t] = int(m_seen.group(0)) if m_seen else -1

            lo = str(label_out).strip().lower()
            m_lab = None
            for lab in DATASET_LABELS:
                if re.search(rf"\b{re.escape(lab)}\b", lo):
                    m_lab = lab
                    break
            dataset_label_preds[t] = m_lab if m_lab is not None else "unknown"

            b, steps = score_and_cost(inst, out)
            bits[t] = b
            cb[t] = cost_bin(int(steps))

        evidence = {
            "code_digest": digest_code(self.code_id),
            "dataset_digest": ds_digest,
            "model_digest": m_digest,
            "seed_model_digest": m_digest_seed,
            "adapter_path": os.path.abspath(adapter_path) if adapter_path is not None else None,
            "config_digest": cfg_digest,
            "public_beacon": PUBLIC_BEACON,
            "run_nonce": run_nonce,
            "seed": int(seed),
            "index_commitment": idx_commit,
            "output_hash": hash_outputs(bits, cb),
            "seen_levels": seen_levels.astype(int).tolist(),
            "dataset_label_preds": dataset_label_preds,
        }
        sig = sign_hmac(evidence)
        return EvalPackage(bits=bits, cost_bins=cb, evidence=evidence, signature=sig)

@dataclass
class VerifierClient:
    trusted_code_digests: set
    trusted_dataset_digests: set
    trusted_model_digests: set

    def verify(self, pkg: EvalPackage) -> bool:
        if not verify_hmac(pkg.evidence, pkg.signature):
            return False
        if pkg.evidence["code_digest"] not in self.trusted_code_digests:
            return False
        if pkg.evidence["dataset_digest"] not in self.trusted_dataset_digests:
            return False
        if pkg.evidence["model_digest"] not in self.trusted_model_digests:
            return False
        if pkg.evidence["output_hash"] != hash_outputs(pkg.bits, pkg.cost_bins):
            return False
        return True

server = BAPServer(code_id=CODE_ID)

# Build dataset registry: (harmonized tasks) + (public HF datasets)
DATASETS: Dict[str, TaskDataset] = {}
DATASETS.update(harmonized_tasks)
DATASETS.update(public_datasets)

trusted_code = {digest_code(CODE_ID)}
trusted_models = {digest_model_public(m) for m in HF_MODELS}
trusted_datasets = {digest_dataset(ds, include_text=False) for ds in DATASETS.values()}

verifier = VerifierClient(trusted_code, trusted_datasets, trusted_models)

print("Datasets registered:", len(DATASETS))
print("Trusted datasets:", len(trusted_datasets), "| Trusted models:", len(trusted_models))


In [None]:
# ==========================
# 9b) Build grouped datasets (6 groups) with fixed 300/300 train/eval and save to disk
# ==========================
# Groups (as requested):
#   aqua           = aqua_binary + aqua_mc_test
#   arc_challenge  = arc_challenge_binary
#   gsm8k          = gsm8k_answer_test + gsm8k_binary
#   p_time         = p_graph_connectivity + ptime_arith  (+ group_word_generic_easy, if present)
#   np_time        = random_np_rbh_word_avg + np_subset_sum_avg + np_subset_sum_worst + np_rbh_word_worst_3sat
#   undecidable    = remaining halting-related tasks (from the universe list below)
#
# Output keys are:
#   <group>_train  (300 samples)
#   <group>_test   (300 samples)
# These are persisted as JSONL, and registered back into DATASETS for later LoRA/post-train.

import json, os
import numpy as np
from typing import Dict, List, Tuple

GROUPED_DIR = os.environ.get("BAP_GROUPED_DIR", os.path.join(BASE_DIR, "bap_grouped_splits_v1"))
os.makedirs(GROUPED_DIR, exist_ok=True)

N_GROUP_TRAIN = int(os.environ.get("BAP_GROUP_TRAIN_N", "300"))
N_GROUP_TEST  = int(os.environ.get("BAP_GROUP_TEST_N", "300"))
GROUP_BASE_SEED = int(os.environ.get("BAP_GROUP_BASE_SEED", "20260112"))

# Restrict grouping to the canonical set used in your full-coverage runs.
# This prevents accidentally pulling in extra tasks from the harmonized JSONL.
UNIVERSE_KEYS = [
    # AQuA / ARC / GSM8K
    "aqua_binary", "aqua_mc_test",
    "arc_challenge_binary",
    "gsm8k_answer_test", "gsm8k_binary",
    # P-time
    "p_graph_connectivity", "ptime_arith",
    # (optional easy P-time-ish task from harmonized JSONL)
    "group_word_generic_easy",
    # NP-time (NP-hard-ish)
    "random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat",
    # Halting/undecidable family (SGU/TM keys from your runs)
    "sgu_collatz_aligned", "sgu_collatz_aligned__sgu",
    "sgu_index_empty_language", "sgu_index_empty_language__sgu",
    "sgu_index_total_halt", "sgu_index_total_halt__sgu",
    "sgu_semigroup_wp_amp", "sgu_semigroup_wp_amp__sgu",
    "tm_generic_halt", "tm_hard_halt",
]

# Keep only keys that exist in current DATASETS registry.
UNIVERSE_KEYS = [k for k in UNIVERSE_KEYS if k in DATASETS]
print("Grouping universe keys:", len(UNIVERSE_KEYS))

fixed_groups: Dict[str, List[str]] = {
    "aqua": ["aqua_binary", "aqua_mc_test"],
    "arc_challenge": ["arc_challenge_binary"],
    "gsm8k": ["gsm8k_answer_test", "gsm8k_binary"],
    "p_time": ["p_graph_connectivity", "ptime_arith"],
    "np_time": ["random_np_rbh_word_avg", "np_subset_sum_avg", "np_subset_sum_worst", "np_rbh_word_worst_3sat"],
}

# Put group_word_generic_easy into p_time if present (keeps 6 groups total and avoids dropping it)
if "group_word_generic_easy" in UNIVERSE_KEYS and "group_word_generic_easy" in DATASETS:
    if "group_word_generic_easy" not in fixed_groups["p_time"]:
        fixed_groups["p_time"].append("group_word_generic_easy")

used = set(k for ks in fixed_groups.values() for k in ks if k in UNIVERSE_KEYS)
undecidable_keys = [k for k in UNIVERSE_KEYS if k not in used]
fixed_groups["undecidable"] = undecidable_keys

print("=== Group plan ===")
for g, ks in fixed_groups.items():
    ks2 = [k for k in ks if k in DATASETS]
    print(f"[{g}] {len(ks2)} sources:", ks2)

def _flatten_instances(dataset_keys: List[str]) -> List[TaskInstance]:
    inst: List[TaskInstance] = []
    for k in dataset_keys:
        if k not in DATASETS:
            continue
        inst.extend(DATASETS[k].instances)
    return inst

def _sample_disjoint(instances: List[TaskInstance], n_train: int, n_eval: int, seed: int) -> Tuple[List[TaskInstance], List[TaskInstance]]:
    rng = np.random.default_rng(seed)
    N = len(instances)
    if N == 0:
        raise ValueError("No instances to sample from.")
    need = n_train + n_eval
    perm = rng.permutation(N)
    if N >= need:
        tr_idx = perm[:n_train]
        ev_idx = perm[n_train:need]
        train = [instances[int(i)] for i in tr_idx]
        evals = [instances[int(i)] for i in ev_idx]
        return train, evals

    print(f"[WARN] Only {N} instances available < {need}. Using replacement to reach {n_train}/{n_eval}.")
    # train first (no replacement if possible)
    tr_take = min(n_train, N)
    train = [instances[int(i)] for i in perm[:tr_take]]
    # remaining for eval
    remaining = [instances[int(i)] for i in perm[tr_take:]]
    evals = remaining[:min(len(remaining), n_eval)]
    if len(evals) < n_eval:
        extra = rng.choice(instances, size=(n_eval - len(evals)), replace=True).tolist()
        evals = evals + extra
    if len(train) < n_train:
        extra = rng.choice(instances, size=(n_train - len(train)), replace=True).tolist()
        train = train + extra
    return train, evals

def _inst_to_dict(x: TaskInstance) -> dict:
    return {
        "instance_id": x.instance_id,
        "task_type": x.task_type,
        "prompt": x.prompt,
        "label01": x.label01,
        "ground_truth": x.ground_truth,
        "mc_answer": x.mc_answer,
        "mc_choices": x.mc_choices,
        "meta": x.meta,
    }

def _save_jsonl(path: str, instances: List[TaskInstance]):
    with open(path, "w", encoding="utf-8") as f:
        for inst in instances:
            f.write(json.dumps(_inst_to_dict(inst), ensure_ascii=False) + "\n")

def _make_dataset(name: str, instances: List[TaskInstance]) -> TaskDataset:
    return TaskDataset(name=name, instances=instances)

manifest = {
    "version": "bap_grouped_splits_v1",
    "n_train": N_GROUP_TRAIN,
    "n_test": N_GROUP_TEST,
    "base_seed": GROUP_BASE_SEED,
    "universe_keys": UNIVERSE_KEYS,
    "groups": {}
}

for gi, (gname, keys) in enumerate(fixed_groups.items()):
    keys = [k for k in keys if k in DATASETS]
    merged = _flatten_instances(keys)
    seed = GROUP_BASE_SEED + gi * 1000
    train_inst, test_inst = _sample_disjoint(merged, N_GROUP_TRAIN, N_GROUP_TEST, seed=seed)

    train_key = f"{gname}_train"
    test_key  = f"{gname}_test"  # important: _make_train_eval_300() detects *_train/*_test
    train_path = os.path.join(GROUPED_DIR, f"{train_key}.jsonl")
    test_path  = os.path.join(GROUPED_DIR, f"{test_key}.jsonl")
    _save_jsonl(train_path, train_inst)
    _save_jsonl(test_path, test_inst)

    DATASETS[train_key] = _make_dataset(train_key, train_inst)
    DATASETS[test_key]  = _make_dataset(test_key, test_inst)

    manifest["groups"][gname] = {
        "sources": keys,
        "n_merged": len(merged),
        "seed": seed,
        "train_key": train_key,
        "test_key": test_key,
        "train_path": train_path,
        "test_path": test_path,
    }
    print(f"[SAVED] {gname}: merged={len(merged)} train={len(train_inst)} test={len(test_inst)}")

manifest_path = os.path.join(GROUPED_DIR, "manifest.json")
with open(manifest_path, "w", encoding="utf-8") as f:
    json.dump(manifest, f, indent=2, ensure_ascii=False)

GROUP_TEST_KEYS = [f"{g}_test" for g in fixed_groups.keys()]
GROUP_TRAIN_KEYS = [f"{g}_train" for g in fixed_groups.keys()]
print("Grouped dataset keys ready.")
print("GROUP_TEST_KEYS:", GROUP_TEST_KEYS)
print("manifest:", manifest_path)



## 10) Batch Runner: evaluate **all models × all datasets**
This is the piece you asked for: everything connected.
It runs and stores artifacts:
- `bits`, `cost_bins`, `evidence`, `signature`

We store each run as JSON in `RESULTS_DIR`.


In [None]:

def save_package(pkg: EvalPackage, path: str):
    out = {
        "bits": pkg.bits.astype(int).tolist(),
        "cost_bins": pkg.cost_bins.astype(int).tolist(),
        "evidence": pkg.evidence,
        "signature": pkg.signature,
        "signature_type": "hmac",
    }
    with open(path, "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2, sort_keys=True)

def run_all(models: List[str], datasets: Dict[str, TaskDataset], seeds: List[int]) -> pd.DataFrame:
    """Evaluate all (dataset × model × seed) combinations with caching + verification.
    Adds a progress bar and avoids NaN/empty-bit pitfalls.
    """
    rows = []
    # Build a verifier that trusts exactly the dataset digests evaluated in this call.
    trusted_datasets_local = set(digest_dataset(_ds, include_text=False) for _ds in datasets.values())
    verifier_local = VerifierClient(trusted_code, trusted_datasets_local, set(trusted_models))

    jobs = []
    for dname, ds in datasets.items():
        if len(getattr(ds, "instances", [])) == 0:
            continue
        for mid in models:
            for s in seeds:
                jobs.append((dname, ds, mid, int(s)))

    pbar = tqdm(jobs, desc="Evaluating (dataset×model×seed)", unit="run",
                dynamic_ncols=True, disable=not TQDM_ENABLED)

    for dname, ds, mid, s in pbar:
        pbar.set_postfix(dataset=dname, model=mid, seed=s)

        run_nonce = f"{dname}|{mid}|seed{s}"
        out_path = os.path.join(RESULTS_DIR, f"{safe_name(dname)}__{safe_name(mid)}__{s}.json")

        if os.path.exists(out_path):
            # cached
            with open(out_path, "r", encoding="utf-8") as f:
                saved = json.load(f)
            bits = np.array(saved.get("bits", []), dtype=int)
            acc = float(bits.mean()) if len(bits) > 0 else float("nan")
            rows.append({"dataset": dname, "model": mid, "seed": s, "acc": acc, "cached": True, "path": out_path})
            continue

        pkg = server.run(ds, mid, run_nonce=run_nonce, n=N_EVAL, with_replacement=WITH_REPLACEMENT)
        ok = verifier_local.verify(pkg)
        if not ok:
            raise RuntimeError(f"Verification failed for {dname} {mid} seed{s}")
        save_package(pkg, out_path)

        acc = float(pkg.bits.mean()) if len(pkg.bits) > 0 else float("nan")
        rows.append({"dataset": dname, "model": mid, "seed": s, "acc": acc, "cached": False, "path": out_path})
        print("[DONE]", dname, mid, "seed", s, "acc", acc)

    return pd.DataFrame(rows)


# WARNING: This can be very expensive if DATASETS contains many harmonized tasks.
# Start small by setting TASKS to a short list or by filtering DATASETS keys.
# Example: DATASETS_TO_RUN = {k: DATASETS[k] for k in ["gsm8k_binary","aqua_binary"] if k in DATASETS}
GROUP_TEST_KEYS = [
    "aqua_test",
    "arc_challenge_test",
    "gsm8k_test",
    "p_time_test",
    "np_time_test",
    "undecidable_test",
]
DATASETS_TO_RUN = {k: DATASETS[k] for k in GROUP_TEST_KEYS if k in DATASETS}


df_runs = run_all(HF_MODELS, DATASETS_TO_RUN, EVAL_SEEDS)
display(df_runs.head())
print("Total runs:", len(df_runs))



# =======================
# 10b) Before/After runner with per-dataset LoRA state switching
# =======================

LORA_PER_DATASET = bool(int(os.environ.get("BAP_LORA_PER_DATASET", "0")))  # set 1 to train adapters
LORA_TRAIN_N = int(os.environ.get("BAP_LORA_TRAIN_N", "300"))
LORA_EVAL_N = int(os.environ.get("BAP_LORA_EVAL_N", "300"))
LORA_SPLIT_SEED = int(os.environ.get("BAP_LORA_SPLIT_SEED", "123"))

LORA_MAX_STEPS_PER_DATASET = int(os.environ.get("BAP_LORA_MAX_STEPS", "200"))
LORA_LR_PER_DATASET = float(os.environ.get("BAP_LORA_LR", "2e-4"))
LORA_R_PER_DATASET = int(os.environ.get("BAP_LORA_R", "8"))
LORA_ALPHA_PER_DATASET = int(os.environ.get("BAP_LORA_ALPHA", "16"))
LORA_DROPOUT_PER_DATASET = float(os.environ.get("BAP_LORA_DROPOUT", "0.05"))
LORA_MAX_LEN_PER_DATASET = int(os.environ.get("BAP_LORA_MAX_LEN", "512"))

LORA_ADAPTERS_DIR = os.path.join(RESULTS_DIR, "lora_per_dataset")
os.makedirs(LORA_ADAPTERS_DIR, exist_ok=True)

def _make_train_eval_300(ds_name: str, datasets: Dict[str, TaskDataset], train_n: int, eval_n: int, seed: int):
    """Return (train_ds, eval_ds). Prefer explicit *_train/*_test pairs, else split within ds."""
    # Explicit pair if available
    if ds_name.endswith("_test"):
        cand_train = ds_name[:-5] + "_train"
        if cand_train in datasets:
            train_ds_full = datasets[cand_train]
            eval_ds_full = datasets[ds_name]
            # sample 300 each (no replacement)
            rng = np.random.default_rng(seed)
            tr_idx = rng.choice(len(train_ds_full.instances), size=min(train_n, len(train_ds_full.instances)), replace=False)
            ev_idx = rng.choice(len(eval_ds_full.instances), size=min(eval_n, len(eval_ds_full.instances)), replace=False)
            tr = [train_ds_full.instances[int(i)] for i in tr_idx]
            ev = [eval_ds_full.instances[int(i)] for i in ev_idx]
            return TaskDataset(name=f"{cand_train}_n{len(tr)}", instances=tr), TaskDataset(name=f"{ds_name}_n{len(ev)}", instances=ev)

    # Fallback: split within ds
    base = datasets[ds_name]
    rng = np.random.default_rng(seed)
    idx = rng.permutation(len(base.instances)).tolist()
    tr = [base.instances[i] for i in idx[:min(train_n, len(idx))]]
    ev = [base.instances[i] for i in idx[min(train_n, len(idx)):min(train_n+eval_n, len(idx))]]
    return TaskDataset(name=f"{ds_name}_train_n{len(tr)}", instances=tr), TaskDataset(name=f"{ds_name}_eval_n{len(ev)}", instances=ev)

def _train_rows_for_inst(inst: TaskInstance) -> Optional[str]:
    """Build a single training text example aligned with HFModelWrapper.generate() prompts."""
    if inst.task_type == "answer":
        gt_raw = inst.ground_truth or ""
        gt = extract_gsm8k_final(gt_raw) or extract_model_final(gt_raw) or gt_raw.strip()
        prompt = (
            "Solve the following math problem. Return ONLY the final numeric answer.\n\n"
            f"Problem:\n{inst.prompt.strip()}\n\nFinal answer:"
        )
        return (prompt + " " + str(gt).strip()).strip()
    if inst.task_type == "mc":
        if inst.mc_answer is None:
            return None
        return (inst.prompt.strip() + " " + inst.mc_answer.strip()).strip()
    if inst.task_type == "binary":
        if inst.label01 is None:
            return None
        prompt = (
            "Read the following problem and decide if the statement is TRUE (1) or FALSE (0).\n\n"
            f"Problem:\n{inst.prompt.strip()}\n\nReturn only a single token (0 or 1):"
        )
        return (prompt + " " + str(int(inst.label01))).strip()
    return None

def build_lora_train_rows(train_ds: TaskDataset, n: int, seed: int) -> List[Dict[str, str]]:
    rng = np.random.default_rng(seed)
    m = min(n, len(train_ds.instances))
    idx = rng.choice(len(train_ds.instances), size=m, replace=False)
    rows = []
    for i in idx:
        txt = _train_rows_for_inst(train_ds.instances[int(i)])
        if txt is not None:
            rows.append({"text": txt})
    return rows

def train_lora_adapter_simple(base_model_id: str, train_rows: List[Dict[str,str]], out_dir: str,
                             max_steps: int, lr: float, r: int, alpha: int, dropout: float, max_len: int = 512):
    """Train and save a PEFT LoRA adapter to out_dir."""
    import torch
    from datasets import Dataset as HFDataset
    from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
    from peft import LoraConfig, get_peft_model

    os.makedirs(out_dir, exist_ok=True)

    tok = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    kwargs = {"device_map": "auto"}
    if USE_4BIT:
        kwargs = {"load_in_4bit": True, "device_map": "auto"}

    model = AutoModelForCausalLM.from_pretrained(base_model_id, **kwargs)
    model.train()

    target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    peft_cfg = LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )
    model = get_peft_model(model, peft_cfg)

    hf = HFDataset.from_list(train_rows)

    def tok_fn(ex):
        x = tok(ex["text"], truncation=True, max_length=max_len)
        x["labels"] = x["input_ids"].copy()
        return x

    hf = hf.map(tok_fn, remove_columns=["text"])
    collator = DataCollatorForLanguageModeling(tok, mlm=False)

    args = TrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=lr,
        max_steps=max_steps,
        logging_steps=max(1, max_steps//10),
        save_steps=max_steps,
        save_total_limit=1,
        report_to=[],
        fp16=torch.cuda.is_available(),
    )

    trainer = Trainer(model=model, args=args, train_dataset=hf, data_collator=collator)
    trainer.train()
    model.save_pretrained(out_dir)
    tok.save_pretrained(out_dir)
    print("[LoRA saved]", out_dir)

def ensure_lora_adapter(model_id: str, train_ds: TaskDataset, adapter_dir: str, seed: int) -> str:
    """Train adapter if not already present; return adapter path."""
    marker = os.path.join(adapter_dir, "adapter_config.json")
    if os.path.exists(marker):
        return adapter_dir
    train_rows = build_lora_train_rows(train_ds, n=LORA_TRAIN_N, seed=seed)
    if len(train_rows) == 0:
        raise RuntimeError(f"No train rows constructed for {train_ds.name}")
    train_lora_adapter_simple(
        base_model_id=model_id,
        train_rows=train_rows,
        out_dir=adapter_dir,
        max_steps=LORA_MAX_STEPS_PER_DATASET,
        lr=LORA_LR_PER_DATASET,
        r=LORA_R_PER_DATASET,
        alpha=LORA_ALPHA_PER_DATASET,
        dropout=LORA_DROPOUT_PER_DATASET,
        max_len=LORA_MAX_LEN_PER_DATASET,
    )
    return adapter_dir

def _state_dir(phase: str, ds_base: str, split: str, model_id: str, state: str) -> str:
    return os.path.join(RESULTS_DIR, "before_after", phase, safe_name(ds_base), split, safe_name(model_id), state)

def run_before_after(models: List[str], datasets: Dict[str, TaskDataset], dataset_keys: List[str], seeds: List[int]) -> pd.DataFrame:
    """Runs:
      - BEFORE (base): train + eval
      - AFTER  (lora): eval only, where LoRA is trained on train split for the same dataset key.

    Adds caching + verification + progress bar. Also fixes:
      - verifier_local is defined per dataset key (so digests match the split datasets)
      - uses verifier_local (not a stray global verifier) when trusting adapter digests
      - guards empty bit arrays when computing accuracy
    """
    rows = []

    # Pre-compute total units for a single global progress bar
    present_keys = [k for k in dataset_keys if k in datasets]
    total_units = len(present_keys) * len(models) * len(seeds) * 3  # before(train)+before(eval)+after(eval)
    pbar = tqdm(total=total_units, desc="Before/After (base vs LoRA)", unit="run",
                dynamic_ncols=True, disable=not TQDM_ENABLED)

    try:
        for ds_key in dataset_keys:
            if ds_key not in datasets:
                print("[SKIP missing dataset]", ds_key)
                continue

            train_ds, eval_ds = _make_train_eval_300(ds_key, datasets, LORA_TRAIN_N, LORA_EVAL_N, LORA_SPLIT_SEED)

            # Verifier must trust *these* split datasets (their digests differ from the original full dataset)
            trusted_datasets_local = {
                digest_dataset(train_ds, include_text=False),
                digest_dataset(eval_ds, include_text=False),
            }
            verifier_local = VerifierClient(trusted_code, trusted_datasets_local, set(trusted_models))

            for mid in models:
                for s in seeds:
                    s = int(s)

                    # ---------- BEFORE: base on train + eval ----------
                    for split_name, ds_obj in [("train", train_ds), ("eval", eval_ds)]:
                        pbar.set_postfix(dataset=ds_key, model=mid, seed=s, phase=f"before/{split_name}")

                        run_nonce = f"{ds_key}|{mid}|base|{split_name}|seed{s}"
                        out_dir = _state_dir("before", ds_key, split_name, mid, "base")
                        os.makedirs(out_dir, exist_ok=True)
                        out_path = os.path.join(out_dir, f"seed{s}.json")

                        if os.path.exists(out_path):
                            with open(out_path, "r", encoding="utf-8") as f:
                                saved = json.load(f)
                            bits = np.array(saved.get("bits", []), dtype=int)
                            acc = float(bits.mean()) if len(bits) > 0 else float("nan")
                            rows.append({
                                "dataset": ds_key, "split": split_name, "phase": "before", "state": "base",
                                "model": mid, "seed": s, "acc": acc, "cached": True, "path": out_path
                            })
                            pbar.update(1)
                        else:
                            pkg = server.run(
                                ds_obj, mid, adapter_path=None, seed_model_id=mid,
                                run_nonce=run_nonce, n=min(N_EVAL, len(ds_obj.instances)), with_replacement=False
                            )
                            ok = verifier_local.verify(pkg)
                            if not ok:
                                raise RuntimeError(f"Verification failed for {ds_key} {mid} before {split_name} seed{s}")
                            save_package(pkg, out_path)

                            acc = float(pkg.bits.mean()) if len(pkg.bits) > 0 else float("nan")
                            rows.append({
                                "dataset": ds_key, "split": split_name, "phase": "before", "state": "base",
                                "model": mid, "seed": s, "acc": acc, "cached": False, "path": out_path
                            })
                            print("[DONE]", ds_key, mid, "before", split_name, "seed", s, "acc", acc)
                            pbar.update(1)

                    # ---------- AFTER: lora on eval only ----------
                    pbar.set_postfix(dataset=ds_key, model=mid, seed=s, phase="after/eval")

                    if LORA_PER_DATASET:
                        adapter_dir = os.path.join(LORA_ADAPTERS_DIR, safe_name(mid), safe_name(ds_key))
                        adapter_path = ensure_lora_adapter(mid, train_ds, adapter_dir, seed=LORA_SPLIT_SEED)
                        # Trust adapter digest for verifier (so verification passes for LoRA-modified model state)
                        verifier_local.trusted_model_digests.add(digest_model_state(mid, adapter_path))
                    else:
                        adapter_path = None

                    run_nonce = f"{ds_key}|{mid}|lora|eval|seed{s}"
                    out_dir = _state_dir("after", ds_key, "eval", mid, "lora" if adapter_path else "no_lora")
                    os.makedirs(out_dir, exist_ok=True)
                    out_path = os.path.join(out_dir, f"seed{s}.json")

                    if os.path.exists(out_path):
                        with open(out_path, "r", encoding="utf-8") as f:
                            saved = json.load(f)
                        bits = np.array(saved.get("bits", []), dtype=int)
                        acc = float(bits.mean()) if len(bits) > 0 else float("nan")
                        rows.append({
                            "dataset": ds_key, "split": "eval", "phase": "after",
                            "state": "lora" if adapter_path else "no_lora",
                            "model": mid, "seed": s, "acc": acc, "cached": True, "path": out_path
                        })
                        pbar.update(1)
                        continue

                    pkg = server.run(
                        eval_ds, mid, adapter_path=adapter_path, seed_model_id=mid,
                        run_nonce=run_nonce, n=min(N_EVAL, len(eval_ds.instances)), with_replacement=False
                    )
                    ok = verifier_local.verify(pkg)
                    if not ok:
                        raise RuntimeError(f"Verification failed for {ds_key} {mid} after eval seed{s}")
                    save_package(pkg, out_path)

                    acc = float(pkg.bits.mean()) if len(pkg.bits) > 0 else float("nan")
                    rows.append({
                        "dataset": ds_key, "split": "eval", "phase": "after",
                        "state": "lora" if adapter_path else "no_lora",
                        "model": mid, "seed": s, "acc": acc, "cached": False, "path": out_path
                    })
                    print("[DONE]", ds_key, mid, "after eval", "seed", s, "acc", acc)
                    pbar.update(1)

        return pd.DataFrame(rows)
    finally:
        try:
            pbar.close()
        except Exception:
            pass


# Example usage:
# DATASET_KEYS = ["gsm8k_answer_test", "aqua_mc_test", "arc_challenge_mc_test"]  # add more keys as desired
# df_ba = run_before_after(HF_MODELS, DATASETS, DATASET_KEYS, EVAL_SEEDS)
# display(df_ba.head())


In [None]:
# =========================
# Minimal runner WITHOUT verifier
# - drop VerifierClient / trusted_code / digests
# - keep caching, saving JSON, and aggregation
# =========================

import os, time, json
import numpy as np
import pandas as pd
from typing import List, Dict, Optional

# Assumes these already exist in your notebook:
# - RESULTS_DIR: str
# - server: object with .run(ds, model_id, ..., run_nonce=..., n=..., with_replacement=...)
# - TaskDataset / TaskInstance types (only need ds.name and ds.instances)
# - safe_name(name: str) -> str
# - N_EVAL: int
# - WITH_REPLACEMENT: bool

def save_package_noverify(pkg, path: str, meta: Optional[Dict] = None):
    """
    Save EvalPackage-like object to JSON without any verifier.
    Expected pkg fields:
      - bits: array-like of {0,1}
      - cost_bins: array-like ints (optional)
      - evidence: any JSON-serializable payload (optional)
    """
    bits = np.asarray(getattr(pkg, "bits", []), dtype=int)
    cost_bins = np.asarray(getattr(pkg, "cost_bins", []), dtype=int)

    out = {
        "bits": bits.tolist(),
        "cost_bins": cost_bins.tolist(),
        "evidence": getattr(pkg, "evidence", None),
        # keep lightweight provenance for debugging/repro
        "signature": {
            "meta": meta or {},
            "saved_at": time.time(),
        },
        "signature_type": "meta_only",
    }
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(out, f, indent=2, sort_keys=True)

def _load_cached_run(path: str):
    with open(path, "r", encoding="utf-8") as f:
        saved = json.load(f)
    bits = np.asarray(saved.get("bits", []), dtype=int)
    cb   = np.asarray(saved.get("cost_bins", []), dtype=int)
    acc  = float(bits.mean()) if bits.size > 0 else float("nan")
    return bits, cb, acc, saved

def run_all_noverify(models: List[str], datasets: Dict[str, "TaskDataset"], seeds: List[int]) -> pd.DataFrame:
    """Evaluate each (dataset, model, seed), cache to JSON, and return a tidy DataFrame.
    Adds a progress bar and guards empty bits.
    """
    rows = []

    jobs = []
    for dname, ds in datasets.items():
        if (not hasattr(ds, "instances")) or len(ds.instances) == 0:
            continue
        for mid in models:
            for s in seeds:
                jobs.append((dname, ds, mid, int(s)))

    pbar = tqdm(jobs, desc="Evaluating (no-verify)", unit="run",
                dynamic_ncols=True, disable=not TQDM_ENABLED)

    for dname, ds, mid, s in pbar:
        pbar.set_postfix(dataset=dname, model=mid, seed=s)

        run_nonce = f"{dname}|{mid}|seed{s}"
        out_path = os.path.join(RESULTS_DIR, f"{safe_name(dname)}__{safe_name(mid)}__{int(s)}.json")

        # cached?
        if os.path.exists(out_path):
            _, _, acc, _ = _load_cached_run(out_path)
            rows.append({"dataset": dname, "model": mid, "seed": int(s), "acc": float(acc), "cached": True, "path": out_path})
            continue

        pkg = server.run(ds, mid, run_nonce=run_nonce, n=N_EVAL, with_replacement=WITH_REPLACEMENT)

        bits = np.asarray(getattr(pkg, "bits", []), dtype=int)
        acc = float(bits.mean()) if len(bits) > 0 else float("nan")

        meta = {
            "dataset": dname,
            "model": mid,
            "seed": int(s),
            "run_nonce": run_nonce,
            "n": int(N_EVAL),
            "with_replacement": bool(WITH_REPLACEMENT),
        }
        save_package_noverify(pkg, out_path, meta=meta)

        rows.append({"dataset": dname, "model": mid, "seed": int(s), "acc": acc, "cached": False, "path": out_path})
        print("[DONE]", dname, mid, "seed", int(s), "acc", acc)

    return pd.DataFrame(rows)



# =========================
# Optional: BEFORE/AFTER runner WITHOUT verifier
# - Keeps your LoRA-per-dataset workflow
# - Removes all verifier checks and digest tracking
# =========================

def run_before_after_noverify(models: List[str], datasets: Dict[str, "TaskDataset"], dataset_keys: List[str], seeds: List[int]) -> pd.DataFrame:
    """Before/After runner without verifier.
    Adds caching + progress bar; guards empty bits.
    """
    rows = []

    present_keys = [k for k in dataset_keys if k in datasets]
    total_units = len(present_keys) * len(models) * len(seeds) * 3
    pbar = tqdm(total=total_units, desc="Before/After (no-verify)", unit="run",
                dynamic_ncols=True, disable=not TQDM_ENABLED)

    try:
        for ds_key in dataset_keys:
            if ds_key not in datasets:
                print("[SKIP missing dataset]", ds_key)
                continue

            train_ds, eval_ds = _make_train_eval_300(ds_key, datasets, LORA_TRAIN_N, LORA_EVAL_N, LORA_SPLIT_SEED)

            for mid in models:
                for s in seeds:
                    s = int(s)

                    # BEFORE: base on train + eval
                    for split_name, ds_obj in [("train", train_ds), ("eval", eval_ds)]:
                        pbar.set_postfix(dataset=ds_key, model=mid, seed=s, phase=f"before/{split_name}")

                        run_nonce = f"{ds_key}|{mid}|base|{split_name}|seed{s}"
                        out_dir = _state_dir("before", ds_key, split_name, mid, "base")
                        os.makedirs(out_dir, exist_ok=True)
                        out_path = os.path.join(out_dir, f"seed{s}.json")

                        if os.path.exists(out_path):
                            _, _, acc, _ = _load_cached_run(out_path)
                            rows.append({"dataset": ds_key, "split": split_name, "phase": "before", "state": "base",
                                         "model": mid, "seed": s, "acc": float(acc), "cached": True, "path": out_path})
                            pbar.update(1)
                        else:
                            pkg = server.run(ds_obj, mid, adapter_path=None, seed_model_id=mid,
                                             run_nonce=run_nonce, n=min(N_EVAL, len(ds_obj.instances)), with_replacement=False)
                            bits = np.asarray(getattr(pkg, "bits", []), dtype=int)
                            acc = float(bits.mean()) if len(bits) > 0 else float("nan")
                            save_package_noverify(pkg, out_path, meta={"dataset": ds_key, "model": mid, "seed": s, "run_nonce": run_nonce})
                            rows.append({"dataset": ds_key, "split": split_name, "phase": "before", "state": "base",
                                         "model": mid, "seed": s, "acc": acc, "cached": False, "path": out_path})
                            print("[DONE]", ds_key, mid, "before", split_name, "seed", s, "acc", acc)
                            pbar.update(1)

                    # AFTER: lora eval
                    pbar.set_postfix(dataset=ds_key, model=mid, seed=s, phase="after/eval")

                    if LORA_PER_DATASET:
                        adapter_dir = os.path.join(LORA_ADAPTERS_DIR, safe_name(mid), safe_name(ds_key))
                        adapter_path = ensure_lora_adapter(mid, train_ds, adapter_dir, seed=LORA_SPLIT_SEED)
                    else:
                        adapter_path = None

                    run_nonce = f"{ds_key}|{mid}|lora|eval|seed{s}"
                    out_dir = _state_dir("after", ds_key, "eval", mid, "lora" if adapter_path else "no_lora")
                    os.makedirs(out_dir, exist_ok=True)
                    out_path = os.path.join(out_dir, f"seed{s}.json")

                    if os.path.exists(out_path):
                        _, _, acc, _ = _load_cached_run(out_path)
                        rows.append({"dataset": ds_key, "split": "eval", "phase": "after",
                                     "state": "lora" if adapter_path else "no_lora",
                                     "model": mid, "seed": s, "acc": float(acc), "cached": True, "path": out_path})
                        pbar.update(1)
                        continue

                    pkg = server.run(eval_ds, mid, adapter_path=adapter_path, seed_model_id=mid,
                                     run_nonce=run_nonce, n=min(N_EVAL, len(eval_ds.instances)), with_replacement=False)
                    bits = np.asarray(getattr(pkg, "bits", []), dtype=int)
                    acc = float(bits.mean()) if len(bits) > 0 else float("nan")
                    save_package_noverify(pkg, out_path, meta={"dataset": ds_key, "model": mid, "seed": s, "run_nonce": run_nonce})
                    rows.append({"dataset": ds_key, "split": "eval", "phase": "after",
                                 "state": "lora" if adapter_path else "no_lora",
                                 "model": mid, "seed": s, "acc": acc, "cached": False, "path": out_path})
                    print("[DONE]", ds_key, mid, "after eval", "seed", s, "acc", acc)
                    pbar.update(1)

        return pd.DataFrame(rows)
    finally:
        try:
            pbar.close()
        except Exception:
            pass




## 11) Analysis Helpers (bit-only)
We keep:
- binomial test for contamination evidence (vs p0),
- paired McNemar exact test,
- complexity-tiered curves (pass rate by cost bin) + AUC.


In [None]:

def leakage_test_binomial(bits: np.ndarray, p0: float, alpha: float = 0.05) -> Dict[str, Any]:
    n = len(bits); k = int(bits.sum())
    pv = stats.binomtest(k, n, p=p0, alternative="greater").pvalue
    return {"n": n, "k": k, "p_hat": k/n, "p0": p0, "p_value": float(pv), "reject": bool(pv < alpha)}

def mcnemar_exact(bA: np.ndarray, bB: np.ndarray) -> float:
    n10 = int(np.sum((bA==1) & (bB==0)))
    n01 = int(np.sum((bA==0) & (bB==1)))
    n = n10 + n01
    if n == 0:
        return 1.0
    cdf = stats.binom.cdf(n10, n, 0.5)
    sf = stats.binom.sf(n10-1, n, 0.5)
    return float(min(1.0, 2.0*min(cdf, sf)))

def pass_rate_by_costbin(bits: np.ndarray, cost_bins_arr: np.ndarray, K: int) -> pd.DataFrame:
    rows = []
    for k in range(K):
        mask = cost_bins_arr == k
        nk = int(mask.sum())
        rows.append({"bin": k, "n": nk, "p_hat": float(bits[mask].mean()) if nk>0 else np.nan})
    return pd.DataFrame(rows)

def auc_mass_weighted(df: pd.DataFrame) -> float:
    v = df.dropna()
    w = v["n"].to_numpy().astype(float)
    if w.sum() == 0:
        return float("nan")
    w = w / w.sum()
    return float(np.sum(w * v["p_hat"].to_numpy()))

def load_pkg(path: str) -> EvalPackage:
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    bits = np.array(obj["bits"], dtype=int)
    cb = np.array(obj["cost_bins"], dtype=int)
    ev = obj["evidence"]
    sig = obj["signature"]
    return EvalPackage(bits=bits, cost_bins=cb, evidence=ev, signature=sig)


In [None]:
df_runs = run_all_noverify(HF_MODELS, DATASETS_TO_RUN, EVAL_SEEDS)
display(df_runs.head())
print("Total runs:", len(df_runs))


## 12) Quick Discrimination Plot (complexity-tiered) for one dataset across all models


In [None]:

# Pick one dataset to visualize
example_dataset = None
for cand in ["gsm8k_answer_test", "gsm8k_binary", "aqua_binary", "arc_challenge_binary"]:
    if cand in DATASETS:
        example_dataset = cand
        break

if example_dataset is None:
    example_dataset = list(DATASETS.keys())[0] if len(DATASETS)>0 else None

print("Example dataset:", example_dataset)

if example_dataset is not None and len(df_runs) > 0:
    K = len(COST_BINS) - 1
    plt.figure()
    for mid in HF_MODELS:
        sub = df_runs[(df_runs["dataset"]==example_dataset) & (df_runs["model"]==mid)].head(1)
        if len(sub)==0:
            continue
        p = sub.iloc[0]["path"]
        pkg = load_pkg(p)
        dfb = pass_rate_by_costbin(pkg.bits, pkg.cost_bins, K)
        plt.plot(dfb["bin"], dfb["p_hat"], marker="o", label=mid.split("/")[-1])
    plt.xlabel("Verifier-cost bin (coarse)")
    plt.ylabel("Pass rate")
    plt.title(f"Complexity-tiered discrimination: {example_dataset}")
    plt.legend()
    plt.show()



## 13) Optional: Baseline Corpora Loader (for n-gram overlap / contamination baselines)
This just **connects** the corpora you used (train splits).  
You can plug in your existing n-gram/PPL baseline code on top of this.


In [None]:

baseline_corpora: Dict[str, List[str]] = {}

def load_corpus_texts() -> Dict[str, List[str]]:
    from datasets import load_dataset
    corp = {}

    # GSM8K train
    ds = load_dataset("gsm8k", "main", split="train")
    corp["gsm8k_train_questions"] = [x["question"] for x in ds]

    # AQUA train
    ds = load_dataset("aqua_rat", split="train")
    corp["aqua_train_questions"] = [x.get("question","") for x in ds]

    # ARC train
    ds = load_dataset("ai2_arc", "ARC-Challenge", split="train")
    corp["arc_challenge_train_stems"] = [x["question"].get("stem","") for x in ds]

    return corp

if RUN_BASELINE_CORPORA:
    try:
        baseline_corpora = load_corpus_texts()
        for k,v in baseline_corpora.items():
            print(k, "docs:", len(v))
    except Exception as e:
        print("Failed to load baseline corpora (need internet):", e)



## 14) Where your original experiment modules fit
At this point, you have:
- all models connected,
- all datasets connected,
- stable artifacts saved.

Now you can run:
- **Exp-1 auditability** (tamper tests) on any `(dataset,model)` pair by editing/altering `evidence/bits`.
- **Exp-2 contamination** (fast proxy + LoRA) by creating a contaminated model variant and re-running `run_all`.
- **Exp-3 discrimination** is already available (cost-binned curves).
- **Exp-4 overhead** is computed via timing wrappers around `server.run`.

If you want, I can also port your **coverage attack (S4)** blocks into this same artifact schema, but I didn't do it here because you asked specifically to connect all models and datasets.



# 15) ICML Upgrade: Real LoRA Contamination Attack (GSM8K)
This section closes the **realism gap** by training a **real cheater model** via LoRA:
- Leak a fraction \(\rho\) of a fixed GSM8K test **eval pool** into the training mix.
- Keep training compute constant across \(\rho\) (same steps, same mix size).
- Run a BAP-style **bit-only** evaluation and plot **Reject Rate vs \(\rho\)**.

⚠️ Real training code. Default is OFF. Set `RUN_LORA=True` when ready.


In [None]:

# =============== LoRA config knobs ===============
RUN_LORA = True  # <<< set True to actually train

LORA_BASE_MODEL = "Qwen/Qwen3-4B"  # fast
LORA_RHOS = [0.0, 0.001, 0.005, 0.01, 0.05]  # contamination rates

# Fixed compute across rhos
LORA_TRAIN_MIX_SIZE = 512    # total training examples (constant)
LORA_MAX_STEPS = 200         # constant steps (you can raise)
LORA_LR = 2e-4
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
LORA_MAX_LEN = 512

# Define leakage set on a fixed eval pool
LORA_EVAL_POOL_SIZE = 1000
LORA_POOL_SEED = 123

# Where to store adapters
LORA_DIR = os.path.join(RESULTS_DIR, "lora_cheaters")
os.makedirs(LORA_DIR, exist_ok=True)
print("LORA_DIR:", LORA_DIR)



## 15.1 Load GSM8K train/test (for LoRA)
We reuse the existing HF loader; we need the **train** split for LoRA.


In [None]:
# Load GSM8K train/test for contamination-style experiments.
# Prefer the fixed grouped split (gsm8k_train/gsm8k_test) if available; otherwise fall back to HF GSM8K.
try:
    if "gsm8k_train" in DATASETS and "gsm8k_test" in DATASETS:
        gsm8k_train_ds = DATASETS["gsm8k_train"]
        gsm8k_test_ds  = DATASETS["gsm8k_test"]
        print("[Grouped] GSM8K train:", len(gsm8k_train_ds.instances), "| test:", len(gsm8k_test_ds.instances))
    else:
        if "gsm8k_answer_test" not in public_datasets:
            public_datasets["gsm8k_answer_test"] = load_gsm8k_answer("test")
        gsm8k_test_ds = public_datasets["gsm8k_answer_test"]
        gsm8k_train_ds = load_gsm8k_answer("train")
        print("[HF] GSM8K train:", len(gsm8k_train_ds.instances), "| test:", len(gsm8k_test_ds.instances))
except Exception as e:
    gsm8k_train_ds = None
    gsm8k_test_ds = None
    print("Failed to load GSM8K train/test:", e)



## 15.2 LoRA training + evaluation helpers
We avoid triple-quote docstrings here to keep notebook generation robust.


In [None]:

def gsm8k_prompt_completion(inst: TaskInstance) -> str:
    # Same "final answer only" style as evaluation; no CoT.
    gt = extract_gsm8k_final(inst.ground_truth or "") or ""
    prompt = (
        "Solve the following math problem. Return ONLY the final numeric answer.\n\n"
        f"Problem:\n{inst.prompt}\n\nFinal answer:"
    )
    return prompt + " " + gt

def make_eval_pool_indices(N_test: int, pool_size: int, seed: int):
    rng = np.random.default_rng(seed)
    m = min(pool_size, N_test)
    return rng.choice(N_test, size=m, replace=False).tolist()

def make_training_mix_for_rho(train_ds: TaskDataset, test_ds: TaskDataset, eval_pool_idx, rho: float, mix_size: int, seed: int = 0):
    # Returns (train_rows, leaked_ids)
    rng = np.random.default_rng(seed)
    pool = [test_ds.instances[i] for i in eval_pool_idx]
    m = int(round(rho * len(pool)))
    leaked = pool[:m]  # monotone leakage by prefix
    leaked_ids = set(x.instance_id for x in leaked)

    rows = [{"text": gsm8k_prompt_completion(x)} for x in leaked]
    need = mix_size - len(rows)
    if need > 0:
        idx = rng.choice(len(train_ds.instances), size=need, replace=False)
        rows.extend({"text": gsm8k_prompt_completion(train_ds.instances[int(i)])} for i in idx)
    rng.shuffle(rows)
    return rows, leaked_ids

def train_lora_adapter(base_model_id: str, train_rows, out_dir: str, max_steps: int, lr: float, r: int, alpha: int, dropout: float, max_len: int = 512):
    # Real LoRA training. Requires transformers + peft.
    import torch
    from datasets import Dataset as HFDataset
    from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
    from peft import LoraConfig, get_peft_model, TaskType

    tok = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")
    model.train()

    hf = HFDataset.from_list(train_rows)

    def tok_fn(batch):
        return tok(batch["text"], truncation=True, max_length=max_len)

    hf = hf.map(tok_fn, batched=True, remove_columns=["text"])

    lcfg = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, task_type=TaskType.CAUSAL_LM)
    model = get_peft_model(model, lcfg)

    args = TrainingArguments(
        output_dir=out_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        max_steps=max_steps,
        learning_rate=lr,
        logging_steps=max(1, max_steps//10),
        save_steps=max_steps,
        fp16=True,
        report_to=[],
        remove_unused_columns=False,
    )

    collator = DataCollatorForLanguageModeling(tok, mlm=False)
    trainer = Trainer(model=model, args=args, train_dataset=hf, data_collator=collator)
    trainer.train()
    model.save_pretrained(out_dir)
    tok.save_pretrained(out_dir)
    print("Saved LoRA adapter:", out_dir)

# --- Minimal adapter-aware evaluator (bit-only output) ---
from dataclasses import dataclass
from typing import Optional, Any, Dict

@dataclass
class EvalPkg2:
    bits: np.ndarray
    cost_bins: np.ndarray
    evidence: Dict[str, Any]
    signature: str

def digest_model_with_adapter(model_id: str, adapter_path: Optional[str]) -> str:
    return sha256_json({"model_id": model_id, "adapter_path": adapter_path})

def load_base_or_lora(model_id: str, adapter_path: Optional[str]):
    # Load base model and optionally attach a LoRA adapter.
    from transformers import AutoTokenizer, AutoModelForCausalLM
    tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    kwargs = {"device_map": "auto"}
    if "USE_4BIT" in globals() and USE_4BIT:
        kwargs = {"load_in_4bit": True, "device_map": "auto"}

    model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
    model.eval()
    if adapter_path is not None:
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, adapter_path)
        model.eval()
    return tok, model

def generate_answer(tok, model, question: str, gen_cfg: dict):
    import torch
    from transformers import GenerationConfig
    prompt = (
        "Solve the following math problem. Return ONLY the final numeric answer.\n\n"
        f"Problem:\n{question}\n\nFinal answer:"
    )
    inputs = tok(prompt, return_tensors="pt", truncation=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    gcfg = GenerationConfig(**gen_cfg)
    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gcfg)
    text = tok.decode(out[0], skip_special_tokens=True)
    if text.startswith(prompt):
        return text[len(prompt):].strip()
    return text.strip()

def bap_run_with_adapter(dataset: TaskDataset, model_id: str, adapter_path: Optional[str], run_nonce: str, n: int, with_replacement: bool):
    # BAP-style: seed->indices->bits/costbins, plus receipt.
    task_type = "answer"
    gen_cfg = GEN_CFG_ANSWER if "GEN_CFG_ANSWER" in globals() else dict(max_new_tokens=256, do_sample=False, temperature=0.0)

    cfg_digest = digest_config(task_type, n, with_replacement, gen_cfg)
    ds_digest = digest_dataset(dataset, include_text=False)
    m_digest = digest_model_with_adapter(model_id, adapter_path)

    seed = derive_seed(PUBLIC_BEACON, ds_digest, m_digest, cfg_digest, run_nonce)
    idx = sample_indices(seed, len(dataset.instances), n, with_replacement)
    idx_commit = indices_commitment(idx)

    bits = np.zeros(n, dtype=int)
    cb = np.zeros(n, dtype=int)

    if not DRY_RUN:
        tok, model = load_base_or_lora(model_id, adapter_path)
    for t, j in enumerate(idx):
        inst = dataset.instances[int(j)]
        out = "0" if DRY_RUN else generate_answer(tok, model, inst.prompt, gen_cfg)
        b, steps = score_and_cost(inst, out)
        bits[t] = b
        cb[t] = cost_bin(int(steps))

    evidence = {
        "code_digest": digest_code(CODE_ID),  # reuse
        "dataset_digest": ds_digest,
        "model_digest": m_digest,
        "config_digest": cfg_digest,
        "public_beacon": PUBLIC_BEACON,
        "run_nonce": run_nonce,
        "seed": int(seed),
        "index_commitment": idx_commit,
        "output_hash": hash_outputs(bits, cb),
    }
    sig = sign_hmac(evidence)
    return EvalPkg2(bits=bits, cost_bins=cb, evidence=evidence, signature=sig)



## 15.3 Run the LoRA contamination curve (Reject Rate vs \(\rho\))
This produces:
- `df_lora_curve`: per-rho per-seed results
- a plot: reject rate vs rho


In [None]:

if RUN_LORA:
    assert gsm8k_train_ds is not None and gsm8k_test_ds is not None, "GSM8K train/test not available."

    eval_pool_idx = make_eval_pool_indices(len(gsm8k_test_ds.instances), LORA_EVAL_POOL_SIZE, LORA_POOL_SEED)

    # Pilot p0 from base model (same BAP sampling interface)
    p0_list = []
    for s in EVAL_SEEDS:
        pkg0 = bap_run_with_adapter(gsm8k_test_ds, LORA_BASE_MODEL, adapter_path=None,
                                    run_nonce=f"pilot|base|seed{s}", n=N_EVAL, with_replacement=WITH_REPLACEMENT)
        p0_list.append(float(pkg0.bits.mean()))
    p0 = float(np.mean(p0_list))
    print("Pilot p0 (base acc estimate):", p0)

    rows = []
    leaked_registry = {}  # rho -> leaked_ids (open-track only)
    for rho in LORA_RHOS:
        adapter_dir = os.path.join(LORA_DIR, f"{safe_name(LORA_BASE_MODEL)}__rho{rho}")
        os.makedirs(adapter_dir, exist_ok=True)

        train_rows, leaked_ids = make_training_mix_for_rho(
            gsm8k_train_ds, gsm8k_test_ds, eval_pool_idx, rho=rho, mix_size=LORA_TRAIN_MIX_SIZE, seed=int(10_000*rho)+7
        )
        leaked_registry[rho] = leaked_ids

        if rho == 0.0:
            adapter_use = None
        else:
            marker = os.path.join(adapter_dir, "adapter_config.json")
            if not os.path.exists(marker):
                train_lora_adapter(
                    base_model_id=LORA_BASE_MODEL,
                    train_rows=train_rows,
                    out_dir=adapter_dir,
                    max_steps=LORA_MAX_STEPS,
                    lr=LORA_LR,
                    r=LORA_R,
                    alpha=LORA_ALPHA,
                    dropout=LORA_DROPOUT,
                    max_len=LORA_MAX_LEN,
                )
            else:
                print("Adapter exists; skipping training:", adapter_dir)
            adapter_use = adapter_dir

        for s in EVAL_SEEDS:
            pkg = bap_run_with_adapter(gsm8k_test_ds, LORA_BASE_MODEL, adapter_path=adapter_use,
                                       run_nonce=f"lora|rho{rho}|seed{s}", n=N_EVAL, with_replacement=WITH_REPLACEMENT)
            assert verify_hmac(pkg.evidence, pkg.signature)
            k = int(pkg.bits.sum()); n = len(pkg.bits)
            pv = stats.binomtest(k, n, p=p0, alternative="greater").pvalue
            rows.append({"rho": rho, "seed": s, "acc": float(pkg.bits.mean()), "p_value": float(pv), "reject": bool(pv < 0.05)})

    df_lora_curve = pd.DataFrame(rows)
    display(df_lora_curve.groupby("rho").agg(acc_mean=("acc","mean"), reject_rate=("reject","mean"), p_med=("p_value","median")).reset_index())

    g = df_lora_curve.groupby("rho")["reject"].mean().reset_index()
    plt.figure()
    plt.plot(g["rho"]*100, g["reject"], marker="o")
    plt.axhline(0.05, linestyle="--")
    plt.xlabel("Contamination rate rho (%)")
    plt.ylabel("Reject rate (power proxy at alpha=0.05)")
    plt.title("BAP Detection Power vs REAL LoRA Contamination (GSM8K)")
    plt.show()
else:
    print("RUN_LORA=False; skipping real LoRA training. Set RUN_LORA=True to run.")



# 16) ICML Upgrade: Baselines (N-gram overlap / PPL / Min-K% Prob)
This section closes the **baseline gap**.

Baselines:
1. **N-gram overlap** between GSM8K test questions and GSM8K train questions.
2. **PPL** and **Min-K% Prob** of the suspect model on the question text.

We report **AUC** for identifying contaminated items (labels are known in the LoRA experiment).


In [None]:

# ---- AUC helper (no sklearn required) ----
def auc_from_scores(scores: np.ndarray, labels01: np.ndarray) -> float:
    scores = np.asarray(scores, dtype=float)
    labels01 = np.asarray(labels01, dtype=int)
    pos = scores[labels01 == 1]
    neg = scores[labels01 == 0]
    if len(pos) == 0 or len(neg) == 0:
        return float("nan")
    order = np.argsort(scores)
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(1, len(scores)+1)

    uniq, inv, cnt = np.unique(scores, return_inverse=True, return_counts=True)
    if np.any(cnt > 1):
        for i, c in enumerate(cnt):
            if c > 1:
                idx = np.where(inv == i)[0]
                ranks[idx] = ranks[idx].mean()

    rank_sum_pos = ranks[labels01 == 1].sum()
    n_pos = len(pos); n_neg = len(neg)
    return float((rank_sum_pos - n_pos*(n_pos+1)/2) / (n_pos*n_neg))

# ---- N-gram overlap baseline ----
_word_re = re.compile(r"[A-Za-z0-9]+")
def word_ngrams(text: str, n: int = 5):
    toks = [t.lower() for t in _word_re.findall(text)]
    if len(toks) < n:
        return []
    return [" ".join(toks[i:i+n]) for i in range(len(toks)-n+1)]

def hash_ngram(s: str) -> int:
    h = hashlib.blake2b(s.encode("utf-8"), digest_size=8).digest()
    return int.from_bytes(h, "big")

def build_ngram_set(corpus_texts, n: int = 5):
    hs = set()
    for txt in corpus_texts:
        for ng in word_ngrams(txt, n=n):
            hs.add(hash_ngram(ng))
    return hs

def ngram_overlap_score(text: str, ngram_set: set, n: int = 5) -> float:
    ngs = word_ngrams(text, n=n)
    if not ngs:
        return 0.0
    hit = sum((hash_ngram(ng) in ngram_set) for ng in ngs)
    return float(hit / len(ngs))



## 16.1 Build GSM8K train 5-gram set


In [None]:

ngset_gsm8k_5 = None
if gsm8k_train_ds is not None:
    ngset_gsm8k_5 = build_ngram_set([x.prompt for x in gsm8k_train_ds.instances], n=5)
    print("Built GSM8K train 5-gram set size:", len(ngset_gsm8k_5))
else:
    print("GSM8K train not available; cannot build n-gram baseline.")



## 16.2 PPL / Min-K% baselines (likelihood features)
These require a forward pass on each question. Start with `BASELINE_MAX_ITEMS=256`.


In [None]:

def compute_token_logprobs(tok, model, text: str, max_length: int = 512) -> np.ndarray:
    import torch
    enc = tok(text, return_tensors="pt", truncation=True, max_length=max_length)
    input_ids = enc["input_ids"].to(model.device)
    attn = enc.get("attention_mask", None)
    if attn is not None:
        attn = attn.to(model.device)

    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attn)
        logits = out.logits

    shift_logits = logits[:, :-1, :]
    shift_labels = input_ids[:, 1:]
    log_probs = torch.log_softmax(shift_logits, dim=-1)
    token_lp = log_probs.gather(-1, shift_labels.unsqueeze(-1)).squeeze(-1)
    return token_lp.squeeze(0).detach().cpu().numpy()

def perplexity_from_logprobs(token_logprobs: np.ndarray) -> float:
    nll = -float(np.mean(token_logprobs))
    return float(np.exp(nll))

def min_k_avg_logprob(token_logprobs: np.ndarray, k_frac: float = 0.2) -> float:
    L = len(token_logprobs)
    if L == 0:
        return float("nan")
    k = max(1, int(round(k_frac * L)))
    idx = np.argsort(token_logprobs)[:k]
    return float(np.mean(token_logprobs[idx]))



## 16.3 Compute baseline AUC (requires `leaked_registry` from LoRA section)
If you haven't run LoRA yet, this cell will safely skip.


In [None]:

BASELINE_MAX_ITEMS = 256
K_FRAC = 0.2

if "leaked_registry" in globals() and gsm8k_test_ds is not None and "eval_pool_idx" in globals():
    # Choose a rho to evaluate (pick 1% if present; else highest rho)
    rho_star = 0.01 if 0.01 in leaked_registry else sorted(leaked_registry.keys())[-1]
    contam_ids = leaked_registry[rho_star]

    pool_idx = eval_pool_idx[:BASELINE_MAX_ITEMS]
    y = np.array([1 if gsm8k_test_ds.instances[i].instance_id in contam_ids else 0 for i in pool_idx], dtype=int)

    # N-gram overlap AUC (higher overlap => more likely contaminated)
    if ngset_gsm8k_5 is not None:
        ng_scores = np.array([ngram_overlap_score(gsm8k_test_ds.instances[i].prompt, ngset_gsm8k_5, n=5) for i in pool_idx], dtype=float)
        print(f"AUC(N-gram overlap, rho={rho_star}):", auc_from_scores(ng_scores, y))
    else:
        print("No n-gram set; skipping overlap baseline.")

    # Likelihood baselines under the suspect model
    adapter_dir = os.path.join(LORA_DIR, f"{safe_name(LORA_BASE_MODEL)}__rho{rho_star}")
    if rho_star == 0.0:
        adapter_dir = None

    try:
        tok, model = load_base_or_lora(LORA_BASE_MODEL, adapter_dir)
        ppl_scores = []
        mink_scores = []
        for i in pool_idx:
            q = gsm8k_test_ds.instances[i].prompt
            lp = compute_token_logprobs(tok, model, q, max_length=512)
            ppl_scores.append(-perplexity_from_logprobs(lp))  # negative PPL => higher means more memorized
            mink_scores.append(min_k_avg_logprob(lp, k_frac=K_FRAC))

        ppl_scores = np.array(ppl_scores, dtype=float)
        mink_scores = np.array(mink_scores, dtype=float)

        print(f"AUC(-PPL, rho={rho_star}):", auc_from_scores(ppl_scores, y))
        print(f"AUC(Min-K avg logprob, rho={rho_star}):", auc_from_scores(mink_scores, y))
    except Exception as e:
        print("Likelihood baselines failed (missing deps/GPU?):", e)
else:
    print("LoRA artifacts not found yet. Run section 15 first, then rerun this cell.")



# 17) ICML Upgrade: Statistical Power Analysis (justify `N_EVAL`)
Reviewer question: “why `N_EVAL=400`?”

Conservative distribution-free planning (Hoeffding-style):
\[
n \ge \frac{2}{\Delta^2}\log\Big(\frac{1}{\min(\alpha,\beta)}\Big),
\]
where \(\Delta = p_1 - p_0\), \(\alpha\) is significance, and \(1-\beta\) is power.

We provide:
- a planning table for typical \(\Delta\),
- an empirical estimate of \(\Delta(\rho)\) from the LoRA curve (if available).


In [None]:

def n_required_hoeffding(delta: float, alpha: float = 0.05, beta: float = 0.05) -> int:
    if delta <= 0:
        return math.inf
    return int(math.ceil((2.0 / (delta**2)) * math.log(1.0 / min(alpha, beta))))

print("Planning table (alpha=beta=0.05):")
for d in [0.01, 0.02, 0.03, 0.05, 0.08, 0.10]:
    print(f"delta={d:.2f} -> n_required={n_required_hoeffding(d)}")
print("Current N_EVAL:", N_EVAL)



## 17.1 Empirical power from LoRA curve (if available)


In [None]:

if "df_lora_curve" in globals():
    df = df_lora_curve.copy()
    p0 = float(df[df["rho"]==0.0]["acc"].mean())
    rows = []
    for rho in sorted(df["rho"].unique()):
        p1 = float(df[df["rho"]==rho]["acc"].mean())
        delta = p1 - p0
        rows.append({"rho": rho, "p0": p0, "p1": p1, "delta": delta, "n_required(alpha=beta=0.05)": n_required_hoeffding(delta, 0.05, 0.05)})
    display(pd.DataFrame(rows))
    print("Current N_EVAL:", N_EVAL)
else:
    print("No LoRA curve yet. Run section 15 to estimate delta(rho) empirically.")
