In [None]:
# bp_llm_ultrafeedback_eval.py
# Evaluate BP-LLM (unary BP classifier with JJ bound) win rate on UltraFeedback
# with Llama-3.2-3B policy priors.

import math
import dataclasses
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------
# Config
# -----------------------------

HF_TOKEN = " "  # your Read token


In [None]:
# bp_llm_ultrafeedback_eval_qwen25_3b.py
# UltraFeedback BCO baseline (leakage-safe) for Qwen2.5-3B

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")

# Use Qwen 2.5 3B as policy; optional Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-3B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-3B"   # None to disable reference subtraction

DATASET         = "openbmb/UltraFeedback"
DATASET_CONFIG  = None
SPLIT           = "train[:10%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 1.0

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# UF helpers
# =============================================================================
def _extract_text(c: Dict) -> Optional[str]:
    if not isinstance(c, dict):
        return None
    for k in ("text", "response", "output", "completion", "content"):
        v = c.get(k)
        if isinstance(v, str) and v.strip():
            return v
    res = c.get("result")
    if isinstance(res, dict):
        v = res.get("text") or res.get("response") or res.get("output")
        if isinstance(v, str) and v.strip():
            return v
    return None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _extract_score(c: Dict) -> Optional[float]:
    if not isinstance(c, dict):
        return None
    ann = c.get("annotations")
    if isinstance(ann, dict):
        help_ = ann.get("helpfulness")
        if isinstance(help_, dict):
            r = help_.get("Rating") or help_.get("rating") or help_.get("score")
            r = _to_float(r)
            if r is not None:
                return r
        for key in ("overall", "quality", "correctness", "honesty", "safety"):
            sub = ann.get(key)
            if isinstance(sub, dict):
                r = sub.get("Rating") or sub.get("rating") or sub.get("score")
                r = _to_float(r)
                if r is not None:
                    return r
        for _, v in ann.items():
            fv = _to_float(v)
            if fv is not None:
                return fv
    for k in ("score", "rating", "rank"):
        fv = _to_float(c.get(k))
        if fv is not None:
            return fv
    return None

