In [None]:
from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    PreTrainedTokenizerBase,
)

# Optional vLLM backend -----------------------------------------------------
try:
    from vllm import LLM, SamplingParams  # type: ignore
    _VLLM_AVAILABLE = True
except ImportError:  # pragma: no cover
    _VLLM_AVAILABLE = False

# ---------------------------------------------------------------------------
# Logging / constants
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
logger = logging.getLogger("ensemble_inference")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EOS_TEXT = ""  # Most Qwen / Llama models use empty string as EOS
STEP_TOKEN = "<extra_0>"  # Token separator used by reward model
SYSTEM_PROMPT = "You are a helpful assistant."
STOP_TOKENS_TEXT = {".", "\n"}  # Stop decoding after these tokens

# ---------------------------------------------------------------------------
# Conversation Template
# ---------------------------------------------------------------------------

class ConversationTemplate:
    """
    A conversation template for constructing dialogue prompts.
    It includes a system prompt, a single user question, and accumulated assistant responses.
    """
    def __init__(self, system_prompt: str, initial_question: str):
        self.system = system_prompt
        self.question = initial_question
        self.assistant_parts: List[str] = []  # Collected assistant responses

    def add_assistant(self, content: str):
        """Append a new assistant response to the prompt context."""
        self.assistant_parts.append(content.strip())

    def render(self) -> str:
        """
        Render the full prompt to be fed into a language model.
        It includes the system message, user input, and accumulated assistant responses.
        """
        lines = [
            f"[SYSTEM] {self.system} [/SYSTEM]",
            f"<user>\n{self.question.strip()}\n</user>",
            f"<assistant>\n" + "\n".join(self.assistant_parts)
        ]
        return "".join(lines)

# ---------------------------------------------------------------------------
# Utility: trim text at the last occurrence of stop tokens
# ---------------------------------------------------------------------------

def _trim_text(txt: str) -> str:
    """Truncate the text after the last known stop token for cleaner outputs."""
    best_pos = -1
    best_tok = None
    for tok in STOP_TOKENS_TEXT:
        pos = txt.rfind(tok)
        if pos > best_pos:
            best_pos = pos
            best_tok = tok
    if best_pos != -1:
        return txt[: best_pos + len(best_tok)]
    return txt

# ---------------------------------------------------------------------------
# Utility: extract token-level reward scores from logits
# ---------------------------------------------------------------------------

def _step_rewards(logits: torch.Tensor, mask: torch.Tensor):
    """
    Compute step-wise probabilities using softmax over logits.
    Only consider positions where mask is non-zero (STEP_TOKEN positions).
    """
    probs = F.softmax(logits, dim=-1) * mask.unsqueeze(-1)
    arr: List[List[float]] = []
    for sample in probs:
        pos = sample[sample != 0].view(-1, 2)[:, 1]
        arr.append(pos.cpu().tolist())
    return arr

# ---------------------------------------------------------------------------
# Output container for model generation
# ---------------------------------------------------------------------------

@dataclass
class GenOutput:
    text: str
    ended_with_eos: bool  # Whether EOS token was generated

# ---------------------------------------------------------------------------
# Abstract base class for any generator (HF or vLLM)
# ---------------------------------------------------------------------------

class BaseGenerator:
    name: str

    def generate(self, prompt: str, **kw) -> GenOutput:
        """Abstract method for generating model outputs."""
        raise NotImplementedError

# ---------------------------------------------------------------------------
# HuggingFace Transformers-based Generator
# ---------------------------------------------------------------------------

