# Change baseline

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = "Qwen/QwQ-32B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # enable 4-bit (QLoRA-style) weights
    bnb_4bit_quant_type="nf4",    # NF4 gives the best accuracy for most LLMs
    bnb_4bit_use_double_quant=True, # optional: second quantisation pass to save ~0.4 bits/param
    bnb_4bit_compute_dtype=torch.bfloat16  # faster matmuls on recent GPUs; fall back to float16 if needed
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # let Accelerate split layers across all visible GPUs
    quantization_config=bnb_config,
    torch_dtype="auto",     # keeps non-linear layers in their original dtype
    trust_remote_code=True  # Qwen models need their custom code
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

Fetching 14 files: 100%|██████████| 14/14 [30:17<00:00, 129.84s/it]  
Loading checkpoint shards: 100%|██████████| 14/14 [01:23<00:00,  5.93s/it]


In [3]:
prompt = "How many r's are in the word \"strawberry\""
messages = [
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=32768
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

/pytorch/aten/src/ATen/native/cuda/TensorCompare.cu:112: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


# Using vllm

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
import concurrent.futures
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    model_name:             str = "Qwen/Qwen2.5-Math-7B"
    max_new_tokens:         int = 384
    num_rollouts:           int = 8
    reward_threshold:       float = 0.5
    samples_per_question:   int = 1
    use_llm:                bool = True
    use_contri:             bool = False
    # PRM Model config
    hidden_size:        int = 512
    num_layers:         int = 3
    dropout:            float = 0.2
    # PRMTrainer config
    batch_size:         int = 12
    learning_rate:      float = 5e-4
    num_workers:        int = 4
    weight_decay:       float = 1e-2
    lr_scheduler:       str   = "cosine"
    dataset_size:       int = 0
    warmup_steps:       int   = 22
    grad_clip:          float = 1.0
    epochs:             int = 25
    # Misc config
    use_wandb:          bool = True
    wandb_project:      str = "mc_prm"
    run_name:           str = "test_gsm8k_100_ori_mse"
    checkpoint_dir:     str = "./checkpoints/gsm8k/ori_mse"
    seed:               int = 42

In [None]:
################################################################################
#                        UTILITY: ANSWER NORMALISATION                         #
################################################################################
import sympy as sp

def _strip_markup(ans: str) -> str:
    """Remove common LaTeX/markup & variable tags."""
    # # Remove LaTeX inline math wrappers \( … \) or \[ … \]
    # ans = re.sub(r"\\[\[(](.*?)[\\\])]", r"\1", ans)
    # # Remove \boxed{…}
    # ans = re.sub(r"\\boxed\{([^}]*)\}", r"\1", ans)
    ans = re.sub(r"\\\[.*?\\\]", "", ans)
    ans = re.sub(r"\$\$.*?\$\$", "", ans)
    # Remove inline LaTeX: \( ... \) and $...$
    ans = re.sub(r"\\\((.*?)\\\)", r"\1", ans)
    ans = re.sub(r"\$(.*?)\$", r"\1", ans)
    # Remove \boxed{...}
    ans = re.sub(r"\\boxed\s*{([^}]*)}", r"\1", ans)
    # Remove LaTeX commands like \text{...}, \frac{...}, etc.
    ans = re.sub(r"\\[a-zA-Z]+\s*(\{[^{}]*\})?", "", ans)
    # Remove variable assignments like "y =" or "x=" at start
    ans = re.sub(r"^[a-zA-Z]\s*=\s*", "", ans)
    # Trim outer $ … $ if present
    ans = ans.strip()
    if ans.startswith("$") and ans.endswith("$"):
        ans = ans[1:-1]
    return ans.strip()

def _sanitize(text: str) -> str:
    """Normalise a candidate answer string for comparison."""
    text = _strip_markup(text)
    text = text.strip()
    text = re.sub(r"[\s\.;:,]+$", "", text)     # trailing punctuation
    text = re.sub(r"\s+", " ", text)              # collapse spaces
    return text

def _to_float(expr: str) -> Optional[float]:
    try:
        return float(eval(expr.replace("^", "**")))
    except Exception:
        return None

def _numeric_equiv(a: str, b: str) -> bool:
    """Return True if `a` and `b` are numerically equivalent or exact match."""
    a_clean, b_clean = map(_sanitize, (a, b))
    if a_clean == b_clean:
        return True

    # Attempt simple numeric evaluation
    a_val, b_val = _to_float(a_clean), _to_float(b_clean)
    if a_val is not None and b_val is not None:
        return math.isclose(a_val, b_val, rel_tol=1e-6)

    if sp is not None:
        try:
            a_expr = sp.sympify(a_clean.replace("^", "**"))
            b_expr = sp.sympify(b_clean.replace("^", "**"))
            return sp.simplify(a_expr - b_expr) == 0
        except Exception:
            pass
    return False

def system_prompt(type):
    prompt = ""
    if type == "sample":
        prompt = """You are a math-problem expert. Your task is to complete the step-by-step solution for the problem provided. Write each reasoning step on its own line in the exact form \"Step k: [your reasoning step]\n\", numbering start from Step 1. When the final answer is obtained, write exactly one final line, \"Answer: [Final answer]\". Do NOT add explanations, extra steps, or any text after the "Answer:" line.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Follow the FORMAT GUIDE structure exactly. Generate rationales step-by-step, not directly to the final answer. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    if type == "rollout":
        prompt = """You are a math problem-solving expert. Continue solving the given problem step by step, strictly following the required format. Each new step must begin with \"Step k+1: ...\", \"Step k+2:...\", and so on, continuing from the last given step number. When the final answer is reached, write only one final line starting with: \"Answer: [Final Answer]\". Do not add any explanations, extra commentary, or additional text after the "Answer:" line. Your output must follow this exact step-by-step format with no deviations.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
Continue and finish the solution:
Step k+1: [Step k+1 reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Current solution steps:
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Continue and finish the solution:
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Current solution steps:
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Continue and finish the solution:
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Keep the reasoning steps precise and factual and complete the solution. Follow the FORMAT GUIDE structure exactly. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    return prompt


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from typing import List, Optional
from tqdm import tqdm
import random

# Project‑level helpers 
from utils import _sanitize, _numeric_equiv, _strip_markup, _to_float, system_prompt
from config import PRMConfig

class MCReward:
    STEP_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Step\s*              # word 'Step' (case-insensitive)
        (\d+)                # capture step number
        \s*[:.\-]            # separator (: . or -)
    """,
    re.IGNORECASE | re.VERBOSE,
    )
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    ## Masked rewards ##
    OP_TOKENS = ["add", "plus", "sum", "subtract", "minus",
             "multiply", "times", "product", "divide", "quotient"]
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        self.llm = LLM(
            model=hf_model.name_or_path,     # 동일 체크포인트
            dtype="float16",                 # fp16 / bfloat16
            tensor_parallel_size=torch.cuda.device_count(),  # 여러 GPU → 자동 shard
            trust_remote_code=True,
        )
        self.sparams = SamplingParams(
            max_tokens=config.max_new_tokens,
            temperature=0.8,
            top_p=0.8,
            n=config.num_rollouts,          # 한번에 num_rollouts 샘플
            # 필요 시 stop 토큰 추가
        )

    # Function to generate one or more step-by-step solutions for a given question.
    def generate_solutions(self, question: str, sys_prompt: str, num_solutions: int):
        prompt = f"{sys_prompt}\n\n{question}\n"  # Prompt the model to start the step-by-step solution
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        # Generate multiple solutions via sampling
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.config.max_new_tokens,
            do_sample=True,
            num_return_sequences=num_solutions,
            temperature=0.8,         # sampling temperature for diversity (adjust as needed)
            top_p=0.8,               # top-p sampling for diversity
            pad_token_id=self.tokenizer.eos_token_id  # pad token ID to avoid warning for some models
        )
        solutions = []
        prompt_len = input_ids.shape[-1]
        for i in range(num_solutions):
            # Each output is the concatenation of the prompt and the generated completion.
            generated_ids = outputs[i]
            # Extract only the newly generated tokens (skip the prompt tokens).
            gen_ids = generated_ids[prompt_len:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            solutions.append(text)
            # print(f"{i}-th Sampled Solutions:",text)
        return solutions
    
    def gsm8k_solutions(self, question: str, gold_solution: str):
        # 1. Split lines *before* the final answer marker (#### …)
        lines: List[str] = []
        gold_answer: str = ""
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        for raw_ln in gold_solution.splitlines():
            ln = raw_ln.strip()
            if not ln:
                continue  # skip empty
            ans_match = _ANSWER_RE.match(ln)
            if ans_match:
                gold_answer = ans_match.group(1).strip()
                break  # everything after #### is ignored
            lines.append(ln)

        if not gold_answer:
            raise ValueError("Could not find final answer marker '#### <answer>' in gold_solution.")

        # 2. Prefix each explanatory line with "Step i:"
        solution_steps = [f"Step {i + 1}: {txt}" for i, txt in enumerate(lines)]
        return {
            "question": question,
            "solution": solution_steps,
            "gold_answer": gold_answer,
        }

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """Split each step to start with 'Step X:' and the answer to start with 'Answer:'."""
        steps = []
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            if self.STEP_PATTERN.match(line):
                cleaned = re.sub(r'^[\s>#*\-]+', '', line)
                steps.append(cleaned)
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                print(f"[{i+1}-th Step, {idx}-th Original Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Masked solution paths
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.2, top_p=0.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:],
                                     skip_special_tokens=True).strip()

    def perturbed_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        """Compute MC correctness rates *after masking* the current step.
        Each step `i` is replaced with a *perturbed* version where important
        tokens (numbers, fractions, single‑letter variables) are substituted by
        the literal string ``[MASKED]``. All preceding steps remain intact.
        """
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            # print("perturbed step:", ptb_prefix_steps)

            prefix_tokens = self.tokenizer.encode("\n".join(ptb_prefix_steps) + "\n", return_tensors="pt").to(self.device)
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            new_token_start = prefix_ids.shape[-1]
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                print(f"Masked [{i+1}-th Step, {idx}-th Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards

    # Build datasets based on input datas
    def build_datasets(self, problems: List):
        dataset = []  # will hold the output list of dicts
        for problem in problems:
            question = problem["question"]
            # gold_answer = problem["gold_answer"]
            gold_answer = _sanitize(problem["gold_answer"])
            # Generate one or more solutions for this question
            sample_prompt = system_prompt("sample")
            rollout_prompt = system_prompt("rollout")
            solutions = self.generate_solutions(question, sys_prompt=sample_prompt, num_solutions=self.config.samples_per_question)
            
            for sol_text in solutions:
                steps, answer = self.parse_solution(sol_text)
                # print("Parsed solution:", steps, answer)
                if answer is None: # If no answer was found in the solution (edge case), skip this solution
                    continue
                # 2. Compute *original* & *perturbed* per‑step rewards
                # ----------------------------------------------------------
                ori_rewards = self.compute_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                ptb_rewards = self.perturbed_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                # Align lengths (robustness)
                if len(ptb_rewards) != len(ori_rewards):
                    ptb_rewards = ptb_rewards[: len(ori_rewards)]
                # contributions = [max(0, o - p) for o, p in zip(ori_rewards, ptb_rewards)]
                contributions = [o - p for o, p in zip(ori_rewards, ptb_rewards)]
                entry = {
                    "question": question,
                    "completion": steps,          # list[str] (Step i: ...)
                    "ori_rewards": ori_rewards,    # list[float]
                    "ptb_rewards": ptb_rewards,    # list[float]
                    "contributions": contributions,  # ori − ptb
                    "answer": answer,
                    "gold_answer": gold_answer,
                }
                dataset.append(entry)
        return dataset
    
    # Build datasets based on input datas
    def build_datasets_gsm8k(self, *, split: str = "train", start: int = 0, take: int | None):
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        rollout_pr = system_prompt("rollout")
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))

        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []
        
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = _ANSWER_RE.match(ln)
                if m:
                    gold_ans = sanitize(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(q_txt, rollout_pr, steps, gold_ans)
            ptb = psr(q_txt, rollout_pr, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            #  ── (3) Append entry ───────────────────────────────────────────
            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "ptb_rewards":   ptb,
                    "contributions": contrib,
                    "answer":        gold_ans,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
            # print(entry)
        return dataset

    def build_datasets_math(self, *, split: str = "train", start: int = 0, take: int | None):
        """
        ① MATH 데이터셋 로드 → ② 정답·스텝 추출 → ③ 보상 계산 → ④ dict 리스트 반환
        """
        boxed_re   = re.compile(r'\\boxed\{(.+?)\}', re.S)
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시

        rollout_prompt = system_prompt("rollout")
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)

        # shuffle & take
        if take is not None:
            ds = ds.select(range(start, start+take))

        # (alias) time optimize
        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []

        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            full_sol   = sample["solution"]
            m          = boxed_re.search(full_sol)
            gold_ans   = sanitize(m.group(1)) if m else None
            sol_wo_box = boxed_re.sub("", full_sol)
            raw_steps  = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps      = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(sample["problem"], rollout_prompt, steps, gold_ans)
            ptb = psr(sample["problem"], rollout_prompt, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            # ── (3) Append entry ───────────────────────────────────────────
            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "answer":        gold_ans,
                "gold_answer":   gold_ans,
                "level":         sample["level"],
                "type":          sample["type"],
            }
            dataset.append(entry)
            # print(entry)
        return dataset

    def build_datasets_math_vllm(self, *, split: str = "train", start: int = 0, take: int | None):
        from vllm import LLM, SamplingParams
        boxed_re   = re.compile(r'\\boxed\{(.+?)\}', re.S)
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시

        rollout_prompt = system_prompt("rollout")
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)

        # shuffle & take
        if take is not None:
            ds = ds.select(range(start, start+take))

        # (alias) time optimize
        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []

        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            full_sol   = sample["solution"]
            m          = boxed_re.search(full_sol)
            gold_ans   = sanitize(m.group(1)) if m else None
            sol_wo_box = boxed_re.sub("", full_sol)
            raw_steps  = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps      = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(sample["problem"], rollout_prompt, steps, gold_ans)
            ptb = psr(sample["problem"], rollout_prompt, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            # ── (3) Append entry ───────────────────────────────────────────
            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "answer":        gold_ans,
                "gold_answer":   gold_ans,
                "level":         sample["level"],
                "type":          sample["type"],
            }
            dataset.append(entry)
        return dataset


In [None]:
mcr = MCReward(config=cfg, model=model, tokenizer=tokenizer)
math_dataset = mcr.build_datasets_math_vllm(split="test",take=2)
math_dataset[1]