<a href="https://colab.research.google.com/github/hshen13/debias_tta/blob/main/OOD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# 0) Install
# ============================================
!pip -q install -U "transformers>=4.51" "datasets>=2.20" "accelerate>=0.33" scikit-learn bitsandbytes

import random
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print("cuda:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
[?25hcuda: True
gpu: NVIDIA A100-SXM4-80GB


In [None]:
# ============================================
# 1) Load Qwen3 (choose Base for "pretrain ID", Instruct for "policy ID")
# ============================================
QWEN_ID = "Qwen/Qwen3-4B-Base"   # pretrain-distribution ID anchor
# QWEN_ID = "Qwen/Qwen3-4B-Instruct-2507"  # policy/assistant-style anchor

MAX_LEN = 256
BATCH_SIZE = 16

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

tok = AutoTokenizer.from_pretrained(QWEN_ID, trust_remote_code=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

qwen = AutoModelForCausalLM.from_pretrained(
    QWEN_ID,
    quantization_config=bnb,
    device_map="auto",
    trust_remote_code=True,
).eval()

print("Loaded:", QWEN_ID)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loaded: Qwen/Qwen3-4B-Base


In [None]:
# ============================================
# 2) (Optional but recommended) background LM for Likelihood Ratio (LLR)
#    LLR helps reduce "likelihood trap" by subtracting a generic model's NLL.
# ============================================
BG_ID = "gpt2"   # small & fast baseline; you can swap to other small LM

bg_tok = AutoTokenizer.from_pretrained(BG_ID)
if bg_tok.pad_token is None:
    bg_tok.pad_token = bg_tok.eos_token

bg = AutoModelForCausalLM.from_pretrained(BG_ID).to(qwen.device).eval()
print("Loaded BG:", BG_ID)


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Loaded BG: gpt2


In [None]:
# ============================================
# 3) Load RealToxicityPrompts
# ============================================
rtp = load_dataset("allenai/real-toxicity-prompts", split="train")
# dataset card notes: prompts are from OpenWebText; toxicity scored by Perspective API
# and stratified across toxicity ranges. (We just use the provided fields.)
# (See HF dataset card for details.)
print("RTP rows:", len(rtp))


README.md: 0.00B [00:00, ?B/s]

prompts.jsonl:   0%|          | 0.00/67.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/99442 [00:00<?, ? examples/s]

RTP rows: 99442


In [None]:
# # ============================================
# # 4) Choose ID set (two modes)
# #   - pretrain: use general web/wiki text as ID
# #   - policy  : use instruction/chat user prompts as ID
# #
# # NOTE: These are proxies: "correct ID" is what *your deployment* expects.
# # ============================================
# ID_MODE = "pretrain"     # change to "pretrain" to test the other definition

# N_ID  = 500           # increase for more stable threshold
# N_RTP = 2000          # how many RTP prompts to test (full 100k is slower)

# def sample_id_texts(mode, n):
#     if mode == "pretrain":
#         # Proxy for general text distribution: Wikipedia / wikitext.
#         ds = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
#         # take non-empty lines
#         pool = [x["text"] for x in ds if x["text"].strip()]
#         return random.sample(pool, n)
#     elif mode == "policy":
#         # Proxy for "normal user prompts": ultrachat user messages.
#         # This is a common open instruction/chat corpus.
#         ds = load_dataset("HuggingFaceH4/ultrachat_200k", "default", split="train_sft")
#         # each row has 'prompt' & 'response' in many versions; robustly pick prompt-like field
#         # If schema differs, print(ds.features) and adjust.
#         if "prompt" in ds.features:
#             pool = [x["prompt"] for x in ds if isinstance(x["prompt"], str) and x["prompt"].strip()]
#         elif "messages" in ds.features:
#             # fallback: take first user message
#             pool = []
#             for x in ds:
#                 msgs = x["messages"]
#                 if isinstance(msgs, list):
#                     for m in msgs:
#                         if m.get("role") == "user" and isinstance(m.get("content"), str) and m["content"].strip():
#                             pool.append(m["content"])
#                             break
#         else:
#             raise ValueError("UltraChat schema not recognized; inspect ds.features.")
#         return random.sample(pool, n)
#     else:
#         raise ValueError("ID_MODE must be 'pretrain' or 'policy'.")

# id_texts = sample_id_texts(ID_MODE, N_ID)

# # RTP prompts (use the provided 'text' field; it's a snippet)
# rtp_idx = random.sample(range(len(rtp)), N_RTP)
# # --- 先看一下 schema（建议你跑一次） ---
# print(rtp.features)
# print(rtp[0].keys())
# print(rtp[0])

# def get_prompt_text_and_tox(row):
#     # Case 1: 新/旧版本可能直接顶层 text/toxicity
#     if "text" in row:
#         txt = row["text"]
#         tox = row.get("toxicity", None)
#         return txt, tox

#     # Case 2: 常见 schema：prompt 是 dict，里面有 text/toxicity
#     if "prompt" in row:
#         p = row["prompt"]
#         if isinstance(p, dict):
#             txt = p.get("text", None)
#             tox = p.get("toxicity", row.get("toxicity", None))
#             return txt, tox
#         # 少数情况 prompt 可能就是字符串
#         if isinstance(p, str):
#             txt = p
#             tox = row.get("toxicity", None)
#             return txt, tox

#     raise KeyError(f"Cannot find prompt text in row keys={list(row.keys())}")

# # 重新抽样
# rtp_idx = random.sample(range(len(rtp)), N_RTP)

# rtp_texts = []
# rtp_tox = []
# for i in rtp_idx:
#     txt, tox = get_prompt_text_and_tox(rtp[i])
#     if txt is None or tox is None:
#         continue
#     rtp_texts.append(txt)
#     rtp_tox.append(float(tox))

# rtp_tox = np.array(rtp_tox, dtype=np.float32)

# print("RTP usable:", len(rtp_texts), "tox range:", float(rtp_tox.min()), float(rtp_tox.max()))
# print("example:", rtp_texts[0], "tox:", rtp_tox[0])
# print("ID_MODE:", ID_MODE, "ID samples:", len(id_texts), "RTP samples:", len(rtp_texts))


In [None]:
# ============================================
# 4) Choose ID set: WritingPrompts style (ID) vs RealToxicityPrompts (test)
#   Question: "Is RTP OOD relative to WritingPrompts prompt style?"
# ============================================
from datasets import load_dataset
import random
import numpy as np

SEED = 42
random.seed(SEED)

# ---- sample sizes ----
N_ID  = 500    # WritingPrompts prompts as ID
N_RTP = 2000   # RTP prompts to test

# -----------------------------
# 4.1 Load WritingPrompts as ID
# -----------------------------
# HF dataset id: euclaise/writingprompts  (has prompt/story; we use prompt)
wp = load_dataset("euclaise/writingprompts", split="train")

print("WritingPrompts columns:", wp.column_names)
# Robust pick prompt column
if "prompt" in wp.column_names:
    wp_pool = [x for x in wp["prompt"] if isinstance(x, str) and x.strip()]
elif "text" in wp.column_names:
    wp_pool = [x for x in wp["text"] if isinstance(x, str) and x.strip()]
else:
    # fallback: find first string column
    wp_pool = []
    for c in wp.column_names:
        try:
            # try reading a few values and keep strings
            vals = wp[c][:1000]
            if isinstance(vals[0], str):
                wp_pool = [x for x in wp[c] if isinstance(x, str) and x.strip()]
                break
        except Exception:
            pass
    if not wp_pool:
        raise ValueError("Cannot find a usable text column in WritingPrompts dataset.")

if len(wp_pool) < N_ID:
    raise ValueError(f"WritingPrompts usable prompts={len(wp_pool)} < N_ID={N_ID}, reduce N_ID.")

id_texts = random.sample(wp_pool, N_ID)
print("ID sample:", id_texts[0][:200])

# ---------------------------------
# 4.2 Load RealToxicityPrompts (RTP)
# ---------------------------------
rtp = load_dataset("allenai/real-toxicity-prompts", split="train")
print("RTP features:", rtp.features)
print("RTP example keys:", rtp[0].keys())
print("RTP example:", rtp[0])

def get_rtp_prompt_text_and_tox(row):
    """
    RTP 在 HF 上常见结构：
      row["prompt"] 是 dict，里面有 "text" 和 "toxicity"
    也兼容少数版本：顶层直接有 "text"/"toxicity"
    """
    # Case A: top-level
    if "text" in row:
        txt = row.get("text", None)
        tox = row.get("toxicity", None)
        return txt, tox

    # Case B: nested prompt dict
    if "prompt" in row:
        p = row["prompt"]
        if isinstance(p, dict):
            txt = p.get("text", None)
            tox = p.get("toxicity", row.get("toxicity", None))
            return txt, tox
        if isinstance(p, str):
            txt = p
            tox = row.get("toxicity", None)
            return txt, tox

    raise KeyError(f"Cannot find RTP prompt text in row keys={list(row.keys())}")

# 多抽一些以防过滤后不足
rtp_idx = random.sample(range(len(rtp)), min(len(rtp), N_RTP * 3))

rtp_texts, rtp_tox = [], []
for i in rtp_idx:
    txt, tox = get_rtp_prompt_text_and_tox(rtp[i])
    if isinstance(txt, str) and txt.strip() and tox is not None:
        rtp_texts.append(txt)
        rtp_tox.append(float(tox))
    if len(rtp_texts) >= N_RTP:
        break

if len(rtp_texts) < N_RTP:
    print(f"Warning: only got {len(rtp_texts)} RTP prompts after filtering (requested {N_RTP}).")

rtp_tox = np.array(rtp_tox, dtype=np.float32)

print("RTP usable:", len(rtp_texts),
      "| tox range:", float(rtp_tox.min()), float(rtp_tox.max()))
print("RTP sample:", rtp_texts[0][:200], "| tox:", float(rtp_tox[0]))

print("Final -> ID(WritingPrompts):", len(id_texts), "RTP:", len(rtp_texts))


README.md:   0%|          | 0.00/837 [00:00<?, ?B/s]

data/train-00000-of-00002-105e07cb0d1994(…):   0%|          | 0.00/272M [00:00<?, ?B/s]

data/train-00001-of-00002-4fdb982c110564(…):   0%|          | 0.00/272M [00:00<?, ?B/s]

data/test-00000-of-00001-16503b0c26ed00c(…):   0%|          | 0.00/30.0M [00:00<?, ?B/s]

data/validation-00000-of-00001-137b93e1e(…):   0%|          | 0.00/30.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/272600 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/15138 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/15620 [00:00<?, ? examples/s]

WritingPrompts columns: ['prompt', 'story']
ID sample: [ WP ] A massive wall of light , a few feet thick and a few hundred feet high , stretches from the north to south pole . This wall slowly circumnavigates the globe in 1 year . No one knows where it ca
RTP features: {'filename': Value('string'), 'begin': Value('int64'), 'end': Value('int64'), 'challenging': Value('bool'), 'prompt': {'text': Value('string'), 'toxicity': Value('float64'), 'profanity': Value('float64'), 'sexually_explicit': Value('float64'), 'flirtation': Value('float64'), 'identity_attack': Value('float64'), 'threat': Value('float64'), 'insult': Value('float64'), 'severe_toxicity': Value('float64')}, 'continuation': {'text': Value('string'), 'profanity': Value('float64'), 'sexually_explicit': Value('float64'), 'identity_attack': Value('float64'), 'flirtation': Value('float64'), 'threat': Value('float64'), 'insult': Value('float64'), 'severe_toxicity': Value('float64'), 'toxicity': Value('float64')}}
RTP example keys: 

In [None]:
# ============================================
# 5) Scoring: per-token NLL and Likelihood Ratio (LLR)
#   score_NLL(x)  = NLL_qwen(x)
#   score_LLR(x)  = NLL_qwen(x) - NLL_bg(x)
#
# Larger score => more "surprising" => more OOD-like (training-agnostic likelihood-based)
# ============================================
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader

@torch.no_grad()
def avg_nll(model, tokenizer, texts, batch_size=8, max_len=256):
    """
    返回每条文本的 per-token 平均 NLL（忽略 padding）
    score 越大 => 越“意外/不匹配” => 越 OOD-like
    """
    out_scores = []
    loader = DataLoader(texts, batch_size=batch_size, shuffle=False)

    for batch in loader:
        enc = tokenizer(
            list(batch),
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len,
        )
        input_ids = enc["input_ids"].to(model.device)
        attn = enc["attention_mask"].to(model.device)

        # labels: padding 位置设为 -100，后面用 ignore_index 跳过
        labels = input_ids.clone()
        labels[attn == 0] = -100

        logits = model(input_ids=input_ids, attention_mask=attn).logits  # [B,L,V]

        # shift for next-token prediction
        shift_logits = logits[:, :-1, :].contiguous()   # [B,L-1,V]
        shift_labels = labels[:, 1:].contiguous()       # [B,L-1]

        # token-level CE, ignore padding (-100)
        # reduction='none' -> [B*(L-1)]
        ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
            reduction="none",
        ).view(shift_labels.size())  # [B,L-1]

        # per-sample average over non-ignored tokens
        mask = (shift_labels != -100)
        token_counts = mask.sum(dim=1).clamp_min(1)  # avoid div by 0
        sample_loss = (ce * mask).sum(dim=1) / token_counts

        out_scores.extend(sample_loss.detach().float().cpu().tolist())

    return np.array(out_scores, dtype=np.float32)


def llr_scores(texts, qwen, tok, bg, bg_tok, batch_size=8, max_len=256):
    nll_q = avg_nll(qwen, tok, texts, batch_size=batch_size, max_len=max_len)
    nll_b = avg_nll(bg, bg_tok, texts, batch_size=batch_size, max_len=max_len)
    return nll_q - nll_b

# Choose which score you want to use:

USE_LLR = True
if USE_LLR:
    id_scores  = llr_scores(id_texts, qwen, tok, bg, bg_tok, batch_size=BATCH_SIZE, max_len=MAX_LEN)
    rtp_scores = llr_scores(rtp_texts, qwen, tok, bg, bg_tok, batch_size=BATCH_SIZE, max_len=MAX_LEN)
else:
    id_scores  = avg_nll(qwen, tok, id_texts, batch_size=BATCH_SIZE, max_len=MAX_LEN)
    rtp_scores = avg_nll(qwen, tok, rtp_texts, batch_size=BATCH_SIZE, max_len=MAX_LEN)
# drop NaNs
id_scores = id_scores[~np.isnan(id_scores)]
rtp_scores = rtp_scores[~np.isnan(rtp_scores)]
print("score stats | ID mean:", float(id_scores.mean()), "RTP mean:", float(rtp_scores.mean()))


score stats | ID mean: -0.5676233768463135 RTP mean: -0.05840018391609192


In [None]:
# ============================================
# 6) Calibrate threshold on ID, then "detect" RTP OOD rate
#   Example: set threshold at 95th percentile of ID => ~5% ID false positives
# ============================================
q = 0.95
tau = float(np.quantile(id_scores, q))
is_ood = (rtp_scores > tau)

print(f"Threshold tau = ID {int(q*100)}th percentile = {tau:.4f}")
print(f"RTP OOD rate  = {is_ood.mean():.4f}  (fraction of RTP flagged as OOD)")


Threshold tau = ID 95th percentile = 0.3264
RTP OOD rate  = 0.3160  (fraction of RTP flagged as OOD)


In [None]:
# ============================================
# 7) Breakdown by RTP toxicity bins (0-0.25, 0.25-0.5, 0.5-0.75, 0.75-1)
#   RTP was stratified across toxicity ranges per dataset card.
# ============================================
bins = [0.0, 0.25, 0.5, 0.75, 1.0000001]
labels = ["[0,0.25)", "[0.25,0.5)", "[0.5,0.75)", "[0.75,1]"]

# Align lengths if any NaNs were dropped (rare)
m = min(len(rtp_tox), len(is_ood))
tox = rtp_tox[:m]
ood = is_ood[:m]

for i in range(4):
    lo, hi = bins[i], bins[i+1]
    idx = (tox >= lo) & (tox < hi)
    if idx.sum() == 0:
        continue
    print(f"{labels[i]}  count={int(idx.sum())}  OOD_rate={ood[idx].mean():.4f}  mean_score={rtp_scores[:m][idx].mean():.4f}")


[0,0.25)  count=1221  OOD_rate=0.3276  mean_score=-0.0474
[0.25,0.5)  count=338  OOD_rate=0.2959  mean_score=-0.1046
[0.5,0.75)  count=224  OOD_rate=0.2679  mean_score=-0.0848
[0.75,1]  count=217  OOD_rate=0.3318  mean_score=-0.0212


In [None]:
# =========================================================
# 0) Install
# =========================================================
!pip -q install -U "transformers>=4.51" "datasets>=2.20" "accelerate>=0.33" scikit-learn bitsandbytes pandas tqdm

import os, random, math
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
from datasets import load_dataset
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve
from sklearn.neighbors import NearestNeighbors

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print("cuda:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m145.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m42.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m132.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.3 which is incompatible.[0m[

In [None]:
# =========================================================
# 1) Load Qwen3 + Background LM (for LLR)
# =========================================================
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

QWEN_ID = "Qwen/Qwen3-4B-Base"   # 你也可以换成 Instruct，但概念就偏 policy 分布
BG_ID   = "gpt2"                # 背景LM：小而快，做 LLR 用

MAX_LEN = 256
BATCH_SIZE = 8

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

tok = AutoTokenizer.from_pretrained(QWEN_ID, trust_remote_code=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

qwen = AutoModelForCausalLM.from_pretrained(
    QWEN_ID,
    quantization_config=bnb,
    device_map="auto",
    trust_remote_code=True,
).eval()

bg_tok = AutoTokenizer.from_pretrained(BG_ID)
if bg_tok.pad_token is None:
    bg_tok.pad_token = bg_tok.eos_token

bg = AutoModelForCausalLM.from_pretrained(BG_ID).to(qwen.device).eval()

print("Loaded Qwen:", QWEN_ID)
print("Loaded BG  :", BG_ID)


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Loaded Qwen: Qwen/Qwen3-4B-Base
Loaded BG  : gpt2


In [None]:
# =========================================================
# 2) Load datasets & build ID/OOD splits
#   ID: WritingPrompts prompts
#   OOD: RealToxicityPrompts prompts (RTP)
# =========================================================

# -------- parameters (adjust for speed) --------
N_ID_TRAIN = 800     # for fitting threshold / density stats
N_ID_TEST  = 800     # for estimating realized FPR
N_OOD      = 2000    # RTP size for evaluation

# --- WritingPrompts as ID ---
wp = load_dataset("euclaise/writingprompts", split="train")  # has prompt/story
print("WritingPrompts columns:", wp.column_names)

if "prompt" in wp.column_names:
    wp_pool = [x for x in wp["prompt"] if isinstance(x, str) and x.strip()]
elif "text" in wp.column_names:
    wp_pool = [x for x in wp["text"] if isinstance(x, str) and x.strip()]
else:
    raise ValueError("WritingPrompts: cannot find prompt/text column.")

random.shuffle(wp_pool)
id_train_texts = wp_pool[:N_ID_TRAIN]
id_test_texts  = wp_pool[N_ID_TRAIN:N_ID_TRAIN+N_ID_TEST]
print("ID train:", len(id_train_texts), "ID test:", len(id_test_texts))

# --- RTP as OOD ---
rtp = load_dataset("allenai/real-toxicity-prompts", split="train")
print("RTP keys:", rtp[0].keys())
print("RTP features:", rtp.features)

def get_rtp_prompt_text_and_tox(row):
    # top-level schema
    if "text" in row:
        return row.get("text", None), row.get("toxicity", None)
    # nested prompt dict schema (common)
    if "prompt" in row:
        p = row["prompt"]
        if isinstance(p, dict):
            return p.get("text", None), p.get("toxicity", row.get("toxicity", None))
        if isinstance(p, str):
            return p, row.get("toxicity", None)
    raise KeyError(f"Cannot find RTP prompt text in row keys={list(row.keys())}")

rtp_texts, rtp_tox = [], []
# oversample then filter
for i in random.sample(range(len(rtp)), min(len(rtp), N_OOD * 5)):
    txt, tox = get_rtp_prompt_text_and_tox(rtp[i])
    if isinstance(txt, str) and txt.strip() and tox is not None:
        rtp_texts.append(txt)
        rtp_tox.append(float(tox))
    if len(rtp_texts) >= N_OOD:
        break

rtp_tox = np.array(rtp_tox, dtype=np.float32)
print("OOD (RTP):", len(rtp_texts), "tox range:", float(rtp_tox.min()), float(rtp_tox.max()))
print("Example ID:", id_train_texts[0][:140])
print("Example OOD:", rtp_texts[0][:140], "tox:", rtp_tox[0])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/837 [00:00<?, ?B/s]

data/train-00000-of-00002-105e07cb0d1994(…):   0%|          | 0.00/272M [00:00<?, ?B/s]

data/train-00001-of-00002-4fdb982c110564(…):   0%|          | 0.00/272M [00:00<?, ?B/s]

data/test-00000-of-00001-16503b0c26ed00c(…):   0%|          | 0.00/30.0M [00:00<?, ?B/s]

data/validation-00000-of-00001-137b93e1e(…):   0%|          | 0.00/30.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/272600 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/15138 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/15620 [00:00<?, ? examples/s]

WritingPrompts columns: ['prompt', 'story']
ID train: 800 ID test: 800


README.md: 0.00B [00:00, ?B/s]

prompts.jsonl:   0%|          | 0.00/67.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/99442 [00:00<?, ? examples/s]

RTP keys: dict_keys(['filename', 'begin', 'end', 'challenging', 'prompt', 'continuation'])
RTP features: {'filename': Value('string'), 'begin': Value('int64'), 'end': Value('int64'), 'challenging': Value('bool'), 'prompt': {'text': Value('string'), 'toxicity': Value('float64'), 'profanity': Value('float64'), 'sexually_explicit': Value('float64'), 'flirtation': Value('float64'), 'identity_attack': Value('float64'), 'threat': Value('float64'), 'insult': Value('float64'), 'severe_toxicity': Value('float64')}, 'continuation': {'text': Value('string'), 'profanity': Value('float64'), 'sexually_explicit': Value('float64'), 'identity_attack': Value('float64'), 'flirtation': Value('float64'), 'threat': Value('float64'), 'insult': Value('float64'), 'severe_toxicity': Value('float64'), 'toxicity': Value('float64')}}
OOD (RTP): 2000 tox range: 0.005985789000988007 0.9868122935295105
Example ID: [ WP ] A man in Australia gets in a car wreck . A woman in Seattle wins the lottery . Connect these two 

In [None]:
# =========================================================
# 3) Scoring functions (training-agnostic OOD detectors)
#   - NLL
#   - LLR (NLL_qwen - NLL_bg)
#   - Energy: mean(-logsumexp(logits))
#   - Entropy: mean entropy of next-token distribution
#   - MaxProb: -mean(max prob)  (convert to "higher = more OOD")
# =========================================================
from torch.utils.data import DataLoader

@torch.no_grad()
def _batched_forward(model, tokenizer, texts, batch_size=8, max_len=256, need_hidden=False):
    loader = DataLoader(texts, batch_size=batch_size, shuffle=False)
    for batch in loader:
        enc = tokenizer(list(batch), return_tensors="pt", padding=True, truncation=True, max_length=max_len)
        input_ids = enc["input_ids"].to(model.device)
        attn = enc["attention_mask"].to(model.device)
        outputs = model(
            input_ids=input_ids,
            attention_mask=attn,
            output_hidden_states=need_hidden,
            use_cache=False
        )
        yield input_ids, attn, outputs

@torch.no_grad()
def score_nll(model, tokenizer, texts, batch_size=8, max_len=256):
    scores = []
    for input_ids, attn, outputs in _batched_forward(model, tokenizer, texts, batch_size, max_len, need_hidden=False):
        labels = input_ids.clone()
        labels[attn == 0] = -100
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        ce = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            ignore_index=-100,
            reduction="none",
        ).view(shift_labels.size())

        mask = (shift_labels != -100)
        denom = mask.sum(dim=1).clamp_min(1)
        per_sample = (ce * mask).sum(dim=1) / denom
        scores.extend(per_sample.float().cpu().tolist())
    return np.array(scores, dtype=np.float32)

@torch.no_grad()
def score_energy(model, tokenizer, texts, batch_size=8, max_len=256):
    # energy_t = -logsumexp(logits_t)
    scores = []
    for input_ids, attn, outputs in _batched_forward(model, tokenizer, texts, batch_size, max_len, need_hidden=False):
        labels = input_ids.clone()
        labels[attn == 0] = -100
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        mask = (shift_labels != -100)

        energy_tok = -torch.logsumexp(shift_logits, dim=-1)  # [B,L-1]
        denom = mask.sum(dim=1).clamp_min(1)
        per_sample = (energy_tok * mask).sum(dim=1) / denom
        # higher energy -> more OOD-like
        scores.extend(per_sample.float().cpu().tolist())
    return np.array(scores, dtype=np.float32)

@torch.no_grad()
def score_entropy(model, tokenizer, texts, batch_size=8, max_len=256):
    scores = []
    for input_ids, attn, outputs in _batched_forward(model, tokenizer, texts, batch_size, max_len, need_hidden=False):
        labels = input_ids.clone()
        labels[attn == 0] = -100
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        mask = (shift_labels != -100)

        logp = torch.log_softmax(shift_logits, dim=-1)
        p = torch.exp(logp)
        ent = -(p * logp).sum(dim=-1)  # [B,L-1]
        denom = mask.sum(dim=1).clamp_min(1)
        per_sample = (ent * mask).sum(dim=1) / denom
        scores.extend(per_sample.float().cpu().tolist())
    return np.array(scores, dtype=np.float32)

@torch.no_grad()
def score_neg_maxprob(model, tokenizer, texts, batch_size=8, max_len=256):
    # -mean(max prob): higher => more OOD-like
    scores = []
    for input_ids, attn, outputs in _batched_forward(model, tokenizer, texts, batch_size, max_len, need_hidden=False):
        labels = input_ids.clone()
        labels[attn == 0] = -100
        logits = outputs.logits

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()
        mask = (shift_labels != -100)

        p = torch.softmax(shift_logits, dim=-1)
        maxp = p.max(dim=-1).values  # [B,L-1]
        denom = mask.sum(dim=1).clamp_min(1)
        per_sample = -(maxp * mask).sum(dim=1) / denom
        scores.extend(per_sample.float().cpu().tolist())
    return np.array(scores, dtype=np.float32)

def score_llr(texts):
    nll_q = score_nll(qwen, tok, texts, batch_size=BATCH_SIZE, max_len=MAX_LEN)
    nll_b = score_nll(bg, bg_tok, texts, batch_size=BATCH_SIZE, max_len=MAX_LEN)
    return nll_q - nll_b


In [None]:
# =========================================================
# 4) Embedding-based detectors (Mahalanobis / kNN)
#   Representation: mean of last-layer hidden states over non-pad tokens
# =========================================================
@torch.no_grad()
def extract_reps(model, tokenizer, texts, batch_size=8, max_len=256):
    reps = []
    for input_ids, attn, outputs in _batched_forward(model, tokenizer, texts, batch_size, max_len, need_hidden=True):
        # last hidden: [B,L,H]
        h = outputs.hidden_states[-1]
        mask = attn.unsqueeze(-1).float()  # [B,L,1]
        denom = mask.sum(dim=1).clamp_min(1.0)
        rep = (h * mask).sum(dim=1) / denom  # [B,H]
        reps.append(rep.float().cpu())
    return torch.cat(reps, dim=0).numpy()

def fit_mahalanobis(id_reps, eps=1e-5):
    mu = id_reps.mean(axis=0, keepdims=True)  # [1,H]
    X = id_reps - mu
    # covariance + ridge
    cov = (X.T @ X) / max(1, (id_reps.shape[0] - 1))
    cov = cov + eps * np.eye(cov.shape[0])
    inv = np.linalg.inv(cov)
    return mu, inv

def score_mahalanobis(reps, mu, inv):
    X = reps - mu
    # sqrt( (x-mu)^T inv (x-mu) )
    d2 = np.einsum("bi,ij,bj->b", X, inv, X)
    return np.sqrt(np.maximum(d2, 0.0)).astype(np.float32)

def fit_knn(id_reps, k=10):
    nn = NearestNeighbors(n_neighbors=k, algorithm="auto").fit(id_reps)
    return nn

def score_knn(reps, nn):
    dists, _ = nn.kneighbors(reps, return_distance=True)
    return dists.mean(axis=1).astype(np.float32)


In [None]:
# =========================================================
# 5) Metrics: AUROC / AUPR / FPR@95TPR + OOD rate @ ID-quantile threshold
# =========================================================
def fpr_at_95_tpr(y_true, scores):
    fpr, tpr, _ = roc_curve(y_true, scores)
    idx = np.where(tpr >= 0.95)[0]
    if len(idx) == 0:
        return float("nan")
    return float(np.min(fpr[idx]))

def eval_detector(scores_id_test, scores_ood, scores_id_train_for_tau, tau_q=0.95):
    # labels: ID=0, OOD=1
    y = np.concatenate([np.zeros_like(scores_id_test), np.ones_like(scores_ood)])
    s = np.concatenate([scores_id_test, scores_ood])

    auroc = roc_auc_score(y, s)
    aupr  = average_precision_score(y, s)
    fpr95 = fpr_at_95_tpr(y, s)

    tau = float(np.quantile(scores_id_train_for_tau, tau_q))
    fpr_real = float((scores_id_test > tau).mean())
    ood_rate = float((scores_ood > tau).mean())

    return {
        "AUROC": auroc,
        "AUPR": aupr,
        "FPR@95TPR": fpr95,
        f"tau(ID q={tau_q})": tau,
        "Realized_FPR_on_IDtest": fpr_real,
        "OOD_rate_on_RTP": ood_rate
    }


In [None]:
# =========================================================
# 6) Significance tests
#   (A) Bootstrap CI for AUROC (and AUPR if you want)
#   (B) Permutation test for score distribution gap (IDtest vs OOD)
# =========================================================
def bootstrap_ci_metric(y, s, metric_fn, B=300, alpha=0.05):
    n = len(y)
    vals = []
    rng = np.random.default_rng(SEED)
    for _ in range(B):
        idx = rng.integers(0, n, n)
        vals.append(metric_fn(y[idx], s[idx]))
    vals = np.sort(np.array(vals, dtype=np.float64))
    lo = float(np.quantile(vals, alpha/2))
    hi = float(np.quantile(vals, 1-alpha/2))
    return lo, hi

def permutation_test_gap(id_scores, ood_scores, stat="mean", B=2000):
    # H0: ID and OOD come from same distribution (exchangeable)
    rng = np.random.default_rng(SEED)
    x = np.array(id_scores, dtype=np.float64)
    y = np.array(ood_scores, dtype=np.float64)

    if stat == "mean":
        obs = y.mean() - x.mean()
        z = np.concatenate([x, y])
        n = len(x)
        count = 0
        for _ in range(B):
            rng.shuffle(z)
            x2 = z[:n]
            y2 = z[n:]
            if (y2.mean() - x2.mean()) >= obs:
                count += 1
        p = (count + 1) / (B + 1)
        return float(obs), float(p)
    else:
        raise ValueError("Only stat='mean' implemented.")


In [None]:
# =========================================================
# 7) Run all detectors + benchmark comparison table
# =========================================================
results = []

# --- 7.1 Likelihood / uncertainty detectors ---
detectors = {
    "NLL": lambda xs: score_nll(qwen, tok, xs, batch_size=BATCH_SIZE, max_len=MAX_LEN),
    "LLR": lambda xs: score_llr(xs),
    "Energy": lambda xs: score_energy(qwen, tok, xs, batch_size=BATCH_SIZE, max_len=MAX_LEN),
    "Entropy": lambda xs: score_entropy(qwen, tok, xs, batch_size=BATCH_SIZE, max_len=MAX_LEN),
    "NegMaxProb": lambda xs: score_neg_maxprob(qwen, tok, xs, batch_size=BATCH_SIZE, max_len=MAX_LEN),
}

cache = {}

for name, fn in detectors.items():
    print(f"\n=== Detector: {name} ===")
    id_tr = fn(id_train_texts)
    id_te = fn(id_test_texts)
    ood   = fn(rtp_texts)

    cache[name] = (id_tr, id_te, ood)

    r = eval_detector(id_te, ood, id_tr, tau_q=0.95)
    r["Detector"] = name

    # significance: bootstrap AUROC CI
    y = np.concatenate([np.zeros_like(id_te), np.ones_like(ood)])
    s = np.concatenate([id_te, ood])
    lo, hi = bootstrap_ci_metric(y, s, lambda yy, ss: roc_auc_score(yy, ss), B=300)
    r["AUROC_CI95"] = f"[{lo:.4f}, {hi:.4f}]"

    # permutation test on mean score gap
    gap, p = permutation_test_gap(id_te, ood, stat="mean", B=2000)
    r["MeanGap(OOD-ID)"] = gap
    r["PermTest_p"] = p

    results.append(r)

# --- 7.2 Embedding-based detectors (Mahalanobis / kNN) ---
print("\n=== Embeddings for Mahalanobis / kNN ===")
id_train_rep = extract_reps(qwen, tok, id_train_texts, batch_size=BATCH_SIZE, max_len=MAX_LEN)
id_test_rep  = extract_reps(qwen, tok, id_test_texts,  batch_size=BATCH_SIZE, max_len=MAX_LEN)
ood_rep      = extract_reps(qwen, tok, rtp_texts,      batch_size=BATCH_SIZE, max_len=MAX_LEN)

# Mahalanobis
mu, inv = fit_mahalanobis(id_train_rep, eps=1e-4)
id_tr_m = score_mahalanobis(id_train_rep, mu, inv)
id_te_m = score_mahalanobis(id_test_rep,  mu, inv)
ood_m   = score_mahalanobis(ood_rep,      mu, inv)

r = eval_detector(id_te_m, ood_m, id_tr_m, tau_q=0.95)
r["Detector"] = "Mahalanobis(last_hidden_mean)"

y = np.concatenate([np.zeros_like(id_te_m), np.ones_like(ood_m)])
s = np.concatenate([id_te_m, ood_m])
lo, hi = bootstrap_ci_metric(y, s, lambda yy, ss: roc_auc_score(yy, ss), B=300)
r["AUROC_CI95"] = f"[{lo:.4f}, {hi:.4f}]"
gap, p = permutation_test_gap(id_te_m, ood_m, stat="mean", B=2000)
r["MeanGap(OOD-ID)"] = gap
r["PermTest_p"] = p
results.append(r)

# kNN
nn = fit_knn(id_train_rep, k=10)
id_tr_k = score_knn(id_train_rep, nn)
id_te_k = score_knn(id_test_rep,  nn)
ood_k   = score_knn(ood_rep,      nn)

r = eval_detector(id_te_k, ood_k, id_tr_k, tau_q=0.95)
r["Detector"] = "kNNdist(k=10,last_hidden_mean)"

y = np.concatenate([np.zeros_like(id_te_k), np.ones_like(ood_k)])
s = np.concatenate([id_te_k, ood_k])
lo, hi = bootstrap_ci_metric(y, s, lambda yy, ss: roc_auc_score(yy, ss), B=300)
r["AUROC_CI95"] = f"[{lo:.4f}, {hi:.4f}]"
gap, p = permutation_test_gap(id_te_k, ood_k, stat="mean", B=2000)
r["MeanGap(OOD-ID)"] = gap
r["PermTest_p"] = p
results.append(r)

df = pd.DataFrame(results).sort_values("AUROC", ascending=False)
df



=== Detector: NLL ===

=== Detector: LLR ===

=== Detector: Energy ===

=== Detector: Entropy ===

=== Detector: NegMaxProb ===

=== Embeddings for Mahalanobis / kNN ===


Unnamed: 0,AUROC,AUPR,FPR@95TPR,tau(ID q=0.95),Realized_FPR_on_IDtest,OOD_rate_on_RTP,Detector,AUROC_CI95,MeanGap(OOD-ID),PermTest_p
6,0.992205,0.996157,0.03125,33.81263,0.32125,1.0,"kNNdist(k=10,last_hidden_mean)","[0.9888, 0.9951]",14.652404,0.0005
5,0.988055,0.994573,0.0575,28.242908,0.9925,1.0,Mahalanobis(last_hidden_mean),"[0.9839, 0.9917]",868.835344,0.0005
1,0.707419,0.866736,0.9675,0.130485,0.07625,0.416,LLR,"[0.6891, 0.7260]",0.479065,0.0005
2,0.592153,0.807329,0.98,-24.015625,0.06125,0.2185,Energy,"[0.5715, 0.6139]",0.263988,0.0005
0,0.577696,0.757807,0.96875,6.113281,0.06375,0.0845,NLL,"[0.5521, 0.5988]",0.2229,0.0005
3,0.572162,0.732814,0.92,5.699804,0.07875,0.05,Entropy,"[0.5478, 0.5980]",0.126816,0.0005
4,0.514572,0.711735,0.97625,-0.175037,0.0875,0.0565,NegMaxProb,"[0.4916, 0.5377]",-0.001789,0.687156


In [None]:
# =========================================================
# 8) Optional: Toxicity-bin breakdown for the BEST detector
# =========================================================
best = df.iloc[0]["Detector"]
print("Best detector:", best)

# retrieve scores for best detector
if best in cache:
    id_tr, id_te, ood = cache[best]
    scores = ood
else:
    # embedding ones
    if best.startswith("Mahalanobis"):
        id_tr, id_te, scores = id_tr_m, id_te_m, ood_m
    else:
        id_tr, id_te, scores = id_tr_k, id_te_k, ood_k

tau = float(np.quantile(id_tr, 0.95))
is_ood = scores > tau

bins = [0.0, 0.25, 0.5, 0.75, 1.0000001]
labels = ["[0,0.25)", "[0.25,0.5)", "[0.5,0.75)", "[0.75,1]"]

m = min(len(rtp_tox), len(is_ood))
tox = rtp_tox[:m]
ood_flag = is_ood[:m]
scores_m = scores[:m]

print(f"tau(ID 95%): {tau:.4f} | overall OOD_rate: {ood_flag.mean():.4f}")

for i in range(4):
    lo, hi = bins[i], bins[i+1]
    idx = (tox >= lo) & (tox < hi)
    if idx.sum() == 0:
        continue
    print(f"{labels[i]} count={int(idx.sum())} OOD_rate={ood_flag[idx].mean():.4f} mean_score={scores_m[idx].mean():.4f}")


Best detector: kNNdist(k=10,last_hidden_mean)
tau(ID 95%): 33.8126 | overall OOD_rate: 1.0000
[0,0.25) count=1229 OOD_rate=1.0000 mean_score=47.7064
[0.25,0.5) count=316 OOD_rate=1.0000 mean_score=47.4213
[0.5,0.75) count=238 OOD_rate=1.0000 mean_score=46.6864
[0.75,1] count=217 OOD_rate=1.0000 mean_score=45.4820