class HFGenerator(BaseGenerator):
    def __init__(self, path: str, *, device: str = "auto", dtype: torch.dtype = torch.bfloat16):
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=dtype,
            device_map=device,
            trust_remote_code=True
        ).eval()
        self.name = path
        self.device = next(self.model.parameters()).device if device == "auto" else torch.device(device)

        # Optional stop string list
        self.stop_strings = list(STOP_TOKENS_TEXT) + [
            self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
        ]

    @torch.inference_mode()
    def generate(self, prompt: str, *, max_tokens=64, temperature=0.95, top_p=0.7) -> GenOutput:
        ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        cfg = GenerationConfig(
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_tokens,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        out = self.model.generate(**ids, generation_config=cfg, tokenizer=self.tokenizer)[0]
        ended = bool(self.tokenizer.eos_token_id in out)
        txt = self.tokenizer.decode(out[len(ids["input_ids"][0]):], skip_special_tokens=False)
        return GenOutput(_trim_text(txt) if not ended else txt, ended)

# ---------------------------------------------------------------------------
# vLLM-based Generator
# ---------------------------------------------------------------------------

class VLLMGenerator(BaseGenerator):
    def __init__(self, path: str):
        if not _VLLM_AVAILABLE:
            raise RuntimeError("vLLM is not installed.")
        self._llm = LLM(model=path)
        self._sp = SamplingParams(max_tokens=128, temperature=0.95, top_p=0.7, stop=list(STOP_TOKENS_TEXT))
        self.name = path
        self._eos_text = EOS_TEXT

    @torch.inference_mode()
    def generate(self, prompt: str, *, max_tokens=30, temperature=0.95, top_p=0.7) -> GenOutput:
        self._sp.max_tokens, self._sp.temperature, self._sp.top_p = max_tokens, temperature, top_p
        txt = self._llm.generate([prompt], self._sp)[0].outputs[0].text
        ended = txt.endswith(self._eos_text)
        return GenOutput(_trim_text(txt), ended)

# ---------------------------------------------------------------------------
# ModelPool: caches all loaded generators and reward models
# ---------------------------------------------------------------------------

class ModelPool:
    _gen_cache: Dict[Tuple[str, str], BaseGenerator] = {}
    _reward_cache: Dict[str, str] = {}

    @classmethod
    def get_generator(cls, path: str, engine: str = "hf", device: Optional[str] = None) -> BaseGenerator:
        """
        Load a generator model (e.g., HF or vLLM) to a specified device (e.g., 'cuda:0', 'cpu').
        """
        key = (engine, path)
        if key not in cls._gen_cache:
            logger.info("[Pool] loading %s (%s)", path, engine)

            resolved_device = device or "auto"
            logger.info(f"→ Assigned to device: {resolved_device}")

            if engine == "hf":
                cls._gen_cache[key] = HFGenerator(path, device=resolved_device)
            elif engine == "vllm":
                cls._gen_cache[key] = VLLMGenerator(path)  # vLLM usually uses global config
            else:
                raise ValueError(f"Unknown engine: {engine}")
        return cls._gen_cache[key]

    @classmethod
    def get_reward(cls, path: str, device: Optional[str] = None) -> "PRMScorer":
        """
        Load a reward model to a specified device (e.g., 'cuda:0', 'cpu').
        """
        if path not in cls._reward_cache:
            logger.info("[Pool] loading reward model %s", path)
            resolved_device = device or "auto"
            logger.info(f"→ Reward model assigned to device: {resolved_device}")
            cls._reward_cache[path] = PRMScorer(path, device=resolved_device)
        return cls._reward_cache[path]


# ---------------------------------------------------------------------------
# PRMScorer: reward model used for evaluating step-level outputs
# ---------------------------------------------------------------------------

class PRMScorer:
    def __init__(self, path: str, device: str = "auto"):
        self.tok = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.mod = AutoModel.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            device_map=device,
            trust_remote_code=True
        ).eval()
        self.sep_id = self.tok.encode(STEP_TOKEN)[0]


    @torch.inference_mode()
    def score(self, question: str, answer: str) -> float:
        """Compute reward score from model output at STEP_TOKEN positions."""
        if not answer.endswith(STEP_TOKEN):
            answer += STEP_TOKEN
        msgs = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": question},
            {"role": "assistant", "content": answer},
        ]
        convo = self.tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
        ids = self.tok(convo, return_tensors="pt").input_ids
        mask = ids == self.sep_id
        probs = _step_rewards(self.mod(ids).logits, mask)[0]
        return float(sum(probs) / len(probs) * 10.0) if probs else 0.0

    @torch.inference_mode()
    def score_batch_augmented(self, prompt: str, completions: List[str]) -> List[float]:
        """
        Efficiently score a batch of completions using the format:
        [prompt + STEP_TOKEN + completion + STEP_TOKEN]
        """
        inputs = [prompt + STEP_TOKEN + c + STEP_TOKEN for c in completions]
        enc = self.tok(inputs, return_tensors="pt", padding=True, truncation=True).to(self.mod.device)
        mask = enc["input_ids"] == self.sep_id
        logits = self.mod(**enc).logits
        probs = _step_rewards(logits, mask)
        return [float(sum(p) / len(p) * 10.0) if p else 0.0 for p in probs]

    
    
