In [None]:
# bp_llm_ultrafeedback_eval.py
# Evaluate BP-LLM (unary BP classifier with JJ bound) win rate on UltraFeedback
# with Llama-3.2-3B policy priors.

import math
import dataclasses
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------
# Config
# -----------------------------

HF_TOKEN = " "  # your Read token


In [None]:
# BCO with Llama-3.x on Capybara-Preferences

# bp_llm_capybara_eval.py
# Capybara-Preferences BCO baseline (leakage-safe)
#
# What this does:
#   • Uses argilla/Capybara-Preferences (or the *-Filtered* variant) which provides
#     multi-turn conversations (`messages`), multiple candidate generations (`generations`)
#     and UltraFeedback `ratings`.
#   • Builds (prompt, chosen, rejected) by top-vs-bottom rating with a minimum gap.
#   • Splits into TRAIN/TEST (tune β and δ on TRAIN; report on TEST).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#   • TEST evaluation is fully label-free (BCO doesn’t need labels).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • To reduce obvious hallucination/URL noise, you can switch to the
#     "argilla/Capybara-Preferences-Filtered" dataset.
#   • This file is BCO-only for apples-to-apples comparison with your BP-LLM script.

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
#HF_TOKEN = os.environ.get("HF_TOKEN")

# Use an *Instruct* model for the policy; optional Base as reference
MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"   # None to disable reference subtraction

# --- DATASET: switch to Capybara-Preferences ---
DATASET         = "argilla/Capybara-Preferences"   # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # minimum rating gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# Capybara helpers
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _last_role_is_user(messages: List[Dict[str, str]]) -> bool:
    try:
        return isinstance(messages, list) and len(messages) > 0 and messages[-1].get("role") == "user"
    except Exception:
        return False

def _ensure_assistant_last(msgs):
    """Return (history_without_last, last_assistant_content) or (None, None) if malformed."""
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    # ensure history ends with a user turn for clean prompting (optional)
    if len(history) == 0 or history[-1].get("role") != "user":
        # try to trim trailing assistant turns if any (defensive)
        while len(history) and history[-1].get("role") == "assistant":
            history.pop()
        if len(history) == 0 or history[-1].get("role") != "user":
            return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = 0.0) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) from the *published* Capybara-Preferences schema:
      - chosen, rejected: list[{role, content}, ...] (full convo ending with assistant)
      - chosen_rating, rejected_rating: ints
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = record.get("chosen_rating")
    rr = record.get("rejected_rating")

    # basic checks
    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    try:
        cr = float(cr); rr = float(rr)
    except Exception:
        return None

    # require a minimum gap if desired
    if (abs(cr - rr) < min_gap):
        return None

    # split each side into (prompt history, last assistant response)
    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # Use the *longest common prefix* of histories as the prompt (they should usually match).
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k]  # common history up to the last matching turn
    if not prompt or prompt[-1].get("role") != "user":
        # fall back: use the chosen history without last assistant
        prompt = ch_hist

    return {
        "prompt": prompt,          # list[{role, content}] (ends with user)
        "chosen": ch_resp,         # string (assistant)
        "rejected": rj_resp,       # string (assistant)
    }


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(
    tokenizer,
    prompt_or_messages: Union[str, List[Dict[str, str]]]
) -> Optional[str]:
    """
    If FORCE_CHAT_TEMPLATE and tokenizer supports it:
      • For list[{'role','content'}] -> apply chat template directly (multi-turn)
      • For str -> wrap as single user turn, then add generation prompt
    Returns a textual prefix (no tokens) or None if templating not available.
    """
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        if isinstance(prompt_or_messages, list):
            msgs = prompt_or_messages
        else:
            msgs = [
                {"role": "user", "content": str(prompt_or_messages).strip()},
                {"role": "assistant", "content": ""},
            ]
            # the trailing empty assistant is just to emphasize gen prompt;
            # apply_chat_template(add_generation_prompt=True) will handle it anyway
            msgs = msgs[:-1]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]], response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for concatenated prompt+response.
    Uses chat template when available and supports multi-turn 'messages'.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt_or_messages)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        if isinstance(prompt_or_messages, list):
            # Fallback: flatten messages into a plain text prompt
            flat = []
            for m in prompt_or_messages:
                role = m.get("role", "user")
                content = (m.get("content") or "").strip()
                if content:
                    flat.append(f"{role.upper()}: {content}")
            prompt_text = "\n".join(flat)
        else:
            prompt_text = str(prompt_or_messages).strip()
        sep = "" if prompt_text.endswith("\n") else "\n"
        full_text = prompt_text + sep + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[Union[str, List[Dict[str, str]]]],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (i.e., not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    # prompts may be list[dict] (chat messages) here
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    # Keep only pairs where BOTH chosen and rejected are valid
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min rating gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (rating gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Adapted examples (rating gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925

Scoring TRAIN...
[TRAIN BCO baseline] WR=52.65% [51.0, 54.3] on 3696 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=52.65% [51.0, 54.3] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=52.65% [51.0, 54.3] with beta=0.5 delta=-0.1190

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=53.08% [49.9, 56.3] | beta=0.5 delta=-0.1190


In [None]:
# bp_llm_capybara_eval.py
# Evaluate BP-LLM (unary JJ) win rate on argilla/Capybara-Preferences.
# Adds a Turbo outer loop (K_TURBO >= 1) and improves prior-only calibration:
#   • λ-ref sweep without re-scoring (cache sums+counts once)
#   • Optional length penalty α without re-scoring when SCORING_MODE='lp_alpha'
#   • Per-pair token equalization (same k_i for chosen/rejected per model)
#
# Fixes kept:
#   • Render chat ONCE with policy tokenizer; feed same text to BOTH models
#   • Use the correct tokenizer for the reference model
#   • Cap scored response tokens to reduce length bias

import os
import math
import itertools
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
#HF_TOKEN = os.environ.get("HF_TOKEN")

MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"   # None to disable reference subtraction

DATASET         = "argilla/Capybara-Preferences"   # or "-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"

TRAIN_FRAC      = 0.8
SEED            = 42
MIN_GAP         = 0.0

MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

USE_CHAT_TEMPLATE = True

# Scoring: {"mean", "sum", "lp_alpha"}
SCORING_MODE          = "mean"   # >>> NEW/UPDATED <<< change to "lp_alpha" to tune α grid below
LENGTH_PENALTY_ALPHA  = 0.0      # default α; can be tuned in prior-only grid when SCORING_MODE='lp_alpha'

# >>> NEW/UPDATED <<< reference-weight & equalization knobs
LAMBDA_REF_DEFAULT    = 1.0
LAMBDA_REF_GRID       = (0.0, 0.5, 1.0, 1.5, 2.0)  # swept in prior-only grid
LP_ALPHA_GRID         = (-0.25, 0.0, 0.25)         # used only if SCORING_MODE == "lp_alpha"
EQUALIZE_PAIR_TOKENS  = True                       # per-pair k_i = min(#tok_ch, #tok_rj) per model

# BP hyperparams (defaults)
BETA            = 1.0
DELTA           = 0.0
TAU             = 1.0
GAMMA           = 1.0
JJ_INNER_STEPS  = 2

# === Turbo outer loop ===
# K_TURBO=1 reproduces the original single-pass behavior (no policy re-score).
# K_TURBO=2 runs: (1) classifier update -> extrinsic -> policy re-score (virtual),
#                  then (2) final classifier update used for the decision.
K_TURBO         = 2
TURBO_MIX       = 1.0  # 0..1 mix of extrinsic when nudging "scores" (1.0 = full)

# Grid tuning
DO_TUNE         = True
BETAS           = [0.5, 1.0, 2.0]
DELTAS          = ['bco', 0.0]
TAUS            = [0.5, 1.0, 2.0]
# Add a gentler γ to avoid overconfident JJ when labels are noisy/close:
GAMMAS          = [0.25, 0.5, 1.0, 1.5, 2.0]
JJ_STEPS_LIST   = [1, 2, 3, 0]        # 0 => adaptive (tolerance-based)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# =========================
# Capybara adapter
# =========================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]) -> Tuple[Optional[List[Dict[str,str]]], Optional[str]]:
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))

    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    if (cr is None) or (rr is None):
        return None
    if abs(cr - rr) < min_gap:
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # Longest common prefix
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k] if k > 0 else ch_hist
    if not prompt or prompt[-1].get("role") != "user":
        prompt = ch_hist

    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}


# =========================
# Chat rendering (render ONCE with policy tokenizer)
# =========================
def render_prompt_with_policy_template(policy_tokenizer, prompt_or_messages: Union[str, List[Dict[str,str]]]) -> str:
    """
    Render the chat template ONCE using the policy tokenizer.
    The resulting TEXT is fed to BOTH policy and reference models.
    """
    if USE_CHAT_TEMPLATE and hasattr(policy_tokenizer, "apply_chat_template"):
        try:
            if isinstance(prompt_or_messages, list):
                msgs = prompt_or_messages
            else:
                msgs = [{"role": "user", "content": str(prompt_or_messages).strip()}]
            return policy_tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            pass
    if isinstance(prompt_or_messages, list):
        flat = []
        for m in prompt_or_messages:
            role = m.get("role", "user")
            content = (m.get("content") or "").strip()
            if content:
                flat.append(f"{role.upper()}: {content}")
        p = "\n".join(flat)
    else:
        p = str(prompt_or_messages).strip()
    return p if p.endswith("\n") else (p + "\n")

