In [None]:
# bp_llm_ultrafeedback_eval.py
# Evaluate BP-LLM (unary BP classifier with JJ bound) win rate on UltraFeedback
# with Llama-3.2-3B policy priors.

import math
import dataclasses
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------
# Config
# -----------------------------

HF_TOKEN = " "  # add your HF Read token


In [None]:
# BCO with Llama-3.1-8B-Instruct (HelpSteer2)

# bp_llm_helpsteer2_eval.py
# HelpSteer2 BCO baseline (leakage-safe)
#
# What this does:
#   • Uses HelpSteer2 in either of two modes:
#       (A) preference mode (data_dir="preference"): map (response_1, response_2, preference_strength) -> (chosen, rejected)
#       (B) ratings mode (default split): build top-vs-bottom pairs by helpfulness, with a min gap.
#   • Splits into TRAIN/TEST (tune β and δ on TRAIN; report on TEST).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#   • TEST evaluation is fully label-free (BCO doesn’t need labels).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • Lower MIN_GAP to admit more pairs if you see too few examples (ratings mode).
#   • This file is BCO-only for apples-to-apples comparison with your BP-LLM script.

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, DefaultDict
from collections import defaultdict

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")  # set if using gated models

# Use an *Instruct* model for the policy; optional Base as reference
MODEL_NAME      = "meta-llama/Llama-3.2-3B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.2-3B"  # e.g., "meta-llama/Llama-3.1-8B" to enable reference subtraction

# ---- Dataset: HelpSteer2 ----
# Mode A (preferred): USE_PREFERENCE=True loads pairwise data from data_dir="preference"
# Mode B: USE_PREFERENCE=False builds pairs from ratings by helpfulness gap
DATASET             = "nvidia/HelpSteer2"
DATASET_CONFIG      = None           # keep None
USE_PREFERENCE      = True           # set False to use ratings mode
PREFERENCE_SPLIT    = "train"        # HS2 preference usually ships as one 'train' file
RATINGS_SPLIT       = "train[:20%]"  # iterate smaller first; use "train" later

# Preference-mode knobs
PREF_MIN_STRENGTH   = 1              # skip weak prefs; keep abs(strength) >= this
PREF_KEEP_SPLIT     = None           # or "train" / "validation" if you want to filter by the file's 'split' column

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering (applies to ratings-mode only)
MIN_GAP              = 1.0         # minimum helpfulness-score gap (0–4 scale) to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# HelpSteer2 helpers
# =============================================================================
def _safe_str(x) -> Optional[str]:
    return x if isinstance(x, str) and x.strip() else None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _get_int(value) -> Optional[int]:
    """Best-effort parse to int (accepts floats/strings); returns None if not possible."""
    try:
        if value is None:
            return None
        if isinstance(value, bool):
            return int(value)
        if isinstance(value, (int,)):
            return int(value)
        if isinstance(value, float):
            return int(round(value))
        # string
        s = str(value).strip()
        if s == "":
            return None
        return int(round(float(s)))
    except Exception:
        return None

