In [None]:
# bp_llm_ultrafeedback_eval.py
# Evaluate BP-LLM (unary BP classifier with JJ bound) win rate on UltraFeedback
# with Llama-3.2-3B policy priors.

import math
import dataclasses
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# -----------------------------
# Config
# -----------------------------

HF_TOKEN = " "  # your Read token


In [None]:
# BCO with Llama-3.1-8B-instruct

# bp_llm_ultrafeedback_eval_openbmb.py
# UltraFeedback BCO baseline (leakage-safe)
#
# What this does:
#   • Adapts UltraFeedback to (prompt, chosen, rejected) using top-vs-bottom with a minimum score gap.
#   • Splits into TRAIN/TEST (tune β and δ on TRAIN; report on TEST).
#   • δ estimated by BCO on TRAIN: δ = 0.5 ( E_pos[β s] + E_neg[β s] ).
#   • Scoring is chat-template aware and length-normalized (mean log-prob),
#     limited to SCORE_MAX_GEN_TOKENS response tokens.
#   • Any pair with a truncated side is dropped (don’t bias as zeros).
#   • TEST evaluation is fully label-free (BCO doesn’t need labels).
#
# Notes:
#   • Optionally set a reference model (REF_MODEL_NAME) to use log π − log π_ref.
#   • Lower MIN_GAP to admit more pairs if you see too few examples.
#   • This file is BCO-only for apples-to-apples comparison with your BP-LLM script.

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN")

# Use an *Instruct* model for the policy; optional Base as reference
MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"   # None to disable reference subtraction


DATASET         = "openbmb/UltraFeedback"
DATASET_CONFIG  = None
SPLIT           = "train[:10%]"  # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Chat-templating + scoring
FORCE_CHAT_TEMPLATE  = True
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 1.0         # minimum helpfulness-score gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.5, 1.0, 2.0]
DELTAS  = ["bco", 0.0]  # try BCO shift and a constant zero shift

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)


# =============================================================================
# UF helpers
# =============================================================================
def _extract_text(c: Dict) -> Optional[str]:
    if not isinstance(c, dict):
        return None
    for k in ("text", "response", "output", "completion", "content"):
        v = c.get(k)
        if isinstance(v, str) and v.strip():
            return v
    res = c.get("result")
    if isinstance(res, dict):
        v = res.get("text") or res.get("response") or res.get("output")
        if isinstance(v, str) and v.strip():
            return v
    return None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _extract_score(c: Dict) -> Optional[float]:
    if not isinstance(c, dict):
        return None
    ann = c.get("annotations")
    if isinstance(ann, dict):
        help_ = ann.get("helpfulness")
        if isinstance(help_, dict):
            r = help_.get("Rating") or help_.get("rating") or help_.get("score")
            r = _to_float(r)
            if r is not None:
                return r
        for key in ("overall", "quality", "correctness", "honesty", "safety"):
            sub = ann.get(key)
            if isinstance(sub, dict):
                r = sub.get("Rating") or sub.get("rating") or sub.get("score")
                r = _to_float(r)
                if r is not None:
                    return r
        for _, v in ann.items():
            fv = _to_float(v)
            if fv is not None:
                return fv
    for k in ("score", "rating", "rank"):
        fv = _to_float(c.get(k))
        if fv is not None:
            return fv
    return None