# ---------------------------------------------------------------------------
# EnsembleReasoner: multi-model decoding loop with step-wise reward scoring
# ---------------------------------------------------------------------------

class EnsembleReasoner:
    def __init__(self, generators: List[BaseGenerator], scorer: PRMScorer, max_rounds: int = 500,
                 score_threshold: float = 0.5, accumulate_context: bool = True):
        self.generators = generators
        self.scorer = scorer
        self.max_rounds = max_rounds
        self.score_threshold = score_threshold
        self.accumulate_context = accumulate_context

    def __call__(self, question: str) -> str:
        """
        Iteratively decode using multiple generators.
        In each round, the best candidate (with highest reward) is selected and appended.
        Generation stops early if reward is low or EOS is emitted.
        """
        convo = ConversationTemplate(SYSTEM_PROMPT, question)

        for rnd in range(1, self.max_rounds + 1):
            prompt = convo.render()

            # Filter out generators that exceed input length
            available_gens: List[BaseGenerator] = []
            for g in self.generators:
                tok = getattr(g, "tokenizer", None)
                if tok is not None:
                    length = tok(prompt, return_tensors="pt").input_ids.size(1)
                    if length > tok.model_max_length:
                        logger.info("Skip %s: prompt length %d > max %d",
                                    g.name, length, tok.model_max_length)
                        continue
                available_gens.append(g)

            if not available_gens:
                logger.error("No generators available for current prompt length; stopping early.")
                break

            # outs = [g.generate(prompt) for g in available_gens]
            from concurrent.futures import ThreadPoolExecutor

            with ThreadPoolExecutor(max_workers=len(available_gens)) as executor:
                outs = list(executor.map(lambda g: g.generate(prompt), available_gens))

            segs = [o.text for o in outs]

            # Score each candidate using prompt + STEP_TOKEN + candidate + STEP_TOKEN
            # scores = []
            # for o in outs:
            #     augmented = prompt + STEP_TOKEN + o.text + STEP_TOKEN
            #     scores.append(self.scorer.score(question, augmented))
            completions = [o.text for o in outs]
            scores = self.scorer.score_batch_augmented(prompt, completions)
            
            
            for g, t, s in zip(available_gens, segs, scores):
                logger.info(f"→ {g.name} | {s:.2f} | {t.replace(chr(10), '\\n')}")

            best_idx = int(torch.tensor(scores).argmax())
            best_out = outs[best_idx]
            best_score = scores[best_idx]

            if best_score < self.score_threshold:
                logger.info("Stop: best score %.2f < threshold", best_score)
                continue

            convo.add_assistant(best_out.text)

            if best_out.ended_with_eos:
                logger.info("Early stop: EOS token emitted")
                break

        # Return the final composed assistant response
        return "\n".join(convo.assistant_parts)

### 第一版 直接调用

In [None]:
model_specs = [
    {"path": "Qwen/Qwen2.5-Math-1.5B-Instruct", "engine": "hf", "device": "cuda:0"},
    {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "engine": "hf", "device": "cuda:1"},
    # {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "engine": "hf", "device": "cuda:2"},
    # {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "engine": "hf", "device": "cuda:3"},
]

reward_spec = {"path": "Qwen/Qwen2.5-Math-PRM-7B", "device": "cuda:4"}

model_pool = ModelPool()
gens = [
    model_pool.get_generator(spec["path"], spec.get("engine", "hf"), spec.get("device"))
    for spec in model_specs
]
scorer = model_pool.get_reward(reward_spec["path"], device=reward_spec["device"])

