In [None]:
%pip install transformers accelerate torch --upgrade

In [None]:
import re
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

In [None]:
# --- simple "stepwise" reward using a final-answer reward model ---
def reward_step_score(prompt, text, rw, rw_tok, device="cuda"):
    sents = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text.strip()) if s.strip()]
    if not sents:
        return 0.0
    pref = prompt.strip() + "\n"
    scores = []
    for i in range(1, len(sents) + 1):
        part = pref + " ".join(sents[:i])
        inp = rw_tok(part, return_tensors="pt", truncation=True).to(device)
        with torch.no_grad():
            v = rw(**inp).logits[0].item()
        scores.append(float(v))
    return sum(scores) / len(scores)

In [None]:
# --- tiny helpers (short names) ---
def chat(p):  # adjust to your chat template if needed
    return f"<|system|>\nYou are a helpful assistant.\n<|user|>\n{p}\n<|assistant|>\n"

def gen(lm, lm_tok, prefix, k, device="cuda", max_new_tokens=96):
    inp = lm_tok(prefix, return_tensors="pt").to(device)
    with torch.no_grad():
        out = lm.generate(
            **inp,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.9,
            top_p=0.95,
            num_return_sequences=k,
            pad_token_id=lm_tok.eos_token_id,
        )
    return [lm_tok.decode(o, skip_special_tokens=True) for o in out]

def cont(parent, child):
    return child[len(parent):].strip() if child.startswith(parent) else child

In [None]:
# --- basic beam search with reward-only ranking ---
def beam_search_reward(prompt, lm, lm_tok, rw, rw_tok,
                       N=4, M=2, rounds=2, max_new_tokens=96, device="cuda"):
    """
    N = total kept per round, M = parents to expand (N % M == 0),
    rounds = how many expansions. Ranking = reward only.
    """
    assert N % M == 0, "N must be divisible by M"
    lm.to(device).eval()
    rw.to(device).eval()

    px = chat(prompt)

    # round 0
    init = gen(lm, lm_tok, px, N, device, max_new_tokens)
    beams = []
    for full in init:
        comp = cont(px, full)
        sc = reward_step_score(prompt, comp, rw, rw_tok, device)
        beams.append({"full": full, "comp": comp, "score": sc})

    # expand rounds
    for _ in range(rounds):
        parents = sorted(beams, key=lambda b: b["score"], reverse=True)[:M]
        kids = []
        K = N // M
        for p in parents:
            child_full_list = gen(lm, lm_tok, p["full"], K, device, max_new_tokens)
            for cf in child_full_list:
                comp = cont(px, cf)
                sc = reward_step_score(prompt, comp, rw, rw_tok, device)
                kids.append({"full": cf, "comp": comp, "score": sc})
        beams = sorted(kids, key=lambda b: b["score"], reverse=True)[:N]

    # return best-first
    return sorted(beams, key=lambda b: b["score"], reverse=True)


In [None]:
# --- example (swap models if you like) ---
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    lm_name = "HuggingFaceH4/zephyr-7b-beta"
    rw_name = "OpenAssistant/reward-model-deberta-v3-large-v2"

    lm_tok = AutoTokenizer.from_pretrained(lm_name, use_fast=True)
    lm = AutoModelForCausalLM.from_pretrained(
        lm_name,
        torch_dtype=torch.float16 if device == "cuda" else None,
        device_map="auto" if device == "cuda" else None,
    )

    rw_tok = AutoTokenizer.from_pretrained(rw_name, use_fast=True)
    rw = AutoModelForSequenceClassification.from_pretrained(
        rw_name,
        torch_dtype=torch.float16 if device == "cuda" else None,
        device_map="auto" if device == "cuda" else None,
    )

    prompt = "Roger has five tennis balls and buys two cans of three tennis balls each. How many tennis balls does he have now? Think step by step."
    res = beam_search_reward(prompt, lm, lm_tok, rw, rw_tok, N=4, M=2, rounds=2, device=device)
    for i, r in enumerate(res[:3], 1):
        print(f"[{i}] score={r['score']:.3f}  {r['comp'][:160].replace('\\n', ' ')}")