def adapt_openbmb_ultrafeedback(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) as top-vs-bottom by score, with a minimum score gap.
    Drops near-ties to reduce label noise.
    """
    prompt = record.get("instruction")
    if not isinstance(prompt, str) or not prompt.strip():
        return None
    comps = record.get("completions")
    if not isinstance(comps, list) or len(comps) < 2:
        return None

    pairs = []
    for c in comps:
        text = _extract_text(c)
        score = _extract_score(c)
        if isinstance(text, str) and text.strip() and (score is not None):
            pairs.append((text, float(score)))
    if len(pairs) < 2:
        return None

    pairs.sort(key=lambda t: t[1])  # low ... high
    lo_txt, lo_s = pairs[0]
    hi_txt, hi_s = pairs[-1]
    if (hi_s - lo_s) < min_gap:
        return None

    return {"prompt": prompt, "chosen": hi_txt, "rejected": lo_txt}


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    # Qwen tokenizers sometimes need use_fast=False; try fast then fallback.
    try:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=False)

    # Ensure pad token is set (Qwen uses eos as pad for causal LM)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    # trust_remote_code is usually not needed for Qwen2.5, but allow override via env
    trust_remote = os.environ.get("TRUST_REMOTE_CODE", "0") == "1"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, token=token, torch_dtype=dtype, trust_remote_code=trust_remote
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt: str) -> Optional[str]:
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = [
            {"role": "user", "content": prompt.strip()},
            {"role": "assistant", "content": ""},  # add_generation_prompt=True will append assistant prefix
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt: str, response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for the concatenated prompt+response.
    Uses chat template when available (Qwen provides one).
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        prompt_text = prompt.strip()
        full_text   = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_openbmb_ultrafeedback(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/683 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loading dataset: openbmb/UltraFeedback [train[:10%]] config=None ...


README.md: 0.00B [00:00, ?B/s]

evol_instruct.jsonl:   0%|          | 0.00/168M [00:00<?, ?B/s]

false_qa.jsonl:   0%|          | 0.00/25.9M [00:00<?, ?B/s]

flan.jsonl:   0%|          | 0.00/240M [00:00<?, ?B/s]

sharegpt.jsonl:   0%|          | 0.00/313M [00:00<?, ?B/s]

truthful_qa.jsonl: 0.00B [00:00, ?B/s]

ultrachat.jsonl:   0%|          | 0.00/182M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/63967 [00:00<?, ? examples/s]

Adapted examples (gap ≥ 1.0): 6205
Split sizes: TRAIN=4964 | TEST=1241

Scoring TRAIN...
[TRAIN BCO baseline] WR=78.08% [76.9, 79.2] on 4964 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=78.08% [76.9, 79.2] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=78.08% [76.9, 79.2] with beta=0.5 delta=-0.2492

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=79.45% [77.2, 81.7] | beta=0.5 delta=-0.2492


In [None]:
# bp_llm_ultrafeedback_eval_qwen25_7b.py
# UltraFeedback BCO baseline (leakage-safe) for Qwen2.5-3B

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Use Qwen 2.5 3B as policy; optional Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-7B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-7B"   # None to disable reference subtraction

DATASET         = "openbmb/UltraFeedback"
DATASET_CONFIG  = None
SPLIT           = "train[:10%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 1.0

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# UF helpers
# =============================================================================
def _extract_text(c: Dict) -> Optional[str]:
    if not isinstance(c, dict):
        return None
    for k in ("text", "response", "output", "completion", "content"):
        v = c.get(k)
        if isinstance(v, str) and v.strip():
            return v
    res = c.get("result")
    if isinstance(res, dict):
        v = res.get("text") or res.get("response") or res.get("output")
        if isinstance(v, str) and v.strip():
            return v
    return None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _extract_score(c: Dict) -> Optional[float]:
    if not isinstance(c, dict):
        return None
    ann = c.get("annotations")
    if isinstance(ann, dict):
        help_ = ann.get("helpfulness")
        if isinstance(help_, dict):
            r = help_.get("Rating") or help_.get("rating") or help_.get("score")
            r = _to_float(r)
            if r is not None:
                return r
        for key in ("overall", "quality", "correctness", "honesty", "safety"):
            sub = ann.get(key)
            if isinstance(sub, dict):
                r = sub.get("Rating") or sub.get("rating") or sub.get("score")
                r = _to_float(r)
                if r is not None:
                    return r
        for _, v in ann.items():
            fv = _to_float(v)
            if fv is not None:
                return fv
    for k in ("score", "rating", "rank"):
        fv = _to_float(c.get(k))
        if fv is not None:
            return fv
    return None

def adapt_openbmb_ultrafeedback(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) as top-vs-bottom by score, with a minimum score gap.
    Drops near-ties to reduce label noise.
    """
    prompt = record.get("instruction")
    if not isinstance(prompt, str) or not prompt.strip():
        return None
    comps = record.get("completions")
    if not isinstance(comps, list) or len(comps) < 2:
        return None

    pairs = []
    for c in comps:
        text = _extract_text(c)
        score = _extract_score(c)
        if isinstance(text, str) and text.strip() and (score is not None):
            pairs.append((text, float(score)))
    if len(pairs) < 2:
        return None

    pairs.sort(key=lambda t: t[1])  # low ... high
    lo_txt, lo_s = pairs[0]
    hi_txt, hi_s = pairs[-1]
    if (hi_s - lo_s) < min_gap:
        return None

    return {"prompt": prompt, "chosen": hi_txt, "rejected": lo_txt}


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    # Qwen tokenizers sometimes need use_fast=False; try fast then fallback.
    try:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=False)

    # Ensure pad token is set (Qwen uses eos as pad for causal LM)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    # trust_remote_code is usually not needed for Qwen2.5, but allow override via env
    trust_remote = os.environ.get("TRUST_REMOTE_CODE", "0") == "1"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, token=token, torch_dtype=dtype, trust_remote_code=trust_remote
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt: str) -> Optional[str]:
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = [
            {"role": "user", "content": prompt.strip()},
            {"role": "assistant", "content": ""},  # add_generation_prompt=True will append assistant prefix
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt: str, response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for the concatenated prompt+response.
    Uses chat template when available (Qwen provides one).
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        prompt_text = prompt.strip()
        full_text   = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_openbmb_ultrafeedback(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/686 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Loading dataset: openbmb/UltraFeedback [train[:10%]] config=None ...
Adapted examples (gap ≥ 1.0): 6205
Split sizes: TRAIN=4964 | TEST=1241

Scoring TRAIN...
[TRAIN BCO baseline] WR=80.02% [78.9, 81.1] on 4964 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=80.02% [78.9, 81.1] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=80.02% [78.9, 81.1] with beta=0.5 delta=-0.3762

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=80.42% [78.2, 82.6] | beta=0.5 delta=-0.3762


In [None]:
# bp_llm_capybara_eval_qwen25_3b.py
# Capybara-Preferences BCO baseline (leakage-safe) for Qwen2.5-3B

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Use Qwen 2.5 3B as policy; optional Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-3B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-3B"   # set to None to disable reference subtraction

# --- DATASET: Capybara-Preferences ---
DATASET         = "argilla/Capybara-Preferences"   # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # min rating gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# Capybara helpers
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]):
    """Return (history_without_last, last_assistant_content) or (None, None) if malformed."""
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = 0.0) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) from Capybara-Preferences:
      - chosen/rejected: list[{role, content}, ...] (full convo ending with assistant)
      - chosen_rating/rejected_rating: ints
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))
    if not (isinstance(chosen, list) and isinstance(rejected, list) and cr is not None and rr is not None):
        return None
    if abs(cr - rr) < min_gap:
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # longest common prefix of histories (they usually match)
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k] if k > 0 else ch_hist
    if not prompt or prompt[-1].get("role") != "user":
        prompt = ch_hist  # fallback

    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}