In [None]:
text = r"The expression $2\cdot 3 \cdot 4\cdot 5+1$ is equal to 121, since multiplication is carried out before addition. However, we can obtain values other than 121 for this expression if we are allowed to change it by inserting parentheses. For example, we can obtain 144 by writing \[ (2\cdot (3\cdot 4)) \cdot (5+1) = 144. \]In total, how many values can be obtained from the expression $2\cdot 3\cdot 4 \cdot 5 + 1$ by inserting parentheses? (Note that rearranging terms is not allowed, only inserting parentheses)."

reasoner = EnsembleReasoner(gens, scorer, max_rounds=100, score_threshold=2.0, accumulate_context=True)

response = reasoner(text)
print(response)

### 第二版 封装多模型调用

In [None]:
import math
from typing import List, Dict

def run_selective_ensemble(
    question: str,
    model_specs: List[Dict] = None,
    reward_spec: Dict = None,
    max_rounds: int = 500,
    score_threshold: float = 0.5
) -> str:
    """
    Automatically selects the top-2 generators based on combined score
    (low perplexity + high confidence), and performs ensemble reasoning.

    Args:
        question (str): The input question to answer.
        model_specs (List[dict]): Each model must have 'path', 'engine', 'device'.
        reward_spec (dict): Dict with keys 'path' and 'device'.
        max_rounds (int): Maximum reasoning rounds for EnsembleReasoner.
        score_threshold (float): Minimum score threshold to continue reasoning.

    Returns:
        str: Final answer generated by EnsembleReasoner using top-2 selected models.
    """
    if model_specs is None:
        model_specs = [
            {"path": "Qwen/Qwen2.5-Math-1.5B-Instruct", "engine": "hf", "device": "cuda:0"},
            {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "engine": "hf", "device": "cuda:1"},
            {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "engine": "hf", "device": "cuda:2"},
            {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "engine": "hf", "device": "cuda:3"},
        ]

    if reward_spec is None:
        reward_spec = {"path": "Qwen/Qwen2.5-Math-PRM-7B", "device": "cuda:4"}

    model_pool = ModelPool()
    generators = []
    scores = []

    # Load generators from model specs
    for spec in model_specs:
        try:
            gen = model_pool.get_generator(
                path=spec["path"],
                engine=spec.get("engine", "hf"),
                device=spec.get("device")
            )
            generators.append(gen)
        except Exception as e:
            logger.warning(f"Failed to load model {spec['path']}: {e}")

    # Load reward model
    try:
        scorer = model_pool.get_reward(
            path=reward_spec["path"],
            device=reward_spec.get("device")
        )
    except Exception as e:
        logger.error(f"Failed to load reward model: {e}")
        return ""

    # Score each generator using PPL + confidence
    for gen in generators:
        try:
            convo = ConversationTemplate(SYSTEM_PROMPT, question)
            prompt = convo.render()
            tokenizer = getattr(gen, "tokenizer", None)
            model = getattr(gen, "model", None)

            if tokenizer is None or model is None:
                logger.warning(f"Skipping {gen.name}: missing tokenizer or model")
                continue

            with torch.inference_mode():
                inputs = tokenizer(prompt, return_tensors="pt").to(gen.device)
                outputs = model(**inputs, output_attentions=False, output_hidden_states=False)
                logits = outputs.logits[:, :-1, :]
                labels = inputs["input_ids"][:, 1:]

                log_probs = F.log_softmax(logits, dim=-1)
                token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)

                mask = labels != tokenizer.pad_token_id
                token_log_probs = token_log_probs[mask]

                avg_nll = -token_log_probs.mean().item()
                perplexity = math.exp(avg_nll)

                probs = F.softmax(logits, dim=-1)
                max_probs = probs.max(dim=-1).values.squeeze(0)
                mask_flat = mask.squeeze(0)
                confidence = max_probs[mask_flat].mean().item()

                combined_score = -perplexity + confidence
                scores.append((combined_score, gen))

                logger.info(f"{gen.name} | PPL: {perplexity:.2f} | Confidence: {confidence:.2f} | Score: {combined_score:.2f}")

        except Exception as e:
            logger.warning(f"Error scoring model {gen.name}: {e}")
            continue

    if len(scores) < 2:
        logger.error("Not enough valid models to run ensemble reasoning.")
        return ""

    # Select top-2 generators
    top_gens = sorted(scores, key=lambda x: x[0], reverse=True)[:2]
    selected_generators = [item[1] for item in top_gens]
    logger.info(f"Selected models: {[g.name for g in selected_generators]}")

    # Run ensemble reasoning
    reasoner = EnsembleReasoner(
        generators=selected_generators,
        scorer=scorer,
        max_rounds=max_rounds,
        score_threshold=score_threshold
    )
    return reasoner(question)


