In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from typing import List, Optional
from tqdm import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1"

model_name = "Qwen/QwQ-32B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # enable 4-bit (QLoRA-style) weights
    bnb_4bit_quant_type="nf4",    # NF4 gives the best accuracy for most LLMs
    bnb_4bit_use_double_quant=True, # optional: second quantisation pass to save ~0.4 bits/param
    bnb_4bit_compute_dtype=torch.bfloat16  # faster matmuls on recent GPUs; fall back to float16 if needed
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # let Accelerate split layers across all visible GPUs
    quantization_config=bnb_config,
    torch_dtype="auto",     # keeps non-linear layers in their original dtype
    trust_remote_code=True  # Qwen models need their custom code
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

Loading checkpoint shards: 100%|██████████| 14/14 [01:04<00:00,  4.59s/it]


In [3]:
class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    model_name:             str = "Qwen/QwQ-32B"    # "Qwen/Qwen2.5-Math-7B", "Qwen/Qwen2.5-Math-7B-Instruct" , "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "meta-llama/Llama-3.1-8B"
    max_new_tokens:         int = 256
    num_rollouts:           int = 8      
    samples_per_question:   int = 1
    use_llm:                bool = True  # Use llm for masking
    use_contri:             bool = True  # If true, use "contributions" as step rewards else use ori_rewards
    # PRM Model config
    hidden_size:        int = 512
    num_layers:         int = 3
    dropout:            float = 0.2
    # PRMTrainer config
    batch_size:         int = 12
    learning_rate:      float = 5e-4
    num_workers:        int = 4
    weight_decay:       float = 1e-2
    lr_scheduler:       str   = "cosine"
    dataset_size:       int = 0
    warmup_steps:       int   = 22
    grad_clip:          float = 1.0
    epochs:             int = 25
    # Misc config
    use_wandb:          bool = True
    wandb_project:      str = "mc_prm"
    run_name:           str = "test_gsm8k_100_contri_mse"
    checkpoint_dir:     str = "./checkpoints/gsm8k/contri_mse"
    seed:               int = 42

In [6]:
import sympy as sp
import math
from typing import Optional
import re

def _extract_boxed_answer(text: str) -> Optional[str]:
    """Extract content from \\boxed{...} with proper brace balancing"""
    import re
    pattern = r'\\boxed\{'
    matches = list(re.finditer(pattern, text))
    
    if not matches:
        return None
    
    start_match = matches[-1]  # Use last occurrence (final answer)
    start_pos = start_match.end() - 1  # Position of opening brace
    
    brace_count = 0
    pos = start_pos
    
    while pos < len(text):
        if text[pos] == '{':
            brace_count += 1
        elif text[pos] == '}':
            brace_count -= 1
            if brace_count == 0:
                content = text[start_pos + 1:pos]
                return content.strip()
        pos += 1
    
    return None 

def _strip_markup_enhanced(ans: str) -> str:
    """Enhanced markup removal with better LaTeX handling"""
    # Remove LaTeX display math wrappers \[ … \] or $$ … $$
    ans = re.sub(r"\\\[.*?\\\]", "", ans)
    ans = re.sub(r"\$\$.*?\$\$", "", ans)
    # Remove inline LaTeX: \( ... \) and $...$
    ans = re.sub(r"\\\((.*?)\\\)", r"\1", ans)
    ans = re.sub(r"\$(.*?)\$", r"\1", ans)
    # Handle \boxed{...} with proper brace balancing
    boxed_content = _extract_boxed_answer(f"\\boxed{{{ans}}}")
    if boxed_content:
        ans = boxed_content
    # Remove common LaTeX commands but preserve fractions
    ans = re.sub(r"\\text\s*\{([^}]*)\}", r"\1", ans)  # \text{...}
    ans = re.sub(r"\\mathrm\s*\{([^}]*)\}", r"\1", ans)  # \mathrm{...}
    # Convert LaTeX fractions to evaluable form: \frac{a}{b} -> (a)/(b)
    ans = re.sub(r"\\frac\s*\{([^}]*)\}\s*\{([^}]*)\}", r"(\1)/(\2)", ans)
    # Remove remaining LaTeX commands
    ans = re.sub(r"\\[a-zA-Z]+\*?", "", ans)
    # Remove variable assignments like "k =" or "x=" at start
    ans = re.sub(r"^[a-zA-Z]\s*=\s*", "", ans)
    # Clean up extra whitespace and punctuation
    ans = ans.strip()
    if ans.startswith("$") and ans.endswith("$"):
        ans = ans[1:-1]
    
    return ans.strip()