# =============================================================================
# Model (Qwen-friendly)
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    # Qwen tokenizers: try fast, then fallback to slow
    try:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=False)

    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    trust_remote = os.environ.get("TRUST_REMOTE_CODE", "0") == "1"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=token,
        torch_dtype=dtype,
        trust_remote_code=trust_remote
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]]) -> Optional[str]:
    """
    If FORCE_CHAT_TEMPLATE and tokenizer supports it:
      • For list[{'role','content'}] -> apply chat template directly (multi-turn)
      • For str -> wrap as single user turn, then add generation prompt
    """
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = prompt_or_messages if isinstance(prompt_or_messages, list) else [
            {"role": "user", "content": str(prompt_or_messages).strip()}
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]], response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for concatenated prompt+response.
    Supports multi-turn 'messages' and Qwen chat template.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt_or_messages)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        # Fallback: flatten messages to plain text
        if isinstance(prompt_or_messages, list):
            flat = []
            for m in prompt_or_messages:
                role = m.get("role", "user")
                content = (m.get("content") or "").strip()
                if content:
                    flat.append(f"{role.upper()}: {content}")
            prompt_text = "\n".join(flat)
        else:
            prompt_text = str(prompt_or_messages).strip()
        sep = "" if prompt_text.endswith("\n") else "\n"
        full_text = prompt_text + sep + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[Union[str, List[Dict[str, str]]]],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count) over response tokens only.
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    Keep only pairs where BOTH chosen and rejected have >=1 scored token
    (and for BOTH models if a reference is used).
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    valid_pairs = []
    for i in range(len(records)):
        if valid_mask[i].item() and valid_mask[i + len(records)].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """Label-free pairwise rule: predict chosen if s_pos > s_neg (δ cancels)."""
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = sum(1 for i in range(N) if mu[i] > mu[i + N])
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(records, model, ref_model, tokenizer, betas, deltas):
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)
    for beta in betas:
        for delta_spec in deltas:
            delta = estimate_delta_bco(base_s, beta, N) if delta_spec == "bco" else float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)
            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] | beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")
    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min rating gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (rating gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # TRAIN baseline
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS)
    else:
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] | beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/78.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15404 [00:00<?, ? examples/s]