In [None]:
text = r"The expression $2\cdot 3 \cdot 4\cdot 5+1$ is equal to 121, since multiplication is carried out before addition. However, we can obtain values other than 121 for this expression if we are allowed to change it by inserting parentheses. For example, we can obtain 144 by writing \[ (2\cdot (3\cdot 4)) \cdot (5+1) = 144. \]In total, how many values can be obtained from the expression $2\cdot 3\cdot 4 \cdot 5 + 1$ by inserting parentheses? (Note that rearranging terms is not allowed, only inserting parentheses)."

answer = run_selective_ensemble(text)
print("Final Answer:\n", answer)


### 第三版 完整ensemble调用

In [None]:
import math
from typing import List, Dict, Callable
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from loguru import logger

# Assumes SYSTEM_PROMPT and ConversationTemplate already defined

class ModelStatStore:
    def __init__(self):
        self._stats: Dict[str, Dict[str, float]] = {}

    def has(self, model_path: str) -> bool:
        return model_path in self._stats

    def get(self, model_path: str) -> Dict[str, float]:
        return self._stats[model_path]

    def set(self, model_path: str, stats: Dict[str, float]):
        self._stats[model_path] = stats

    def maybe_compute(self, model_path: str, model, tokenizer, device, dataset: List[str]):
        if not self.has(model_path):
            stats = compute_model_stats_on_dataset(model, tokenizer, device, dataset)
            self.set(model_path, stats)
        return self.get(model_path)

def compute_model_stats_on_dataset(model, tokenizer, device, dataset: List[str]) -> Dict[str, float]:
    all_ppls, all_confs = [], []
    for problem in dataset:
        inputs = tokenizer(problem, return_tensors="pt").to(device)
        with torch.inference_mode():
            outputs = model(**inputs)
            logits = outputs.logits[:, :-1, :]
            labels = inputs["input_ids"][:, 1:]
            log_probs = F.log_softmax(logits, dim=-1)
            token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
            mask = labels != tokenizer.pad_token_id
            token_log_probs = token_log_probs[mask]
            avg_nll = -token_log_probs.mean().item()
            perplexity = math.exp(avg_nll)

            probs = F.softmax(logits, dim=-1)
            max_probs = probs.max(dim=-1).values.squeeze(0)
            mask_flat = mask.squeeze(0)
            confidence = max_probs[mask_flat].mean().item()

            all_ppls.append(perplexity)
            all_confs.append(confidence)

    return {
        "ppl_mean": float(torch.tensor(all_ppls).mean()),
        "ppl_std": float(torch.tensor(all_ppls).std()),
        "conf_mean": float(torch.tensor(all_confs).mean()),
        "conf_std": float(torch.tensor(all_confs).std()),
    }

def score_question_for_model(question: str, model, tokenizer, device: str, prompt_builder: Callable) -> Dict[str, float]:
    prompt = prompt_builder(question)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.inference_mode():
        outputs = model(**inputs)
        logits = outputs.logits[:, :-1, :]
        labels = inputs["input_ids"][:, 1:]
        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
        mask = labels != tokenizer.pad_token_id
        token_log_probs = token_log_probs[mask]
        avg_nll = -token_log_probs.mean().item()
        ppl = math.exp(avg_nll)

        probs = F.softmax(logits, dim=-1)
        max_probs = probs.max(dim=-1).values.squeeze(0)
        mask_flat = mask.squeeze(0)
        conf = max_probs[mask_flat].mean().item()

    return {"ppl": ppl, "conf": conf}

def determine_model_count(question_scores: List[Dict[str, float]], model_stats: Dict[str, Dict[str, float]]) -> int:
    over_threshold = 0
    for score, (model_path, stats) in zip(question_scores, model_stats.items()):
        if score["ppl"] > stats["ppl_mean"] + 2:
            over_threshold += 1
    if over_threshold >= len(question_scores) * 0.90:
        return 3
    else:
        return 2

