In [None]:
import json
import math
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# -------------------------------
# 配置
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
input_file = "data/hendrycks_math_train.json"
acc_file = "output/deepseek-r1-1.5b-generated-predictions-detailed-results.jsonl"
output_file = "z_score/deepseek-r1-1.5b_ppl_conf_acc_z_scores_results.json"

# 模型统计量
model_stats = {
    "ppl_mean": 9.795982360839844,
    "ppl_std": 22.284496307373047,
    "conf_mean": 0.6799513101577759,
    "conf_std": 0.08082679659128189
}

# -------------------------------
# 加载模型和 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto").eval()

# -------------------------------
# 加载数据
with open(input_file, "r") as f:
    data = json.load(f)

with open(acc_file, "r") as f:
    acc_data = [json.loads(line) for line in f]

assert len(data) == len(acc_data), "样本数量不匹配"

# -------------------------------
# 计算 PPL 和 Confidence
def compute_ppl_and_conf(text):
    inputs = tokenizer(text, return_tensors="pt").to("mps")
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        logits = outputs.logits
        ppl = math.exp(loss.item())

        probs = torch.softmax(logits, dim=-1)
        max_probs = probs.max(dim=-1).values
        conf = max_probs[0, 1:-1].mean().item()  # exclude BOS & EOS

    return ppl, conf

# -------------------------------
# 批量处理
results = []
for sample, acc_record in tqdm(zip(data, acc_data), total=len(data)):
    input_text = sample["input"]
    acc = 1 if acc_record.get("accuracy", 0.0) >= 99.9 else 0


    try:
        ppl, conf = compute_ppl_and_conf(input_text)
        z_ppl = (ppl - model_stats["ppl_mean"]) / model_stats["ppl_std"]
        z_conf = (conf - model_stats["conf_mean"]) / model_stats["conf_std"]
        results.append({
            "z_ppl": z_ppl,
            "z_conf": z_conf,
            "acc": acc
        })
    except Exception as e:
        results.append({
            "error": str(e),
            "acc": acc
        })

# -------------------------------
# 保存结果
with open(output_file, "w") as f:
    json.dump(results, f, indent=2)

print(f"✅ 完成，已保存到: {output_file}")