Adapted examples (rating gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925

Scoring TRAIN...
[TRAIN BCO baseline] WR=55.06% [53.5, 56.7] on 3696 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=55.06% [53.5, 56.7] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=55.06% [53.5, 56.7] with beta=0.5 delta=-0.1202

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=52.86% [49.6, 56.1] | beta=0.5 delta=-0.1202


In [None]:
# bp_llm_capybara_eval_qwen25_7b.py
# Capybara-Preferences BCO baseline (leakage-safe) for Qwen2.5-3B

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable, Union

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Use Qwen 2.5 3B as policy; optional Base as reference
MODEL_NAME      = "Qwen/Qwen2.5-7B-Instruct"
REF_MODEL_NAME  = "Qwen/Qwen2.5-7B"   # set to None to disable reference subtraction

# --- DATASET: Capybara-Preferences ---
DATASET         = "argilla/Capybara-Preferences"   # or "argilla/Capybara-Preferences-Filtered"
DATASET_CONFIG  = None
SPLIT           = "train[:30%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 10000
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 0.0         # min rating gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# Capybara helpers
# =============================================================================
def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _ensure_assistant_last(msgs: List[Dict[str, str]]):
    """Return (history_without_last, last_assistant_content) or (None, None) if malformed."""
    if not isinstance(msgs, list) or len(msgs) == 0:
        return None, None
    last = msgs[-1]
    if not isinstance(last, dict) or last.get("role") != "assistant":
        return None, None
    resp = (last.get("content") or "").strip()
    if not resp:
        return None, None
    history = [m for m in msgs[:-1] if isinstance(m, dict) and "role" in m and "content" in m]
    while len(history) and history[-1].get("role") == "assistant":
        history.pop()
    if len(history) == 0 or history[-1].get("role") != "user":
        return None, None
    return history, resp

def adapt_capybara_preferences(record: Dict, min_gap: float = 0.0) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) from Capybara-Preferences:
      - chosen/rejected: list[{role, content}, ...] (full convo ending with assistant)
      - chosen_rating/rejected_rating: ints
    """
    chosen = record.get("chosen")
    rejected = record.get("rejected")
    cr = _to_float(record.get("chosen_rating"))
    rr = _to_float(record.get("rejected_rating"))
    if not (isinstance(chosen, list) and isinstance(rejected, list) and cr is not None and rr is not None):
        return None
    if abs(cr - rr) < min_gap:
        return None

    ch_hist, ch_resp = _ensure_assistant_last(chosen)
    rj_hist, rj_resp = _ensure_assistant_last(rejected)
    if ch_hist is None or rj_hist is None:
        return None

    # longest common prefix of histories (they usually match)
    m = min(len(ch_hist), len(rj_hist))
    k = 0
    while k < m and (ch_hist[k].get("role") == rj_hist[k].get("role")) \
                 and (ch_hist[k].get("content") == rj_hist[k].get("content")):
        k += 1
    prompt = ch_hist[:k] if k > 0 else ch_hist
    if not prompt or prompt[-1].get("role") != "user":
        prompt = ch_hist  # fallback

    return {"prompt": prompt, "chosen": ch_resp, "rejected": rj_resp}


# =============================================================================
# Model (Qwen-friendly)
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    # Qwen tokenizers: try fast, then fallback to slow
    try:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    except Exception:
        tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=False)

    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    trust_remote = os.environ.get("TRUST_REMOTE_CODE", "0") == "1"
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=token,
        torch_dtype=dtype,
        trust_remote_code=trust_remote
    )
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]]) -> Optional[str]:
    """
    If FORCE_CHAT_TEMPLATE and tokenizer supports it:
      • For list[{'role','content'}] -> apply chat template directly (multi-turn)
      • For str -> wrap as single user turn, then add generation prompt
    """
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = prompt_or_messages if isinstance(prompt_or_messages, list) else [
            {"role": "user", "content": str(prompt_or_messages).strip()}
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt_or_messages: Union[str, List[Dict[str, str]]], response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for concatenated prompt+response.
    Supports multi-turn 'messages' and Qwen chat template.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt_or_messages)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        # Fallback: flatten messages to plain text
        if isinstance(prompt_or_messages, list):
            flat = []
            for m in prompt_or_messages:
                role = m.get("role", "user")
                content = (m.get("content") or "").strip()
                if content:
                    flat.append(f"{role.upper()}: {content}")
            prompt_text = "\n".join(flat)
        else:
            prompt_text = str(prompt_or_messages).strip()
        sep = "" if prompt_text.endswith("\n") else "\n"
        full_text = prompt_text + sep + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[Union[str, List[Dict[str, str]]]],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count) over response tokens only.
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    Keep only pairs where BOTH chosen and rejected have >=1 scored token
    (and for BOTH models if a reference is used).
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    valid_pairs = []
    for i in range(len(records)):
        if valid_mask[i].item() and valid_mask[i + len(records)].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """Label-free pairwise rule: predict chosen if s_pos > s_neg (δ cancels)."""
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = sum(1 for i in range(N) if mu[i] > mu[i + N])
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(records, model, ref_model, tokenizer, betas, deltas):
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)
    for beta in betas:
        for delta_spec in deltas:
            delta = estimate_delta_bco(base_s, beta, N) if delta_spec == "bco" else float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)
            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] | beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")
    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min rating gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_capybara_preferences(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (rating gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # TRAIN baseline
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS)
    else:
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] | beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading dataset: argilla/Capybara-Preferences [train[:30%]] config=None ...
Adapted examples (rating gap ≥ 0.0): 4621
Split sizes: TRAIN=3696 | TEST=925

Scoring TRAIN...
[TRAIN BCO baseline] WR=56.28% [54.7, 57.9] on 3696 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=56.28% [54.7, 57.9] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=56.28% [54.7, 57.9] with beta=0.5 delta=-0.2013

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=53.51% [50.3, 56.7] | beta=0.5 delta=-0.2013