def select_top_models_by_z_score(question: str, model_specs: List[Dict], prompt_builder, model_stats: Dict[str, Dict[str, float]], model_pool, model_count: int = -1) -> List[Dict]:
    results = []
    question_scores = []
    for spec in model_specs:
        model = model_pool.get_generator(spec["path"], spec.get("engine", "hf"), spec.get("device")).model
        tokenizer = model_pool.get_generator(spec["path"], spec.get("engine", "hf"), spec.get("device")).tokenizer
        score = score_question_for_model(question, model, tokenizer, spec["device"], prompt_builder)
        stats = model_stats[spec["path"]]
        z_ppl = (stats["ppl_mean"] - score["ppl"]) / stats["ppl_std"]
        z_conf = (score["conf"] - stats["conf_mean"]) / stats["conf_std"]
        total_score = z_ppl + z_conf
        results.append((total_score, spec))
        question_scores.append(score)

    if model_count == -1:
        model_count = determine_model_count(question_scores, model_stats)

    results = sorted(results, key=lambda x: x[0], reverse=True)
    return [spec for _, spec in results[:model_count]]

def run_zscore_ensemble(
    question: str,
    dataset_problems: List[str],
    model_specs: List[Dict],
    reward_spec: Dict,
    stat_store: ModelStatStore,
    max_rounds: int = 500,
    score_threshold: float = 0.5
) -> str:

    logger.info("[Stage 1] Computing or retrieving reference statistics for all models...")
    model_pool = ModelPool()
    model_stats = {}
    for spec in model_specs:
        model_path = spec["path"]
        generator = model_pool.get_generator(spec["path"], spec.get("engine", "hf"), spec.get("device"))
        stats = stat_store.maybe_compute(model_path, generator.model, generator.tokenizer, generator.device, dataset_problems)
        model_stats[model_path] = stats
        logger.info(
            f"→ Stats for {model_path}: "
            f"PPL µ={stats['ppl_mean']:.2f}, σ={stats['ppl_std']:.2f} | "
            f"Conf µ={stats['conf_mean']:.2f}, σ={stats['conf_std']:.2f}"
        )


    logger.info("[Stage 2] Selecting top models based on z-score (auto model count)...")
    prompt_builder = lambda q: ConversationTemplate(SYSTEM_PROMPT, q).render()
    selected_specs = select_top_models_by_z_score(
        question=question,
        model_specs=model_specs,
        prompt_builder=prompt_builder,
        model_stats=model_stats,
        model_pool=model_pool,
        model_count=-1
    )
    logger.info(f"✅ Selected models: {[s['path'] for s in selected_specs]}")

    logger.info("[Stage 3] Loading selected generators and reward model...")
    generators = [
        model_pool.get_generator(spec["path"], spec.get("engine", "hf"), spec.get("device"))
        for spec in selected_specs
    ]
    scorer = model_pool.get_reward(reward_spec["path"], device=reward_spec["device"])

    logger.info("[Stage 4] Running ensemble reasoner...")
    reasoner = EnsembleReasoner(
        generators=generators,
        scorer=scorer,
        max_rounds=max_rounds,
        score_threshold=score_threshold
    )
    return reasoner(question)

In [None]:
from datasets import load_dataset

# 初始化模型统计缓存器
stat_store = ModelStatStore()

# 加载 MATH-500 数据集问题列表（只提取 math 问题本身）
math_dataset = load_dataset("/mnt/data/zichuanfu/.cache/huggingface/hub/datasets--HuggingFaceH4--MATH-500/snapshots/ff5b20257d8185524591543f8ff5993951537bb8", split="test")
math_problems = [x["problem"] for x in math_dataset]


In [None]:
# model_specs = [
#     {"path": "Qwen/Qwen2.5-Math-1.5B-Instruct", "engine": "hf", "device": "cuda:0"},
#     {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "engine": "hf", "device": "cuda:1"},
#     {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "engine": "hf", "device": "cuda:2"},
#     {"path": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "engine": "hf", "device": "cuda:3"},
# ]