def load_hs2_preference_pairs(
    min_strength: int = PREF_MIN_STRENGTH,
    keep_split: Optional[str] = PREF_KEEP_SPLIT
) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from HelpSteer2-preference by mapping
    response_1/response_2 + signed strength -> chosen/rejected.

    Positive strength => response_2 preferred
    Negative strength => response_1 preferred
    Zero or |strength| < min_strength => skip
    """
    ds = load_dataset(DATASET, data_dir="preference", split=PREFERENCE_SPLIT)

    out: List[Dict] = []
    for r in ds:
        sp = r.get("split")
        if keep_split is not None and sp != keep_split:
            continue

        p  = _safe_str(r.get("prompt") or r.get("instruction"))
        r1 = _safe_str(
            r.get("response_1")
            or r.get("response1")
            or r.get("resp1")
            or r.get("candidate_1")
            or r.get("output_1")
        )
        r2 = _safe_str(
            r.get("response_2")
            or r.get("response2")
            or r.get("resp2")
            or r.get("candidate_2")
            or r.get("output_2")
        )

        # robustly retrieve the signed strength
        s = (
            r.get("preference_strength")
            or r.get("preference")
            or r.get("label")
            or r.get("preference_score")
        )
        s = _get_int(s)

        if not (p and r1 and r2 and s is not None):
            continue
        if abs(s) < int(min_strength) or r1 == r2:
            continue

        if s > 0:
            chosen, rejected = r2, r1
        else:  # s < 0
            chosen, rejected = r1, r2

        out.append({"prompt": p, "chosen": chosen, "rejected": rejected})

    return out

def load_hs2_ratings_pairs(min_gap: float = MIN_GAP, split: str = RATINGS_SPLIT) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from the ratings split by taking the higher-vs-lower
    'helpfulness' response for each prompt, requiring a minimum rating gap.
    Typical fields: prompt, response, helpfulness (0..4), etc.
    """
    ds = load_dataset(DATASET, split=split) if DATASET_CONFIG is None else load_dataset(DATASET, name=DATASET_CONFIG, split=split)
    # Group up to 2 responses per prompt
    buckets: DefaultDict[str, List[Tuple[str, float]]] = defaultdict(list)
    for r in ds:
        p = _safe_str(r.get("prompt"))
        resp = _safe_str(r.get("response"))
        help_ = _to_float(r.get("helpfulness"))
        if p and resp and (help_ is not None):
            if len(buckets[p]) < 2:  # dataset typically has 2 responses/prompt
                buckets[p].append((resp, float(help_)))

    # Form pairs: top vs bottom by helpfulness, enforce min gap
    pairs: List[Dict] = []
    for p, lst in buckets.items():
        if len(lst) < 2:
            continue
        lst.sort(key=lambda t: t[1])  # low ... high
        lo_txt, lo_s = lst[0]
        hi_txt, hi_s = lst[-1]
        if (hi_s - lo_s) >= min_gap and (hi_txt != lo_txt):
            pairs.append({"prompt": p, "chosen": hi_txt, "rejected": lo_txt})
    return pairs


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt: str) -> Optional[str]:
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = [
            {"role": "user", "content": prompt.strip()},
            {"role": "assistant", "content": ""},
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt: str, response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for the concatenated prompt+response.
    Uses chat template when available.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        prompt_text = prompt.strip()
        full_text   = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (i.e., not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    # Keep only pairs where BOTH chosen and rejected are valid
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or switch to preference mode.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    # ---- Load pairs from HelpSteer2
    if USE_PREFERENCE:
        print("Loading HelpSteer2 (preference pairs) ...")
        adapted = load_hs2_preference_pairs(min_strength=PREF_MIN_STRENGTH, keep_split=PREF_KEEP_SPLIT)
        print(f"Preference pairs: {len(adapted)}")
    else:
        print(f"Loading HelpSteer2 (ratings split: {RATINGS_SPLIT}) and building pairs by helpfulness gap ≥ {MIN_GAP} ...")
        adapted = load_hs2_ratings_pairs(min_gap=MIN_GAP, split=RATINGS_SPLIT)
        print(f"Built rating-derived pairs: {len(adapted)}")

    if not adapted:
        raise RuntimeError("No valid examples. If using ratings mode, lower MIN_GAP or use USE_PREFERENCE=True.")

    # Optional peek
    try:
        ex = adapted[0]
        print("Example pair:",
              {k: (ex[k][:120] + "…") if isinstance(ex.get(k), str) else ex.get(k)
               for k in ("prompt", "chosen", "rejected")})
    except Exception:
        pass

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small dataset or large TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading HelpSteer2 (preference pairs) ...


README.md: 0.00B [00:00, ?B/s]

preference/preference.jsonl.gz:   0%|          | 0.00/15.3M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Preference pairs: 7117
Example pair: {'prompt': 'Define Signal Discuss its various properties with the help of diagram…', 'chosen': 'A signal is a message that is conveyed from a sender to a receiver through a communication channel. The message can be i…', 'rejected': 'A signal is a form of energy that is used to transmit information from one place to another. It can be in the form of so…'}
Split sizes: TRAIN=5693 | TEST=1424

Scoring TRAIN...
[TRAIN BCO baseline] WR=56.91% [55.6, 58.2] on 5693 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=56.91% [55.6, 58.2] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=56.91% [55.6, 58.2] with beta=0.5 delta=-0.1338

Scoring 

In [None]:
# BCO/BP-LLM with Llama-3.x (HelpSteer2)

# bp_llm_helpsteer2_eval.py
# Evaluate BP-LLM (unary JJ) win rate on nvidia/HelpSteer2.
#
# What this does:
#   • Uses HelpSteer2 in either of two modes:
#       (A) preference mode (data_dir="preference"): map (response_1, response_2, preference_strength) -> (chosen, rejected)
#       (B) ratings mode (default split): build top-vs-bottom pairs by helpfulness, with a min gap.
#   • Splits into TRAIN/TEST (tune β,δ,τ,γ, JJ on TRAIN; report on TEST with NO labels).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • Lower MIN_GAP to admit more pairs if you see too few examples (ratings mode).
#   • Increase PREF_MIN_STRENGTH to ignore very weak preferences (preference mode).

import os
import math
import itertools
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, DefaultDict
from collections import defaultdict

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
# Put your HF token if you use gated models
HF_TOKEN = os.environ.get("HF_TOKEN")

# Suggested: Instruct as policy, Base as reference
MODEL_NAME      = "meta-llama/Llama-3.2-3B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.2-3B"   # None to disable reference subtraction

# ---- Dataset: HelpSteer2 ----
DATASET             = "nvidia/HelpSteer2"
USE_PREFERENCE      = True            # True -> use built-in pairwise preference file
PREFERENCE_SPLIT    = "train"         # HF exposes a single file; it includes its own 'split' column
PREF_MIN_STRENGTH   = 1               # keep pairs with |preference_strength| >= this
PREF_KEEP_SPLIT     = None            # or "train" / "validation" to filter within the preference file

# Ratings mode fallback (set USE_PREFERENCE=False to use this)
RATINGS_SPLIT       = "train[:20%]"   # iterate small first; use "train" later
MIN_GAP             = 1.0             # helpfulness gap (0-4 scale) for ratings mode

# Train/Test split from the adapted pairs (stratification not required here)
TRAIN_FRAC          = 0.8
SEED                = 42

# Input limits
MAX_INPUT_TOKENS    = 1024
MAX_GEN_TOKENS      = 512
SCORE_MAX_GEN_TOKENS= 10000
BATCH_SIZE          = 4

# Chat template toggle (recommended for *Instruct*)
USE_CHAT_TEMPLATE   = True

# Length normalization: "mean" or "sum"
LENGTH_NORM         = "mean"

# BP hyperparams (defaults)
BETA            = 1.0
DELTA           = 0.0
TAU             = 1.0
GAMMA           = 1.0
JJ_INNER_STEPS  = 2

# Grid tuning (TRAIN)
DO_TUNE         = True
BETAS           = [0.5, 1.0, 2.0]
DELTAS          = ['bco', 0.0]        # try BCO shift and zero shift
TAUS            = [0.5, 1.0, 2.0]
GAMMAS          = [0.5, 1.0, 1.5, 2.0]
JJ_STEPS_LIST   = [1, 2, 3, 0]        # 0 => adaptive (tolerance-based)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# =========================
# HelpSteer2 pairing helpers
# =========================
def _safe_str(x) -> Optional[str]:
    return x if isinstance(x, str) and x.strip() else None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _to_int(x) -> Optional[int]:
    try:
        if x is None: return None
        if isinstance(x, bool): return int(x)
        if isinstance(x, int):  return x
        if isinstance(x, float): return int(round(x))
        s = str(x).strip()
        if not s: return None
        return int(round(float(s)))
    except Exception:
        return None

def load_hs2_preference_pairs(
    min_strength: int = PREF_MIN_STRENGTH,
    keep_split: Optional[str] = PREF_KEEP_SPLIT
) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from HelpSteer2 'preference' by mapping:
      response_1/response_2 + signed preference_strength -> chosen/rejected
    Positive strength => response_2 preferred; negative => response_1.
    """
    ds = load_dataset(DATASET, data_dir="preference", split=PREFERENCE_SPLIT)
    out: List[Dict] = []
    for r in ds:
        sp = r.get("split")
        if keep_split is not None and sp != keep_split:
            continue

        p  = _safe_str(r.get("prompt") or r.get("instruction"))
        r1 = _safe_str(r.get("response_1") or r.get("response1") or r.get("candidate_1") or r.get("output_1"))
        r2 = _safe_str(r.get("response_2") or r.get("response2") or r.get("candidate_2") or r.get("output_2"))
        s  = _to_int(r.get("preference_strength") or r.get("preference") or r.get("label") or r.get("preference_score"))

        if not (p and r1 and r2 and s is not None):  # require all fields
            continue
        if abs(s) < int(min_strength) or r1 == r2:   # skip weak prefs and degenerate pairs
            continue

        if s > 0:
            chosen, rejected = r2, r1
        else:  # s < 0
            chosen, rejected = r1, r2

        out.append({"prompt": p, "chosen": chosen, "rejected": rejected})
    return out

def load_hs2_ratings_pairs(min_gap: float = MIN_GAP, split: str = RATINGS_SPLIT) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from the ratings split by taking
    higher-vs-lower 'helpfulness' responses per prompt, requiring a min gap.
    Typical fields: prompt, response, helpfulness (0..4), etc.
    """
    ds = load_dataset(DATASET, split=split)
    buckets: DefaultDict[str, List[Tuple[str, float]]] = defaultdict(list)
    for r in ds:
        p    = _safe_str(r.get("prompt"))
        resp = _safe_str(r.get("response"))
        help_ = _to_float(r.get("helpfulness"))
        if p and resp and (help_ is not None):
            if len(buckets[p]) < 2:  # HS2 usually has exactly 2 responses per prompt
                buckets[p].append((resp, float(help_)))

    pairs: List[Dict] = []
    for p, lst in buckets.items():
        if len(lst) < 2:
            continue
        lst.sort(key=lambda t: t[1])   # low ... high
        lo_txt, lo_s = lst[0]
        hi_txt, hi_s = lst[-1]
        if (hi_s - lo_s) >= min_gap and (hi_txt != lo_txt):
            pairs.append({"prompt": p, "chosen": hi_txt, "rejected": lo_txt})
    return pairs


# =========================
# Tokenization / Scoring
# =========================
def format_prompt_with_chat_template(tokenizer, prompt: str) -> str:
    if USE_CHAT_TEMPLATE and hasattr(tokenizer, "apply_chat_template"):
        messages = [{"role": "user", "content": prompt}]
        try:
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True  # leaves assistant prefix for continuation
            )
        except Exception:
            pass
    return prompt

def concat_prompt_response_text(tokenizer, prompt: str, response: str) -> Tuple[str, str]:
    """Return (full_text, prompt_only_text) for token counting."""
    prompt_text = format_prompt_with_chat_template(tokenizer, prompt).strip()
    full_text = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()
    return full_text, prompt_text

@torch.no_grad()
def sequence_logprob_list(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
    length_norm: str = LENGTH_NORM,
) -> List[Optional[float]]:
    """
    Returns mean/sum log-prob for each (prompt, response). If response was truncated
    (i.e., prompt consumed the full sequence), returns None for that item.
    """
    vals: List[Optional[float]] = []
    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        batch_resps   = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(batch_prompts, batch_resps):
            full_text, prompt_text = concat_prompt_response_text(tokenizer, p, r)
            toks = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            tok_prompt = tokenizer(
                prompt_text,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = tok_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L     = int(masks.sum().item())
            if p_len >= L:
                vals.append(None)  # response totally truncated -> skip
                continue
            targ = ids[p_len:L]              # response tokens
            pred = logprobs[b, p_len-1:L-1]  # shifted predictions
            lp = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1)
            if length_norm == "mean":
                score = float(lp.mean().cpu())
            else:
                score = float(lp.sum().cpu())
            vals.append(score)
    return vals


# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =========================
# BP-LLM (unary JJ) posterior
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    if xi < 1e-8:
        return 1.0 / 8.0
    return math.tanh(xi / 2.0) / (4.0 * xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def bp_unary_posterior_adaptive(mu_prior: float, b: int, params: BPParams,
                                tol: float = 1e-4, max_steps: int = 50) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    steps = params.jj_steps if params.jj_steps and params.jj_steps > 0 else max_steps
    for _ in range(steps):
        mu_prev = mu_hat
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
        if (params.jj_steps is None or params.jj_steps <= 0) and abs(mu_hat - mu_prev) < tol:
            break
    return mu_hat, tau2_hat


# =========================
# Caching scores (pairwise)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    model,
    ref_model,
    tokenizer
) -> Tuple[torch.Tensor, int]:
    """
    Returns base_s tensor shaped [2N] where first N are chosen scores, next N are rejected,
    and N is the number of *valid* pairs (both sides scored).
    Each score is (log pi - log pi_ref) with length normalization.
    """
    prompts_ch  = [r["prompt"] for r in records]
    prompts_rj  = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]

    policy_ch = sequence_logprob_list(model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
    policy_rj = sequence_logprob_list(model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    if ref_model is None:
        ref_ch = [0.0 if (s is not None) else None for s in policy_ch]
        ref_rj = [0.0 if (s is not None) else None for s in policy_rj]
    else:
        ref_ch = sequence_logprob_list(ref_model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
        ref_rj = sequence_logprob_list(ref_model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    chosen_vals, rejected_vals = [], []
    for sc, sr, rc, rr in zip(policy_ch, policy_rj, ref_ch, ref_rj):
        if (sc is None) or (sr is None) or (rc is None) or (rr is None):
            continue  # drop pair if any side was truncated
        chosen_vals.append(sc - rc)
        rejected_vals.append(sr - rr)

    if len(chosen_vals) == 0:
        raise RuntimeError("No valid pairs after scoring. Consider increasing MAX_*_TOKENS or using 'sum' norm.")

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(chosen_vals)
    return base_s, N


# =========================
# Evaluation & tuning
# =========================
def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def evaluate_with_cached(
    base_s: torch.Tensor, N: int, params: BPParams,
    use_adaptive_jj: bool = True, use_labels: bool = True
) -> float:
    """
    If use_labels=False, compare mu_prior only (label-free test).
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()

    correct = 0
    for i in range(N):
        mu_w_prior = mu[i]
        mu_l_prior = mu[i + N]

        if not use_labels:
            # label-free: just compare priors
            if mu_w_prior > mu_l_prior:
                correct += 1
            continue

        if use_adaptive_jj:
            mu_w_post, _ = bp_unary_posterior_adaptive(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior_adaptive(mu_l_prior, b=0, params=params)
        else:
            mu_w_post, _ = bp_unary_posterior(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior(mu_l_prior, b=0, params=params)
        if mu_w_post > mu_l_post:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]],
    taus: Iterable[float], gammas: Iterable[float],
    jj_steps_list: Iterable[int], use_adaptive_jj: bool = True,
    estimate_delta: bool = True, split_name: str = "TRAIN"
):
    base_s, N = cache_pairwise_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)

    for beta, delta_spec, tau, gamma, jj_steps in itertools.product(
        betas, deltas, taus, gammas, jj_steps_list
    ):
        if estimate_delta and (delta_spec == 'bco'):
            delta = estimate_delta_bco(base_s, beta, N)
        else:
            delta = float(delta_spec)

        params = BPParams(beta=beta, delta=delta, tau=tau, gamma=gamma, jj_steps=jj_steps)
        # TRAIN uses labels:
        wr = evaluate_with_cached(base_s, N, params, use_adaptive_jj=use_adaptive_jj, use_labels=True)
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, params))
        if wr > best[0]:
            best = (wr, params)

        print(f"[TUNE/{split_name}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'} "
              f"| tau={tau} gamma={gamma} jj_steps={jj_steps}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BP-LLM Tuning ({split_name})] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f} "
          f"tau={best_params.tau} gamma={best_params.gamma} jj_steps={best_params.jj_steps}")
    return tried, best_params


# =========================
# Main
# =========================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    # ---- Load pairs from HelpSteer2
    if USE_PREFERENCE:
        print("Loading HelpSteer2 (preference pairs) ...")
        adapted = load_hs2_preference_pairs(min_strength=PREF_MIN_STRENGTH, keep_split=PREF_KEEP_SPLIT)
        print(f"Preference pairs: {len(adapted)}")
    else:
        print(f"Loading HelpSteer2 (ratings split: {RATINGS_SPLIT}) and building pairs by helpfulness gap ≥ {MIN_GAP} ...")
        adapted = load_hs2_ratings_pairs(min_gap=MIN_GAP, split=RATINGS_SPLIT)
        print(f"Built rating-derived pairs: {len(adapted)}")

    if not adapted:
        raise RuntimeError("No valid examples. If using ratings mode, lower MIN_GAP or switch USE_PREFERENCE=True.")

    # Optional peek
    try:
        ex = adapted[0]
        print("Example pair:", {k: (ex[k][:120] + "…") if isinstance(ex.get(k), str) else ex.get(k)
                                for k in ("prompt","chosen","rejected")})
    except Exception:
        pass

    # Deterministic split
    g = torch.Generator().manual_seed(SEED)
    idx = torch.randperm(len(adapted), generator=g).tolist()
    split = int(len(idx) * TRAIN_FRAC)
    train_idx, test_idx = idx[:split], idx[split:]
    train_records = [adapted[i] for i in train_idx]
    test_records  = [adapted[i] for i in test_idx]
    print(f"Train records: {len(train_records)} | Test records: {len(test_records)}")

    # QUICK sanity check (non-adaptive JJ) on TRAIN with default params
    params0 = BPParams(beta=BETA, delta=DELTA, tau=TAU, gamma=GAMMA, jj_steps=JJ_INNER_STEPS)
    base_s_tr, N_tr = cache_pairwise_scores(train_records, policy, ref, tok)
    wr0 = evaluate_with_cached(base_s_tr, N_tr, params0, use_adaptive_jj=False, use_labels=True)
    lo0, hi0 = binom_ci_95(wr0, N_tr)
    print(f"[Sanity TRAIN] WR={wr0:.2f}% [{lo0:.1f}, {hi0:.1f}] on {N_tr} pairs "
          f"| beta={params0.beta} delta={params0.delta} tau={params0.tau} gamma={params0.gamma} "
          f"jj_steps={params0.jj_steps}")

    # Tune on TRAIN (uses labels)
    if DO_TUNE:
        print("\nTuning on TRAIN...")
        _tries, best_params = grid_search(
            train_records, policy, ref, tok,
            betas=BETAS, deltas=DELTAS, taus=TAUS, gammas=GAMMAS,
            jj_steps_list=JJ_STEPS_LIST, use_adaptive_jj=True, estimate_delta=True, split_name="TRAIN"
        )
    else:
        best_params = params0

    # Final eval on TEST with NO labels (gamma=0 and JJ disabled)
    base_s_te, N_te = cache_pairwise_scores(test_records, policy, ref, tok)
    test_params = BPParams(
        beta=best_params.beta,
        delta=best_params.delta,   # cancels in pairwise compare, kept for completeness
        tau=best_params.tau,
        gamma=0.0,                 # NO labels on test
        jj_steps=0                 # disable JJ on test
    )
    wr_te = evaluate_with_cached(base_s_te, N_te, test_params, use_adaptive_jj=False, use_labels=False)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={test_params.beta} delta={test_params.delta:.4f} "
          f"| tau={test_params.tau} gamma={test_params.gamma} jj_steps={test_params.jj_steps}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading HelpSteer2 (preference pairs) ...
Preference pairs: 7117
Example pair: {'prompt': 'Define Signal Discuss its various properties with the help of diagram…', 'chosen': 'A signal is a message that is conveyed from a sender to a receiver through a communication channel. The message can be i…', 'rejected': 'A signal is a form of energy that is used to transmit information from one place to another. It can be in the form of so…'}
Train records: 5693 | Test records: 1424
[Sanity TRAIN] WR=99.21% [99.0, 99.4] on 5693 pairs | beta=1.0 delta=0.0 tau=1.0 gamma=1.0 jj_steps=2

Tuning on TRAIN...
[TUNE/TRAIN] WR=93.76% [93.1, 94.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=1
[TUNE/TRAIN] WR=93.76% [93.1, 94.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=2
[TUNE/TRAIN] WR=93.76% [93.1, 94.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=3
[TUNE/TRAIN] WR=93.76% [93.1, 94.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=0
[TUNE/TRAIN] WR=97.52% [97.1, 97.9] | beta=0.5 de

In [None]:
# BCO with Qwen2.5-3B-Instruct (HelpSteer2)

# bp_llm_helpsteer2_eval.py
# HelpSteer2 BCO baseline (leakage-safe)
#
# What this does:
#   • Uses HelpSteer2 in either of two modes:
#       (A) preference mode (data_dir="preference"): map (response_1, response_2, preference_strength) -> (chosen, rejected)
#       (B) ratings mode (default split): build top-vs-bottom pairs by helpfulness, with a min gap.
#   • Splits into TRAIN/TEST (tune β and δ on TRAIN; report on TEST).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#   • TEST evaluation is fully label-free (BCO doesn’t need labels).

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, DefaultDict
from collections import defaultdict

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")  # set if using gated models

# Use an *Instruct* model for the policy; optional Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-3B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-3B"   # set to None to disable reference subtraction

# ---- Dataset: HelpSteer2 ----
DATASET             = "nvidia/HelpSteer2"
DATASET_CONFIG      = None           # keep None
USE_PREFERENCE      = True           # set False to use ratings mode
PREFERENCE_SPLIT    = "train"        # HS2 preference usually ships as one 'train' file
RATINGS_SPLIT       = "train[:20%]"  # iterate smaller first; use "train" later

# Preference-mode knobs
PREF_MIN_STRENGTH   = 1              # skip weak prefs; keep abs(strength) >= this
PREF_KEEP_SPLIT     = None           # or "train" / "validation" to filter by the file's 'split' column

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering (applies to ratings-mode only)
MIN_GAP              = 1.0         # minimum helpfulness-score gap (0–4 scale) to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# HelpSteer2 helpers
# =============================================================================
def _safe_str(x) -> Optional[str]:
    return x if isinstance(x, str) and x.strip() else None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _get_int(value) -> Optional[int]:
    """Best-effort parse to int (accepts floats/strings); returns None if not possible."""
    try:
        if value is None:
            return None
        if isinstance(value, bool):
            return int(value)
        if isinstance(value, (int,)):
            return int(value)
        if isinstance(value, float):
            return int(round(value))
        s = str(value).strip()
        if s == "":
            return None
        return int(round(float(s)))
    except Exception:
        return None

def load_hs2_preference_pairs(
    min_strength: int = PREF_MIN_STRENGTH,
    keep_split: Optional[str] = PREF_KEEP_SPLIT
) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from HelpSteer2-preference by mapping
    response_1/response_2 + signed strength -> chosen/rejected.

    Positive strength => response_2 preferred
    Negative strength => response_1 preferred
    Zero or |strength| < min_strength => skip
    """
    ds = load_dataset(DATASET, data_dir="preference", split=PREFERENCE_SPLIT)

    out: List[Dict] = []
    for r in ds:
        sp = r.get("split")
        if keep_split is not None and sp != keep_split:
            continue

        p  = _safe_str(r.get("prompt") or r.get("instruction"))
        r1 = _safe_str(
            r.get("response_1")
            or r.get("response1")
            or r.get("resp1")
            or r.get("candidate_1")
            or r.get("output_1")
        )
        r2 = _safe_str(
            r.get("response_2")
            or r.get("response2")
            or r.get("resp2")
            or r.get("candidate_2")
            or r.get("output_2")
        )

        # robustly retrieve the signed strength
        s = (
            r.get("preference_strength")
            or r.get("preference")
            or r.get("label")
            or r.get("preference_score")
        )
        s = _get_int(s)

        if not (p and r1 and r2 and s is not None):
            continue
        if abs(s) < int(min_strength) or r1 == r2:
            continue

        if s > 0:
            chosen, rejected = r2, r1
        else:  # s < 0
            chosen, rejected = r1, r2

        out.append({"prompt": p, "chosen": chosen, "rejected": rejected})

    return out

def load_hs2_ratings_pairs(min_gap: float = MIN_GAP, split: str = RATINGS_SPLIT) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from the ratings split by taking the higher-vs-lower
    'helpfulness' response for each prompt, requiring a minimum rating gap.
    Typical fields: prompt, response, helpfulness (0..4), etc.
    """
    ds = load_dataset(DATASET, split=split) if DATASET_CONFIG is None else load_dataset(DATASET, name=DATASET_CONFIG, split=split)
    # Group up to 2 responses per prompt
    buckets: DefaultDict[str, List[Tuple[str, float]]] = defaultdict(list)
    for r in ds:
        p = _safe_str(r.get("prompt"))
        resp = _safe_str(r.get("response"))
        help_ = _to_float(r.get("helpfulness"))
        if p and resp and (help_ is not None):
            if len(buckets[p]) < 2:  # dataset typically has 2 responses/prompt
                buckets[p].append((resp, float(help_)))

    # Form pairs: top vs bottom by helpfulness, enforce min gap
    pairs: List[Dict] = []
    for p, lst in buckets.items():
        if len(lst) < 2:
            continue
        lst.sort(key=lambda t: t[1])  # low ... high
        lo_txt, lo_s = lst[0]
        hi_txt, hi_s = lst[-1]
        if (hi_s - lo_s) >= min_gap and (hi_txt != lo_txt):
            pairs.append({"prompt": p, "chosen": hi_txt, "rejected": lo_txt})
    return pairs


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_id, token=token, torch_dtype=dtype, trust_remote_code=True
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring (unchanged)
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt: str) -> Optional[str]:
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = [
            {"role": "user", "content": prompt.strip()},
            {"role": "assistant", "content": ""},
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt: str, response: str):
    chat_prefix = _apply_chat_prefix(tokenizer, prompt)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        prompt_text = prompt.strip()
        full_text   = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces (unchanged)
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or switch to preference mode.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    # ---- Load pairs from HelpSteer2
    if USE_PREFERENCE:
        print("Loading HelpSteer2 (preference pairs) ...")
        adapted = load_hs2_preference_pairs(min_strength=PREF_MIN_STRENGTH, keep_split=PREF_KEEP_SPLIT)
        print(f"Preference pairs: {len(adapted)}")
    else:
        print(f"Loading HelpSteer2 (ratings split: {RATINGS_SPLIT}) and building pairs by helpfulness gap ≥ {MIN_GAP} ...")
        adapted = load_hs2_ratings_pairs(min_gap=MIN_GAP, split=RATINGS_SPLIT)
        print(f"Built rating-derived pairs: {len(adapted)}")

    if not adapted:
        raise RuntimeError("No valid examples. If using ratings mode, lower MIN_GAP or use USE_PREFERENCE=True.")

    # Optional peek
    try:
        ex = adapted[0]
        print("Example pair:",
              {k: (ex[k][:120] + "…") if isinstance(ex.get(k), str) else ex.get(k)
               for k in ("prompt", "chosen", "rejected")})
    except Exception:
        pass

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small dataset or large TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/683 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loading HelpSteer2 (preference pairs) ...
Preference pairs: 7117
Example pair: {'prompt': 'Define Signal Discuss its various properties with the help of diagram…', 'chosen': 'A signal is a message that is conveyed from a sender to a receiver through a communication channel. The message can be i…', 'rejected': 'A signal is a form of energy that is used to transmit information from one place to another. It can be in the form of so…'}
Split sizes: TRAIN=5693 | TEST=1424

Scoring TRAIN...
[TRAIN BCO baseline] WR=58.14% [56.9, 59.4] on 5693 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=58.14% [56.9, 59.4] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=58.14% [56.9, 5

In [None]:
# BCO/BP-LLM with Qwen2.5-3B (HelpSteer2)

# bp_llm_helpsteer2_eval.py
# Evaluate BP-LLM (unary JJ) win rate on nvidia/HelpSteer2.
#
# What this does:
#   • Uses HelpSteer2 in either of two modes:
#       (A) preference mode (data_dir="preference"): map (response_1, response_2, preference_strength) -> (chosen, rejected)
#       (B) ratings mode (default split): build top-vs-bottom pairs by helpfulness, with a min gap.
#   • Splits into TRAIN/TEST (tune β,δ,τ,γ, JJ on TRAIN; report on TEST with NO labels).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • Lower MIN_GAP to admit more pairs if you see too few examples (ratings mode).
#   • Increase PREF_MIN_STRENGTH to ignore very weak preferences (preference mode).

import os
import math
import itertools
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, DefaultDict
from collections import defaultdict

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
# Put your HF token if you use gated models
HF_TOKEN = os.environ.get("HF_TOKEN")

# Suggested: Instruct as policy, Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-3B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-3B"   # None to disable reference subtraction

# ---- Dataset: HelpSteer2 ----
DATASET             = "nvidia/HelpSteer2"
USE_PREFERENCE      = True            # True -> use built-in pairwise preference file
PREFERENCE_SPLIT    = "train"         # HF exposes a single file; it includes its own 'split' column
PREF_MIN_STRENGTH   = 1               # keep pairs with |preference_strength| >= this
PREF_KEEP_SPLIT     = None            # or "train" / "validation" to filter within the preference file

# Ratings mode fallback (set USE_PREFERENCE=False to use this)
RATINGS_SPLIT       = "train[:20%]"   # iterate small first; use "train" later
MIN_GAP             = 1.0             # helpfulness gap (0-4 scale) for ratings mode

# Train/Test split from the adapted pairs (stratification not required here)
TRAIN_FRAC          = 0.8
SEED                = 42

# Input limits
MAX_INPUT_TOKENS    = 1024
MAX_GEN_TOKENS      = 512
SCORE_MAX_GEN_TOKENS= 10000
BATCH_SIZE          = 4

# Chat template toggle (recommended for *Instruct*)
USE_CHAT_TEMPLATE   = True

# Length normalization: "mean" or "sum"
LENGTH_NORM         = "mean"

# BP hyperparams (defaults)
BETA            = 1.0
DELTA           = 0.0
TAU             = 1.0
GAMMA           = 1.0
JJ_INNER_STEPS  = 2

# Grid tuning (TRAIN)
DO_TUNE         = True
BETAS           = [0.5, 1.0, 2.0]
DELTAS          = ['bco', 0.0]        # try BCO shift and zero shift
TAUS            = [0.5, 1.0, 2.0]
GAMMAS          = [0.5, 1.0, 1.5, 2.0]
JJ_STEPS_LIST   = [1, 2, 3, 0]        # 0 => adaptive (tolerance-based)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# =========================
# HelpSteer2 pairing helpers
# =========================
def _safe_str(x) -> Optional[str]:
    return x if isinstance(x, str) and x.strip() else None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _to_int(x) -> Optional[int]:
    try:
        if x is None: return None
        if isinstance(x, bool): return int(x)
        if isinstance(x, int):  return x
        if isinstance(x, float): return int(round(x))
        s = str(x).strip()
        if not s: return None
        return int(round(float(s)))
    except Exception:
        return None

def load_hs2_preference_pairs(
    min_strength: int = PREF_MIN_STRENGTH,
    keep_split: Optional[str] = PREF_KEEP_SPLIT
) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from HelpSteer2 'preference' by mapping:
      response_1/response_2 + signed preference_strength -> chosen/rejected
    Positive strength => response_2 preferred; negative => response_1.
    """
    ds = load_dataset(DATASET, data_dir="preference", split=PREFERENCE_SPLIT)
    out: List[Dict] = []
    for r in ds:
        sp = r.get("split")
        if keep_split is not None and sp != keep_split:
            continue

        p  = _safe_str(r.get("prompt") or r.get("instruction"))
        r1 = _safe_str(r.get("response_1") or r.get("response1") or r.get("candidate_1") or r.get("output_1"))
        r2 = _safe_str(r.get("response_2") or r.get("response2") or r.get("candidate_2") or r.get("output_2"))
        s  = _to_int(r.get("preference_strength") or r.get("preference") or r.get("label") or r.get("preference_score"))

        if not (p and r1 and r2 and s is not None):  # require all fields
            continue
        if abs(s) < int(min_strength) or r1 == r2:   # skip weak prefs and degenerate pairs
            continue

        if s > 0:
            chosen, rejected = r2, r1
        else:  # s < 0
            chosen, rejected = r1, r2

        out.append({"prompt": p, "chosen": chosen, "rejected": rejected})
    return out

def load_hs2_ratings_pairs(min_gap: float = MIN_GAP, split: str = RATINGS_SPLIT) -> List[Dict]:
    """
    Build (prompt, chosen, rejected) from the ratings split by taking
    higher-vs-lower 'helpfulness' responses per prompt, requiring a min gap.
    Typical fields: prompt, response, helpfulness (0..4), etc.
    """
    ds = load_dataset(DATASET, split=split)
    buckets: DefaultDict[str, List[Tuple[str, float]]] = defaultdict(list)
    for r in ds:
        p    = _safe_str(r.get("prompt"))
        resp = _safe_str(r.get("response"))
        help_ = _to_float(r.get("helpfulness"))
        if p and resp and (help_ is not None):
            if len(buckets[p]) < 2:  # HS2 usually has exactly 2 responses per prompt
                buckets[p].append((resp, float(help_)))

    pairs: List[Dict] = []
    for p, lst in buckets.items():
        if len(lst) < 2:
            continue
        lst.sort(key=lambda t: t[1])   # low ... high
        lo_txt, lo_s = lst[0]
        hi_txt, hi_s = lst[-1]
        if (hi_s - lo_s) >= min_gap and (hi_txt != lo_txt):
            pairs.append({"prompt": p, "chosen": hi_txt, "rejected": lo_txt})
    return pairs


# =========================
# Tokenization / Scoring
# =========================
def format_prompt_with_chat_template(tokenizer, prompt: str) -> str:
    if USE_CHAT_TEMPLATE and hasattr(tokenizer, "apply_chat_template"):
        messages = [{"role": "user", "content": prompt}]
        try:
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True  # leaves assistant prefix for continuation
            )
        except Exception:
            pass
    return prompt

def concat_prompt_response_text(tokenizer, prompt: str, response: str) -> Tuple[str, str]:
    """Return (full_text, prompt_only_text) for token counting."""
    prompt_text = format_prompt_with_chat_template(tokenizer, prompt).strip()
    full_text = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()
    return full_text, prompt_text

@torch.no_grad()
def sequence_logprob_list(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
    length_norm: str = LENGTH_NORM,
) -> List[Optional[float]]:
    """
    Returns mean/sum log-prob for each (prompt, response). If response was truncated
    (i.e., prompt consumed the full sequence), returns None for that item.
    """
    vals: List[Optional[float]] = []
    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        batch_resps   = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(batch_prompts, batch_resps):
            full_text, prompt_text = concat_prompt_response_text(tokenizer, p, r)
            toks = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            tok_prompt = tokenizer(
                prompt_text,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = tok_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L     = int(masks.sum().item())
            if p_len >= L:
                vals.append(None)  # response totally truncated -> skip
                continue
            targ = ids[p_len:L]              # response tokens
            pred = logprobs[b, p_len-1:L-1]  # shifted predictions
            lp = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1)
            if length_norm == "mean":
                score = float(lp.mean().cpu())
            else:
                score = float(lp.sum().cpu())
            vals.append(score)
    return vals


# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        model_id, token=token, torch_dtype=dtype, trust_remote_code=True
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =========================
# BP-LLM (unary JJ) posterior
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    if xi < 1e-8:
        return 1.0 / 8.0
    return math.tanh(xi / 2.0) / (4.0 * xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def bp_unary_posterior_adaptive(mu_prior: float, b: int, params: BPParams,
                                tol: float = 1e-4, max_steps: int = 50) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    steps = params.jj_steps if params.jj_steps and params.jj_steps > 0 else max_steps
    for _ in range(steps):
        mu_prev = mu_hat
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
        if (params.jj_steps is None or params.jj_steps <= 0) and abs(mu_hat - mu_prev) < tol:
            break
    return mu_hat, tau2_hat


# =========================
# Caching scores (pairwise)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    model,
    ref_model,
    tokenizer
) -> Tuple[torch.Tensor, int]:
    """
    Returns base_s tensor shaped [2N] where first N are chosen scores, next N are rejected,
    and N is the number of *valid* pairs (both sides scored).
    Each score is (log pi - log pi_ref) with length normalization.
    """
    prompts_ch  = [r["prompt"] for r in records]
    prompts_rj  = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]

    policy_ch = sequence_logprob_list(model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
    policy_rj = sequence_logprob_list(model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    if ref_model is None:
        ref_ch = [0.0 if (s is not None) else None for s in policy_ch]
        ref_rj = [0.0 if (s is not None) else None for s in policy_rj]
    else:
        ref_ch = sequence_logprob_list(ref_model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
        ref_rj = sequence_logprob_list(ref_model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    chosen_vals, rejected_vals = [], []
    for sc, sr, rc, rr in zip(policy_ch, policy_rj, ref_ch, ref_rj):
        if (sc is None) or (sr is None) or (rc is None) or (rr is None):
            continue  # drop pair if any side was truncated
        chosen_vals.append(sc - rc)
        rejected_vals.append(sr - rr)

    if len(chosen_vals) == 0:
        raise RuntimeError("No valid pairs after scoring. Consider increasing MAX_*_TOKENS or using 'sum' norm.")

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(chosen_vals)
    return base_s, N


# =========================
# Evaluation & tuning
# =========================
def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def evaluate_with_cached(
    base_s: torch.Tensor, N: int, params: BPParams,
    use_adaptive_jj: bool = True, use_labels: bool = True
) -> float:
    """
    If use_labels=False, compare mu_prior only (label-free test).
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()

    correct = 0
    for i in range(N):
        mu_w_prior = mu[i]
        mu_l_prior = mu[i + N]

        if not use_labels:
            # label-free: just compare priors
            if mu_w_prior > mu_l_prior:
                correct += 1
            continue

        if use_adaptive_jj:
            mu_w_post, _ = bp_unary_posterior_adaptive(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior_adaptive(mu_l_prior, b=0, params=params)
        else:
            mu_w_post, _ = bp_unary_posterior(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior(mu_l_prior, b=0, params=params)
        if mu_w_post > mu_l_post:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]],
    taus: Iterable[float], gammas: Iterable[float],
    jj_steps_list: Iterable[int], use_adaptive_jj: bool = True,
    estimate_delta: bool = True, split_name: str = "TRAIN"
):
    base_s, N = cache_pairwise_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)

    for beta, delta_spec, tau, gamma, jj_steps in itertools.product(
        betas, deltas, taus, gammas, jj_steps_list
    ):
        if estimate_delta and (delta_spec == 'bco'):
            delta = estimate_delta_bco(base_s, beta, N)
        else:
            delta = float(delta_spec)

        params = BPParams(beta=beta, delta=delta, tau=tau, gamma=gamma, jj_steps=jj_steps)
        # TRAIN uses labels:
        wr = evaluate_with_cached(base_s, N, params, use_adaptive_jj=use_adaptive_jj, use_labels=True)
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, params))
        if wr > best[0]:
            best = (wr, params)

        print(f"[TUNE/{split_name}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'} "
              f"| tau={tau} gamma={gamma} jj_steps={jj_steps}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BP-LLM Tuning ({split_name})] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f} "
          f"tau={best_params.tau} gamma={best_params.gamma} jj_steps={best_params.jj_steps}")
    return tried, best_params


# =========================
# Main
# =========================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    # ---- Load pairs from HelpSteer2
    if USE_PREFERENCE:
        print("Loading HelpSteer2 (preference pairs) ...")
        adapted = load_hs2_preference_pairs(min_strength=PREF_MIN_STRENGTH, keep_split=PREF_KEEP_SPLIT)
        print(f"Preference pairs: {len(adapted)}")
    else:
        print(f"Loading HelpSteer2 (ratings split: {RATINGS_SPLIT}) and building pairs by helpfulness gap ≥ {MIN_GAP} ...")
        adapted = load_hs2_ratings_pairs(min_gap=MIN_GAP, split=RATINGS_SPLIT)
        print(f"Built rating-derived pairs: {len(adapted)}")

    if not adapted:
        raise RuntimeError("No valid examples. If using ratings mode, lower MIN_GAP or switch USE_PREFERENCE=True.")

    # Optional peek
    try:
        ex = adapted[0]
        print("Example pair:", {k: (ex[k][:120] + "…") if isinstance(ex.get(k), str) else ex.get(k)
                                for k in ("prompt","chosen","rejected")})
    except Exception:
        pass

    # Deterministic split
    g = torch.Generator().manual_seed(SEED)
    idx = torch.randperm(len(adapted), generator=g).tolist()
    split = int(len(idx) * TRAIN_FRAC)
    train_idx, test_idx = idx[:split], idx[split:]
    train_records = [adapted[i] for i in train_idx]
    test_records  = [adapted[i] for i in test_idx]
    print(f"Train records: {len(train_records)} | Test records: {len(test_records)}")

    # QUICK sanity check (non-adaptive JJ) on TRAIN with default params
    params0 = BPParams(beta=BETA, delta=DELTA, tau=TAU, gamma=GAMMA, jj_steps=JJ_INNER_STEPS)
    base_s_tr, N_tr = cache_pairwise_scores(train_records, policy, ref, tok)
    wr0 = evaluate_with_cached(base_s_tr, N_tr, params0, use_adaptive_jj=False, use_labels=True)
    lo0, hi0 = binom_ci_95(wr0, N_tr)
    print(f"[Sanity TRAIN] WR={wr0:.2f}% [{lo0:.1f}, {hi0:.1f}] on {N_tr} pairs "
          f"| beta={params0.beta} delta={params0.delta} tau={params0.tau} gamma={params0.gamma} "
          f"jj_steps={params0.jj_steps}")

    # Tune on TRAIN (uses labels)
    if DO_TUNE:
        print("\nTuning on TRAIN...")
        _tries, best_params = grid_search(
            train_records, policy, ref, tok,
            betas=BETAS, deltas=DELTAS, taus=TAUS, gammas=GAMMAS,
            jj_steps_list=JJ_STEPS_LIST, use_adaptive_jj=True, estimate_delta=True, split_name="TRAIN"
        )
    else:
        best_params = params0

    # Final eval on TEST with NO labels (gamma=0 and JJ disabled)
    base_s_te, N_te = cache_pairwise_scores(test_records, policy, ref, tok)
    test_params = BPParams(
        beta=best_params.beta,
        delta=best_params.delta,   # cancels in pairwise compare, kept for completeness
        tau=best_params.tau,
        gamma=0.0,                 # NO labels on test
        jj_steps=0                 # disable JJ on test
    )
    wr_te = evaluate_with_cached(base_s_te, N_te, test_params, use_adaptive_jj=False, use_labels=False)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={test_params.beta} delta={test_params.delta:.4f} "
          f"| tau={test_params.tau} gamma={test_params.gamma} jj_steps={test_params.jj_steps}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading HelpSteer2 (preference pairs) ...
Preference pairs: 7117
Example pair: {'prompt': 'Define Signal Discuss its various properties with the help of diagram…', 'chosen': 'A signal is a message that is conveyed from a sender to a receiver through a communication channel. The message can be i…', 'rejected': 'A signal is a form of energy that is used to transmit information from one place to another. It can be in the form of so…'}
Train records: 5693 | Test records: 1424
[Sanity TRAIN] WR=99.03% [98.8, 99.3] on 5693 pairs | beta=1.0 delta=0.0 tau=1.0 gamma=1.0 jj_steps=2

Tuning on TRAIN...
[TUNE/TRAIN] WR=93.85% [93.2, 94.5] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=1
[TUNE/TRAIN] WR=93.85% [93.2, 94.5] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=2
[TUNE/TRAIN] WR=93.85% [93.2, 94.5] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=3
[TUNE/TRAIN] WR=93.85% [93.2, 94.5] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=0
[TUNE/TRAIN] WR=97.22% [96.8, 97.7] | beta=0.5 de

# The Following is for Qwen-7B

In [None]:
# bp_llm_helpsteer2_eval.py
# BP-LLM PRIOR-WR evaluation on nvidia/HelpSteer2 with Qwen2.5 (3B/7B).
# Key upgrades vs. your script:
#   1) Uses the *correct tokenizer per model* (policy vs reference).
#   2) Prior-only calibration of λ_ref, α (length term), and η (uncond penalty).
#   3) Cross-model token equalization + per-pair cap sweep (to reduce length bias).
#   4) Optional min-token filtering to drop noisy/short pairs.
#   5) TEST remains label-free (γ=0, JJ disabled); TRAIN tuning is prior-only (no label leakage).

import os, math, itertools, random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, DefaultDict, Union
from collections import defaultdict

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
HF_TOKEN = os.environ.get("HF_TOKEN")

# ---------- Choose your model here ----------
# Default: 3B. To try 7B, swap these two lines:
# MODEL_NAME      = "Qwen/Qwen2.5-3B-Instruct"
# REF_MODEL_NAME  = "Qwen/Qwen2.5-3B"             # set None to disable reference subtraction
MODEL_NAME    = "Qwen/Qwen2.5-7B-Instruct"
REF_MODEL_NAME= "Qwen/Qwen2.5-7B"

DATASET             = "nvidia/HelpSteer2"
USE_PREFERENCE      = True
PREFERENCE_SPLIT    = "train"
PREF_MIN_STRENGTH   = 1
PREF_KEEP_SPLIT     = None

# Ratings-mode fallback (unused unless USE_PREFERENCE=False)
RATINGS_SPLIT       = "train"
MIN_GAP             = 1.0

# Deterministic split (HS2 preference file has 7117 total)
TRAIN_FRAC          = 0.8
SEED                = 0
STRICT_COUNTS       = True
EXPECTED_TOTAL      = 7117
EXPECTED_TRAIN      = 5693

# Limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
BATCH_SIZE           = 4

# Chat template (recommended for *Instruct*)
USE_CHAT_TEMPLATE    = True
PROMPT_SUFFIX        = "\n\nAssistant:"

# ====== PRIOR scoring mode ======
# 'lp_alpha': logprob_sum + α * (#scored_tokens)  (set α via grid below)
# 'mean'    : mean logprob (α ignored)
# 'sum'     : raw sum (α ignored)
SCORING_MODE          = "lp_alpha"
LP_ALPHA_GRID         = (-1.50, -1.25, -1.00, -0.75, -0.50, -0.35, -0.25, 0.0, 0.25)

# Reference subtraction weight λ
LAMBDA_REF_GRID       = (0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.5, 3.0, 3.5, 4.0)

# Unconditional (response-only) penalty weight η (0 disables)
USE_UNCOND_PENALTY    = True
UNCOND_WEIGHT_GRID    = (0.0, 0.15, 0.25, 0.40)

# Token equalization & cap
EQUALIZE_PAIR_TOKENS  = True
EQUALIZE_CROSS_MODEL  = True
DO_CAP_SWEEP          = True
K_PAIR_CAP_GRID       = (128, 256, 384, 512)

# Min-token filtering (train/test)
APPLY_MIN_TOK_FILTER  = True
MIN_TOK_TRAIN         = 8
MIN_TOK_TEST          = 8

# BP (used for diagnostic / sanity only; TEST uses γ=0)
BETA, DELTA, TAU, GAMMA, JJ_INNER_STEPS = 1.0, 0.0, 1.0, 1.0, 2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# =========================
# HelpSteer2 adapters
# =========================
def _safe_str(x) -> Optional[str]:
    return x if isinstance(x, str) and x.strip() else None

def _to_float(x) -> Optional[float]:
    try:
        if x is None: return None
        if isinstance(x, (int, float)): return float(x)
        s = str(x).strip()
        return float(s) if s else None
    except Exception:
        return None

def _to_int(x) -> Optional[int]:
    try:
        if x is None: return None
        if isinstance(x, bool): return int(x)
        if isinstance(x, int):  return x
        if isinstance(x, float): return int(round(x))
        s = str(x).strip()
        if not s: return None
        return int(round(float(s)))
    except Exception:
        return None

def load_hs2_preference_pairs(min_strength=PREF_MIN_STRENGTH, keep_split=PREF_KEEP_SPLIT) -> List[Dict]:
    ds = load_dataset(DATASET, data_dir="preference", split=PREFERENCE_SPLIT)
    out: List[Dict] = []
    for r in ds:
        sp = r.get("split")
        if keep_split is not None and sp != keep_split:
            continue
        p  = _safe_str(r.get("prompt") or r.get("instruction"))
        r1 = _safe_str(r.get("response_1") or r.get("response1") or r.get("candidate_1") or r.get("output_1"))
        r2 = _safe_str(r.get("response_2") or r.get("response2") or r.get("candidate_2") or r.get("output_2"))
        s  = _to_int(r.get("preference_strength") or r.get("preference") or r.get("label") or r.get("preference_score"))
        if not (p and r1 and r2 and s is not None): continue
        if r1 == r2 or abs(s) < int(min_strength):   continue
        chosen, rejected = (r2, r1) if s > 0 else (r1, r2)
        out.append({"prompt": p, "chosen": chosen, "rejected": rejected})
    return out

def load_hs2_ratings_pairs(min_gap=MIN_GAP, split=RATINGS_SPLIT) -> List[Dict]:
    ds = load_dataset(DATASET, split=split)
    buckets: DefaultDict[str, List[Tuple[str, float]]] = defaultdict(list)
    for r in ds:
        p = _safe_str(r.get("prompt")); resp = _safe_str(r.get("response")); h = _to_float(r.get("helpfulness"))
        if p and resp and (h is not None):
            buckets[p].append((resp, float(h)))
    pairs: List[Dict] = []
    for p, lst in buckets.items():
        if len(lst) < 2: continue
        lst.sort(key=lambda t: t[1])
        lo_txt, lo_s = lst[0]; hi_txt, hi_s = lst[-1]
        if (hi_s - lo_s) >= min_gap and (hi_txt != lo_txt):
            pairs.append({"prompt": p, "chosen": hi_txt, "rejected": lo_txt})
    return pairs

# =========================
# Prompt rendering
# =========================
def render_prompt_with_policy_template(policy_tokenizer, prompt: Union[str, List[Dict[str,str]]]) -> str:
    if USE_CHAT_TEMPLATE and hasattr(policy_tokenizer, "apply_chat_template"):
        try:
            msgs = [{"role": "user", "content": str(prompt).strip()}]
            return policy_tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        except Exception:
            pass
    p = str(prompt).strip()
    s = PROMPT_SUFFIX if not USE_CHAT_TEMPLATE else ""
    if not p.endswith("\n"): p += "\n"
    return (p + s).strip()

def render_prompt_list(policy_tokenizer, prompts: List[Union[str, List[Dict[str,str]]]]) -> List[str]:
    return [render_prompt_with_policy_template(policy_tokenizer, p).strip() for p in prompts]

# =========================
# Tokenization / scoring helpers
# =========================
@torch.no_grad()
def _sequence_token_counts_by_tokenizer(tokenizer, prompt_texts: List[str], responses: List[str]) -> torch.Tensor:
    counts = []
    for p_txt, r in zip(prompt_texts, responses):
        toks_prompt = tokenizer(p_txt, truncation=True, max_length=MAX_INPUT_TOKENS,
                                return_tensors="pt", add_special_tokens=True)
        toks_full   = tokenizer(p_txt + r.strip(), truncation=True,
                                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                                return_tensors="pt", add_special_tokens=True)
        p_len = int(toks_prompt["input_ids"].shape[-1])
        L     = int(toks_full["input_ids"].shape[-1])
        cnt   = max(0, L - p_len)
        counts.append(cnt)
    return torch.tensor(counts, dtype=torch.long)

@torch.no_grad()
def _sequence_logprob_stats_text_with_k(model, tokenizer, prompt_texts, responses, k_list: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    sums, counts = [], []
    for i in range(0, len(prompt_texts), BATCH_SIZE):
        p_batch = prompt_texts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]
        k_batch = k_list[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p_txt, r in zip(p_batch, r_batch):
            full_text = p_txt + r.strip()
            toks_full = tokenizer(full_text, truncation=True,
                                  max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                                  return_tensors="pt", add_special_tokens=True)
            toks_prompt = tokenizer(p_txt, truncation=True, max_length=MAX_INPUT_TOKENS,
                                    return_tensors="pt", add_special_tokens=True)
            p_len = toks_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks_full); batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence([bi["input_ids"].squeeze(0) for bi in batch_inputs],
                                                    batch_first=True, padding_value=pad_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence([bi["attention_mask"].squeeze(0) for bi in batch_inputs],
                                                         batch_first=True, padding_value=0)
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]; ids = input_ids[b]; masks = attention_mask[b]
            L     = int(masks.sum().item()); k = int(k_batch[b].item())
            k_eff = max(0, min(k, L - p_len))
            if k_eff <= 0:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = p_len + k_eff
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu()); counts.append(torch.tensor(k_eff))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

@torch.no_grad()
def _response_only_scores(model, tok, responses: List[str], k_list: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    sums, cnts = [], []
    for i in range(0, len(responses), BATCH_SIZE):
        r_batch = [r.strip() for r in responses[i:i+BATCH_SIZE]]
        toks_full = tok(r_batch, truncation=True,
                        max_length=min(MAX_GEN_TOKENS, tok.model_max_length),
                        return_tensors="pt", add_special_tokens=True, padding=True)
        input_ids = toks_full["input_ids"].to(DEVICE)
        attn = toks_full["attention_mask"].to(DEVICE)
        logits = model(input_ids=input_ids, attention_mask=attn).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)
        for b in range(input_ids.size(0)):
            L = int(attn[b].sum().item())
            k = int(min(int(k_list[i+b].item()), max(0, L - 1)))
            if k <= 0:
                sums.append(torch.tensor(0.0)); cnts.append(torch.tensor(0)); continue
            ids = input_ids[b]
            targ = ids[1:1+k]
            pred = logprobs[b, :k]
            lp = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp.cpu()); cnts.append(torch.tensor(k))
    return torch.stack(sums), torch.stack(cnts)

def _score_from_sums_counts(lp_sum: torch.Tensor, tok_count: torch.Tensor, mode: str, alpha: float) -> torch.Tensor:
    if mode == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif mode == "lp_alpha":
        return lp_sum.to(torch.float32) + (alpha * tok_count.to(torch.float32))
    else:
        return lp_sum.to(torch.float32)

# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype, trust_remote_code=True)
    model = model.to(device); model.eval()
    return tok, model

# =========================
# PRIOR cache (pairwise)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    policy_model, ref_model,
    policy_tok, ref_tok=None,
    k_pair_cap: Optional[int] = None,
    min_tok_required: int = 0,
    compute_uncond: bool = False
):
    prompts_raw = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]
    prompt_texts = render_prompt_list(policy_tok, prompts_raw)

    # Nominal counts (per tokenizer)
    pol_cnt_ch_nom = _sequence_token_counts_by_tokenizer(policy_tok, prompt_texts, resps_ch)
    pol_cnt_rj_nom = _sequence_token_counts_by_tokenizer(policy_tok, prompt_texts, resps_rj)
    pol_pair_min   = torch.minimum(pol_cnt_ch_nom, pol_cnt_rj_nom)

    if ref_model is not None:
        ref_cnt_ch_nom = _sequence_token_counts_by_tokenizer(ref_tok, prompt_texts, resps_ch)
        ref_cnt_rj_nom = _sequence_token_counts_by_tokenizer(ref_tok, prompt_texts, resps_rj)
        ref_pair_min   = torch.minimum(ref_cnt_ch_nom, ref_cnt_rj_nom)
    else:
        ref_cnt_ch_nom = pol_cnt_ch_nom.clone(); ref_cnt_rj_nom = pol_cnt_rj_nom.clone()
        ref_pair_min   = pol_pair_min.clone()

    # Shared per-pair k (equalize across sides/models) + cap
    if EQUALIZE_PAIR_TOKENS:
        k_shared = torch.minimum(pol_pair_min, ref_pair_min) if (EQUALIZE_CROSS_MODEL and ref_model is not None) else pol_pair_min
    else:
        k_shared = pol_pair_min
    if k_pair_cap is not None and k_pair_cap > 0:
        k_shared = torch.clamp(k_shared, max=int(k_pair_cap))

    # Score policy with shared k
    pol_sum_ch, pol_cnt_ch = _sequence_logprob_stats_text_with_k(policy_model, policy_tok, prompt_texts, resps_ch, k_shared)
    pol_sum_rj, pol_cnt_rj = _sequence_logprob_stats_text_with_k(policy_model, policy_tok, prompt_texts, resps_rj, k_shared)

    # Score reference with same k (using its OWN tokenizer!)
    if ref_model is None:
        ref_sum_ch = torch.zeros_like(pol_sum_ch); ref_cnt_ch = torch.ones_like(pol_cnt_ch)
        ref_sum_rj = torch.zeros_like(pol_sum_rj); ref_cnt_rj = torch.ones_like(pol_cnt_rj)
    else:
        ref_sum_ch, ref_cnt_ch = _sequence_logprob_stats_text_with_k(ref_model, ref_tok, prompt_texts, resps_ch, k_shared)
        ref_sum_rj, ref_cnt_rj = _sequence_logprob_stats_text_with_k(ref_model, ref_tok, prompt_texts, resps_rj, k_shared)

    # Min-token filter
    if min_tok_required and min_tok_required > 0:
        keep = (pol_cnt_ch >= min_tok_required) & (pol_cnt_rj >= min_tok_required)
        if ref_model is not None:
            keep = keep & (ref_cnt_ch >= min_tok_required) & (ref_cnt_rj >= min_tok_required)
    else:
        keep = torch.ones_like(pol_cnt_ch, dtype=torch.bool)

    valid_ch = (pol_cnt_ch > 0) & (ref_cnt_ch > 0)
    valid_rj = (pol_cnt_rj > 0) & (ref_cnt_rj > 0)
    valid = keep & valid_ch & valid_rj
    idx = [i for i in range(len(records)) if valid[i].item()]
    dropped = len(records) - len(idx)

    if not idx:
        raise RuntimeError("No valid pairs after scoring/filtering.")
    t = torch.tensor(idx, dtype=torch.long)

    pol_sum_ch, pol_cnt_ch = pol_sum_ch[t], pol_cnt_ch[t]
    pol_sum_rj, pol_cnt_rj = pol_sum_rj[t], pol_cnt_rj[t]
    ref_sum_ch, ref_cnt_ch = ref_sum_ch[t], ref_cnt_ch[t]
    ref_sum_rj, ref_cnt_rj = ref_sum_rj[t], ref_cnt_rj[t]

    return (pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
            ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj,
            len(idx), dropped)

def make_scores_from_components(sum_t, cnt_t, mode, alpha):
    if mode == "lp_alpha": return sum_t.to(torch.float32) + alpha * cnt_t.to(torch.float32)
    if mode == "mean":     return sum_t.to(torch.float32) / torch.clamp(cnt_t.to(torch.float32), min=1.0)
    return sum_t.to(torch.float32)

def make_base_s_from_components(pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
                                ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj,
                                lambda_ref, mode, alpha,
                                uncond_sum_ch=None, uncond_cnt_ch=None, uncond_sum_rj=None, uncond_cnt_rj=None, uncond_weight: float = 0.0):
    pol_score_ch = make_scores_from_components(pol_sum_ch, pol_cnt_ch, mode, alpha)
    pol_score_rj = make_scores_from_components(pol_sum_rj, pol_cnt_rj, mode, alpha)
    ref_score_ch = make_scores_from_components(ref_sum_ch, ref_cnt_ch, mode, alpha)
    ref_score_rj = make_scores_from_components(ref_sum_rj, ref_cnt_rj, mode, alpha)
    chosen  = pol_score_ch - lambda_ref * ref_score_ch
    rejected= pol_score_rj - lambda_ref * ref_score_rj
    if (uncond_sum_ch is not None) and (uncond_cnt_ch is not None):
        u_ch = make_scores_from_components(uncond_sum_ch, uncond_cnt_ch, mode, alpha)
        u_rj = make_scores_from_components(uncond_sum_rj, uncond_cnt_rj, mode, alpha)
        chosen  = chosen  - (uncond_weight * u_ch)
        rejected= rejected - (uncond_weight * u_rj)
    return torch.cat([chosen.to(torch.float32), rejected.to(torch.float32)], dim=0)

# =========================
# BP (JJ) pieces — diagnostic only
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    return (1.0/8.0) if xi < 1e-8 else math.tanh(xi/2.0)/(4.0*xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2*b - 1)
    mu_hat, tau2_hat = mu_prior, tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat*mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0/tau2) + 2.0*lam
        eta    = (mu_prior/tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def evaluate_prior_wr(base_s: torch.Tensor, N: int, beta: float = 1.0, delta: float = 0.0) -> float:
    mu = (beta * base_s) - delta
    mu = mu.tolist()
    correct = sum(1 for i in range(N) if mu[i] > mu[i+N])
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0; se = math.sqrt(p * (1 - p) / max(N, 1))
    return max(0.0, 100*(p-1.96*se)), min(100.0, 100*(p+1.96*se))

# =========================
# Main
# =========================
def main():
    random.seed(SEED); torch.manual_seed(SEED)
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")

    print("Loading policy model ...")
    tok_policy, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    tok_ref, ref = (None, None)
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        tok_ref, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    # Load pairs
    if USE_PREFERENCE:
        print("Loading HelpSteer2 (preference pairs) ...")
        adapted = load_hs2_preference_pairs(PREF_MIN_STRENGTH, PREF_KEEP_SPLIT)
        print(f"Preference pairs loaded: {len(adapted)}")
    else:
        print("Loading HelpSteer2 (ratings) and building pairs ...")
        adapted = load_hs2_ratings_pairs(MIN_GAP, RATINGS_SPLIT)
        print(f"Rating-derived pairs: {len(adapted)}")
    if not adapted: raise RuntimeError("No pairs found.")

    # Deterministic split
    random.Random(SEED).shuffle(adapted)
    if STRICT_COUNTS and len(adapted) == EXPECTED_TOTAL:
        n_train = EXPECTED_TRAIN
    else:
        n_train = max(1, int(TRAIN_FRAC * len(adapted)))
        if STRICT_COUNTS and len(adapted) != EXPECTED_TOTAL:
            print(f"[WARN] Found {len(adapted)} pairs, expected {EXPECTED_TOTAL}; using {100*TRAIN_FRAC:.1f}% split.")
    train_recs, test_recs = adapted[:n_train], adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # ----- PRIOR-only calibration on TRAIN (with cap sweep + optional uncond penalty)
    best_overall = (-1.0, None)
    for cap in (K_PAIR_CAP_GRID if DO_CAP_SWEEP else (None,)):
        print(f"\n[CAP={cap}] Caching TRAIN prior components ...")
        (pol_sum_ch_tr, pol_cnt_ch_tr, pol_sum_rj_tr, pol_cnt_rj_tr,
         ref_sum_ch_tr, ref_cnt_ch_tr, ref_sum_rj_tr, ref_cnt_rj_tr,
         N_tr, dropped_tr) = cache_pairwise_scores(
            train_recs, policy, ref, tok_policy, tok_ref,
            k_pair_cap=cap, min_tok_required=(MIN_TOK_TRAIN if APPLY_MIN_TOK_FILTER else 0),
            compute_uncond=False
        )
        print(f"TRAIN kept N={N_tr} (dropped {dropped_tr})")

        # Optional: unconditional penalty components (from reference if available else policy)
        if USE_UNCOND_PENALTY:
            print("[CAP={}] Scoring unconditional (response-only) components ...".format(cap))
            # reuse counts from policy side for k's
            u_model, u_tok = (ref, tok_ref) if ref is not None else (policy, tok_policy)
            u_sum_ch_tr, u_cnt_ch_tr = _response_only_scores(u_model, u_tok, [r['chosen'] for r in train_recs][:N_tr], pol_cnt_ch_tr)
            u_sum_rj_tr, u_cnt_rj_tr = _response_only_scores(u_model, u_tok, [r['rejected'] for r in train_recs][:N_tr], pol_cnt_rj_tr)
        else:
            u_sum_ch_tr = u_cnt_ch_tr = u_sum_rj_tr = u_cnt_rj_tr = None

        tried = []
        alpha_grid = LP_ALPHA_GRID if SCORING_MODE == "lp_alpha" else (0.0,)
        for lam, alpha, eta in itertools.product(LAMBDA_REF_GRID, alpha_grid, (UNCOND_WEIGHT_GRID if USE_UNCOND_PENALTY else (0.0,))):
            base_s = make_base_s_from_components(
                pol_sum_ch_tr, pol_cnt_ch_tr, pol_sum_rj_tr, pol_cnt_rj_tr,
                ref_sum_ch_tr, ref_cnt_ch_tr, ref_sum_rj_tr, ref_cnt_rj_tr,
                lambda_ref=lam, mode=SCORING_MODE, alpha=alpha,
                uncond_sum_ch=u_sum_ch_tr, uncond_cnt_ch=u_cnt_ch_tr,
                uncond_sum_rj=u_sum_rj_tr, uncond_cnt_rj=u_cnt_rj_tr, uncond_weight=eta
            )
            wr = evaluate_prior_wr(base_s, N_tr)
            lo, hi = binom_ci_95(wr, N_tr)
            tried.append((wr, lam, alpha, eta))
            print(f"[TRAIN-PRIOR cap={cap}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] | λ={lam:.2f} α={alpha:+.2f} η={eta:.2f}")

        wr_best, lam_best, alpha_best, eta_best = max(tried, key=lambda t: t[0])
        lo_b, hi_b = binom_ci_95(wr_best, N_tr)
        print(f"[SELECT cap={cap}] Best TRAIN-PRIOR WR={wr_best:.2f}% [{lo_b:.1f}, {hi_b:.1f}] "
              f"| λ={lam_best:.2f} α={alpha_best:+.2f} η={eta_best:.2f}")

        if wr_best > best_overall[0]:
            best_overall = (wr_best, {"cap": cap, "lambda": lam_best, "alpha": alpha_best, "eta": eta_best, "N": N_tr})

    sel = best_overall[1]
    print(f"\n[FINAL TRAIN PICK] cap={sel['cap']} λ={sel['lambda']:.2f} α={sel['alpha']:+.2f} η={sel['eta']:.2f} "
          f"| TRAIN-PRIOR WR={best_overall[0]:.2f}% on N={sel['N']}")

    # ----- TEST (label-free)
    print("\nScoring TEST (PRIOR-ONLY) with selected settings ...")
    (pol_sum_ch_te, pol_cnt_ch_te, pol_sum_rj_te, pol_cnt_rj_te,
     ref_sum_ch_te, ref_cnt_ch_te, ref_sum_rj_te, ref_cnt_rj_te,
     N_te, dropped_te) = cache_pairwise_scores(
        test_recs if len(test_recs) > 0 else train_recs,
        policy, ref, tok_policy, tok_ref,
        k_pair_cap=sel['cap'],
        min_tok_required=(MIN_TOK_TEST if APPLY_MIN_TOK_FILTER else 0),
        compute_uncond=False
    )
    print(f"TEST kept N={N_te} (dropped {dropped_te})")

    if USE_UNCOND_PENALTY:
        u_model, u_tok = (ref, tok_ref) if ref is not None else (policy, tok_policy)
        u_sum_ch_te, u_cnt_ch_te = _response_only_scores(u_model, u_tok, [r['chosen'] for r in test_recs][:N_te], pol_cnt_ch_te)
        u_sum_rj_te, u_cnt_rj_te = _response_only_scores(u_model, u_tok, [r['rejected'] for r in test_recs][:N_te], pol_cnt_rj_te)
    else:
        u_sum_ch_te = u_cnt_ch_te = u_sum_rj_te = u_cnt_rj_te = None

    base_s_te = make_base_s_from_components(
        pol_sum_ch_te, pol_cnt_ch_te, pol_sum_rj_te, pol_cnt_rj_te,
        ref_sum_ch_te, ref_cnt_ch_te, ref_sum_rj_te, ref_cnt_rj_te,
        lambda_ref=sel['lambda'], mode=SCORING_MODE, alpha=sel['alpha'],
        uncond_sum_ch=u_sum_ch_te, uncond_cnt_ch=u_cnt_ch_te,
        uncond_sum_rj=u_sum_rj_te, uncond_cnt_rj=u_cnt_rj_te, uncond_weight=sel['eta']
    )
    wr_te = evaluate_prior_wr(base_s_te, N_te)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, PRIOR-ONLY)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| λ={sel['lambda']:.2f} α={sel['alpha']:+.2f} η={sel['eta']:.2f} "
          f"| mode={SCORING_MODE} cap={sel['cap']} N={N_te}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loading HelpSteer2 (preference pairs) ...


README.md: 0.00B [00:00, ?B/s]

preference/preference.jsonl.gz:   0%|          | 0.00/15.3M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Preference pairs loaded: 7117
Split sizes: TRAIN=5693 | TEST=1424

[CAP=128] Caching TRAIN prior components ...
TRAIN kept N=5597 (dropped 96)
[CAP=128] Scoring unconditional (response-only) components ...
[TRAIN-PRIOR cap=128] WR=56.08% [54.8, 57.4] | λ=0.00 α=-1.50 η=0.00
[TRAIN-PRIOR cap=128] WR=56.99% [55.7, 58.3] | λ=0.00 α=-1.50 η=0.15
[TRAIN-PRIOR cap=128] WR=56.91% [55.6, 58.2] | λ=0.00 α=-1.50 η=0.25
[TRAIN-PRIOR cap=128] WR=56.96% [55.7, 58.3] | λ=0.00 α=-1.50 η=0.40
[TRAIN-PRIOR cap=128] WR=56.08% [54.8, 57.4] | λ=0.00 α=-1.25 η=0.00
[TRAIN-PRIOR cap=128] WR=56.96% [55.7, 58.3] | λ=0.00 α=-1.25 η=0.15
[TRAIN-PRIOR cap=128] WR=57.03% [55.7, 58.3] | λ=0.00 α=-1.25 η=0.25
[TRAIN-PRIOR cap=128] WR=56.96% [55.7, 58.3] | λ=0.00 α=-1.25 η=0.40
[TRAIN-PRIOR cap=128] WR=56.08% [54.8, 57.4] | λ=0.00 α=-1.00 η=0.00
[TRAIN-PRIOR cap=128] WR=57.05% [55.8, 58.3] | λ=0.00 α=-1.00 η=0.15
[TRAIN-PRIOR cap=128] WR=57.05% [55.8, 58.3] | λ=0.00 α=-1.00 η=0.25
[TRAIN-PRIOR cap=128] WR=56.71% [55