def render_prompt_list(policy_tokenizer, prompts: List[Union[str, List[Dict[str,str]]]]) -> List[str]:
    return [render_prompt_with_policy_template(policy_tokenizer, p).strip() for p in prompts]


# =========================
# Tokenization / Scoring helpers
# =========================
@torch.no_grad()
def _sequence_token_counts_by_tokenizer(
    tokenizer,
    prompt_texts: List[str],
    responses: List[str],
) -> torch.Tensor:
    """
    Count how many response tokens would be scored (no model forward).
    Applies same truncation rules as scoring (prompt capped at MAX_INPUT_TOKENS,
    full text capped at MAX_INPUT_TOKENS+MAX_GEN_TOKENS), then clamps to SCORE_MAX_GEN_TOKENS.
    """
    counts = []
    for p_txt, r in zip(prompt_texts, responses):
        toks_prompt = tokenizer(
            p_txt, truncation=True, max_length=MAX_INPUT_TOKENS,
            return_tensors="pt", add_special_tokens=True
        )
        toks_full = tokenizer(
            p_txt + r.strip(),
            truncation=True,
            max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
            return_tensors="pt", add_special_tokens=True
        )
        p_len = int(toks_prompt["input_ids"].shape[-1])
        L     = int(toks_full["input_ids"].shape[-1])
        cnt   = max(0, min(L - p_len, SCORE_MAX_GEN_TOKENS))
        counts.append(cnt)
    return torch.tensor(counts, dtype=torch.long)