def _sanitize_enhanced(text: str) -> str:
    """Enhanced normalization with better numeric handling"""
    text = _strip_markup_enhanced(text)
    text = text.strip()
    # Remove trailing punctuation
    text = re.sub(r"[\s\.;:,]+$", "", text)
    # Normalize spaces
    text = re.sub(r"\s+", " ", text)
    # Handle negative signs and spaces
    text = re.sub(r"\s*-\s*", "-", text)
    return text

def _to_float_enhanced(expr: str) -> Optional[float]:
    """Enhanced numeric evaluation with fraction support"""
    try:
        # Handle simple cases first
        if expr.replace(".", "").replace("-", "").isdigit():
            return float(expr)
        
        # Handle fractions: -33/2, 33/2, etc.
        if re.match(r"^-?\d+/\d+$", expr):
            parts = expr.split("/")
            return float(parts[0]) / float(parts[1])
        
        # Handle parenthetical fractions: (-33)/(2)
        paren_match = re.match(r"^\(([^)]+)\)/\(([^)]+)\)$", expr)
        if paren_match:
            num, den = paren_match.groups()
            return float(num) / float(den)
        
        # Try eval for more complex expressions
        safe_expr = expr.replace("^", "**")
        return float(eval(safe_expr))
        
    except Exception:
        return None

def _numeric_equiv_enhanced(a: str, b: str) -> bool:
    """Enhanced numeric equivalence with better fraction handling"""
    a_clean, b_clean = map(_sanitize_enhanced, (a, b))
    # Exact string match first
    if a_clean == b_clean:
        return True

    # Numeric comparison
    a_val, b_val = _to_float_enhanced(a_clean), _to_float_enhanced(b_clean)
    if a_val is not None and b_val is not None:
        return math.isclose(a_val, b_val, rel_tol=1e-6, abs_tol=1e-9)

    # SymPy fallback for symbolic expressions
    try:
        a_expr = sp.sympify(a_clean.replace("^", "**"))
        b_expr = sp.sympify(b_clean.replace("^", "**"))
        return sp.simplify(a_expr - b_expr) == 0
    except Exception:
        pass
    
    return False

def system_prompt(type):
    prompt = ""
    if type == "sample":
        prompt = """You are a math-problem expert. Your task is to complete the step-by-step solution for the problem provided. Write each reasoning step on its own line in the exact form \"Step k: [your reasoning step]\n\", numbering start from Step 1. When the final answer is obtained, write exactly one final line, \"Answer: [Final answer]\". Do NOT add explanations, extra steps, or any text after the "Answer:" line.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Follow the FORMAT GUIDE structure exactly. Generate rationales step-by-step, not directly to the final answer. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    if type == "rollout":
        prompt = """You are a math problem-solving expert. Continue solving the given problem step by step, strictly following the required format. Each new step must begin with \"Step k+1: ...\", \"Step k+2:...\", and so on, continuing from the last given step number. When the final answer is reached, write only one final line starting with: \"Answer: [Final Answer]\". Do not add any explanations, extra commentary, or additional text after the "Answer:" line. Your output must follow this exact step-by-step format with no deviations.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
