In [5]:
"""
Notebook‑ready DeepSeek‑Qwen ensemble with proper **Qwen2.5‑Math‑PRM‑7B** reward
===============================================================================

* Two candidate generators (default: 1.5 B & 7 B Distill‑Qwen).  
* Reward model now follows the **official step‑scoring recipe**: we load the RM
  with `AutoModel`, pass the conversation through the chat template, identify
  `<extra_0>` delimiters, and extract the probability of the *positive* label
  at each step token.  The overall score is the **mean probability × 10** (so
  it roughly ranges 0‑10 like a human grade).

Usage in a notebook
-------------------
```python
from ensemble_inference import run_ensemble
print(run_ensemble("Explain gradient accumulation in simple terms."))
```
You can plug in extra candidate models by passing a list of paths to
`run_ensemble(..., candidates=[...])`.

Dependencies: `transformers >= 4.40`, `accelerate`, CUDA.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn.functional as F
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)

logger = logging.getLogger("ensemble_inference")
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")

STOP_TOKENS = {".", "\n"}
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _trim_at_stop(text: str) -> str:
    for t in STOP_TOKENS:
        i = text.find(t)
        if i != -1:
            return text[: i + len(t)]
    return text


# ---------------------------------------------------------------------------
# Generator wrapper
# ---------------------------------------------------------------------------

@dataclass
class HFModel:
    name: str
    tokenizer: PreTrainedTokenizerBase
    model: PreTrainedModel

    @classmethod
    def load(cls, path: str, *, dtype: torch.dtype = torch.float16):
        logger.info("Loading generator %s …", path)
        tok = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        mod = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=dtype,
            device_map="auto",
            trust_remote_code=True,
        ).eval()
        return cls(path, tok, mod)

    @torch.inference_mode()
    def generate_segment(
        self,
        prompt: str,
        *,
        max_new_tokens: int = 128,
        temperature: float = 0.7,
        top_p: float = 0.95,
    ) -> str:
        ids = self.tokenizer(prompt, return_tensors="pt").to(DEVICE)
        cfg = GenerationConfig(
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        out = self.model.generate(**ids, generation_config=cfg)[0]
        decoded = self.tokenizer.decode(out[len(ids["input_ids"][0]) :], skip_special_tokens=True)
        return _trim_at_stop(decoded).strip()


# ---------------------------------------------------------------------------
# Reward model wrapper (official step‑based scoring)
# ---------------------------------------------------------------------------

STEP_TOKEN = "<extra_0>"
SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}."


def _make_step_rewards(logits: torch.Tensor, token_masks: torch.Tensor):
    """Return list of positive‑label probabilities for each step."""
    probs = F.softmax(logits, dim=-1) * token_masks.unsqueeze(-1)  # B, T, C
    results = []
    for sample in probs:  # iterate batch (usually 1)
        positive = sample[sample != 0].view(-1, 2)[:, 1]  # steps × 2 → take label‑1 prob
        results.append(positive.cpu().tolist())
    return results


class PRMScorer:
    """Qwen2.5‑Math‑PRM‑7B scorer following the official implementation."""

    def __init__(self, path: str, *, dtype: torch.dtype = torch.bfloat16):
        logger.info("Loading reward model %s …", path)
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(
            path,
            torch_dtype=dtype,
            device_map="auto",
            trust_remote_code=True,
        ).eval()
        # Pre‑encode step separator id for masking
        self.step_sep_id = self.tokenizer.encode(STEP_TOKEN)[0]

    @torch.inference_mode()
    def score(self, question: str, answer: str) -> float:
        # Ensure at least one step token at the end
        if not answer.strip().endswith(STEP_TOKEN):
            answer = answer.rstrip() + STEP_TOKEN

        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer},
        ]
        convo = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
        )
        input_ids = self.tokenizer(convo, return_tensors="pt").input_ids.to(DEVICE)

        logits = self.model(input_ids=input_ids).logits  # (B, T, 2)
        mask = input_ids == self.step_sep_id  # (B, T)
        step_probs = _make_step_rewards(logits, mask)[0]  # list[float]

        if not step_probs:  # fallback
            return 0.0
        return float(sum(step_probs) / len(step_probs) * 10.0)  # map to 0‑10


# ---------------------------------------------------------------------------
# Core ensemble loop
# ---------------------------------------------------------------------------

@dataclass
class EnsembleReasoner:
    candidate_models: List[HFModel]
    scorer: PRMScorer
    max_rounds: int = 5
    score_threshold: float = 0.5

    def __call__(self, question: str) -> Tuple[str, List[str]]:
        context = question.strip()
        chosen_segments: List[str] = []

        for rnd in range(1, self.max_rounds + 1):
            logger.info("⏩  Round %d", rnd)
            segments = [m.generate_segment(context) for m in self.candidate_models]
            scores = [self.scorer.score(question, seg) for seg in segments]

            for m, seg, sc in zip(self.candidate_models, segments, scores):
                logger.info("→ %s | score %.2f | %s", m.name, sc, seg.replace("\n", "\\n"))

            best_idx = int(torch.tensor(scores).argmax())
            best_score, best_seg = scores[best_idx], segments[best_idx]
            logger.info("✅  Chosen: model=%s score=%.2f", self.candidate_models[best_idx].name, best_score)

            if best_score < self.score_threshold:
                logger.info("Stopping early (score below threshold).")
                break

            chosen_segments.append(best_seg)
            context += " " + best_seg

        return " ".join(chosen_segments), chosen_segments


# ---------------------------------------------------------------------------
# Convenience functions
# ---------------------------------------------------------------------------

def load_default_models(dtype: torch.dtype = torch.float16):
    gens = [
        "/root/autodl-tmp/DeepSeek-R1-Distill-Qwen-1.5B",
        "/root/autodl-tmp/DeepSeek-R1-Distill-Qwen-7B",
    ]
    candidates = [HFModel.load(p, dtype=dtype) for p in gens]
    scorer = PRMScorer("/root/autodl-tmp/Qwen2.5-Math-PRM-7B")
    return candidates, scorer


def run_ensemble(question: str, *, max_rounds: int = 5, score_threshold: float = 0.5) -> str:
    candidates, scorer = load_default_models()
    reasoner = EnsembleReasoner(candidates, scorer, max_rounds=max_rounds, score_threshold=score_threshold)
    answer, _ = reasoner(question)
    return answer

In [6]:
q = "Explain gradient accumulation in simple terms."
print(run_ensemble(q))

[INFO] Loading generator /root/autodl-tmp/DeepSeek-R1-Distill-Qwen-1.5B …
[INFO] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
[INFO] Loading generator /root/autodl-tmp/DeepSeek-R1-Distill-Qwen-7B …
[INFO] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[INFO] Loading reward model /root/autodl-tmp/Qwen2.5-Math-PRM-7B …
[INFO] We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of the model checkpoint at /root/autodl-tmp/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[INFO] ⏩  Round 1
[INFO] → /root/autodl-tmp/DeepSeek-R1-Distill-Qwen-1.5B | score 9.49 | Need to be careful about the wording.
[INFO] → /root/autodl-tmp/DeepSeek-R1-Distill-Qwen-7B | score 9.77 | For someone who has some basic knowledge of machine learning, how would you explain gradient accumulation? How is it different from gradient descent? What are th

For someone who has some basic knowledge of machine learning, how would you explain gradient accumulation? How is it different from gradient descent? What are the scenarios where gradient accumulation is useful?

Gradient accumulation is a technique used in training machine learning models, particularly deep learning models, that helps in managing memory usage and can be useful in certain scenarios. It involves computing gradients across multiple examples or data points at once and storing these gradients, rather than computing them for each example separately. This can be particularly beneficial when dealing with large datasets, as it reduces the memory required and can speed up training. The main idea is to compute the gradient for a batch of data, and then use that gradient to update the model's parameters. This approach is different from traditional gradient descent, which computes gradients one example at a time, which can be slow and memory-intensive for large datasets.