@torch.no_grad()
def _sequence_logprob_stats_text_with_k(
    model,
    tokenizer,
    prompt_texts: List[str],
    responses: List[str],
    k_list: torch.Tensor,  # per-example cap on response tokens (equalized or not)
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns (lp_sum, tok_count) for each pair using per-example k caps.
    """
    assert len(prompt_texts) == len(responses) == int(k_list.numel())
    sums, counts = [], []
    for i in range(0, len(prompt_texts), BATCH_SIZE):
        p_batch = prompt_texts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]
        k_batch = k_list[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p_txt, r in zip(p_batch, r_batch):
            full_text = p_txt + r.strip()
            toks_full = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            toks_prompt = tokenizer(
                p_txt,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = toks_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks_full)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L     = int(masks.sum().item())
            k     = int(k_batch[b].item())
            # ensure we don't exceed actual available tokens
            k_eff = max(0, min(k, L - p_len))
            if k_eff <= 0:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = p_len + k_eff
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(k_eff))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_sums_counts(lp_sum: torch.Tensor, tok_count: torch.Tensor,
                            mode: str, alpha: float) -> torch.Tensor:
    if mode == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif mode == "lp_alpha":
        return lp_sum.to(torch.float32) + (alpha * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =========================
# BP-LLM (unary JJ) posterior
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    if xi < 1e-8:
        return 1.0 / 8.0
    return math.tanh(xi / 2.0) / (4.0 * xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def bp_unary_posterior_adaptive(mu_prior: float, b: int, params: BPParams,
                                tol: float = 1e-4, max_steps: int = 50) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    steps = params.jj_steps if params.jj_steps and params.jj_steps > 0 else max_steps
    for _ in range(steps):
        mu_prev = mu_hat
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
        if (params.jj_steps is None or params.jj_steps <= 0) and abs(mu_hat - mu_prev) < tol:
            break
    return mu_hat, tau2_hat


# =========================
# Turbo outer loop on cached scores
# =========================
def _jj_update(mu_prior: float, b: int, params: BPParams, use_adaptive_jj: bool) -> float:
    if use_adaptive_jj:
        mu_post, _ = bp_unary_posterior_adaptive(mu_prior, b=b, params=params)
    else:
        mu_post, _ = bp_unary_posterior(mu_prior, b=b, params=params)
    return mu_post

def turbo_refine_scores(
    base_s: torch.Tensor, N: int, params: BPParams,
    K_turbo: int, use_adaptive_jj: bool = True, mix: float = 1.0
) -> torch.Tensor:
    """
    Simulate Turbo policy<->classifier exchanges by nudging the *scores*
    using the classifier's μ-domain extrinsic: Δμ = μ_post - μ_prior.
    We update s := s + (mix/β)*Δμ for each side, then recompute μ = β s - δ.
    """
    assert K_turbo >= 2, "Use K_turbo>=2 for Turbo refinement; K=1 is the baseline."
    s = base_s.clone().to(torch.float32)

    for _ in range(K_turbo - 1):
        # Current priors from (virtual) policy score
        mu = (params.beta * s) - params.delta

        # One JJ update per side using labels
        for i in range(N):
            # chosen (b=1), rejected (b=0)
            mu_w_prior = float(mu[i])
            mu_l_prior = float(mu[i + N])

            mu_w_post  = _jj_update(mu_w_prior, b=1, params=params, use_adaptive_jj=use_adaptive_jj)
            mu_l_post  = _jj_update(mu_l_prior, b=0, params=params, use_adaptive_jj=use_adaptive_jj)

            # μ-extrinsics
            dmu_w = mu_w_post - mu_w_prior
            dmu_l = mu_l_post - mu_l_prior

            # Feed extrinsic back to the policy score (virtual re-score)
            s[i]   = s[i]   + (mix * dmu_w) / max(params.beta, 1e-8)
            s[i+N] = s[i+N] + (mix * dmu_l) / max(params.beta, 1e-8)

        # loop continues with updated s

    return s


# =========================
# Evaluation & tuning
# =========================
def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def evaluate_with_cached(
    base_s: torch.Tensor, N: int, params: BPParams,
    use_adaptive_jj: bool = True, use_labels: bool = True, K_turbo: int = 1
) -> float:
    """
    K_turbo semantics:
      - K_turbo=1: original single classifier pass (no policy re-score).
      - K_turbo>=2: run K_turbo-1 Turbo exchanges that nudge 's', then a final
                    classifier pass for the decision.
    """
    if not use_labels:
        # Label-free: can't run JJ; just compare priors (policy-only).
        mu = (params.beta * base_s) - params.delta
        mu = mu.tolist()
        correct = 0
        for i in range(N):
            if mu[i] > mu[i + N]:
                correct += 1
        return 100.0 * correct / N

    if K_turbo <= 1:
        # Original behavior: one classifier update, no score refinement
        mu = (params.beta * base_s) - params.delta
        mu = mu.tolist()
        correct = 0
        for i in range(N):
            mu_w_prior = mu[i]
            mu_l_prior = mu[i + N]
            mu_w_post  = _jj_update(mu_w_prior, b=1, params=params, use_adaptive_jj=use_adaptive_jj)
            mu_l_post  = _jj_update(mu_l_prior, b=0, params=params, use_adaptive_jj=use_adaptive_jj)
            if mu_w_post > mu_l_post:
                correct += 1
        return 100.0 * correct / N

    # Turbo refinement of scores, then a final classifier pass
    s_ref = turbo_refine_scores(
        base_s=base_s, N=N, params=params, K_turbo=K_turbo,
        use_adaptive_jj=use_adaptive_jj, mix=TURBO_MIX
    )
    mu = (params.beta * s_ref) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        mu_w_post = _jj_update(mu[i], b=1, params=params, use_adaptive_jj=use_adaptive_jj)
        mu_l_post = _jj_update(mu[i + N], b=0, params=params, use_adaptive_jj=use_adaptive_jj)
        if mu_w_post > mu_l_post:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search(
    records: List[Dict], policy_model, ref_model, policy_tok, ref_tok,
    betas: Iterable[float], deltas: Iterable[Optional[str]],
    taus: Iterable[float], gammas: Iterable[float],
    jj_steps_list: Iterable[int], use_adaptive_jj: bool = True,
    estimate_delta: bool = True, split_name: str = "TRAIN",
    K_turbo: int = 1
):
    # cache base with defaults
    (base_s, N,
     pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
     ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj) = cache_pairwise_scores(
        records, policy_model, ref_model, policy_tok, ref_tok,
        lambda_ref=LAMBDA_REF_DEFAULT
    )

    tried = []
    best = (-1.0, None)

    for beta, delta_spec, tau, gamma, jj_steps in itertools.product(
        betas, deltas, taus, gammas, jj_steps_list
    ):
        if estimate_delta and (delta_spec == 'bco'):
            delta = estimate_delta_bco(base_s, beta, N)
        else:
            delta = float(delta_spec)

        params = BPParams(beta=beta, delta=delta, tau=tau, gamma=gamma, jj_steps=jj_steps)
        wr = evaluate_with_cached(
            base_s, N, params,
            use_adaptive_jj=use_adaptive_jj, use_labels=True, K_turbo=K_turbo
        )
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, params))
        if wr > best[0]:
            best = (wr, params)

        print(f"[TUNE/{split_name}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'} "
              f"| tau={tau} gamma={gamma} jj_steps={jj_steps} K_turbo={K_turbo}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BP-LLM Tuning ({split_name})] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f} "
          f"tau={best_params.tau} gamma={best_params.gamma} jj_steps={best_params.jj_steps} "
          f"K_turbo={K_turbo}")
    return tried, best_params


# =========================
# Caching scores (pairwise) — now returns components (sums+counts)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    policy_model,
    ref_model,
    policy_tok,
    ref_tok=None,
    lambda_ref: float = LAMBDA_REF_DEFAULT,
) -> Tuple[
    torch.Tensor, int,  # base_s, N
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,  # pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor   # ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj
]:
    """
    Returns:
      base_s: tensor [2N] (first N chosen, next N rejected) computed with given lambda_ref and current SCORING_MODE/alpha.
      N: number of valid pairs
      plus sums and counts per side per model to allow recombination without re-scoring.
    """
    prompts_raw = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]

    # Render once with policy tokenizer to TEXT
    prompt_texts = render_prompt_list(policy_tok, prompts_raw)

    # --- POLICY: compute per-pair k (equalized or not) via token counts (no forward)
    pol_cnt_ch_nom = _sequence_token_counts_by_tokenizer(policy_tok, prompt_texts, resps_ch)
    pol_cnt_rj_nom = _sequence_token_counts_by_tokenizer(policy_tok, prompt_texts, resps_rj)
    if EQUALIZE_PAIR_TOKENS:
        k_pol = torch.minimum(pol_cnt_ch_nom, pol_cnt_rj_nom)
    else:
        # use individual caps
        k_pol = None

    # Summed logprobs under policy with caps
    pol_sum_ch, pol_cnt_ch = _sequence_logprob_stats_text_with_k(
        policy_model, policy_tok, prompt_texts, resps_ch,
        k_pol if k_pol is not None else pol_cnt_ch_nom
    )
    pol_sum_rj, pol_cnt_rj = _sequence_logprob_stats_text_with_k(
        policy_model, policy_tok, prompt_texts, resps_rj,
        k_pol if k_pol is not None else pol_cnt_rj_nom
    )
    pol_score_ch = _score_from_sums_counts(pol_sum_ch, pol_cnt_ch, SCORING_MODE, LENGTH_PENALTY_ALPHA)
    pol_score_rj = _score_from_sums_counts(pol_sum_rj, pol_cnt_rj, SCORING_MODE, LENGTH_PENALTY_ALPHA)

    # --- REFERENCE: same process if provided
    if ref_model is None:
        ref_sum_ch = torch.zeros_like(pol_sum_ch); ref_cnt_ch = torch.ones_like(pol_cnt_ch)
        ref_sum_rj = torch.zeros_like(pol_sum_rj); ref_cnt_rj = torch.ones_like(pol_cnt_rj)
        ref_score_ch = torch.zeros_like(pol_score_ch)
        ref_score_rj = torch.zeros_like(pol_score_rj)
    else:
        # Rendered TEXT is the same; tokenize with ref_tok internally
        ref_cnt_ch_nom = _sequence_token_counts_by_tokenizer(ref_tok, prompt_texts, resps_ch)
        ref_cnt_rj_nom = _sequence_token_counts_by_tokenizer(ref_tok, prompt_texts, resps_rj)
        if EQUALIZE_PAIR_TOKENS:
            k_ref = torch.minimum(ref_cnt_ch_nom, ref_cnt_rj_nom)
        else:
            k_ref = None

        ref_sum_ch, ref_cnt_ch = _sequence_logprob_stats_text_with_k(
            ref_model, ref_tok, prompt_texts, resps_ch,
            k_ref if k_ref is not None else ref_cnt_ch_nom
        )
        ref_sum_rj, ref_cnt_rj = _sequence_logprob_stats_text_with_k(
            ref_model, ref_tok, prompt_texts, resps_rj,
            k_ref if k_ref is not None else ref_cnt_rj_nom
        )
        ref_score_ch = _score_from_sums_counts(ref_sum_ch, ref_cnt_ch, SCORING_MODE, LENGTH_PENALTY_ALPHA)
        ref_score_rj = _score_from_sums_counts(ref_sum_rj, ref_cnt_rj, SCORING_MODE, LENGTH_PENALTY_ALPHA)

    # Valid pairs: both sides have ≥1 scored token under BOTH models (if ref used)
    valid_ch = (pol_cnt_ch > 0) & (ref_cnt_ch > 0)
    valid_rj = (pol_cnt_rj > 0) & (ref_cnt_rj > 0)
    idx = [i for i in range(len(records)) if valid_ch[i].item() and valid_rj[i].item()]
    if not idx:
        raise RuntimeError("No valid pairs after scoring. Increase limits or reduce MIN_GAP.")

    idx_t = torch.tensor(idx, dtype=torch.long)
    # slice components
    pol_sum_ch = pol_sum_ch[idx_t]; pol_cnt_ch = pol_cnt_ch[idx_t]
    pol_sum_rj = pol_sum_rj[idx_t]; pol_cnt_rj = pol_cnt_rj[idx_t]
    ref_sum_ch = ref_sum_ch[idx_t]; ref_cnt_ch = ref_cnt_ch[idx_t]
    ref_sum_rj = ref_sum_rj[idx_t]; ref_cnt_rj = ref_cnt_rj[idx_t]

    # recompute scores after slicing
    pol_score_ch = _score_from_sums_counts(pol_sum_ch, pol_cnt_ch, SCORING_MODE, LENGTH_PENALTY_ALPHA)
    pol_score_rj = _score_from_sums_counts(pol_sum_rj, pol_cnt_rj, SCORING_MODE, LENGTH_PENALTY_ALPHA)
    ref_score_ch = _score_from_sums_counts(ref_sum_ch, ref_cnt_ch, SCORING_MODE, LENGTH_PENALTY_ALPHA)
    ref_score_rj = _score_from_sums_counts(ref_sum_rj, ref_cnt_rj, SCORING_MODE, LENGTH_PENALTY_ALPHA)

    chosen_vals   = (pol_score_ch - lambda_ref * ref_score_ch).to(torch.float32)
    rejected_vals = (pol_score_rj - lambda_ref * ref_score_rj).to(torch.float32)
    base_s = torch.cat([chosen_vals, rejected_vals], dim=0)
    N = len(idx)

    return (base_s, N,
            pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
            ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj)


# >>> NEW/UPDATED <<< helpers to recompute scores/base_s without model passes
def make_scores_from_components(
    sum_t: torch.Tensor, cnt_t: torch.Tensor,
    mode: str, alpha: float
) -> torch.Tensor:
    return _score_from_sums_counts(sum_t, cnt_t, mode, alpha)

def make_base_s_from_components(
    pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
    ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj,
    lambda_ref: float, mode: str, alpha: float
) -> torch.Tensor:
    pol_score_ch = make_scores_from_components(pol_sum_ch, pol_cnt_ch, mode, alpha)
    pol_score_rj = make_scores_from_components(pol_sum_rj, pol_cnt_rj, mode, alpha)
    ref_score_ch = make_scores_from_components(ref_sum_ch, ref_cnt_ch, mode, alpha)
    ref_score_rj = make_scores_from_components(ref_sum_rj, ref_cnt_rj, mode, alpha)
    return torch.cat([
        (pol_score_ch - lambda_ref * ref_score_ch).to(torch.float32),
        (pol_score_rj - lambda_ref * ref_score_rj).to(torch.float32)
    ], dim=0)


# --- PRIOR-ONLY (no labels) calibration over {lambda_ref} (+ optional α) ---
def grid_search_prior_only_components(
    N: int,
    pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
    ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj,
    lambda_grid: Iterable[float] = LAMBDA_REF_GRID,
    alpha_grid: Optional[Iterable[float]] = None,  # only used if SCORING_MODE=='lp_alpha'
    split_name: str = "TRAIN-PRIOR"
):
    tried = []
    best = (-1.0, None, None)

    if SCORING_MODE == "lp_alpha":
        if alpha_grid is None:
            alpha_grid = LP_ALPHA_GRID
    else:
        alpha_grid = [LENGTH_PENALTY_ALPHA]

    for lam, alpha in itertools.product(lambda_grid, alpha_grid):
        base_s = make_base_s_from_components(
            pol_sum_ch, pol_cnt_ch, pol_sum_rj, pol_cnt_rj,
            ref_sum_ch, ref_cnt_ch, ref_sum_rj, ref_cnt_rj,
            lambda_ref=lam, mode=SCORING_MODE, alpha=alpha
        )
        # prior-only: γ=0, jj=0, K_turbo=1
        params = BPParams(beta=1.0, delta=0.0, tau=1.0, gamma=0.0, jj_steps=0)
        wr = evaluate_with_cached(
            base_s, N, params,
            use_adaptive_jj=False, use_labels=False, K_turbo=1
        )
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, lam, alpha))
        if wr > best[0]:
            best = (wr, lam, alpha)
        print(f"[TUNE/{split_name}] PRIOR-ONLY WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| λ={lam:.2f} α={alpha:+.2f} | (γ=0, jj=0, K_turbo=1)")

    best_wr, best_lam, best_alpha = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[Prior-only calibration] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] "
          f"with λ={best_lam:.2f} α={best_alpha:+.2f} (SCORING_MODE={SCORING_MODE})")
    return tried, best_lam, best_alpha


# =========================
# Main
# =========================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok_policy, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    tok_ref, ref = None, None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        tok_ref, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a and a["prompt"] and a["chosen"] and a["rejected"]:
            adapted.append(a)

    if not adapted:
        raise RuntimeError("No valid examples after adapting Capybara. Try MIN_GAP=0.0 or expand SPLIT.")

    # Deterministic split
    g = torch.Generator().manual_seed(SEED)
    idx = torch.randperm(len(adapted), generator=g).tolist()
    split = int(len(idx) * TRAIN_FRAC)
    train_idx, test_idx = idx[:split], idx[split:]
    train_records = [adapted[i] for i in train_idx]
    test_records  = [adapted[i] for i in test_idx]
    print(f"Train records: {len(train_records)} | Test records: {len(test_records)}")

    # ----- Cache TRAIN (with default λ & α for Turbo tuning path) -----
    (base_s_tr, N_tr,
     pol_sum_ch_tr, pol_cnt_ch_tr, pol_sum_rj_tr, pol_cnt_rj_tr,
     ref_sum_ch_tr, ref_cnt_ch_tr, ref_sum_rj_tr, ref_cnt_rj_tr) = cache_pairwise_scores(
        train_records, policy, ref, tok_policy, tok_ref, lambda_ref=LAMBDA_REF_DEFAULT
    )

    # A) (Optional) Turbo tuning WITH labels on TRAIN
    if DO_TUNE:
        print("\nTuning Turbo (WITH labels) on TRAIN...")
        _tries_turbo, best_params_turbo = grid_search(
            train_records, policy, ref, tok_policy, tok_ref,
            betas=BETAS, deltas=DELTAS, taus=TAUS, gammas=GAMMAS,
            jj_steps_list=JJ_STEPS_LIST, use_adaptive_jj=True,
            estimate_delta=True, split_name="TRAIN", K_turbo=K_TURBO
        )
    else:
        best_params_turbo = BPParams(beta=BETA, delta=DELTA, tau=TAU, gamma=GAMMA, jj_steps=JJ_INNER_STEPS)

    # Sanity print for TRAIN with Turbo settings
    wr_train_turbo = evaluate_with_cached(
        base_s_tr, N_tr, best_params_turbo,
        use_adaptive_jj=True, use_labels=True, K_turbo=K_TURBO
    )
    lo0, hi0 = binom_ci_95(wr_train_turbo, N_tr)
    print(f"[Sanity TRAIN w/ Turbo] WR={wr_train_turbo:.2f}% [{lo0:.1f}, {hi0:.1f}] "
          f"| beta={best_params_turbo.beta} delta={best_params_turbo.delta:.4f} "
          f"| tau={best_params_turbo.tau} gamma={best_params_turbo.gamma} "
          f"jj_steps={best_params_turbo.jj_steps} K_turbo={K_TURBO}")

    # B) PRIOR-ONLY calibration for TEST: pick {lambda_ref, alpha} (β,δ cancel in γ=0 regime)
    print("\nCalibrating PRIOR-ONLY (no labels) on TRAIN to pick {lambda_ref, alpha} for TEST...")
    _tries_prior, best_lambda, best_alpha = grid_search_prior_only_components(
        N_tr,
        pol_sum_ch_tr, pol_cnt_ch_tr, pol_sum_rj_tr, pol_cnt_rj_tr,
        ref_sum_ch_tr, ref_cnt_ch_tr, ref_sum_rj_tr, ref_cnt_rj_tr,
        lambda_grid=LAMBDA_REF_GRID,
        alpha_grid=LP_ALPHA_GRID if SCORING_MODE == "lp_alpha" else None,
        split_name="TRAIN-PRIOR"
    )

    # Cache TEST components (no need to re-run for multiple λ/α)
    (base_s_te_default, N_te,
     pol_sum_ch_te, pol_cnt_ch_te, pol_sum_rj_te, pol_cnt_rj_te,
     ref_sum_ch_te, ref_cnt_ch_te, ref_sum_rj_te, ref_cnt_rj_te) = cache_pairwise_scores(
        test_records, policy, ref, tok_policy, tok_ref, lambda_ref=LAMBDA_REF_DEFAULT
    )

    # Build TEST base_s with best λ, α and evaluate PRIOR-ONLY (γ=0)
    base_s_te = make_base_s_from_components(
        pol_sum_ch_te, pol_cnt_ch_te, pol_sum_rj_te, pol_cnt_rj_te,
        ref_sum_ch_te, ref_cnt_ch_te, ref_sum_rj_te, ref_cnt_rj_te,
        lambda_ref=best_lambda, mode=SCORING_MODE, alpha=best_alpha
    )
    test_params = BPParams(beta=1.0, delta=0.0, tau=1.0, gamma=0.0, jj_steps=0)
    wr_te = evaluate_with_cached(
        base_s_te, N_te, test_params,
        use_adaptive_jj=False, use_labels=False, K_turbo=1
    )
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, PRIOR-ONLY)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| λ={best_lambda:.2f} α={best_alpha:+.2f} | (γ=0, jj=0, K_turbo=1, mode={SCORING_MODE})")


if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/78.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15404 [00:00<?, ? examples/s]

Train records: 3696 | Test records: 925

Tuning Turbo (WITH labels) on TRAIN...
[TUNE/TRAIN] WR=97.89% [97.4, 98.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.25 jj_steps=1 K_turbo=2
[TUNE/TRAIN] WR=97.89% [97.4, 98.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.25 jj_steps=2 K_turbo=2
[TUNE/TRAIN] WR=97.89% [97.4, 98.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.25 jj_steps=3 K_turbo=2
[TUNE/TRAIN] WR=97.89% [97.4, 98.4] | beta=0.5 delta=bco | tau=0.5 gamma=0.25 jj_steps=0 K_turbo=2
[TUNE/TRAIN] WR=98.70% [98.3, 99.1] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=1 K_turbo=2
[TUNE/TRAIN] WR=98.70% [98.3, 99.1] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=2 K_turbo=2
[TUNE/TRAIN] WR=98.70% [98.3, 99.1] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=3 K_turbo=2
[TUNE/TRAIN] WR=98.70% [98.3, 99.1] | beta=0.5 delta=bco | tau=0.5 gamma=0.5 jj_steps=0 K_turbo=2
[TUNE/TRAIN] WR=99.65% [99.5, 99.8] | beta=0.5 delta=bco | tau=0.5 gamma=1.0 jj_steps=1 K_turbo=2
[TUNE/TRAIN] WR=99.65% [99.5, 99.8

In [None]:
# dpo_capybara_eval.py
# DPO baseline (leakage-safe) on argilla/Capybara-Preferences
#
# What this does:
#   • Adapts Capybara to (prompt, chosen, rejected) using the published schema:
#       - 'chosen' / 'rejected' are chat transcripts (list of {role, content})
#       - 'chosen_rating' / 'rejected_rating' are numeric
#     It extracts a common user-ended history as the prompt and the last assistant
#     message as the response for each side.
#   • Renders the chat template ONCE with the policy tokenizer; feeds the same
#     rendered text to policy and reference models.
#   • Computes DPO margins Δ = (s_pos - s_neg) with s = log π − log π_ref (ref optional).
#   • TRAIN: tune β by minimizing mean DPO loss  L_DPO = E[-log σ(βΔ)]  (also reports WR).
#   • TEST: evaluate win rate (labels only for metric; scoring is label-free).
#   • Length-normalized scoring over response tokens; drops truncated pairs.
#
# Notes:
#   • Set REF_MODEL_NAME=None to disable reference subtraction (use plain log π).
#   • If you get too few pairs, use a bigger SPLIT or reduce MIN_GAP.
#   • Also works with "argilla/Capybara-Preferences-Filtered".
#
# Requires: datasets, transformers, torch

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Policy (Instruct) + optional Base reference (Llama 3.2)
MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"     # set to None to disable reference subtraction

DATASET         = "argilla/Capybara-Preferences"  # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"                   # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Scoring
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # Capybara ratings often tie; start at 0.0, raise to 0.5/1.0 to drop near-ties

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.25, 0.5, 1.0, 2.0]    # temperature-like scale for Δ

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)

# =============================================================================
# Capybara adapter
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]) -> Tuple[Optional[List[Dict[str,str]]], Optional[str]]:
    """
    Return (history_without_last, last_assistant_content) if the last turn is assistant
    and the remaining history ends with user; otherwise (None, None).
    """
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    # ensure history ends with user
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def _common_user_ended_prefix(a: List[Dict[str,str]], b: List[Dict[str,str]]) -> List[Dict[str,str]]:
    m = min(len(a), len(b))
    k = 0
    while k < m and (a[k].get("role") == b[k].get("role")) and (a[k].get("content") == b[k].get("content")):
        k += 1
    prefix = a[:k] if k > 0 else a
    # if it doesn't end with user, fall back to 'a' (which we already ensured ends with user)
    return prefix if prefix and prefix[-1].get("role") == "user" else a

def adapt_capybara(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Capybara schema:
      - 'chosen' / 'rejected': list[{role, content}] (assistant last)
      - 'chosen_rating' / 'rejected_rating': numeric
    Produces:
      prompt: common history (list[{role, content}], ends with user)
      chosen/rejected: final assistant strings
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))

    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    if (cr is None) or (rr is None) or (abs(cr - rr) < min_gap):
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    prompt = _common_user_ended_prefix(ch_hist, rj_hist)
    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}

# =============================================================================
# Models
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype, trust_remote_code=True).to(device)
    model.eval()
    return tok, model

# =============================================================================
# Prompt rendering (policy chat template) + scoring
# =============================================================================
def render_prompt_with_policy_template(policy_tokenizer, prompt_or_messages: Union[str, List[Dict[str,str]]]) -> str:
    """
    Render chat template ONCE with the policy tokenizer.
    The resulting text is fed to BOTH policy and reference models.
    Supports list-of-messages prompts.
    """
    if hasattr(policy_tokenizer, "apply_chat_template"):
        try:
            if isinstance(prompt_or_messages, list):
                msgs = prompt_or_messages
            else:
                msgs = [{"role": "user", "content": str(prompt_or_messages).strip()}]
            return policy_tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            pass
    p = (str(prompt_or_messages).strip() if not isinstance(prompt_or_messages, list) else
         "\n".join(f"{m.get('role','user').upper()}: {(m.get('content') or '').strip()}"
                   for m in prompt_or_messages))
    return p if p.endswith("\n") else (p + "\n")

def render_prompt_list(policy_tokenizer, prompts: List[Union[str, List[Dict[str,str]]]]) -> List[str]:
    return [render_prompt_with_policy_template(policy_tokenizer, p) for p in prompts]

def _concat_prompt_response_text(prompt_text: str, response: str) -> Tuple[str, str]:
    return prompt_text + response.strip(), prompt_text

@torch.no_grad()
def sequence_logprob_stats_text(
    model,
    tokenizer,
    prompt_texts: List[str],   # already rendered with policy template
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over response tokens
    – tok_count is the number of response tokens scored
    If response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompt_texts), BATCH_SIZE):
        p_batch = prompt_texts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p_txt, r in zip(p_batch, r_batch):
            full_text, prompt_only = _concat_prompt_response_text(p_txt, r)
            toks_full = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            toks_prompt = tokenizer(
                prompt_only,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = toks_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks_full)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)

# =============================================================================
# DPO pieces
# =============================================================================
@dataclass
class DPOTune:
    beta: float = 1.0

def dpo_loss_and_wr(delta: torch.Tensor, beta: float) -> Tuple[float, float]:
    """
    Mean DPO loss and win rate for a fixed β.
    Loss = E[-log σ(βΔ)] = mean(softplus(-βΔ)), WR = mean(Δ > 0)*100
    """
    z = beta * delta
    loss = torch.nn.functional.softplus(-z).mean().item()
    wr = (delta > 0).to(torch.float32).mean().item() * 100.0
    return loss, wr

@torch.no_grad()
def cache_dpo_margins(
    records: List[Dict],
    policy_model,
    ref_model,
    policy_tok,
    ref_tok=None
) -> Tuple[torch.Tensor, int]:
    """
    Returns Δ tensor of shape [N_valid] where
      Δ_i = (s_pos_i - s_neg_i),
      s = log π(y|x) - log π_ref(y|x)  (ref optional)
    Valid pairs require ≥1 response token scored for BOTH sides (and for ref if used).
    """
    prompts_raw = [r["prompt"] for r in records]  # list[Union[str, List[dict]]], here list-of-messages
    prompt_texts = render_prompt_list(policy_tok, prompts_raw)

    resps_ch = [r["chosen"] for r in records]
    resps_rj = [r["rejected"] for r in records]

    # Policy (own tokenizer)
    pol_sum_ch, pol_cnt_ch = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_ch)
    pol_sum_rj, pol_cnt_rj = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_rj)
    pol_score_ch = _score_from_stats(pol_sum_ch, pol_cnt_ch)
    pol_score_rj = _score_from_stats(pol_sum_rj, pol_cnt_rj)

    # Reference (optional; own tokenizer; same prompt_texts)
    if ref_model is None:
        ref_score_ch = torch.zeros_like(pol_score_ch); ref_cnt_ch = torch.ones_like(pol_cnt_ch)
        ref_score_rj = torch.zeros_like(pol_score_rj); ref_cnt_rj = torch.ones_like(pol_cnt_rj)
    else:
        if ref_tok is None:
            raise ValueError("ref_model provided without ref_tok")
        ref_sum_ch, ref_cnt_ch = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_ch)
        ref_sum_rj, ref_cnt_rj = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_rj)
        ref_score_ch = _score_from_stats(ref_sum_ch, ref_cnt_ch)
        ref_score_rj = _score_from_stats(ref_sum_rj, ref_cnt_rj)

    # Validity mask: both sides non-truncated for both models (if ref used)
    valid_ch = (pol_cnt_ch > 0) & (ref_cnt_ch > 0)
    valid_rj = (pol_cnt_rj > 0) & (ref_cnt_rj > 0)
    valid_idx = [i for i in range(len(records)) if valid_ch[i].item() and valid_rj[i].item()]
    if not valid_idx:
        raise RuntimeError("No valid (non-truncated) pairs. Increase token limits, expand SPLIT, or lower MIN_GAP.")

    # s = log π - log π_ref (ref=0 if disabled); Δ = s_pos - s_neg
    s_pos = pol_score_ch[valid_idx] - ref_score_ch[valid_idx]
    s_neg = pol_score_rj[valid_idx] - ref_score_rj[valid_idx]
    delta = (s_pos - s_neg).to(torch.float32)  # [N_valid]
    return delta, len(valid_idx)

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def tune_beta_dpo(
    records: List[Dict], policy_model, ref_model, policy_tok, ref_tok,
    betas: Iterable[float], split_name: str = "TRAIN"
) -> Tuple[List[Tuple[float,float,float,float]], DPOTune]:
    """
    Returns (trials, best), where each trial is (beta, loss, wr, N).
    Select β minimizing mean DPO loss on TRAIN (ties broken by higher WR).
    """
    delta, N = cache_dpo_margins(records, policy_model, ref_model, policy_tok, ref_tok)
    trials = []
    best = (float("inf"), -1.0, None)  # (loss, wr, beta)
    for beta in betas:
        loss, wr = dpo_loss_and_wr(delta, beta)
        trials.append((beta, loss, wr, N))
        lo, hi = binom_ci_95(wr, N)
        print(f"[TUNE/{split_name} DPO] beta={beta:.3g} | loss={loss:.4f} | WR={wr:.2f}% [{lo:.1f}, {hi:.1f}]")
        if (loss < best[0]) or (abs(loss - best[0]) < 1e-6 and wr > best[1]):
            best = (loss, wr, beta)

    best_beta = best[2]
    print(f"\n[DPO Tuning ({split_name})] Best beta={best_beta:.3g} "
          f"(loss={best[0]:.4f}, WR={best[1]:.2f}%) on {N} pairs")
    return trials, DPOTune(beta=best_beta)

# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok_policy, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    tok_ref, ref = None, None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        tok_ref, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    print("Columns:", ds.column_names)
    if len(ds) > 0:
        sample0 = ds[0]
        print("Row[0] keys:", list(sample0.keys()))
        preview = {k: (str(sample0[k])[:200] + "…") for k in sample0.keys()}
        print("Row[0] preview:", preview)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Try MIN_GAP=0.0, expand SPLIT, or inspect data.")

    # Deterministic split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Tune β on TRAIN (min DPO loss; report WR too)
    if DO_TUNE:
        print("\nTuning beta on TRAIN (DPO)...")
        _tries, best = tune_beta_dpo(train_recs, policy, ref, tok_policy, tok_ref, betas=BETAS, split_name="TRAIN")
        beta = best.beta
    else:
        beta = 1.0

    # Evaluate on TEST
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as proxy sanity check.")

    print("\nScoring TEST (DPO metrics)...")
    delta_te, N_te = cache_dpo_margins(eval_set, policy, ref, tok_policy, tok_ref)
    loss_te, wr_te = dpo_loss_and_wr(delta_te, beta)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[DPO EVAL (TEST)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] | loss={loss_te:.4f} | beta={beta:.3g} | N={N_te}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Columns: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] keys: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] preview: {'source': 'Airoboros…', 'chosen': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'chosen_rating': '5…', 'chosen_model': 'teknium/OpenHermes-2.5-Mistral-7B…', 'rejected': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'rejected_rating': '4…', 'rejected_model': 'gpt-4-1106-preview…'}
Adapted examples (gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925


# The Following is for Llama3B

In [None]:
# BCO with Llama-3.x on Capybara-Preferences

# bp_llm_capybara_eval.py
# Capybara-Preferences BCO baseline (leakage-safe)
#
# What this does:
#   • Uses argilla/Capybara-Preferences (or the *-Filtered* variant) which provides
#     multi-turn conversations (`messages`), multiple candidate generations (`generations`)
#     and UltraFeedback `ratings`.
#   • Builds (prompt, chosen, rejected) by top-vs-bottom rating with a minimum gap.
#   • Splits into TRAIN/TEST (tune β and δ on TRAIN; report on TEST).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#   • TEST evaluation is fully label-free (BCO doesn’t need labels).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • To reduce obvious hallucination/URL noise, you can switch to the
#     "argilla/Capybara-Preferences-Filtered" dataset.
#   • This file is BCO-only for apples-to-apples comparison with your BP-LLM script.

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
#HF_TOKEN = os.environ.get("HF_TOKEN")

# Use an *Instruct* model for the policy; optional Base as reference
MODEL_NAME      = "meta-llama/Llama-3.2-3B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.2-3B"   # None to disable reference subtraction

# --- DATASET: switch to Capybara-Preferences ---
DATASET         = "argilla/Capybara-Preferences"   # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # minimum rating gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# Capybara helpers
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _last_role_is_user(messages: List[Dict[str, str]]) -> bool:
    try:
        return isinstance(messages, list) and len(messages) > 0 and messages[-1].get("role") == "user"
    except Exception:
        return False

def _ensure_assistant_last(msgs):
    """Return (history_without_last, last_assistant_content) or (None, None) if malformed."""
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    # ensure history ends with a user turn for clean prompting (optional)
    if len(history) == 0 or history[-1].get("role") != "user":
        # try to trim trailing assistant turns if any (defensive)
        while len(history) and history[-1].get("role") == "assistant":
            history.pop()
        if len(history) == 0 or history[-1].get("role") != "user":
            return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = 0.0) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) from the *published* Capybara-Preferences schema:
      - chosen, rejected: list[{role, content}, ...] (full convo ending with assistant)
      - chosen_rating, rejected_rating: ints
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = record.get("chosen_rating")
    rr = record.get("rejected_rating")

    # basic checks
    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    try:
        cr = float(cr); rr = float(rr)
    except Exception:
        return None

    # require a minimum gap if desired
    if (abs(cr - rr) < min_gap):
        return None

    # split each side into (prompt history, last assistant response)
    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # Use the *longest common prefix* of histories as the prompt (they should usually match).
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k]  # common history up to the last matching turn
    if not prompt or prompt[-1].get("role") != "user":
        # fall back: use the chosen history without last assistant
        prompt = ch_hist

    return {
        "prompt": prompt,          # list[{role, content}] (ends with user)
        "chosen": ch_resp,         # string (assistant)
        "rejected": rj_resp,       # string (assistant)
    }


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(
    tokenizer,
    prompt_or_messages: Union[str, List[Dict[str, str]]]
) -> Optional[str]:
    """
    If FORCE_CHAT_TEMPLATE and tokenizer supports it:
      • For list[{'role','content'}] -> apply chat template directly (multi-turn)
      • For str -> wrap as single user turn, then add generation prompt
    Returns a textual prefix (no tokens) or None if templating not available.
    """
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        if isinstance(prompt_or_messages, list):
            msgs = prompt_or_messages
        else:
            msgs = [
                {"role": "user", "content": str(prompt_or_messages).strip()},
                {"role": "assistant", "content": ""},
            ]
            # the trailing empty assistant is just to emphasize gen prompt;
            # apply_chat_template(add_generation_prompt=True) will handle it anyway
            msgs = msgs[:-1]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]], response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for concatenated prompt+response.
    Uses chat template when available and supports multi-turn 'messages'.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt_or_messages)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        if isinstance(prompt_or_messages, list):
            # Fallback: flatten messages into a plain text prompt
            flat = []
            for m in prompt_or_messages:
                role = m.get("role", "user")
                content = (m.get("content") or "").strip()
                if content:
                    flat.append(f"{role.upper()}: {content}")
            prompt_text = "\n".join(flat)
        else:
            prompt_text = str(prompt_or_messages).strip()
        sep = "" if prompt_text.endswith("\n") else "\n"
        full_text = prompt_text + sep + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[Union[str, List[Dict[str, str]]]],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (i.e., not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    # prompts may be list[dict] (chat messages) here
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    # Keep only pairs where BOTH chosen and rejected are valid
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min rating gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (rating gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/844 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Adapted examples (rating gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925

Scoring TRAIN...
[TRAIN BCO baseline] WR=55.22% [53.6, 56.8] on 3696 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=55.22% [53.6, 56.8] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=55.22% [53.6, 56.8] with beta=0.5 delta=-0.0730

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=53.62% [50.4, 56.8] | beta=0.5 delta=-0.0730


In [None]:
# bp_llm_capybara_eval.py
# Evaluate BP-LLM (unary JJ) win rate on argilla/Capybara-Preferences.
# Changes from UltraFeedback version:
#  - Dataset: argilla/Capybara-Preferences (or *-Filtered)
#  - Adapter uses published schema: chosen/rejected (list of messages) + ratings
#  - Prompt = common chat history (ends with user); Response = last assistant turn
#  - Chat-templating accepts Union[str, List[Dict[str,str]]]

import math
import itertools
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
# Put your HF token if you use gated models
# HF_TOKEN = os.environ.get("HF_TOKEN")


# Suggested: Instruct as policy, Base as reference
MODEL_NAME      = "meta-llama/Llama-3.2-3B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.2-3B"   # None to disable reference subtraction

# --- DATASET: Capybara Preferences ---
DATASET         = "argilla/Capybara-Preferences"   # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"       # keep small while iterating

# Train/Test split from the adapted pairs (stratification not required here)
TRAIN_FRAC      = 0.8
SEED            = 42
MIN_GAP         = 0.0                 # ratings are often tied; start at 0.0, raise to 0.5/1.0 if desired

# Input limits
MAX_INPUT_TOKENS = 1024
MAX_GEN_TOKENS   = 512
BATCH_SIZE       = 4

# Chat template toggle (recommended for *Instruct*)
USE_CHAT_TEMPLATE = True

# Length normalization: "mean" or "sum"
LENGTH_NORM = "mean"

# BP hyperparams (defaults)
BETA            = 1.0
DELTA           = 0.0
TAU             = 1.0
GAMMA           = 1.0
JJ_INNER_STEPS  = 2

# Grid tuning
DO_TUNE         = True
BETAS           = [0.5, 1.0, 2.0]
DELTAS          = ['bco', 0.0]        # try BCO shift and zero shift
TAUS            = [0.5, 1.0, 2.0]
GAMMAS          = [0.5, 1.0, 1.5, 2.0]
JJ_STEPS_LIST   = [1, 2, 3, 0]        # 0 => adaptive (tolerance-based)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# =========================
# Capybara adapter
# =========================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]) -> Tuple[Optional[List[Dict[str,str]]], Optional[str]]:
    """
    Return (history_without_last, last_assistant_content) or (None, None) if malformed.
    Ensures history ends with a 'user' turn (trim trailing assistants if needed).
    """
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    # ensure history ends with user
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Use published Capybara schema:
      chosen, rejected: list[{role, content}, ...] (assistant last)
      chosen_rating, rejected_rating: ints/floats
    Produces:
      prompt: common history (list[{role, content}], ends with user)
      chosen/rejected: final assistant strings
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))

    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    if (cr is None) or (rr is None):
        return None
    if abs(cr - rr) < min_gap:
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # longest common prefix for prompt
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k] if k > 0 else ch_hist
    if not prompt or prompt[-1].get("role") != "user":
        # fallback—use chosen history; we already ensured it ends with user
        prompt = ch_hist

    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}


# =========================
# Tokenization / Scoring
# =========================
def _apply_chat_prefix(
    tokenizer,
    prompt_or_messages: Union[str, List[Dict[str, str]]]
) -> Optional[str]:
    """
    If USE_CHAT_TEMPLATE and tokenizer supports it:
      • For list[{'role','content'}] -> apply chat template directly (multi-turn)
      • For str -> wrap as single user turn, then add generation prompt
    Returns a textual prefix or None if templating not available.
    """
    if not USE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        if isinstance(prompt_or_messages, list):
            msgs = prompt_or_messages
        else:
            msgs = [{"role": "user", "content": str(prompt_or_messages).strip()}]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def concat_prompt_response_text(
    tokenizer,
    prompt_or_messages: Union[str, List[Dict[str, str]]],
    response: str
) -> Tuple[str, str]:
    """
    Return (full_text, prompt_only_text) for token counting.
    Works with str or list-of-messages prompts.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt_or_messages)
    if chat_prefix is not None:
        prompt_text = chat_prefix.strip()
        full_text   = prompt_text + response.strip()
        return full_text, prompt_text

    # Fallback: flatten messages if needed
    if isinstance(prompt_or_messages, list):
        flat = []
        for m in prompt_or_messages:
            role = m.get("role", "user")
            content = (m.get("content") or "").strip()
            if content:
                flat.append(f"{role.upper()}: {content}")
        prompt_text = "\n".join(flat)
    else:
        prompt_text = str(prompt_or_messages).strip()

    sep = "" if prompt_text.endswith("\n") else "\n"
    full_text = prompt_text + sep + response.strip()
    return full_text, prompt_text