Continue and finish the solution:
Step k+1: [Step k+1 reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Current solution steps:
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Continue and finish the solution:
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Current solution steps:
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Continue and finish the solution:
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Keep the reasoning steps precise and factual and complete the solution. Follow the FORMAT GUIDE structure exactly. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    return prompt


In [10]:
import math
import sympy as sp

class MCRewardShaped:
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )
    
    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize_enhanced(match.group(1))
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize_enhanced(candidate)
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize_enhanced(line.split("Answer", 1)[-1])
        return None
    
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer: "
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"[{i+1}-th Step, {idx}-th Original Rollout]", completion, "Pred Answer", pred_answer, "Gold Answer", gold_answer)
                if pred_answer is not None and _numeric_equiv_enhanced(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Using perurbed rollouts to compute step rewards
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.2, top_p=0.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()

    def perturb_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            # print("perturbed step:", ptb_prefix_steps)

            prefix_tokens = self.tokenizer.encode("\n".join(ptb_prefix_steps) + "\n", return_tensors="pt").to(self.device)
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            new_token_start = prefix_ids.shape[-1]
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"Masked [{i+1}-th Step, {idx}-th Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv_enhanced(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards
    
    # Using Mutual Information to compute step rewards
    def entropy_bits_exact(self, prompt: str, target: str) -> float:
        """True H(A|prompt) in bits/token, by ∑_t H(p_t). Memory-intensive: stores full probs tensor."""
        LOG2E = 1 / math.log(2)
        full   = prompt + target
        inputs = self.tokenizer(full, return_tensors="pt").to(self.device)
        Lp     = len(self.tokenizer(prompt)["input_ids"])

        with torch.no_grad():
            logits = self.model(**inputs).logits.float()      # [1,L,V]

        probs = logits.softmax(-1)                      # [...,V]
        token_H = -(probs * probs.log()).sum(-1) * LOG2E  # bits/token

        mask = torch.zeros_like(inputs["input_ids"], dtype=torch.bool)
        mask[:, Lp:] = True                             # answer tokens
        return token_H[mask].sum().item() / mask.sum().item()
    
    def compute_step_mi(self, question: str, steps: List[str], gold_answer: str):
        sys_prompt = """Solve the given problem with step by step reasoning in the format of "Step k: <k-th rationale>" and write final answer in the format of "Answer: <answer>".\nProblem: """
        question = re.sub(r' +', ' ', question) 
        gold_answer = "Answer: " + gold_answer
        context = sys_prompt + question + "\n\n"

        mi_incremental = []
        cumulative_prompt = context
        for i, step in enumerate(steps):
            h_before = self.entropy_bits_exact(cumulative_prompt, gold_answer)
            cumulative_prompt += step+"\n"
            h_after = self.entropy_bits_exact(cumulative_prompt, gold_answer)
            # I(S_i ; A | context, S_1,...,S_{i-1}) = H(A|prev) - H(A|prev,S_i)
            incremental_mi = h_before - h_after
            mi_incremental.append(incremental_mi)
        return mi_incremental
    
    # Build datasets based on input datas
    def gsm8k_reward_dataset(self, *, split: str = "train", start: int = 0, take: int | None):
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))
        else:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, len(ds)))

        dataset    = []
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = self._ANSWER_RE.match(ln)
                if m:
                    gold_ans = _sanitize_enhanced(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            ori = self.compute_step_rewards(q_txt, system_prompt("rollout"), steps, gold_ans)
            ptb = self.perturb_step_rewards(q_txt, system_prompt("rollout"), steps, gold_ans, use_llm)
            mi = self.compute_step_mi(q_txt, steps, gold_ans)

            naive = [round(o + max(m, 0), 4) for o, m in zip(ori, mi)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "ptb_rewards":   ptb,
                    "contributions": contrib,
                    "mi_rewards":   mi,
                    "naive_rewards": naive,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
        return dataset

    def math_reward_dataset(self, *, split: str = "train", start: int = 0, take: int | None):
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)
        if take is not None:
            ds = ds.select(range(start, start+take))
        else:
            ds = ds.select(range(start, len(ds)))

        dataset    = []
        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            full_sol   = sample["solution"]

            boxed_content = _extract_boxed_answer(full_sol)
            gold_ans = _sanitize_enhanced(boxed_content) if boxed_content else None
            if gold_ans is None:
                # Fallback: look for last mathematical expression
                lines = [line.strip() for line in full_sol.splitlines() if line.strip()]
                for line in reversed(lines):
                    if re.search(r'[\d\-+*/()=]', line):
                        gold_ans = _sanitize_enhanced(line)
                        break
            
            # Remove all \\boxed{...} for step extraction  
            sol_wo_box = re.sub(r'\\boxed\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', '', full_sol)
            raw_steps = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # Calculate rewards
            ori = self.compute_step_rewards(sample["problem"], system_prompt("rollout"), steps, gold_ans)
            ptb = self.perturb_step_rewards(sample["problem"], system_prompt("rollout"), steps, gold_ans, use_llm)
            mi = self.compute_step_mi(sample["problem"], steps, gold_ans)

            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]
            naive = [round(o + max(m, 0), 4) for o, m in zip(ori, mi)]

            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "mi_rewards":   mi,
                "naive_rewards": naive,
                "gold_answer":   gold_ans,
            }
            dataset.append(entry)
        return dataset
    
    def gsm8k_mi_reward_dataset(self, *, split: str = "train", start: int = 0, take: int | None):
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))
        else:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, len(ds)))

        dataset    = []
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = self._ANSWER_RE.match(ln)
                if m:
                    gold_ans = _sanitize_enhanced(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            ori = self.compute_step_rewards(q_txt, system_prompt("rollout"), steps, gold_ans)
            mi = self.compute_step_mi(q_txt, steps, gold_ans)

            naive = [round(o + max(m, 0), 4) for o, m in zip(ori, mi)]

            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "mi_rewards":   mi,
                    "naive_rewards": naive,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
        return dataset

    def math_mi_reward_dataset(self, *, split: str = "train", start: int = 0, take: int | None):
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)
        if take is not None:
            ds = ds.select(range(start, start+take))
        else:
            ds = ds.select(range(start, len(ds)))

        dataset    = []
        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            full_sol   = sample["solution"]

            boxed_content = _extract_boxed_answer(full_sol)
            gold_ans = _sanitize_enhanced(boxed_content) if boxed_content else None
            if gold_ans is None:
                # Fallback: look for last mathematical expression
                lines = [line.strip() for line in full_sol.splitlines() if line.strip()]
                for line in reversed(lines):
                    if re.search(r'[\d\-+*/()=]', line):
                        gold_ans = _sanitize_enhanced(line)
                        break
            
            # Remove all \\boxed{...} for step extraction  
            sol_wo_box = re.sub(r'\\boxed\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', '', full_sol)
            raw_steps = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # Calculate rewards
            ori = self.compute_step_rewards(sample["problem"], system_prompt("rollout"), steps, gold_ans)
            mi = self.compute_step_mi(sample["problem"], steps, gold_ans)

            naive = [round(o + max(m, 0), 4) for o, m in zip(ori, mi)]

            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "mi_rewards":   mi,
                "naive_rewards": naive,
                "gold_answer":   gold_ans,
            }
            dataset.append(entry)
        return dataset