def adapt_openbmb_ultrafeedback(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) as top-vs-bottom by score, with a minimum score gap.
    Drops near-ties to reduce label noise.
    """
    prompt = record.get("instruction")
    if not isinstance(prompt, str) or not prompt.strip():
        return None
    comps = record.get("completions")
    if not isinstance(comps, list) or len(comps) < 2:
        return None

    pairs = []
    for c in comps:
        text = _extract_text(c)
        score = _extract_score(c)
        if isinstance(text, str) and text.strip() and (score is not None):
            pairs.append((text, float(score)))
    if len(pairs) < 2:
        return None

    pairs.sort(key=lambda t: t[1])  # low ... high
    lo_txt, lo_s = pairs[0]
    hi_txt, hi_s = pairs[-1]
    if (hi_s - lo_s) < min_gap:
        return None

    return {"prompt": prompt, "chosen": hi_txt, "rejected": lo_txt}


# =============================================================================
# Model
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =============================================================================
# Encoding + scoring
# =============================================================================
def _apply_chat_prefix(tokenizer, prompt: str) -> Optional[str]:
    if not FORCE_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"):
        return None
    try:
        msgs = [
            {"role": "user", "content": prompt.strip()},
            {"role": "assistant", "content": ""},
        ]
        return tokenizer.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        return None

def encode_pair(tokenizer, prompt: str, response: str):
    """
    Returns (input_ids, attention_mask, prompt_len) for the concatenated prompt+response.
    Uses chat template when available.
    """
    chat_prefix = _apply_chat_prefix(tokenizer, prompt)
    if chat_prefix is not None:
        prompt_text = chat_prefix
        full_text   = chat_prefix + response.strip()
    else:
        prompt_text = prompt.strip()
        full_text   = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()

    toks_full = tokenizer(
        full_text,
        truncation=True,
        max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
        return_tensors="pt",
        add_special_tokens=True,
    )
    toks_prompt = tokenizer(
        prompt_text,
        truncation=True,
        max_length=MAX_INPUT_TOKENS,
        return_tensors="pt",
        add_special_tokens=True,
    )
    prompt_len = toks_prompt["input_ids"].shape[-1]
    return toks_full, prompt_len

@torch.no_grad()
def sequence_logprob_stats(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over scored response tokens
    – tok_count is the number of response tokens scored
    If a response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompts), BATCH_SIZE):
        p_batch = prompts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(p_batch, r_batch):
            toks, p_len = encode_pair(tokenizer, p, r)
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)


# =============================================================================
# BCO pieces
# =============================================================================
@dataclass
class BCOParams:
    beta: float  = 1.0
    delta: float = 0.0

@torch.no_grad()
def cache_policy_and_ref_scores(records: List[Dict], model, ref_model, tokenizer) -> Tuple[torch.Tensor, int]:
    """
    Build base scores s = (policy - ref) for valid pairs only.
    A pair is valid if BOTH sides have >= 1 scored response token
    (i.e., not truncated past the prompt) for BOTH models if ref is used.
    Returns:
        base_s: tensor of shape [2N_valid] (first N are chosen, next N are rejected)
        N_valid: number of valid pairs.
    """
    prompts   = [r["prompt"] for r in records for _ in (0, 1)]
    responses = [r["chosen"] for r in records] + [r["rejected"] for r in records]

    pol_sum, pol_cnt = sequence_logprob_stats(model, tokenizer, prompts, responses)
    pol_score = _score_from_stats(pol_sum, pol_cnt)

    if ref_model is None:
        ref_score = torch.zeros_like(pol_score)
        ref_cnt   = torch.ones_like(pol_cnt)  # pretend "valid"
    else:
        ref_sum, ref_cnt = sequence_logprob_stats(ref_model, tokenizer, prompts, responses)
        ref_score = _score_from_stats(ref_sum, ref_cnt)

    valid_mask = (pol_cnt > 0) & (ref_cnt > 0)
    # Keep only pairs where BOTH chosen and rejected are valid
    valid_pairs = []
    for i in range(len(records)):
        i_ch = i
        i_rj = i + len(records)
        if valid_mask[i_ch].item() and valid_mask[i_rj].item():
            valid_pairs.append(i)

    if len(valid_pairs) == 0:
        raise RuntimeError("No valid (non-truncated) pairs after scoring. Increase limits or lower MIN_GAP.")

    chosen_vals, rejected_vals = [], []
    for i in valid_pairs:
        i_ch = i
        i_rj = i + len(records)
        chosen_vals.append(float(pol_score[i_ch] - ref_score[i_ch]))
        rejected_vals.append(float(pol_score[i_rj] - ref_score[i_rj]))

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(valid_pairs)
    return base_s, N

def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    """δ_BCO = 0.5 ( E_pos[β s] + E_neg[β s] )."""
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def bco_win_rate(base_s: torch.Tensor, N: int, params: BCOParams) -> float:
    """
    BCO decision rule (label-free): predict chosen if β s_pos - δ > β s_neg - δ  ⇔  s_pos > s_neg.
    δ cancels out for pairwise comparison, but we still compute it for completeness/reporting.
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()
    correct = 0
    for i in range(N):
        if mu[i] > mu[i + N]:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search_train_bco(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]]
):
    """
    Tune β and δ on TRAIN only. δ can be 'bco' (estimated) or a numeric constant.
    """
    base_s, N = cache_policy_and_ref_scores(records, model, ref_model, tokenizer)

    tried = []
    best = (-1.0, None)  # (wr, params)
    for beta in betas:
        for delta_spec in deltas:
            if delta_spec == "bco":
                delta = estimate_delta_bco(base_s, beta, N)
            else:
                delta = float(delta_spec)
            params = BCOParams(beta=beta, delta=delta)
            wr = bco_win_rate(base_s, N, params)
            lo, hi = binom_ci_95(wr, N)
            tried.append((wr, lo, hi, params))
            if wr > best[0]:
                best = (wr, params)

            print(f"[TUNE/TRAIN BCO] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
                  f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BCO Tuning (TRAIN)] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f}")
    return tried, best_params


# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)
    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_openbmb_ultrafeedback(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Train/Test split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Cache TRAIN scores and baseline WR with default β=1, δ=0
    print("\nScoring TRAIN...")
    base_s_tr, N_tr = cache_policy_and_ref_scores(train_recs, policy, ref, tok)
    wr_tr_base = bco_win_rate(base_s_tr, N_tr, BCOParams(beta=1.0, delta=0.0))
    lo_trb, hi_trb = binom_ci_95(wr_tr_base, N_tr)
    print(f"[TRAIN BCO baseline] WR={wr_tr_base:.2f}% [{lo_trb:.1f}, {hi_trb:.1f}] on {N_tr} valid pairs")

    # Tune β, δ (TRAIN only)
    if DO_TUNE:
        print("\nTuning on TRAIN (BCO)...")
        _tries, best_params = grid_search_train_bco(
            train_recs, policy, ref, tok, betas=BETAS, deltas=DELTAS
        )
    else:
        # Optionally estimate δ via BCO at default β
        best_params = BCOParams(beta=1.0, delta=estimate_delta_bco(base_s_tr, 1.0, N_tr))

    # TEST evaluation (label-free)
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as a proxy sanity check.")

    print("\nScoring TEST...")
    base_s_te, N_te = cache_policy_and_ref_scores(eval_set, policy, ref, tok)
    wr_te = bco_win_rate(base_s_te, N_te, best_params)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[BCO EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={best_params.beta} delta={best_params.delta:.4f}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading dataset: openbmb/UltraFeedback [train[:10%]] config=None ...


README.md: 0.00B [00:00, ?B/s]

evol_instruct.jsonl:   0%|          | 0.00/168M [00:00<?, ?B/s]

false_qa.jsonl:   0%|          | 0.00/25.9M [00:00<?, ?B/s]

flan.jsonl:   0%|          | 0.00/240M [00:00<?, ?B/s]

sharegpt.jsonl:   0%|          | 0.00/313M [00:00<?, ?B/s]

truthful_qa.jsonl: 0.00B [00:00, ?B/s]

ultrachat.jsonl:   0%|          | 0.00/182M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/63967 [00:00<?, ? examples/s]

Adapted examples (gap ≥ 1.0): 6205
Split sizes: TRAIN=4964 | TEST=1241

Scoring TRAIN...
[TRAIN BCO baseline] WR=75.44% [74.2, 76.6] on 4964 valid pairs

Tuning on TRAIN (BCO)...
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=0.5 delta=bco
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=0.5 delta=0.0000
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=1.0 delta=bco
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=1.0 delta=0.0000
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=2.0 delta=bco
[TUNE/TRAIN BCO] WR=75.44% [74.2, 76.6] | beta=2.0 delta=0.0000

[BCO Tuning (TRAIN)] Best WR=75.44% [74.2, 76.6] with beta=0.5 delta=-0.2136

Scoring TEST...
[BCO EVAL (TEST, no labels)] WR=74.38% [71.9, 76.8] | beta=0.5 delta=-0.2136


In [None]:
# bp_llm_ultrafeedback_eval_openbmb.py
# Evaluate BP-LLM (unary JJ) win rate on openbmb/UltraFeedback.
# Improvements:
#  - Pairing: top-vs-bottom with minimum score gap (drop near-ties)
#  - Length norm: mean log-prob per token (reduces length bias)
#  - Skip truncated examples (no fake zeros)
#  - Optional chat template for Instruct models
#  - Optional reference model (log pi - log pi_ref)
#  - Train/Test split: tune on train (labels), evaluate on test (NO labels)

import math
import itertools
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =========================
# Config
# =========================
# Put your HF token if you use gated models
HF_TOKEN = None

# Suggested: Instruct as policy, Base as reference
MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"   # None to disable reference subtraction

DATASET         = "openbmb/UltraFeedback"
DATASET_CONFIG  = None                # only 'default' typically
SPLIT           = "train[:10%]"       # keep small while iterating

# Train/Test split from the adapted pairs (stratification not required here)
TRAIN_FRAC      = 0.8
SEED            = 42
MIN_GAP         = 1.0                 # require at least this helpfulness score gap

# Input limits
MAX_INPUT_TOKENS = 1024
MAX_GEN_TOKENS   = 512
BATCH_SIZE       = 4

# Chat template toggle (recommended for *Instruct*)
USE_CHAT_TEMPLATE = True

# Length normalization: "mean" or "sum"
LENGTH_NORM = "mean"

# BP hyperparams (defaults)
BETA            = 1.0
DELTA           = 0.0
TAU             = 1.0
GAMMA           = 1.0
JJ_INNER_STEPS  = 2

# Grid tuning
DO_TUNE         = True
BETAS           = [0.5, 1.0, 2.0]
DELTAS          = ['bco', 0.0]        # try BCO shift and zero shift
TAUS            = [0.5, 1.0, 2.0]
GAMMAS          = [0.5, 1.0, 1.5, 2.0]
JJ_STEPS_LIST   = [1, 2, 3, 0]        # 0 => adaptive (tolerance-based)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


# =========================
# Schema helpers
# =========================
def _extract_text(c: Dict) -> Optional[str]:
    if not isinstance(c, dict):
        return None
    for k in ("text", "response", "output", "completion", "content"):
        v = c.get(k)
        if isinstance(v, str) and v.strip():
            return v
    res = c.get("result")
    if isinstance(res, dict):
        v = res.get("text") or res.get("response") or res.get("output")
        if isinstance(v, str) and v.strip():
            return v
    return None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _extract_score(c: Dict) -> Optional[float]:
    """
    Extract a scalar score from completion annotations.
    Prefer helpfulness rating; fall back to any numeric fields.
    """
    if not isinstance(c, dict):
        return None
    ann = c.get("annotations")
    if isinstance(ann, dict):
        help_ = ann.get("helpfulness")
        if isinstance(help_, dict):
            r = help_.get("Rating") or help_.get("rating") or help_.get("score")
            r = _to_float(r)
            if r is not None:
                return r
        for key in ("overall", "quality", "correctness", "honesty", "safety"):
            sub = ann.get(key)
            if isinstance(sub, dict):
                r = sub.get("Rating") or sub.get("rating") or sub.get("score")
                r = _to_float(r)
                if r is not None:
                    return r
        # Any flat numeric
        for _, v in ann.items():
            fv = _to_float(v)
            if fv is not None:
                return fv
    for k in ("score", "rating", "rank"):
        fv = _to_float(c.get(k))
        if fv is not None:
            return fv
    return None

def adapt_openbmb_ultrafeedback(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) by picking the highest-scored completion
    vs the lowest-scored completion, and drop near-ties (< min_gap).
    """
    prompt = record.get("instruction")
    if not isinstance(prompt, str) or not prompt.strip():
        return None

    comps = record.get("completions")
    if not isinstance(comps, list) or len(comps) < 2:
        return None

    pairs = []
    for c in comps:
        text = _extract_text(c)
        score = _extract_score(c)
        if isinstance(text, str) and text.strip() and score is not None:
            pairs.append((text, float(score)))

    if len(pairs) < 2:
        return None

    pairs.sort(key=lambda t: t[1])  # low ... high
    lo_txt, lo_s = pairs[0]
    hi_txt, hi_s = pairs[-1]
    if hi_s - lo_s < min_gap:
        return None

    return {"prompt": prompt, "chosen": hi_txt, "rejected": lo_txt}


# =========================
# Tokenization / Scoring
# =========================
def format_prompt_with_chat_template(tokenizer, prompt: str) -> str:
    if USE_CHAT_TEMPLATE and hasattr(tokenizer, "apply_chat_template"):
        messages = [{"role": "user", "content": prompt}]
        try:
            return tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True  # leaves assistant prefix for continuation
            )
        except Exception:
            pass
    return prompt

def concat_prompt_response_text(tokenizer, prompt: str, response: str) -> Tuple[str, str]:
    """Return (full_text, prompt_only_text) for token counting."""
    prompt_text = format_prompt_with_chat_template(tokenizer, prompt).strip()
    # Append the response after the assistant prompt prefix (if any)
    full_text = prompt_text + ("\n" if not prompt_text.endswith("\n") else "") + response.strip()
    return full_text, prompt_text

@torch.no_grad()
def sequence_logprob_list(
    model,
    tokenizer,
    prompts: List[str],
    responses: List[str],
    length_norm: str = LENGTH_NORM,
) -> List[Optional[float]]:
    """
    Returns mean/sum log-prob for each (prompt, response). If response was truncated
    (i.e., prompt consumed the full sequence), returns None for that item.
    """
    vals: List[Optional[float]] = []
    for i in range(0, len(prompts), BATCH_SIZE):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        batch_resps   = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p, r in zip(batch_prompts, batch_resps):
            full_text, prompt_text = concat_prompt_response_text(tokenizer, p, r)
            toks = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            tok_prompt = tokenizer(
                prompt_text,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = tok_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L     = int(masks.sum().item())
            if p_len >= L:
                vals.append(None)  # response totally truncated -> skip
                continue
            targ = ids[p_len:L]              # response tokens
            pred = logprobs[b, p_len-1:L-1]  # shifted predictions
            lp = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1)
            if length_norm == "mean":
                score = float(lp.mean().cpu())
            else:
                score = float(lp.sum().cpu())
            vals.append(score)
    return vals


# =========================
# Model loader
# =========================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype)
    model = model.to(device)
    model.eval()
    return tok, model


# =========================
# BP-LLM (unary JJ) posterior
# =========================
@dataclass
class BPParams:
    beta: float = BETA
    delta: float = DELTA
    tau: float = TAU
    gamma: float = GAMMA
    jj_steps: int = JJ_INNER_STEPS

def jj_lambda(xi: float) -> float:
    if xi < 1e-8:
        return 1.0 / 8.0
    return math.tanh(xi / 2.0) / (4.0 * xi)

def bp_unary_posterior(mu_prior: float, b: int, params: BPParams) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    for _ in range(params.jj_steps):
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
    return mu_hat, tau2_hat

def bp_unary_posterior_adaptive(mu_prior: float, b: int, params: BPParams,
                                tol: float = 1e-4, max_steps: int = 50) -> Tuple[float, float]:
    tau2 = params.tau ** 2
    gamma_tilde = params.gamma * (2 * b - 1)
    mu_hat = mu_prior
    tau2_hat = tau2
    steps = params.jj_steps if params.jj_steps and params.jj_steps > 0 else max_steps
    for _ in range(steps):
        mu_prev = mu_hat
        xi = abs(params.gamma) * math.sqrt(mu_hat * mu_hat + tau2_hat)
        lam = jj_lambda(xi)
        Lambda = (1.0 / tau2) + 2.0 * lam
        eta    = (mu_prior / tau2) + 0.5 * gamma_tilde
        mu_hat = eta / Lambda
        tau2_hat = 1.0 / Lambda
        if (params.jj_steps is None or params.jj_steps <= 0) and abs(mu_hat - mu_prev) < tol:
            break
    return mu_hat, tau2_hat


# =========================
# Caching scores (pairwise)
# =========================
@torch.no_grad()
def cache_pairwise_scores(
    records: List[Dict],
    model,
    ref_model,
    tokenizer
) -> Tuple[torch.Tensor, int]:
    """
    Returns base_s tensor shaped [2N] where first N are chosen scores, next N are rejected,
    and N is the number of *valid* pairs (both sides scored).
    Each score is (log pi - log pi_ref) with length normalization.
    """
    prompts_ch  = [r["prompt"] for r in records]
    prompts_rj  = [r["prompt"] for r in records]
    resps_ch    = [r["chosen"] for r in records]
    resps_rj    = [r["rejected"] for r in records]

    policy_ch = sequence_logprob_list(model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
    policy_rj = sequence_logprob_list(model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    if ref_model is None:
        ref_ch = [0.0 if (s is not None) else None for s in policy_ch]
        ref_rj = [0.0 if (s is not None) else None for s in policy_rj]
    else:
        ref_ch = sequence_logprob_list(ref_model, tokenizer, prompts_ch, resps_ch, LENGTH_NORM)
        ref_rj = sequence_logprob_list(ref_model, tokenizer, prompts_rj, resps_rj, LENGTH_NORM)

    chosen_vals, rejected_vals = [], []
    for sc, sr, rc, rr in zip(policy_ch, policy_rj, ref_ch, ref_rj):
        if (sc is None) or (sr is None) or (rc is None) or (rr is None):
            continue  # drop pair if any side was truncated
        chosen_vals.append(sc - rc)
        rejected_vals.append(sr - rr)

    if len(chosen_vals) == 0:
        raise RuntimeError("No valid pairs after scoring. Consider increasing MAX_*_TOKENS or using 'sum' norm.")

    base_s = torch.tensor(chosen_vals + rejected_vals, dtype=torch.float32)
    N = len(chosen_vals)
    return base_s, N


# =========================
# Evaluation & tuning
# =========================
def estimate_delta_bco(base_s: torch.Tensor, beta: float, N: int) -> float:
    r_pos = beta * base_s[:N]
    r_neg = beta * base_s[N:]
    return 0.5 * (float(r_pos.mean()) + float(r_neg.mean()))

def evaluate_with_cached(
    base_s: torch.Tensor, N: int, params: BPParams,
    use_adaptive_jj: bool = True, use_labels: bool = True
) -> float:
    """
    If use_labels=False, compare mu_prior only (label-free test).
    """
    mu = (params.beta * base_s) - params.delta
    mu = mu.tolist()

    correct = 0
    for i in range(N):
        mu_w_prior = mu[i]
        mu_l_prior = mu[i + N]

        if not use_labels:
            # label-free: just compare priors
            if mu_w_prior > mu_l_prior:
                correct += 1
            continue

        if use_adaptive_jj:
            mu_w_post, _ = bp_unary_posterior_adaptive(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior_adaptive(mu_l_prior, b=0, params=params)
        else:
            mu_w_post, _ = bp_unary_posterior(mu_w_prior, b=1, params=params)
            mu_l_post, _ = bp_unary_posterior(mu_l_prior, b=0, params=params)
        if mu_w_post > mu_l_post:
            correct += 1
    return 100.0 * correct / N

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def grid_search(
    records: List[Dict], model, ref_model, tokenizer,
    betas: Iterable[float], deltas: Iterable[Optional[str]],
    taus: Iterable[float], gammas: Iterable[float],
    jj_steps_list: Iterable[int], use_adaptive_jj: bool = True,
    estimate_delta: bool = True, split_name: str = "TRAIN"
):
    base_s, N = cache_pairwise_scores(records, model, ref_model, tokenizer)
    tried = []
    best = (-1.0, None)

    for beta, delta_spec, tau, gamma, jj_steps in itertools.product(
        betas, deltas, taus, gammas, jj_steps_list
    ):
        if estimate_delta and (delta_spec == 'bco'):
            delta = estimate_delta_bco(base_s, beta, N)
        else:
            delta = float(delta_spec)

        params = BPParams(beta=beta, delta=delta, tau=tau, gamma=gamma, jj_steps=jj_steps)
        # TRAIN uses labels:
        wr = evaluate_with_cached(base_s, N, params, use_adaptive_jj=use_adaptive_jj, use_labels=True)
        lo, hi = binom_ci_95(wr, N)
        tried.append((wr, lo, hi, params))
        if wr > best[0]:
            best = (wr, params)

        print(f"[TUNE/{split_name}] WR={wr:.2f}% [{lo:.1f}, {hi:.1f}] "
              f"| beta={beta} delta={'bco' if delta_spec=='bco' else f'{delta:.4f}'} "
              f"| tau={tau} gamma={gamma} jj_steps={jj_steps}")

    best_wr, best_params = best
    lo, hi = binom_ci_95(best_wr, N)
    print(f"\n[BP-LLM Tuning ({split_name})] Best WR={best_wr:.2f}% [{lo:.1f}, {hi:.1f}] with "
          f"beta={best_params.beta} delta={best_params.delta:.4f} "
          f"tau={best_params.tau} gamma={best_params.gamma} jj_steps={best_params.jj_steps}")
    return tried, best_params


# =========================
# Main
# =========================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    ref = None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        _, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    if DATASET_CONFIG:
        ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT)
    else:
        ds = load_dataset(DATASET, split=SPLIT)

    print("Columns:", ds.column_names)
    if len(ds) > 0:
        sample0 = ds[0]
        print("Row[0] keys:", list(sample0.keys()))
        preview = {k: (str(sample0[k])[:200] + "…") for k in sample0.keys()}
        print("Row[0] preview:", preview)

    # Adapt to (prompt, chosen, rejected) with top-vs-bottom + gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_openbmb_ultrafeedback(rec, min_gap=MIN_GAP)
        if a and a["prompt"] and a["chosen"] and a["rejected"]:
            adapted.append(a)

    print(f"Adapted examples (after gap>={MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError(
            "No valid examples after adapting UltraFeedback. "
            "Try reducing MIN_GAP, or inspect rows to tune _extract_*."
        )

    # Deterministic split
    g = torch.Generator().manual_seed(SEED)
    idx = torch.randperm(len(adapted), generator=g).tolist()
    split = int(len(idx) * TRAIN_FRAC)
    train_idx, test_idx = idx[:split], idx[split:]
    train_records = [adapted[i] for i in train_idx]
    test_records  = [adapted[i] for i in test_idx]
    print(f"Train records: {len(train_records)} | Test records: {len(test_records)}")

    # QUICK sanity check (non-adaptive JJ) on TRAIN with default params
    params0 = BPParams(beta=BETA, delta=DELTA, tau=TAU, gamma=GAMMA, jj_steps=JJ_INNER_STEPS)
    base_s_tr, N_tr = cache_pairwise_scores(train_records, policy, ref, tok)
    wr0 = evaluate_with_cached(base_s_tr, N_tr, params0, use_adaptive_jj=False, use_labels=True)
    lo0, hi0 = binom_ci_95(wr0, N_tr)
    print(f"[Sanity TRAIN] WR={wr0:.2f}% [{lo0:.1f}, {hi0:.1f}] on {N_tr} pairs "
          f"| beta={params0.beta} delta={params0.delta} tau={params0.tau} gamma={params0.gamma} "
          f"jj_steps={params0.jj_steps}")

    # Tune on TRAIN (uses labels)
    if DO_TUNE:
        print("\nTuning on TRAIN...")
        _tries, best_params = grid_search(
            train_records, policy, ref, tok,
            betas=BETAS, deltas=DELTAS, taus=TAUS, gammas=GAMMAS,
            jj_steps_list=JJ_STEPS_LIST, use_adaptive_jj=True, estimate_delta=True, split_name="TRAIN"
        )
    else:
        best_params = params0

    # Final eval on TEST with NO labels (gamma=0 and JJ disabled)
    base_s_te, N_te = cache_pairwise_scores(test_records, policy, ref, tok)
    test_params = BPParams(
        beta=best_params.beta,
        delta=best_params.delta,   # note: cancels if you compare priors, but keep for consistency
        tau=best_params.tau,
        gamma=0.0,                 # NO labels on test
        jj_steps=0                 # disable JJ on test
    )
    wr_te = evaluate_with_cached(base_s_te, N_te, test_params, use_adaptive_jj=False, use_labels=False)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"\n[BP-LLM EVAL (TEST, no labels)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] "
          f"| beta={test_params.beta} delta={test_params.delta:.4f} "
          f"| tau={test_params.tau} gamma={test_params.gamma} jj_steps={test_params.jj_steps}")



if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Loading reference model ...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading dataset: openbmb/UltraFeedback [train[:10%]] config=None ...
Columns: ['source', 'instruction', 'models', 'completions', 'correct_answers', 'incorrect_answers']
Row[0] keys: ['source', 'instruction', 'models', 'completions', 'correct_answers', 'incorrect_answers']
Row[0] preview: {'source': 'evol_instruct…', 'instruction': "Can you write a C++ program that prompts the user to enter the name of a country and checks if it borders the Mediterranean Sea? Here's some starter code to help you out:\n#include <iostream>\n#include …", 'models': "['alpaca-7b', 'pythia-12b', 'starchat', 'vicuna-33b']…", 'completions': "[{'annotations': {'helpfulness': {'Rating': '2', 'Rationale': 'The response is clear and not lengthy, but it lacks useful and comprehensive information.', 'Rationale For Rating': 'The code is partiall…", 'correct_answers': "['None']…", 'incorrect_answers': "['None']…"}
Adapted examples (after gap>=1.0): 6205
Train records: 4964 | Test records: 1241
[Sanity TRAIN] WR=99.36% 

In [None]:
# dpo_ultrafeedback_eval_openbmb.py
# UltraFeedback DPO baseline (leakage-safe)
#
# What this does:
#   • Adapts UltraFeedback to (prompt, chosen, rejected) via top-vs-bottom with a minimum score gap.
#   • Renders the chat template ONCE with the policy tokenizer; feeds same rendered text to policy and ref.
#   • Computes DPO margins Δ = (s_pos - s_neg) with s = log π(y|x) − log π_ref(y|x) (ref optional).
#   • TRAIN: tune β by minimizing mean DPO loss  L_DPO = E[-log σ(βΔ)]  (also reports WR).
#   • TEST: evaluate win rate (labels only for metric; scoring is label-free).
#   • Length-normalized scoring over response tokens; drops truncated pairs.
#
# Notes:
#   • Set REF_MODEL_NAME=None to disable reference subtraction (plain log π).
#   • Lower MIN_GAP if you get too few pairs; raise *_TOKENS if truncation is frequent.
#   • This is an evaluation/tuning script (no gradient updates).
#
# Requires: datasets, transformers, torch

import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple, Iterable

import torch
from torch.nn.functional import log_softmax
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# =============================================================================
# Config
# =============================================================================
# HF_TOKEN = os.environ.get("HF_TOKEN")

# Policy (Instruct) + optional Base reference (Llama 3.2)
MODEL_NAME      = "meta-llama/Llama-3.1-8B-Instruct"
REF_MODEL_NAME  = "meta-llama/Llama-3.1-8B"     # set to None to disable reference subtraction

DATASET         = "openbmb/UltraFeedback"
DATASET_CONFIG  = None
SPLIT           = "train[:10%]"                 # iterate small; use "train" later

# Inference/scoring limits
MAX_INPUT_TOKENS     = 1024
MAX_GEN_TOKENS       = 512
SCORE_MAX_GEN_TOKENS = 256
BATCH_SIZE           = 4

# Scoring
SCORING_MODE         = "mean"      # {"mean", "sum", "lp_alpha"}
LENGTH_PENALTY_ALPHA = 0.0         # only for SCORING_MODE == "lp_alpha"

# Pair filtering
MIN_GAP              = 1.0         # minimum helpfulness-score gap to keep a pair

# Train / test split
RNG_SEED   = 42
TRAIN_FRAC = 0.8

# Grid tuning (TRAIN only)
DO_TUNE = True
BETAS   = [0.25, 0.5, 1.0, 2.0]    # temperature-like scale for Δ

# Device / dtype
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.bfloat16 if torch.cuda.is_available() else torch.float32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_grad_enabled(False)

# =============================================================================
# UF helpers (adapt)
# =============================================================================
def _extract_text(c: Dict) -> Optional[str]:
    if not isinstance(c, dict):
        return None
    for k in ("text", "response", "output", "completion", "content"):
        v = c.get(k)
        if isinstance(v, str) and v.strip():
            return v
    res = c.get("result")
    if isinstance(res, dict):
        v = res.get("text") or res.get("response") or res.get("output")
        if isinstance(v, str) and v.strip():
            return v
    return None

def _to_float(x) -> Optional[float]:
    try:
        return float(x)
    except Exception:
        return None

def _extract_score(c: Dict) -> Optional[float]:
    if not isinstance(c, dict):
        return None
    ann = c.get("annotations")
    if isinstance(ann, dict):
        help_ = ann.get("helpfulness")
        if isinstance(help_, dict):
            r = help_.get("Rating") or help_.get("rating") or help_.get("score")
            r = _to_float(r)
            if r is not None:
                return r
        for key in ("overall", "quality", "correctness", "honesty", "safety"):
            sub = ann.get(key)
            if isinstance(sub, dict):
                r = sub.get("Rating") or sub.get("rating") or sub.get("score")
                r = _to_float(r)
                if r is not None:
                    return r
        for _, v in ann.items():
            fv = _to_float(v)
            if fv is not None:
                return fv
    for k in ("score", "rating", "rank"):
        fv = _to_float(c.get(k))
        if fv is not None:
            return fv
    return None

def adapt_openbmb_ultrafeedback(record: Dict, min_gap: float = MIN_GAP) -> Optional[Dict]:
    """
    Build (prompt, chosen, rejected) as top-vs-bottom by score, with a minimum score gap.
    Drops near-ties to reduce label noise.
    """
    prompt = record.get("instruction")
    if not isinstance(prompt, str) or not prompt.strip():
        return None
    comps = record.get("completions")
    if not isinstance(comps, list) or len(comps) < 2:
        return None

    pairs = []
    for c in comps:
        text = _extract_text(c)
        score = _extract_score(c)
        if isinstance(text, str) and text.strip() and (score is not None):
            pairs.append((text, float(score)))
    if len(pairs) < 2:
        return None

    pairs.sort(key=lambda t: t[1])  # low ... high
    lo_txt, lo_s = pairs[0]
    hi_txt, hi_s = pairs[-1]
    if (hi_s - lo_s) < min_gap:
        return None

    return {"prompt": prompt, "chosen": hi_txt, "rejected": lo_txt}

# =============================================================================
# Models
# =============================================================================
def load_causal_lm(model_id: str, token: Optional[str], dtype=DTYPE, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_id, token=token, use_fast=True, trust_remote_code=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=dtype, trust_remote_code=True).to(device)
    model.eval()
    return tok, model

# =============================================================================
# Prompt rendering (policy chat template) + scoring
# =============================================================================
def render_prompt_with_policy_template(policy_tokenizer, prompt: str) -> str:
    """
    Render chat template ONCE with the policy tokenizer.
    The resulting text is fed to BOTH policy and reference models.
    """
    if hasattr(policy_tokenizer, "apply_chat_template"):
        msgs = [{"role": "user", "content": prompt.strip()}]
        try:
            return policy_tokenizer.apply_chat_template(
                msgs, tokenize=False, add_generation_prompt=True
            )
        except Exception:
            pass
    p = prompt.strip()
    return p if p.endswith("\n") else (p + "\n")

def render_prompt_list(policy_tokenizer, prompts: List[str]) -> List[str]:
    return [render_prompt_with_policy_template(policy_tokenizer, p) for p in prompts]

def _concat_prompt_response_text(prompt_text: str, response: str) -> Tuple[str, str]:
    full = prompt_text + response.strip()
    return full, prompt_text

@torch.no_grad()
def sequence_logprob_stats_text(
    model,
    tokenizer,
    prompt_texts: List[str],   # already-rendered with policy template
    responses: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Returns two tensors of shape [B]: (lp_sum, tok_count)
    – lp_sum is the sum of log-probs over response tokens
    – tok_count is the number of response tokens scored
    If response is fully truncated (no tokens past the prompt), tok_count=0
    """
    sums, counts = [], []
    for i in range(0, len(prompt_texts), BATCH_SIZE):
        p_batch = prompt_texts[i:i+BATCH_SIZE]
        r_batch = responses[i:i+BATCH_SIZE]

        batch_inputs, batch_prompt_lens = [], []
        for p_txt, r in zip(p_batch, r_batch):
            full_text, prompt_only = _concat_prompt_response_text(p_txt, r)

            toks_full = tokenizer(
                full_text,
                truncation=True,
                max_length=min(MAX_INPUT_TOKENS + MAX_GEN_TOKENS, tokenizer.model_max_length),
                return_tensors="pt",
                add_special_tokens=True,
            )
            toks_prompt = tokenizer(
                prompt_only,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                return_tensors="pt",
                add_special_tokens=True,
            )
            p_len = toks_prompt["input_ids"].shape[-1]
            batch_inputs.append(toks_full)
            batch_prompt_lens.append(p_len)

        pad_id = tokenizer.pad_token_id or 0
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [bi["input_ids"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=pad_id
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [bi["attention_mask"].squeeze(0) for bi in batch_inputs],
            batch_first=True, padding_value=0
        )
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        logprobs = log_softmax(logits.to(torch.float32), dim=-1)

        for b in range(input_ids.size(0)):
            p_len = batch_prompt_lens[b]
            ids   = input_ids[b]
            masks = attention_mask[b]
            L = int(masks.sum().item())
            if p_len >= L:
                sums.append(torch.tensor(0.0)); counts.append(torch.tensor(0)); continue
            end  = min(L, p_len + SCORE_MAX_GEN_TOKENS)
            targ = ids[p_len:end]
            pred = logprobs[b, p_len-1:end-1]
            lp_sum = pred.gather(-1, targ.unsqueeze(-1)).squeeze(-1).sum()
            sums.append(lp_sum.cpu())
            counts.append(torch.tensor(targ.numel()))
    return torch.stack(sums, dim=0), torch.stack(counts, dim=0)

def _score_from_stats(lp_sum: torch.Tensor, tok_count: torch.Tensor) -> torch.Tensor:
    if SCORING_MODE == "mean":
        denom = torch.clamp(tok_count.to(torch.float32), min=1.0)
        return lp_sum.to(torch.float32) / denom
    elif SCORING_MODE == "lp_alpha":
        return lp_sum.to(torch.float32) + (LENGTH_PENALTY_ALPHA * tok_count.to(torch.float32))
    else:  # "sum"
        return lp_sum.to(torch.float32)

# =============================================================================
# DPO pieces
# =============================================================================
@dataclass
class DPOTune:
    beta: float = 1.0

def sigmoid(x: torch.Tensor) -> torch.Tensor:
    return 1 / (1 + torch.exp(-x))

@torch.no_grad()
def cache_dpo_margins(
    records: List[Dict],
    policy_model,
    ref_model,
    policy_tok,
    ref_tok=None
) -> Tuple[torch.Tensor, int]:
    """
    Returns Δ tensor of shape [N_valid] where
      Δ_i = (s_pos_i - s_neg_i),
      s = log π(y|x) - log π_ref(y|x)  (ref optional)
    Valid pairs require ≥1 response token scored for BOTH sides (and for ref if used).
    """
    prompts_raw = [r["prompt"] for r in records]
    prompt_texts = render_prompt_list(policy_tok, prompts_raw)

    resps_ch = [r["chosen"] for r in records]
    resps_rj = [r["rejected"] for r in records]

    # Policy (own tokenizer)
    pol_sum_ch, pol_cnt_ch = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_ch)
    pol_sum_rj, pol_cnt_rj = sequence_logprob_stats_text(policy_model, policy_tok, prompt_texts, resps_rj)
    pol_score_ch = _score_from_stats(pol_sum_ch, pol_cnt_ch)
    pol_score_rj = _score_from_stats(pol_sum_rj, pol_cnt_rj)

    # Reference (optional; own tokenizer; same prompt_texts)
    if ref_model is None:
        ref_score_ch = torch.zeros_like(pol_score_ch); ref_cnt_ch = torch.ones_like(pol_cnt_ch)
        ref_score_rj = torch.zeros_like(pol_score_rj); ref_cnt_rj = torch.ones_like(pol_cnt_rj)
    else:
        if ref_tok is None:
            raise ValueError("ref_model provided without ref_tok")
        ref_sum_ch, ref_cnt_ch = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_ch)
        ref_sum_rj, ref_cnt_rj = sequence_logprob_stats_text(ref_model, ref_tok, prompt_texts, resps_rj)
        ref_score_ch = _score_from_stats(ref_sum_ch, ref_cnt_ch)
        ref_score_rj = _score_from_stats(ref_sum_rj, ref_cnt_rj)

    # Validity mask: both sides non-truncated for both models (if ref used)
    valid_ch = (pol_cnt_ch > 0) & (ref_cnt_ch > 0)
    valid_rj = (pol_cnt_rj > 0) & (ref_cnt_rj > 0)
    valid_idx = [i for i in range(len(records)) if valid_ch[i].item() and valid_rj[i].item()]
    if not valid_idx:
        raise RuntimeError("No valid (non-truncated) pairs. Increase token limits or lower MIN_GAP.")

    # s = log π - log π_ref (ref=0 if disabled); Δ = s_pos - s_neg
    s_pos = pol_score_ch[valid_idx] - ref_score_ch[valid_idx]
    s_neg = pol_score_rj[valid_idx] - ref_score_rj[valid_idx]
    delta = (s_pos - s_neg).to(torch.float32)  # [N_valid]
    return delta, len(valid_idx)

def dpo_loss_and_wr(delta: torch.Tensor, beta: float) -> Tuple[float, float]:
    """
    Mean DPO loss and win rate for a fixed β.
    Loss = E[-log σ(βΔ)], WR = mean(Δ > 0)*100
    """
    z = beta * delta
    # stabilize: -log σ(z) = softplus(-z)
    loss = torch.nn.functional.softplus(-z).mean().item()
    wr = (delta > 0).to(torch.float32).mean().item() * 100.0
    return loss, wr

def binom_ci_95(pct: float, N: int) -> Tuple[float, float]:
    p = pct / 100.0
    se = math.sqrt(p * (1 - p) / max(N, 1))
    lo = max(0.0, 100.0 * (p - 1.96 * se))
    hi = min(100.0, 100.0 * (p + 1.96 * se))
    return lo, hi

def tune_beta_dpo(
    records: List[Dict], policy_model, ref_model, policy_tok, ref_tok,
    betas: Iterable[float], split_name: str = "TRAIN"
) -> Tuple[List[Tuple[float,float,float,float]], DPOTune]:
    """
    Returns (trials, best), where each trial is (beta, loss, wr, N).
    Selects β minimizing mean DPO loss on TRAIN (ties broken by higher WR).
    """
    delta, N = cache_dpo_margins(records, policy_model, ref_model, policy_tok, ref_tok)
    trials = []
    best = (float("inf"), -1.0, None)  # (loss, wr, beta)
    for beta in betas:
        loss, wr = dpo_loss_and_wr(delta, beta)
        trials.append((beta, loss, wr, N))
        lo, hi = binom_ci_95(wr, N)
        print(f"[TUNE/{split_name} DPO] beta={beta:.3g} | loss={loss:.4f} | WR={wr:.2f}% [{lo:.1f}, {hi:.1f}]")
        if (loss < best[0]) or (abs(loss - best[0]) < 1e-6 and wr > best[1]):
            best = (loss, wr, beta)

    best_beta = best[2]
    print(f"\n[DPO Tuning ({split_name})] Best beta={best_beta:.3g} "
          f"(loss={best[0]:.4f}, WR={best[1]:.2f}%) on {N} pairs")
    return trials, DPOTune(beta=best_beta)

# =============================================================================
# Main
# =============================================================================
def main():
    print(f"Device: {DEVICE} | Dtype: {DTYPE}")
    print("Loading policy model ...")
    tok_policy, policy = load_causal_lm(MODEL_NAME, HF_TOKEN)

    tok_ref, ref = None, None
    if REF_MODEL_NAME:
        print("Loading reference model ...")
        tok_ref, ref = load_causal_lm(REF_MODEL_NAME, HF_TOKEN)

    print(f"Loading dataset: {DATASET} [{SPLIT}] config={DATASET_CONFIG} ...")
    ds = load_dataset(DATASET, name=DATASET_CONFIG, split=SPLIT) if DATASET_CONFIG else load_dataset(DATASET, split=SPLIT)

    # Adapt to pairs with min gap
    adapted: List[Dict] = []
    for rec in ds:
        a = adapt_openbmb_ultrafeedback(rec, min_gap=MIN_GAP)
        if a is not None:
            adapted.append(a)
    print(f"Adapted examples (gap ≥ {MIN_GAP}): {len(adapted)}")
    if not adapted:
        raise RuntimeError("No valid examples after adaptation. Lower MIN_GAP or inspect data.")

    # Deterministic split
    random.Random(RNG_SEED).shuffle(adapted)
    n_all   = len(adapted)
    n_train = max(1, int(TRAIN_FRAC * n_all))
    train_recs = adapted[:n_train]
    test_recs  = adapted[n_train:]
    print(f"Split sizes: TRAIN={len(train_recs)} | TEST={len(test_recs)}")

    # Tune β on TRAIN (min DPO loss; report WR too)
    if DO_TUNE:
        print("\nTuning beta on TRAIN (DPO)...")
        _tries, best = tune_beta_dpo(train_recs, policy, ref, tok_policy, tok_ref, betas=BETAS, split_name="TRAIN")
        beta = best.beta
    else:
        beta = 1.0

    # Evaluate on TEST
    eval_set = test_recs if len(test_recs) > 0 else train_recs
    if len(test_recs) == 0:
        print("\nWARNING: TEST set is empty (small SPLIT/TRAIN_FRAC). Using TRAIN as proxy sanity check.")

    print("\nScoring TEST (DPO metrics)...")
    delta_te, N_te = cache_dpo_margins(eval_set, policy, ref, tok_policy, tok_ref)
    loss_te, wr_te = dpo_loss_and_wr(delta_te, beta)
    lo_te, hi_te = binom_ci_95(wr_te, N_te)
    print(f"[DPO EVAL (TEST)] WR={wr_te:.2f}% [{lo_te:.1f}, {hi_te:.1f}] | loss={loss_te:.4f} | beta={beta:.3g} | N={N_te}")

if __name__ == "__main__":
    main()


Device: cuda | Dtype: torch.bfloat16
Loading policy model ...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Loading reference model ...


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

Loading dataset: openbmb/UltraFeedback [train[:10%]] config=None ...


README.md: 0.00B [00:00, ?B/s]

evol_instruct.jsonl:   0%|          | 0.00/168M [00:00<?, ?B/s]

false_qa.jsonl:   0%|          | 0.00/25.9M [00:00<?, ?B/s]

flan.jsonl:   0%|          | 0.00/240M [00:00<?, ?B/s]

sharegpt.jsonl:   0%|          | 0.00/313M [00:00<?, ?B/s]

truthful_qa.jsonl: 0.00B [00:00, ?B/s]

ultrachat.jsonl:   0%|          | 0.00/182M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/63967 [00:00<?, ? examples/s]

Adapted examples (gap ≥ 1.0): 6205
Split sizes: TRAIN=4964 | TEST=1241

Tuning beta on TRAIN (DPO)...
[TUNE/TRAIN DPO] beta=0.25 | loss=0.6677 | WR=78.85% [77.7, 80.0]
[TUNE/TRAIN DPO] beta=0.5 | loss=0.6475 | WR=78.85% [77.7, 80.0]
[TUNE/TRAIN DPO] beta=1 | loss=0.6170 | WR=78.85% [77.7, 80.0]
[TUNE/TRAIN DPO] beta=2 | loss=0.5818 | WR=78.85% [77.7, 80.0]

[DPO Tuning (TRAIN)] Best beta=2 (loss=0.5818, WR=78.85%) on 4964 pairs

Scoring TEST (DPO metrics)...
[DPO EVAL (TEST)] WR=78.65% [76.4, 80.9] | loss=0.5855 | beta=2 | N=1241