@torch.no_grad()
def sequence_logprob_list(
    model,
    tokenizer,
    prompts: List[Union[str, List[Dict[str, str]]]],
    responses: List[str],
    length_norm: str = LENGTH_NORM,
) -> List[Optional[float]]:
    """
    Returns mean/sum log-prob for each (prompt, response). If response was truncated
    (i.e., prompt consumed the full sequence), returns None for that item.
    """
    vals: List[Optional[float]] = []
    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        batch_resps   = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(batch_prompts, batch_resps):
            full_text, prompt_text = concat_prompt_response_text(tokenizer, p, r)
            toks = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            tok_prompt = tokenizer(
                prompt_text,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = tok_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L     = int(masks.sum().item())
            if p_len >= L:
                vals.append(None)  # response totally truncated -> skip
                continue
            targ = ids[p_len:L]              # response tokens
            pred = logprobs[b, p_len-1:L-1]  # shifted predictions
            lp = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1)
            score = float((lp.mean() if length_norm == "mean" else lp.sum()).cpu())
            vals.append(score)
    return vals


# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =========================
# BP-LLM (unary JJ) posterior
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    if xi < 1e-8:
        return 1.0 / 8.0
    return math.tanh(xi / 2.0) / (4.0 * xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def bp_unary_posterior_adaptive(mu_prior: float, b: int, params: BPParams,
                                tol: float = 1e-4, max_steps: int = 50) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    steps = params.jj_steps if params.jj_steps and params.jj_steps > 0 else max_steps
    for _ in range(steps):
        mu_prev = mu_hat
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
        if (params.jj_steps is None or params.jj_steps <= 0) and abs(mu_hat - mu_prev) < tol:
            break
    return mu_hat, tau2_hat


# =========================
# Caching scores (pairwise)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    model,
    ref_model,
    tokenizer
) -> Tuple[torch.Tensor, int]:
    """
    Returns base_s tensor shaped [2N] where first N are chosen scores, next N are rejected,
    and N is the number of *valid* pairs (both sides scored).
    Each score is (log pi - log pi_ref) with length normalization.
    """
    prompts_ch  = [r["prompt"] for r in records]
    prompts_rj  = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]

    policy_ch = sequence_logprob_list(model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
    policy_rj = sequence_logprob_list(model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    if ref_model is None:
        ref_ch = [0.0 if (s is not None) else None for s in policy_ch]
        ref_rj = [0.0 if (s is not None) else None for s in policy_rj]
    else:
        ref_ch = sequence_logprob_list(ref_model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
        ref_rj = sequence_logprob_list(ref_model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    chosen_vals, rejected_vals = [], []
    for sc, sr, rc, rr in zip(policy_ch, policy_rj, ref_ch, ref_rj):
        if (sc is None) or (sr is None) or (rc is None) or (rr is None):
            continue  # drop pair if any side was truncated
        chosen_vals.append(sc - rc)
        rejected_vals.append(sr - rr)

    if len(chosen_vals) == 0:
        raise RuntimeError("No valid pairs after scoring. Consider increasing MAX_*_TOKENS or using 'sum' norm.")

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(chosen_vals)
    return base_s, N


# =========================
# Evaluation & tuning
# =========================
def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def evaluate_with_cached(
    base_s: torch.Tensor, N: int, params: BPParams,
    use_adaptive_jj: bool = True, use_labels: bool = True
) -> float:
    """
    If use_labels=False, compare mu_prior only (label-free test).
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()

    correct = 0
    for i in range(N):
        mu_w_prior = mu[i]
        mu_l_prior = mu[i + N]

        if not use_labels:
            # label-free: just compare priors
            if mu_w_prior > mu_l_prior:
                correct += 1
            continue

        if use_adaptive_jj:
            mu_w_post, _ = bp_unary_posterior_adaptive(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior_adaptive(mu_l_prior, b=0, params=params)
        else:
            mu_w_post, _ = bp_unary_posterior(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior(mu_l_prior, b=0, params=params)
        if mu_w_post > mu_l_post:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]],
    taus: Iterable[float], gammas: Iterable[float],
    jj_steps_list: Iterable[int], use_adaptive_jj: bool = True,
    estimate_delta: bool = True, split_name: str = "TRAIN"
):
    base_s, N = cache_pairwise_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)

    for beta, delta_spec, tau, gamma, jj_steps in itertools.product(
        betas, deltas, taus, gammas, jj_steps_list
    ):
        if estimate_delta and (delta_spec == 'bco'):
            delta = estimate_delta_bco(base_s, beta, N)
        else:
            delta = float(delta_spec)

        params = BPParams(beta=beta, delta=delta, tau=tau, gamma=gamma, jj_steps=jj_steps)
        # TRAIN uses labels:
        wr = evaluate_with_cached(base_s, N, params, use_adaptive_jj=use_adaptive_jj, use_labels=True)
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, params))
        if wr > best[0]:
            best = (wr, params)

        print(f"[TUNE/{split_name}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'} "
              f"| tau={tau} gamma={gamma} jj_steps={jj_steps}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BP-LLM Tuning ({split_name})] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f} "
          f"tau={best_params.tau} gamma={best_params.gamma} jj_steps={best_params.jj_steps}")
    return tried, best_params


# =========================
# Main
# =========================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    if DATASET_CONFIG:
        ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT)
    else:
        ds = load_dataset(DATASET, split=SPLIT)

    print("Columns:", ds.column_names)
    if len(ds) > 0:
        sample0 = ds[0]
        print("Row[0] keys:", list(sample0.keys()))
        preview = {k: (str(sample0[k])[:200] + "…") for k in sample0.keys()}
        print("Row[0] preview:", preview)

    # Adapt to (prompt, chosen, rejected) from Capybara schema
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a and a["prompt"] and a["chosen"] and a["rejected"]:
            adapted.append(a)

    print(f"Adapted examples (rating gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError(
            "No valid examples after adapting Capybara. "
            "Try MIN_GAP=0.0, expand SPLIT, or inspect rows."
        )

    # Deterministic split
    g = torch.Generator().manual_seed(SEED)
    idx = torch.randperm(len(adapted), generator=g).tolist()
    split = int(len(idx) * TRAIN_FRAC)
    train_idx, test_idx = idx[:split], idx[split:]
    train_records = [adapted[i] for i in train_idx]
    test_records  = [adapted[i] for i in test_idx]
    print(f"Train records: {len(train_records)} | Test records: {len(test_records)}")

    # QUICK sanity check (non-adaptive JJ) on TRAIN with default params
    params0 = BPParams(beta=BETA, delta=DELTA, tau=TAU, gamma=GAMMA, jj_steps=JJ_INNER_STEPS)
    base_s_tr, N_tr = cache_pairwise_scores(train_records, policy, ref, tok)
    wr0 = evaluate_with_cached(base_s_tr, N_tr, params0, use_adaptive_jj=False, use_labels=True)
    lo0, hi0 = binom_ci_95(wr0, N_tr)
    print(f"[Sanity TRAIN] WR={wr0:.2f}% [{lo0:.1f}, {hi0:.1f}] on {N_tr} pairs "
          f"| beta={params0.beta} delta={params0.delta} tau={params0.tau} gamma={params0.gamma} "
          f"jj_steps={params0.jj_steps}")

    # Tune on TRAIN (uses labels)
    if DO_TUNE:
        print("\nTuning on TRAIN...")
        _tries, best_params = grid_search(
            train_records, policy, ref, tok,
            betas=BETAS, deltas=DELTAS, taus=TAUS, gammas=GAMMAS,
            jj_steps_list=JJ_STEPS_LIST, use_adaptive_jj=True, estimate_delta=True, split_name="TRAIN"
        )
    else:
        best_params = params0

    # Final eval on TEST with NO labels (gamma=0 and JJ disabled)
    base_s_te, N_te = cache_pairwise_scores(test_records, policy, ref, tok)
    test_params = BPParams(
        beta=best_params.beta,
        delta=best_params.delta,   # note: cancels in prior-only comparison, kept for completeness
        tau=best_params.tau,
        gamma=0.0,                 # NO labels on test
        jj_steps=0                 # disable JJ on test
    )
    wr_te = evaluate_with_cached(base_s_te, N_te, test_params, use_adaptive_jj=False, use_labels=False)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={test_params.beta} delta={test_params.delta:.4f} "
          f"| tau={test_params.tau} gamma={test_params.gamma} jj_steps={test_params.jj_steps}")


if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Columns: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] keys: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] preview: {'source': 'Airoboros…', 'chosen': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'chosen_rating': '5…', 'chosen_model': 'teknium/OpenHermes-2.5-Mistral-7B…', 'rejected': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'rejected_rating': '4…', 'rejected_model': 'gpt-4-1106-preview…'}
Adapted examples (rating gap ≥ 0.0): 4621
Train records: 3696 | Test r

In [None]:
# dpo_capybara_eval.py
# DPO baseline (leakage-safe) on argilla/Capybara-Preferences
#
# What this does:
#   • Adapts Capybara to (prompt, chosen, rejected) using the published schema:
#       - 'chosen' / 'rejected' are chat transcripts (list of {role, content})
#       - 'chosen_rating' / 'rejected_rating' are numeric
#     It extracts a common user-ended history as the prompt and the last assistant
#     message as the response for each side.
#   • Renders the chat template ONCE with the policy tokenizer; feeds the same
#     rendered text to policy and reference models.
#   • Computes DPO margins Δ = (s_pos - s_neg) with s = log π − log π_ref (ref optional).
#   • TRAIN: tune β by minimizing mean DPO loss  L_DPO = E[-log σ(βΔ)]  (also reports WR).
#   • TEST: evaluate win rate (labels only for metric; scoring is label-free).
#   • Length-normalized scoring over response tokens; drops truncated pairs.
#
# Notes:
#   • Set REF_MODEL_NAME=None to disable reference subtraction (use plain log π).
#   • If you get too few pairs, use a bigger SPLIT or reduce MIN_GAP.
#   • Also works with "argilla/Capybara-Preferences-Filtered".
#
# Requires: datasets, transformers, torch

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Policy (Instruct) + optional Base reference (Llama 3.2)
MODEL_NAME      = "meta-llama/Llama-3.2-3B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.2-3B"     # set to None to disable reference subtraction

DATASET         = "argilla/Capybara-Preferences"  # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"                   # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Scoring
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # Capybara ratings often tie; start at 0.0, raise to 0.5/1.0 to drop near-ties

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.25, 0.5, 1.0, 2.0]    # temperature-like scale for Δ

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)

# =============================================================================
# Capybara adapter
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]) -> Tuple[Optional[List[Dict[str,str]]], Optional[str]]:
    """
    Return (history_without_last, last_assistant_content) if the last turn is assistant
    and the remaining history ends with user; otherwise (None, None).
    """
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    # ensure history ends with user
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def _common_user_ended_prefix(a: List[Dict[str,str]], b: List[Dict[str,str]]) -> List[Dict[str,str]]:
    m = min(len(a), len(b))
    k = 0
    while k < m and (a[k].get("role") == b[k].get("role")) and (a[k].get("content") == b[k].get("content")):
        k += 1
    prefix = a[:k] if k > 0 else a
    # if it doesn't end with user, fall back to 'a' (which we already ensured ends with user)
    return prefix if prefix and prefix[-1].get("role") == "user" else a

def adapt_capybara(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Capybara schema:
      - 'chosen' / 'rejected': list[{role, content}] (assistant last)
      - 'chosen_rating' / 'rejected_rating': numeric
    Produces:
      prompt: common history (list[{role, content}], ends with user)
      chosen/rejected: final assistant strings
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))

    if not (isinstance(chosen, list) and isinstance(rejected, list)):
        return None
    if (cr is None) or (rr is None) or (abs(cr - rr) < min_gap):
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    prompt = _common_user_ended_prefix(ch_hist, rj_hist)
    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}

# =============================================================================
# Models
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype, trust_remote_code=True).to(device)
    model.eval()
    return tok, model

# =============================================================================
# Prompt rendering (policy chat template) + scoring
# =============================================================================
def render_prompt_with_policy_template(policy_tokenizer, prompt_or_messages: Union[str, List[Dict[str,str]]]) -> str:
    """
    Render chat template ONCE with the policy tokenizer.
    The resulting text is fed to BOTH policy and reference models.
    Supports list-of-messages prompts.
    """
    if hasattr(policy_tokenizer, "apply_chat_template"):
        try:
            if isinstance(prompt_or_messages, list):
                msgs = prompt_or_messages
            else:
                msgs = [{"role": "user", "content": str(prompt_or_messages).strip()}]
            return policy_tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            pass
    p = (str(prompt_or_messages).strip() if not isinstance(prompt_or_messages, list) else
         "\n".join(f"{m.get('role','user').upper()}: {(m.get('content') or '').strip()}"
                   for m in prompt_or_messages))
    return p if p.endswith("\n") else (p + "\n")

def render_prompt_list(policy_tokenizer, prompts: List[Union[str, List[Dict[str,str]]]]) -> List[str]:
    return [render_prompt_with_policy_template(policy_tokenizer, p) for p in prompts]

def _concat_prompt_response_text(prompt_text: str, response: str) -> Tuple[str, str]:
    return prompt_text + response.strip(), prompt_text

@torch.no_grad()
def sequence_logprob_stats_text(
    model,
    tokenizer,
    prompt_texts: List[str],   # already rendered with policy template
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over response tokens
    – tok_count is the number of response tokens scored
    If response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompt_texts), BATCH_SIZE):
        p_batch = prompt_texts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p_txt, r in zip(p_batch, r_batch):
            full_text, prompt_only = _concat_prompt_response_text(p_txt, r)
            toks_full = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            toks_prompt = tokenizer(
                prompt_only,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = toks_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks_full)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)

# =============================================================================
# DPO pieces
# =============================================================================
@dataclass
class DPOTune:
    beta: float = 1.0

def dpo_loss_and_wr(delta: torch.Tensor, beta: float) -> Tuple[float, float]:
    """
    Mean DPO loss and win rate for a fixed β.
    Loss = E[-log σ(βΔ)] = mean(softplus(-βΔ)), WR = mean(Δ > 0)*100
    """
    z = beta * delta
    loss = torch.nn.functional.softplus(-z).mean().item()
    wr = (delta > 0).to(torch.float32).mean().item() * 100.0
    return loss, wr

@torch.no_grad()
def cache_dpo_margins(
    records: List[Dict],
    policy_model,
    ref_model,
    policy_tok,
    ref_tok=None
) -> Tuple[torch.Tensor, int]:
    """
    Returns Δ tensor of shape [N_valid] where
      Δ_i = (s_pos_i - s_neg_i),
      s = log π(y|x) - log π_ref(y|x)  (ref optional)
    Valid pairs require ≥1 response token scored for BOTH sides (and for ref if used).
    """
    prompts_raw = [r["prompt"] for r in records]  # list[Union[str, List[dict]]], here list-of-messages
    prompt_texts = render_prompt_list(policy_tok, prompts_raw)

    resps_ch = [r["chosen"] for r in records]
    resps_rj = [r["rejected"] for r in records]

    # Policy (own tokenizer)
    pol_sum_ch, pol_cnt_ch = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_ch)
    pol_sum_rj, pol_cnt_rj = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_rj)
    pol_score_ch = _score_from_stats(pol_sum_ch, pol_cnt_ch)
    pol_score_rj = _score_from_stats(pol_sum_rj, pol_cnt_rj)

    # Reference (optional; own tokenizer; same prompt_texts)
    if ref_model is None:
        ref_score_ch = torch.zeros_like(pol_score_ch); ref_cnt_ch = torch.ones_like(pol_cnt_ch)
        ref_score_rj = torch.zeros_like(pol_score_rj); ref_cnt_rj = torch.ones_like(pol_cnt_rj)
    else:
        if ref_tok is None:
            raise ValueError("ref_model provided without ref_tok")
        ref_sum_ch, ref_cnt_ch = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_ch)
        ref_sum_rj, ref_cnt_rj = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_rj)
        ref_score_ch = _score_from_stats(ref_sum_ch, ref_cnt_ch)
        ref_score_rj = _score_from_stats(ref_sum_rj, ref_cnt_rj)

    # Validity mask: both sides non-truncated for both models (if ref used)
    valid_ch = (pol_cnt_ch > 0) & (ref_cnt_ch > 0)
    valid_rj = (pol_cnt_rj > 0) & (ref_cnt_rj > 0)
    valid_idx = [i for i in range(len(records)) if valid_ch[i].item() and valid_rj[i].item()]
    if not valid_idx:
        raise RuntimeError("No valid (non-truncated) pairs. Increase token limits, expand SPLIT, or lower MIN_GAP.")

    # s = log π - log π_ref (ref=0 if disabled); Δ = s_pos - s_neg
    s_pos = pol_score_ch[valid_idx] - ref_score_ch[valid_idx]
    s_neg = pol_score_rj[valid_idx] - ref_score_rj[valid_idx]
    delta = (s_pos - s_neg).to(torch.float32)  # [N_valid]
    return delta, len(valid_idx)

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def tune_beta_dpo(
    records: List[Dict], policy_model, ref_model, policy_tok, ref_tok,
    betas: Iterable[float], split_name: str = "TRAIN"
) -> Tuple[List[Tuple[float,float,float,float]], DPOTune]:
    """
    Returns (trials, best), where each trial is (beta, loss, wr, N).
    Select β minimizing mean DPO loss on TRAIN (ties broken by higher WR).
    """
    delta, N = cache_dpo_margins(records, policy_model, ref_model, policy_tok, ref_tok)
    trials = []
    best = (float("inf"), -1.0, None)  # (loss, wr, beta)
    for beta in betas:
        loss, wr = dpo_loss_and_wr(delta, beta)
        trials.append((beta, loss, wr, N))
        lo, hi = binom_ci_95(wr, N)
        print(f"[TUNE/{split_name} DPO] beta={beta:.3g} | loss={loss:.4f} | WR={wr:.2f}% [{lo:.1f}, {hi:.1f}]")
        if (loss < best[0]) or (abs(loss - best[0]) < 1e-6 and wr > best[1]):
            best = (loss, wr, beta)

    best_beta = best[2]
    print(f"\n[DPO Tuning ({split_name})] Best beta={best_beta:.3g} "
          f"(loss={best[0]:.4f}, WR={best[1]:.2f}%) on {N} pairs")
    return trials, DPOTune(beta=best_beta)

# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok_policy, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    tok_ref, ref = None, None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        tok_ref, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    print("Columns:", ds.column_names)
    if len(ds) > 0:
        sample0 = ds[0]
        print("Row[0] keys:", list(sample0.keys()))
        preview = {k: (str(sample0[k])[:200] + "…") for k in sample0.keys()}
        print("Row[0] preview:", preview)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Try MIN_GAP=0.0, expand SPLIT, or inspect data.")

    # Deterministic split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Tune β on TRAIN (min DPO loss; report WR too)
    if DO_TUNE:
        print("\nTuning beta on TRAIN (DPO)...")
        _tries, best = tune_beta_dpo(train_recs, policy, ref, tok_policy, tok_ref, betas=BETAS, split_name="TRAIN")
        beta = best.beta
    else:
        beta = 1.0

    # Evaluate on TEST
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as proxy sanity check.")

    print("\nScoring TEST (DPO metrics)...")
    delta_te, N_te = cache_dpo_margins(eval_set, policy, ref, tok_policy, tok_ref)
    loss_te, wr_te = dpo_loss_and_wr(delta_te, beta)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[DPO EVAL (TEST)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] | loss={loss_te:.4f} | beta={beta:.3g} | N={N_te}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Columns: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] keys: ['source', 'chosen', 'chosen_rating', 'chosen_model', 'rejected', 'rejected_rating', 'rejected_model']
Row[0] preview: {'source': 'Airoboros…', 'chosen': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'chosen_rating': '5…', 'chosen_model': 'teknium/OpenHermes-2.5-Mistral-7B…', 'rejected': '[{\'content\': \'The setting is an otherworldly, yet eerily familiar, metropolis known as "Zephyria." It\\\'s a city suspended in the ether, floating amidst nebulous clouds of cosmic dust. The architecture…', 'rejected_rating': '4…', 'rejected_model': 'gpt-4-1106-preview…'}
Adapted examples (gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925