In [None]:
config = PRMConfig()
mcrs = MCRewardShaped(config, model, tokenizer)
mcrs.gsm8k_mi_reward_dataset(start=200, take=2)

In [9]:
config = PRMConfig()
mcrs = MCRewardShaped(config, model, tokenizer)
mcrs.math_mi_reward_dataset(start=200, take=2)

Building MATH reward-dataset:   0%|          | 0/2 [00:00<?, ?it/s]

[1-th Step, 0-th Original Rollout] 89

Wait, but the user provided a current solution with Step 1, but in their example they continue with Step 2 etc. So in this problem, the user has already started Step 1. Let me see:

The problem is "Find the largest prime factor of 9879". The user's current steps are:

Step 1: We see that

$$9879=10000-121=100^2-11^2$$Thus,

$$9879=(100-11)(100+11)=89(111)=3*37*89$$So the answer is $$

But the answer line is incomplete, perhaps the user wants us to continue from Step 1 and finish the steps properly. Wait, the user's instruction says "Continue and finish the solution: Step k+1: ...". So in the problem given, the user has provided some steps, and we need to continue from there. Let me parse:

Original problem: Find the largest prime factor of 9879.

Current steps provided by user:

Step 1: We see that

$$9879=100 Pred Answer 9879=100 Gold Answer 89
[1-th Step, 1-th Original Rollout] 89
Step 1: We see that