model_specs = [
    {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-1.5B-Instruct/snapshots/aafeb0fc6f22cbf0eaeed126eff8be45b0360a35", "engine": "hf", "device": "cuda:0"},
    {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-1.5B/snapshots/ad9f0ae0864d7fbcd1cd905e3c6c5b069cc8b562", "engine": "hf", "device": "cuda:1"},
    {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-7B-Instruct/snapshots/ef9926d75ab1d54532f6a30dd5e760355eb9aa4d", "engine": "hf", "device": "cuda:2"},
    {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-7B/snapshots/916b56a44061fd5cd7d6a8fb632557ed4f724f60", "engine": "hf", "device": "cuda:3"},
    {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--deepseek-ai--DeepSeek-R1-Distill-Qwen-14B/snapshots/1df8507178afcc1bef68cd8c393f61a886323761", "engine": "hf", "device": "cuda:4"},
]

reward_spec = {"path": "/mnt/data/zichuanfu/.cache/huggingface/hub/models--Qwen--Qwen2.5-Math-PRM-7B/snapshots/0610740060112df12585d00a1c5f4624d2f59051", "device": "cuda:5"}


In [None]:
final_answer = run_zscore_ensemble(
    question="If x^2 = 49, what is the positive value of x?",
    dataset_problems=math_problems,
    model_specs=model_specs,
    reward_spec=reward_spec,
    stat_store=stat_store
)

print("Final Answer:\n", final_answer)


In [None]:
text = r"The expression $2\cdot 3 \cdot 4\cdot 5+1$ is equal to 121, since multiplication is carried out before addition. However, we can obtain values other than 121 for this expression if we are allowed to change it by inserting parentheses. For example, we can obtain 144 by writing \[ (2\cdot (3\cdot 4)) \cdot (5+1) = 144. \]In total, how many values can be obtained from the expression $2\cdot 3\cdot 4 \cdot 5 + 1$ by inserting parentheses? (Note that rearranging terms is not allowed, only inserting parentheses)."

final_answer = run_zscore_ensemble(
    question=text,
    dataset_problems=math_problems,
    model_specs=model_specs,
    reward_spec=reward_spec,
    stat_store=stat_store
)
print("Final Answer:\n", final_answer)


In [None]:
import json
from pathlib import Path
from tqdm import tqdm

# 假设这些模块你都已经加载好了
# from your_module import run_zscore_ensemble, ModelStatStore, model_specs, reward_spec

def load_dataset(input_path: str) -> list:
    """加载 JSON 格式的数据集"""
    with open(input_path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_predictions(predictions: list, output_path: str):
    """将推理结果保存为 JSONL 格式"""
    with open(output_path, "w", encoding="utf-8") as f:
        for item in predictions:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")

def run_batch_inference(
    input_path: str,
    output_path: str,
    model_specs: list,
    reward_spec: dict,
    math_problem_stats: list,
    max_examples: int = None
):
    dataset = load_dataset(input_path)
    stat_store = ModelStatStore()

    predictions = []
    for example in tqdm(dataset[:max_examples] if max_examples else dataset):
        instruction = example["instruction"].strip()
        question = example["input"].strip()
        answer = example["output"].strip()

        # 构建推理输入
        prompt = f"\n{instruction}\n{question}\nassistant\n"

        # 调用推理函数
        try:
            result = run_zscore_ensemble(
                question=question,
                dataset_problems=math_problem_stats,
                model_specs=model_specs,
                reward_spec=reward_spec,
                stat_store=stat_store
            )
        except Exception as e:
            print(f"⚠️ Error on question: {question[:80]}... -> {e}")
            result = ""

        predictions.append({
            "prompt": prompt,
            "predict": result.strip(),
            "label": answer.strip()
        })

    save_predictions(predictions, output_path)
    print(f"✅ Saved {len(predictions)} predictions to {output_path}")


# 示例调用：
run_batch_inference(
    input_path="/mnt/data/zichuanfu/LLaMA-Factory/data/hendrycks_math/train.json",
    output_path="deepseek-r1-1.5b-generated-predictions.jsonl",
    model_specs=model_specs,
    reward_spec=reward_spec,
    math_problem_stats=math_problems,  # 这个你已经提前加载过
    max_examples=100  # 可选限制数量
)