$$9879=10000-121=100^2-11^2$$Thus,

$$9879=(

Building MATH reward-dataset:  50%|█████     | 1/2 [01:15<01:15, 75.33s/it]

[1-th Step, 0-th Original Rollout]  For the expression to be an integer, the inner part $\sqrt{144 - \sqrt[3]{x}}$ must be an integer between 0 and 12 inclusive
Step 3: Let the inner square root equal to an integer $k$, so we have $\sqrt{144 - \sqrt[3]{x}} = k$ where $k$ is integer from 0 to 12
Step 4: Squaring both sides gives $144 - \sqrt[3]{x} = k^2$ → $\sqrt[3]{x} = 144 - k^2$
Step 5: To find $x$, cube both sides: $x = (144 - k^2)^3$
Step 6: Since $x$ must be non-negative real number, we must ensure that $144 - k^2$ is non-negative, so $k^2 \leq 144$ → $k \leq 12$
Step 7: The possible integer values of $k$ are 0,1,2,...,12. Each $k$ gives a distinct $x$ value
Step 8: However Pred Answer Step 8: However Gold Answer 13
[1-th Step, 1-th Original Rollout]  To have the entire expression be an integer, $\sqrt{144 - \sqrt[3]{x}}$ must equal an integer $k$ where $k \geq 0$.
Step 3: Setting up the equation $\sqrt{144 - \sqrt[3]{x}} = k$ implies $144 - \sqrt[3]{x} = k^2$.
Step 4: Rearranged,

Building MATH reward-dataset: 100%|██████████| 2/2 [07:38<00:00, 229.38s/it]


[{'question': 'Find the largest prime factor of $9879$.',
  'completion': ['Step 1: We see that\n\n$$9879=10000-121=100^2-11^2$$Thus,\n\n$$9879=(100-11)(100+11)=89(111)=3*37*89$$So the answer is $$'],
  'ori_rewards': [0.875],
  'mi_rewards': [-0.19369802474975595],
  'naive_rewards': [0.875],
  'gold_answer': '89'},
 {'question': 'For how many non-negative real values of $x$ is $\\sqrt{144-\\sqrt[3]{x}}$ an integer?',
  'completion': ['Step 1: If we look at the smallest possible value for $x$, namely $x=0$, then the expression evaluates to $\\sqrt{144}=12$',
   'Step 2: If we choose $x=144^3$ so that $\\sqrt[3]{x}=144$, and then the expression evaluates to $\\sqrt{144-144}=0$',
   'Step 3: Similarly, values of $x$ may be chosen so the expression evaluates to any integer between 0 to 12',
   'Step 4: For example, if we choose $x=143^3$ so that $\\sqrt[3]{x}=143$, the expression evaluates to $\\sqrt{144-143}=1$',
   'Step 5: Thus, there are a total of $12-0+1=$ values of $x$'],
  'ori_r

In [54]:
ds = load_dataset("HuggingFaceTB/MATH", "all", split="train")
sample = ds.select(range(10, 12)) # ['problem', 'level', 'type', 'solution']
sample[1]['solution'] 

'We substitute $f(2) = 5(2)^2 - \\frac{1}{2} + 3 = \\frac{45}{2}$ and $g(2) = (2)^2 - k = 4 - k$. So $f(2) - g(2) = 2$ gives us $\\frac{45}{2} - 4 + k=2$. Solving for $k$, we find $k = \\frac{4}{2} - \\frac{45}{2} + \\frac{8}{2}$ so $\\boxed{k = \\frac{-33}{2}}$.'