In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
import concurrent.futures
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Config

In [7]:
class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    model_name:             str = "Qwen/Qwen2.5-Math-7B"
    max_new_tokens:         int = 384
    num_rollouts:           int = 8
    reward_threshold:       float = 0.5
    samples_per_question:   int = 1
    use_llm:                bool = True
    use_contri:             bool = False
    # PRM Model config
    hidden_size:        int = 512
    num_layers:         int = 3
    dropout:            float = 0.2
    # PRMTrainer config
    batch_size:         int = 12
    learning_rate:      float = 5e-4
    num_workers:        int = 4
    weight_decay:       float = 1e-2
    lr_scheduler:       str   = "cosine"
    dataset_size:       int = 0
    warmup_steps:       int   = 22
    grad_clip:          float = 1.0
    epochs:             int = 25
    # Misc config
    use_wandb:          bool = True
    wandb_project:      str = "mc_prm"
    run_name:           str = "test_gsm8k_100_ori_mse"
    checkpoint_dir:     str = "./checkpoints/gsm8k/ori_mse"
    seed:               int = 42

# MC Rewards

## Utility

In [8]:
################################################################################
#                        UTILITY: ANSWER NORMALISATION                         #
################################################################################
import sympy as sp

def _strip_markup(ans: str) -> str:
    """Remove common LaTeX/markup & variable tags."""
    # # Remove LaTeX inline math wrappers \( … \) or \[ … \]
    # ans = re.sub(r"\\[\[(](.*?)[\\\])]", r"\1", ans)
    # # Remove \boxed{…}
    # ans = re.sub(r"\\boxed\{([^}]*)\}", r"\1", ans)
    ans = re.sub(r"\\\[.*?\\\]", "", ans)
    ans = re.sub(r"\$\$.*?\$\$", "", ans)
    # Remove inline LaTeX: \( ... \) and $...$
    ans = re.sub(r"\\\((.*?)\\\)", r"\1", ans)
    ans = re.sub(r"\$(.*?)\$", r"\1", ans)
    # Remove \boxed{...}
    ans = re.sub(r"\\boxed\s*{([^}]*)}", r"\1", ans)
    # Remove LaTeX commands like \text{...}, \frac{...}, etc.
    ans = re.sub(r"\\[a-zA-Z]+\s*(\{[^{}]*\})?", "", ans)
    # Remove variable assignments like "y =" or "x=" at start
    ans = re.sub(r"^[a-zA-Z]\s*=\s*", "", ans)
    # Trim outer $ … $ if present
    ans = ans.strip()
    if ans.startswith("$") and ans.endswith("$"):
        ans = ans[1:-1]
    return ans.strip()

def _sanitize(text: str) -> str:
    """Normalise a candidate answer string for comparison."""
    text = _strip_markup(text)
    text = text.strip()
    text = re.sub(r"[\s\.;:,]+$", "", text)     # trailing punctuation
    text = re.sub(r"\s+", " ", text)              # collapse spaces
    return text

def _to_float(expr: str) -> Optional[float]:
    try:
        return float(eval(expr.replace("^", "**")))
    except Exception:
        return None

def _numeric_equiv(a: str, b: str) -> bool:
    """Return True if `a` and `b` are numerically equivalent or exact match."""
    a_clean, b_clean = map(_sanitize, (a, b))
    if a_clean == b_clean:
        return True

    # Attempt simple numeric evaluation
    a_val, b_val = _to_float(a_clean), _to_float(b_clean)
    if a_val is not None and b_val is not None:
        return math.isclose(a_val, b_val, rel_tol=1e-6)

    if sp is not None:
        try:
            a_expr = sp.sympify(a_clean.replace("^", "**"))
            b_expr = sp.sympify(b_clean.replace("^", "**"))
            return sp.simplify(a_expr - b_expr) == 0
        except Exception:
            pass
    return False

def system_prompt(type):
    prompt = ""
    if type == "sample":
        prompt = """You are a math-problem expert. Your task is to complete the step-by-step solution for the problem provided. Write each reasoning step on its own line in the exact form \"Step k: [your reasoning step]\n\", numbering start from Step 1. When the final answer is obtained, write exactly one final line, \"Answer: [Final answer]\". Do NOT add explanations, extra steps, or any text after the "Answer:" line.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Follow the FORMAT GUIDE structure exactly. Generate rationales step-by-step, not directly to the final answer. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    if type == "rollout":
        prompt = """You are a math problem-solving expert. Continue solving the given problem step by step, strictly following the required format. Each new step must begin with \"Step k+1: ...\", \"Step k+2:...\", and so on, continuing from the last given step number. When the final answer is reached, write only one final line starting with: \"Answer: [Final Answer]\". Do not add any explanations, extra commentary, or additional text after the "Answer:" line. Your output must follow this exact step-by-step format with no deviations.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
Continue and finish the solution:
Step k+1: [Step k+1 reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Current solution steps:
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Continue and finish the solution:
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Current solution steps:
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Continue and finish the solution:
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Keep the reasoning steps precise and factual and complete the solution. Follow the FORMAT GUIDE structure exactly. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    return prompt


## Original Rewards

In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from typing import List, Optional
from tqdm import tqdm
import random

# Project‑level helpers 
from utils import _sanitize, _numeric_equiv, _strip_markup, _to_float, system_prompt
from config import PRMConfig

class MCReward:
    STEP_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Step\s*              # word 'Step' (case-insensitive)
        (\d+)                # capture step number
        \s*[:.\-]            # separator (: . or -)
    """,
    re.IGNORECASE | re.VERBOSE,
    )
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    ## Masked rewards ##
    OP_TOKENS = ["add", "plus", "sum", "subtract", "minus",
             "multiply", "times", "product", "divide", "quotient"]
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device

    # Function to generate one or more step-by-step solutions for a given question.
    def generate_solutions(self, question: str, sys_prompt: str, num_solutions: int):
        prompt = f"{sys_prompt}\n\n{question}\n"  # Prompt the model to start the step-by-step solution
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        # Generate multiple solutions via sampling
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.config.max_new_tokens,
            do_sample=True,
            num_return_sequences=num_solutions,
            temperature=0.8,         # sampling temperature for diversity (adjust as needed)
            top_p=0.8,               # top-p sampling for diversity
            pad_token_id=self.tokenizer.eos_token_id  # pad token ID to avoid warning for some models
        )
        solutions = []
        prompt_len = input_ids.shape[-1]
        for i in range(num_solutions):
            # Each output is the concatenation of the prompt and the generated completion.
            generated_ids = outputs[i]
            # Extract only the newly generated tokens (skip the prompt tokens).
            gen_ids = generated_ids[prompt_len:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            solutions.append(text)
            # print(f"{i}-th Sampled Solutions:",text)
        return solutions
    
    def gsm8k_solutions(self, question: str, gold_solution: str):
        # 1. Split lines *before* the final answer marker (#### …)
        lines: List[str] = []
        gold_answer: str = ""
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        for raw_ln in gold_solution.splitlines():
            ln = raw_ln.strip()
            if not ln:
                continue  # skip empty
            ans_match = _ANSWER_RE.match(ln)
            if ans_match:
                gold_answer = ans_match.group(1).strip()
                break  # everything after #### is ignored
            lines.append(ln)

        if not gold_answer:
            raise ValueError("Could not find final answer marker '#### <answer>' in gold_solution.")

        # 2. Prefix each explanatory line with "Step i:"
        solution_steps = [f"Step {i + 1}: {txt}" for i, txt in enumerate(lines)]
        return {
            "question": question,
            "solution": solution_steps,
            "gold_answer": gold_answer,
        }

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """Split each step to start with 'Step X:' and the answer to start with 'Answer:'."""
        steps = []
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            if self.STEP_PATTERN.match(line):
                cleaned = re.sub(r'^[\s>#*\-]+', '', line)
                steps.append(cleaned)
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                print(f"[{i+1}-th Step, {idx}-th Original Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Masked solution paths
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.2, top_p=0.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:],
                                     skip_special_tokens=True).strip()

    def perturbed_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        """Compute MC correctness rates *after masking* the current step.
        Each step `i` is replaced with a *perturbed* version where important
        tokens (numbers, fractions, single‑letter variables) are substituted by
        the literal string ``[MASKED]``. All preceding steps remain intact.
        """
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            # print("perturbed step:", ptb_prefix_steps)

            prefix_tokens = self.tokenizer.encode("\n".join(ptb_prefix_steps) + "\n", return_tensors="pt").to(self.device)
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            new_token_start = prefix_ids.shape[-1]
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                print(f"Masked [{i+1}-th Step, {idx}-th Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards

    # Build datasets based on input datas
    def build_datasets(self, problems: List):
        dataset = []  # will hold the output list of dicts
        for problem in problems:
            question = problem["question"]
            # gold_answer = problem["gold_answer"]
            gold_answer = _sanitize(problem["gold_answer"])
            # Generate one or more solutions for this question
            sample_prompt = system_prompt("sample")
            rollout_prompt = system_prompt("rollout")
            solutions = self.generate_solutions(question, sys_prompt=sample_prompt, num_solutions=self.config.samples_per_question)
            
            for sol_text in solutions:
                steps, answer = self.parse_solution(sol_text)
                # print("Parsed solution:", steps, answer)
                if answer is None: # If no answer was found in the solution (edge case), skip this solution
                    continue
                # 2. Compute *original* & *perturbed* per‑step rewards
                # ----------------------------------------------------------
                ori_rewards = self.compute_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                ptb_rewards = self.perturbed_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                # Align lengths (robustness)
                if len(ptb_rewards) != len(ori_rewards):
                    ptb_rewards = ptb_rewards[: len(ori_rewards)]
                # contributions = [max(0, o - p) for o, p in zip(ori_rewards, ptb_rewards)]
                contributions = [o - p for o, p in zip(ori_rewards, ptb_rewards)]
                entry = {
                    "question": question,
                    "completion": steps,          # list[str] (Step i: ...)
                    "ori_rewards": ori_rewards,    # list[float]
                    "ptb_rewards": ptb_rewards,    # list[float]
                    "contributions": contributions,  # ori − ptb
                    "answer": answer,
                    "gold_answer": gold_answer,
                }
                dataset.append(entry)
        return dataset
    
    # Build datasets based on input datas
    def build_datasets_gsm8k(self, *, split: str = "train", start: int = 0, take: int | None):
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        rollout_pr = system_prompt("rollout")
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))

        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []
        
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = _ANSWER_RE.match(ln)
                if m:
                    gold_ans = sanitize(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(q_txt, rollout_pr, steps, gold_ans)
            ptb = psr(q_txt, rollout_pr, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            #  ── (3) Append entry ───────────────────────────────────────────
            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "ptb_rewards":   ptb,
                    "contributions": contrib,
                    "answer":        gold_ans,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
            # print(entry)
        return dataset

    def build_datasets_math(self, *, split: str = "train", start: int = 0, take: int | None):
        """
        ① MATH 데이터셋 로드 → ② 정답·스텝 추출 → ③ 보상 계산 → ④ dict 리스트 반환
        """
        boxed_re   = re.compile(r'\\boxed\{(.+?)\}', re.S)
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시

        rollout_prompt = system_prompt("rollout")
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)

        # shuffle & take
        if take is not None:
            ds = ds.select(range(start, start+take))

        # (alias) time optimize
        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []

        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            full_sol   = sample["solution"]
            m          = boxed_re.search(full_sol)
            gold_ans   = sanitize(m.group(1)) if m else None
            sol_wo_box = boxed_re.sub("", full_sol)
            raw_steps  = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps      = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(sample["problem"], rollout_prompt, steps, gold_ans)
            ptb = psr(sample["problem"], rollout_prompt, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            # ── (3) Append entry ───────────────────────────────────────────
            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "answer":        gold_ans,
                "gold_answer":   gold_ans,
                "level":         sample["level"],
                "type":          sample["type"],
            }
            dataset.append(entry)
            # print(entry)
        return dataset



In [16]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import re
from typing import List, Optional
from tqdm import tqdm
import random
from vllm import LLM, SamplingParams

# Project‑level helpers 
from utils import _sanitize, _numeric_equiv, _strip_markup, _to_float, system_prompt
from config import PRMConfig

class MCReward:
    STEP_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Step\s*              # word 'Step' (case-insensitive)
        (\d+)                # capture step number
        \s*[:.\-]            # separator (: . or -)
    """,
    re.IGNORECASE | re.VERBOSE,
    )
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    ## Masked rewards ##
    OP_TOKENS = ["add", "plus", "sum", "subtract", "minus",
             "multiply", "times", "product", "divide", "quotient"]
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = next(model.parameters()).device
        self.llm = LLM(
            model=config.model_name,     # 동일 체크포인트
            dtype="float16",                 # fp16 / bfloat16
            tensor_parallel_size=torch.cuda.device_count(),  # 여러 GPU → 자동 shard
            trust_remote_code=True,
        )
        self.sparams = SamplingParams(
            max_tokens=config.max_new_tokens,
            temperature=0.8,
            top_p=0.8,
            n=config.num_rollouts,          # 한번에 num_rollouts 샘플
        )
        

    # Function to generate one or more step-by-step solutions for a given question.
    def generate_solutions(self, question: str, sys_prompt: str, num_solutions: int):
        prompt = f"{sys_prompt}\n\n{question}\n"  # Prompt the model to start the step-by-step solution
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        # Generate multiple solutions via sampling
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.config.max_new_tokens,
            do_sample=True,
            num_return_sequences=num_solutions,
            temperature=0.8,         # sampling temperature for diversity (adjust as needed)
            top_p=0.8,               # top-p sampling for diversity
            pad_token_id=self.tokenizer.eos_token_id  # pad token ID to avoid warning for some models
        )
        solutions = []
        prompt_len = input_ids.shape[-1]
        for i in range(num_solutions):
            # Each output is the concatenation of the prompt and the generated completion.
            generated_ids = outputs[i]
            # Extract only the newly generated tokens (skip the prompt tokens).
            gen_ids = generated_ids[prompt_len:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            solutions.append(text)
            # print(f"{i}-th Sampled Solutions:",text)
        return solutions
    
    def gsm8k_solutions(self, question: str, gold_solution: str):
        # 1. Split lines *before* the final answer marker (#### …)
        lines: List[str] = []
        gold_answer: str = ""
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        for raw_ln in gold_solution.splitlines():
            ln = raw_ln.strip()
            if not ln:
                continue  # skip empty
            ans_match = _ANSWER_RE.match(ln)
            if ans_match:
                gold_answer = ans_match.group(1).strip()
                break  # everything after #### is ignored
            lines.append(ln)

        if not gold_answer:
            raise ValueError("Could not find final answer marker '#### <answer>' in gold_solution.")

        # 2. Prefix each explanatory line with "Step i:"
        solution_steps = [f"Step {i + 1}: {txt}" for i, txt in enumerate(lines)]
        return {
            "question": question,
            "solution": solution_steps,
            "gold_answer": gold_answer,
        }

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """Split each step to start with 'Step X:' and the answer to start with 'Answer:'."""
        steps = []
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            if self.STEP_PATTERN.match(line):
                cleaned = re.sub(r'^[\s>#*\-]+', '', line)
                steps.append(cleaned)
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def _generate_rollouts(self, prompt: str) -> list[str]:
        """
        vLLM 에서 동일 프롬프트를 n번(=num_rollouts) 샘플링해서 텍스트만 반환
        """
        outs = self.llm.generate([prompt], self.sparams)   # 배치 길이 1
        return [o.outputs[ri].text for o in outs for ri in range(len(o.outputs))]

    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            step_prefix = "\n".join(steps[:i+1]) + "\n"
            next_label  = f"Step {i+2}:" if i < total_steps-1 else "Answer:"
            prompt = f"{base_prompt}{step_prefix}{next_label}"
            rollout_outputs = self._generate_rollouts(prompt)  
            print("Original rollout outputs:", rollout_outputs)

            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                pred_answer = self._extract_answer(seq)
                print(f"[{i+1}-th Step, {idx}-th Original Rollout]", seq, "Pred Answer", pred_answer)
                if pred_answer and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Masked solution paths
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.2, top_p=0.2,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:],
                                     skip_special_tokens=True).strip()

    def perturbed_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            # print("perturbed step:", ptb_prefix_steps)

            step_prefix = "\n".join(ptb_prefix_steps) + "\n"
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            prompt = f"{base_prompt}{step_prefix}{next_label}"
            rollout_outputs = self._generate_rollouts(prompt)  
            print("Masked rollout outputs:", rollout_outputs)
            
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                pred_answer = self._extract_answer(seq)
                print(f"Masked [{i+1}-th Step, {idx}-th Rollout]", seq, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards

    # Build datasets based on input datas
    def build_datasets(self, problems: List):
        dataset = []  # will hold the output list of dicts
        for problem in problems:
            question = problem["question"]
            # gold_answer = problem["gold_answer"]
            gold_answer = _sanitize(problem["gold_answer"])
            # Generate one or more solutions for this question
            sample_prompt = system_prompt("sample")
            rollout_prompt = system_prompt("rollout")
            solutions = self.generate_solutions(question, sys_prompt=sample_prompt, num_solutions=self.config.samples_per_question)
            
            for sol_text in solutions:
                steps, answer = self.parse_solution(sol_text)
                # print("Parsed solution:", steps, answer)
                if answer is None: # If no answer was found in the solution (edge case), skip this solution
                    continue
                # 2. Compute *original* & *perturbed* per‑step rewards
                # ----------------------------------------------------------
                ori_rewards = self.compute_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                ptb_rewards = self.perturbed_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                # Align lengths (robustness)
                if len(ptb_rewards) != len(ori_rewards):
                    ptb_rewards = ptb_rewards[: len(ori_rewards)]
                # contributions = [max(0, o - p) for o, p in zip(ori_rewards, ptb_rewards)]
                contributions = [o - p for o, p in zip(ori_rewards, ptb_rewards)]
                entry = {
                    "question": question,
                    "completion": steps,          # list[str] (Step i: ...)
                    "ori_rewards": ori_rewards,    # list[float]
                    "ptb_rewards": ptb_rewards,    # list[float]
                    "contributions": contributions,  # ori − ptb
                    "answer": answer,
                    "gold_answer": gold_answer,
                }
                dataset.append(entry)
        return dataset
    
    # Build datasets based on input datas
    def build_datasets_gsm8k(self, *, split: str = "train", start: int = 0, take: int | None):
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        rollout_pr = system_prompt("rollout")
        ds = load_dataset("openai/gsm8k", "main", split=split)
        if take is not None:
            ds = ds.shuffle(seed=self.config.seed).select(range(start, start+take))

        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []
        
        for sample in tqdm(ds, desc="Building GSM-8K reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            q_txt   = sample["question"]
            g_sol   = sample["answer"]
            lines, gold_ans = [], None
            for ln in g_sol.splitlines():
                ln = ln.strip()
                if not ln:
                    continue
                m = _ANSWER_RE.match(ln)
                if m:
                    gold_ans = sanitize(m.group(1))
                    break
                lines.append(ln)
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(q_txt, rollout_pr, steps, gold_ans)
            ptb = psr(q_txt, rollout_pr, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            #  ── (3) Append entry ───────────────────────────────────────────
            entry = {
                    "question":      q_txt,
                    "completion":    steps,
                    "ori_rewards":   ori,
                    "ptb_rewards":   ptb,
                    "contributions": contrib,
                    "answer":        gold_ans,
                    "gold_answer":   gold_ans,
                }
            dataset.append(entry)
            # print(entry)
        return dataset

    def build_datasets_math(self, *, split: str = "train", start: int = 0, take: int | None):
        """
        ① MATH 데이터셋 로드 → ② 정답·스텝 추출 → ③ 보상 계산 → ④ dict 리스트 반환
        """
        boxed_re   = re.compile(r'\\boxed\{(.+?)\}', re.S)
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')   # 소수점·수식 내부 마침표 무시

        rollout_prompt = system_prompt("rollout")
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)

        # shuffle & take
        if take is not None:
            ds = ds.select(range(start, start+take))

        # (alias) time optimize
        csr, psr   = self.compute_step_rewards, self.perturbed_step_rewards
        sanitize   = _sanitize
        use_llm    = self.config.use_llm
        dataset    = []

        for sample in tqdm(ds, desc="Building MATH reward-dataset"):
            # ── (1) extract step solutions ──────────────────────────────────────────
            full_sol   = sample["solution"]
            m          = boxed_re.search(full_sol)
            gold_ans   = sanitize(m.group(1)) if m else None
            sol_wo_box = boxed_re.sub("", full_sol)
            raw_steps  = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps      = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]

            # ── (2) compute rewards ───────────────────────────────────────────────────
            ori = csr(sample["problem"], rollout_prompt, steps, gold_ans)
            ptb = psr(sample["problem"], rollout_prompt, steps, gold_ans, use_llm)
            if len(ptb) != len(ori):
                ptb = ptb[: len(ori)]
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            # ── (3) Append entry ───────────────────────────────────────────
            entry = {
                "question":      sample["problem"],
                "completion":    steps,
                "ori_rewards":   ori,
                "ptb_rewards":   ptb,
                "contributions": contrib,
                "answer":        gold_ans,
                "gold_answer":   gold_ans,
                "level":         sample["level"],
                "type":          sample["type"],
            }
            dataset.append(entry)
            # print(entry)
        return dataset


In [19]:
cfg = PRMConfig()
model = AutoModelForCausalLM.from_pretrained(cfg.model_name)
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
mcr = MCReward(config=cfg, model=model, tokenizer=tokenizer)
math_dataset = mcr.build_datasets_math(split="test",take=2)
math_dataset[1]

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


INFO 07-04 17:16:15 [config.py:823] This model supports multiple tasks: {'score', 'classify', 'embed', 'generate', 'reward'}. Defaulting to 'generate'.
INFO 07-04 17:16:15 [config.py:2195] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 07-04 17:16:18 [core.py:455] Waiting for init message from front-end.
INFO 07-04 17:16:18 [core.py:70] Initializing a V1 LLM engine (v0.9.1) with config: model='Qwen/Qwen2.5-Math-7B', speculative_config=None, tokenizer='Qwen/Qwen2.5-Math-7B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_ad

Process EngineCore_0:
Traceback (most recent call last):
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 519, in run_engine_core
    raise e
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 506, in run_engine_core
    engine_core = EngineCoreProc(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 390, in __init__
    super().__init__(vllm_config, executor_class, log_stats,
  File "/home/leena/anaconda3/envs/prm/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 76, in __init__
    self.model_executor = exe

RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {'EngineCore_0': 1}

In [None]:
import openai
# from vllm import LLM, SamplingParams

openai.api_key = "EMPTY"  # vLLM 기본 키
# openai.api_base = "http://localhost:8000/v1"
openai.api_base = "http://<vllm-server-ip>:8000/v1"

resp = openai.ChatCompletion.create(
    model="Qwen/Qwen2.5-Math-14B",
    messages=[
        {"role": "user", "content": f"""{system_prompt("rollout")} Problem: What is the positive difference between $120\\%$ of 30 and $130\\%$ of 20?\nStep 1: One hundred twenty percent of 30 is $120\\cdot30\\cdot\\frac{1}{100}=36$, and $130\\%$ of 20 is $ 130\\cdot 20\\cdot\\frac{1}{100}=26$"""}
    ],
    temperature=0.7,
    top_p=0.8,
)

print(resp["choices"][0]["message"]["content"])

# PRMDataset

In [2]:
import random, re
from pathlib import Path
from typing import List, Tuple
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

class StepwisePRMDataset(Dataset):
    """
    build_datasets() 가 반환한 entries(list[dict])를
    (input_ids, scalar_reward) 샘플들로 변환한다.

    한 entry = {question, completion[steps], rewards[float], …}
    →  (Problem + Step1,   r1)
        (Problem + Step1 \nStep2,   r2) …
    """
    def __init__(
        self,
        entries: List[dict],
        tokenizer: PreTrainedTokenizer,
        max_length: int = 512,
        use_contr: bool = True,
        *,
        cache_encodings: bool = True,
        preprocess: bool = True,
    ):
        self.tokenizer   = tokenizer
        self.max_length  = max_length
        self.use_contri    = use_contr
        self.cache       = {} if cache_encodings else None
        self.samples: List[Tuple[str, float]] = []

        for e in entries:
            q_txt   = e["question"]
            steps   = e["completion"]
            o_rewards = e["ori_rewards"]
            contri = e["contributions"]
            assert len(steps) == len(o_rewards)

            if self.use_contri:
                rewards = contri
                # rewards = [max(0.0, x) for x in contri]
            else:
                rewards = o_rewards

            prefix_lines = [f"Problem: {q_txt}"]
            for step_txt, r in zip(steps, rewards):
                prefix_lines.append(step_txt)
                full_txt = "\n".join(prefix_lines)
                if preprocess:
                    full_txt = self._clean(full_txt)
                self.samples.append((full_txt, float(r)))   # (text, reward)

    # --------------------------------------------------------------------- utils
    @staticmethod
    def _clean(txt: str) -> str:
        """whitespace normalize + 소문자화(선택적) 등 간단 전처리"""
        txt = re.sub(r"\s+", " ", txt).strip()
        return txt

    # --------------------------------------------------------------------- dunder
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text, reward = self.samples[idx]

        if self.cache is not None and text in self.cache:
            ids = self.cache[text]
        else:
            ids = self.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids.squeeze(0)
            if self.cache is not None:
                self.cache[text] = ids

        return ids, torch.tensor(reward, dtype=torch.float32)
    

  from .autonotebook import tqdm as notebook_tqdm


# PRM Model

In [7]:
# version 2
import torch
import torch.nn as nn
from typing import Optional

class ProcessRewardModel(nn.Module):
    """Enhanced Process Reward Model with dropout and layer normalization"""
    def __init__(self, input_size: int, cfg: "PRMConfig"):
        """ 
        Args:
            input_size : CLS-embedding dim of the frozen LLM backbone
            cfg        : PRMConfig instance (hidden_size, num_layers, dropout …)
        """
        super().__init__()
        
        self.input_size = input_size
        # self.output_size = cfg.output_size
        h = cfg.hidden_size
        p_drop = cfg.dropout
        n_layers = cfg.num_layers
        act_fn     = nn.GELU()

         # ── first projection ────────────────────────────────────────────
        self.in_proj = nn.Sequential(
            nn.Linear(input_size, h),
            nn.LayerNorm(h),
            act_fn,
            nn.Dropout(p_drop),
        )

        # ── stacked residual blocks ─────────────────────────────────────
        blocks = []
        for _ in range(n_layers - 1):
            blocks.append(
                nn.Sequential(                   # pre-LN residual MLP
                    nn.LayerNorm(h),
                    nn.Linear(h, h),
                    act_fn,
                    nn.Dropout(p_drop),
                    nn.Linear(h, h),
                    nn.Dropout(p_drop),
                )
            )
        self.blocks = nn.ModuleList(blocks)

        # ── output head ────────────────────────────────────────────────
        self.out_proj = nn.Sequential(
            nn.LayerNorm(h),
            nn.Linear(h, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.in_proj(x)
        for blk in self.blocks:
            x = x + blk(x)          # residual connection
        return self.out_proj(x).squeeze(-1)

    def get_complexity(self) -> int:
        return sum(p.numel() for p in self.parameters())
    

# PRMTrainer

In [8]:
import json
from pathlib import Path
from typing import Dict, List, Optional
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import logging
import math
from torch.optim.lr_scheduler import LambdaLR

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('omega_prm.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class PRMTrainer:
    """
    (1) entries(list[dict]) → StepwisePRMDataset
    (2) LLM encoder + PRM head fine-tuning
    """
    def __init__(self, cfg: PRMConfig, model, tokenizer):
        self.cfg = cfg
        torch.manual_seed(cfg.seed)

        # ----------------------------- Backbone model LLM (frozen or fine-tuned)
        self.tokenizer = tokenizer
        self.model  = model
        self.model.eval()       # LLM은 feature extractor로 freeze
        for p in self.model.parameters():
            p.requires_grad_(False)

        feat_dim = self.model.config.hidden_size
        self.prm = ProcessRewardModel(feat_dim, cfg=cfg)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.prm.to(self.device)

        self.opt  = optim.AdamW(self.prm.parameters(), lr=cfg.learning_rate, weight_decay = cfg.weight_decay)
        self.crit = nn.MSELoss()
        # self.crit = nn.BCELoss()

        self.scheduler = None
        if cfg.lr_scheduler == "cosine":                    # ⭐
            # total steps = (#batches per epoch) × epochs
            self.total_steps = math.ceil(cfg.epochs * cfg.dataset_size / cfg.batch_size)
            def lr_lambda(step):
                if step < cfg.warmup_steps:
                    return step / max(1, cfg.warmup_steps)
                progress = (step - cfg.warmup_steps) / max(1, self.total_steps - cfg.warmup_steps)
                return 0.5 * (1.0 + math.cos(math.pi * progress))
            self.scheduler = LambdaLR(self.opt, lr_lambda)

        self.ckpt_dir = Path(cfg.checkpoint_dir)
        self.ckpt_dir.mkdir(exist_ok=True, parents=True)

        self.wandb_run = None
        if cfg.use_wandb:                                  # <-- config에 플래그
            self.wandb_run = wandb.init(
                project=cfg.wandb_project,                 # e.g. "omega-prm"
                name=cfg.run_name,                         # e.g. "qwen7b-prm"
                config=vars(cfg),                          # 모든 하이퍼파라미터 로깅
            )

    # ----------------------------------------------------------------- features
    @torch.no_grad()
    def _encode(self, ids: torch.Tensor) -> torch.Tensor:
        """
        input_ids [B,T] → [B, feat_dim] using 마지막 hidden state의 CLS-like 첫 토큰
        """
        out = self.model(input_ids=ids, return_dict=True,output_hidden_states=True)
        return out.hidden_states[-1][:, 0, :]     # CLS embedding

    # ----------------------------------------------------------------- loop util
    def _run_epoch(self, loader: DataLoader, train: bool, epoch_idx: int) -> float:
        self.prm.train(train)
        total = 0.0

        for step, (ids, reward) in enumerate(loader):
            ids, reward = ids.to(self.device), reward.to(self.device)

            with torch.set_grad_enabled(train):
                feats  = self._encode(ids)
                pred   = self.prm(feats).squeeze(-1)
                loss   = self.crit(pred, reward)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.prm.parameters(), self.cfg.grad_clip)
                    self.opt.step()
                    if self.scheduler: self.scheduler.step()

            total += loss.item()

            # -------- minibatch logging --------
            if self.wandb_run and train:
                wandb.log({
                    "batch_loss": loss.item(),
                    "epoch": epoch_idx + step / len(loader),
                    "lr": self.opt.param_groups[0]["lr"],
                    "grad_norm": sum(p.grad.data.norm(2).item()
                                     for p in self.prm.parameters()
                                     if p.grad is not None),
                })

        return total / len(loader)

    # ----------------------------------------------------------------- public
    def fit(self, train_entries: List[dict], val_entries: List[dict]) -> Dict[str, List[float]]:
        self.cfg.dataset_size = len(train_entries) 

        train_ds = StepwisePRMDataset(train_entries, self.tokenizer, self.cfg.max_new_tokens, self.cfg.use_contri)
        val_ds   = StepwisePRMDataset(val_entries,   self.tokenizer, self.cfg.max_new_tokens, self.cfg.use_contri)

        train_loader = DataLoader(
            train_ds, batch_size=self.cfg.batch_size, shuffle=True,
            num_workers=self.cfg.num_workers, pin_memory=True,
        )
        val_loader = DataLoader(
            val_ds, batch_size=self.cfg.batch_size, shuffle=False,
            num_workers=self.cfg.num_workers, pin_memory=True,
        )

        history = {"train": [], "val": []}
        best_val, bad_epoch, patience = float("inf"), 0, 5

        for ep in range(self.cfg.epochs):
            tr_loss = self._run_epoch(train_loader, train=True,  epoch_idx=ep)
            vl_loss = self._run_epoch(val_loader,   train=False, epoch_idx=ep)

            history["train"].append(tr_loss)
            history["val"].append(vl_loss)
            print(f"[Epoch {ep+1}/{self.cfg.epochs}] train={tr_loss:.4f}  val={vl_loss:.4f}")

            # -------- epoch logging --------
            if self.wandb_run:
                wandb.log({"train_loss": tr_loss,"val_loss": vl_loss,"epoch": ep})

            # 체크포인트 저장
            if vl_loss < best_val:
                best_val = vl_loss
                bad_epochs = 0
                self._save_checkpoint("best_prm.pt", epoch=ep, val_loss=vl_loss)
            else:
                bad_epochs += 1
                if bad_epochs >= patience:
                    print(f"[Early-Stopping] no improvement for {patience} epochs")
                    break
        
        self._save_checkpoint("last_prm.pt", epoch=self.cfg.epochs - 1, val_loss=vl_loss)
        return history
    
    # ------------------------------------------------------------------
    # Checkpoint helpers
    def _save_checkpoint(self, filename: str, *, epoch: int, val_loss: float) -> None:
        path = self.ckpt_dir / filename
        save_dict = {
            "epoch": epoch,
            "val_loss": val_loss,
            "prm_state": self.prm.state_dict(),
            "scheduler_state": (self.scheduler.state_dict() if self.scheduler else None),
            "optimizer_state": self.opt.state_dict(),
            "config": vars(self.cfg),              # hyper‑params for reproducibility
            "model_name_or_path": getattr(self.model, "name_or_path", None),
            "tokenizer_config": self.tokenizer.__dict__.get("init_kwargs", {}),
        }
        torch.save(save_dict, path)
        print(f"[CKPT] Saved ⇒ {path}")

    # ------------------------------------------------------------------
    # Simple inference helper
    @torch.no_grad()
    def predict_reward(self, text: str) -> float:
        ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
        feat = self._encode(ids)
        return float(torch.sigmoid(self.prm(feat)).item())


# Main

In [10]:
model_name = "Qwen/Qwen2.5-Math-7B-Instruct"      # "Qwen/Qwen2.5-Math-7B", "Qwen/Qwen2.5-Math-7B-Instruct" , "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cfg = PRMConfig()
print("Finish load model and config!")

# problems = [
#     {"question": "Each notebook costs $5. Sarah buys 4 notebooks and pays with a $50 bill. How much change does she get?", "gold_answer": "30"},
#     {"question": "Solve for y: 2y - 7 = 3(y - 4).", "gold_answer": "5"},
#     # Add more problems as needed...
# ]

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.99it/s]


Finish load model and config!


In [None]:
cfg = PRMConfig()
mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
gsm8k_raw= mcr.build_datasets_gsm8k(split="train", start=10, take=2)

# Print or inspect the dataset
for entry in gsm8k_raw:
    print(entry)
    print("-" * 80)


In [None]:

# mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
# entries_raw = mcr.build_datasets(problems)

# random.shuffle(entries_raw)
# split_idx       = int(0.9 * len(entries_raw)) if len(entries_raw) > 1 else 1
split_idx = 1
train_entries   = reward_ds_small[:split_idx]
val_entries     = reward_ds_small[split_idx:] or reward_ds_small[:1]   # 최소 1개 확보

trainer = PRMTrainer(cfg, model=model, tokenizer=tokenizer)
history = trainer.fit(train_entries, val_entries)
print("Training complete. Loss history:", history)

In [9]:
with open("gsm8k_train_0624_100.json", "r") as file:
        gsm8k_raw = json.load(file)

# trainer = PRMTrainer(cfg, model=model, tokenizer=tokenizer)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B")
cfg = PRMConfig()
train_data = StepwisePRMDataset(gsm8k_raw, tokenizer, cfg.max_new_tokens, True)
print(len(train_data))
idx = 9
print((train_data[idx][1]))
print(tokenizer.decode(train_data[idx][0], skip_special_tokens=True))

350
tensor(0.1250)
Problem: Olaf collects colorful toy cars. At first, his collection consisted of 150 cars. His family, knowing his hobby, decided to give him some toy cars. Grandpa gave Olaf twice as many toy cars as the uncle. Dad gave Olaf 10 toy cars, 5 less than Mum. Auntie gave Olaf 6 toy cars, 1 more than the uncle. How many toy cars does Olaf have in total, after receiving all these gifts? Step 1: Dad gave Olaf 10 toy cars, Step 2: Mom has given Olaf 5 more toy cars than Dad, so 10 + 5 = <<10+5=15>>15 toy cars Step 3: Auntie gave Olaf 6 toy cars,


In [None]:
import json, random
from pathlib import Path

def main():
    model_name =  "Qwen/Qwen2.5-Math-7B" # "Qwen/Qwen2.5-Math-7B-Instruct"  #"Qwen/Qwen2.5-Math-7B"
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = PRMConfig()
    print("Finish load model and config!")

    problems = [
        {"question": "A train travels 120 km in 2 hours and then 180 km in 3 hours. What is the average speed of the train?", "gold_answer": "60"},
        {"question": "Solve for y: 2y - 7 = 3(y - 4).", "gold_answer": "5"},
        # Add more problems as needed...
    ]
    
    mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
    entries_raw = mcr.build_datasets(problems)

    # Print or inspect the dataset
    for entry in entries_raw:
        print(entry)
        print("-" * 80)

    random.shuffle(entries_raw)
    # split_idx       = int(0.9 * len(entries_raw)) if len(entries_raw) > 1 else 1
    split_idx = 1
    train_entries   = entries_raw[:split_idx]
    val_entries     = entries_raw[split_idx:] or entries_raw[:1]   # 최소 1개 확보

    trainer = PRMTrainer(cfg, model=model, tokenizer=tokenizer)
    history = trainer.fit(train_entries, val_entries)
    print("Training complete. Loss history:", history)

if __name__ == "__main__":
    main()

# Inference

In [None]:
import json, random, math, argparse
from pathlib import Path
from typing import List, Dict, Tuple
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# # ---------------- 사용자 정의 유틸 ------------------------------
# from prmtrainer import ProcessRewardModel        # PRM head class
# from utils import _sanitize, _numeric_equiv, _extract_answer  # 앞서 만든 함수

LLM_NAME      = "Qwen/Qwen2.5-Math-7B"
PRM_CKPT_PATH = "./checkpoints/gsm8k/ori_mse/best_prm.pt"
N_ROLLOUTS    = 5                         # Best-of-N
MAX_NEW_TOK   = PRMConfig.max_new_tokens
SEED          = PRMConfig.seed
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"
HIDDEN_DIM    = PRMConfig.hidden_size 

SYSTEM_PROMPT_SAMPLE = (
        "Problem: (sample)\n"
        "Please provide a short and precise step-by-step solution, and a numerical answer in the end, for the question above in the following format, without any extra wording:\n"
        "Step 1: (logical step 1)\n"
        "Step 2: (logical step 2)\n"
        "...\n"
        "Step n: (logical last step)\n"
        "Answer: (Final result)"
        "Please strictly stick to the format above."
)

# Load 
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Load base model
config = PRMConfig()
tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
base  = AutoModelForCausalLM.from_pretrained(LLM_NAME).to(DEVICE).eval()
for p in base.parameters():
    p.requires_grad_(False)
mcr = MCReward(config=config, tokenizer=tokenizer, model=base)

# Load prm model
feat_dim = base.config.hidden_size
prm = ProcessRewardModel(feat_dim, cfg=config)
ckpt = torch.load(PRM_CKPT_PATH, map_location="cpu", weights_only=False)
prm.load_state_dict(ckpt["prm_state"])
prm.to(DEVICE).eval()

In [None]:
PRM_CKPT_PATH = "./checkpoints/gsm8k/contri_mse/best_prm.pt"
# Load prm model
feat_dim = base.config.hidden_size
prm = ProcessRewardModel(feat_dim, cfg=config)
ckpt = torch.load(PRM_CKPT_PATH, map_location="cpu", weights_only=False)
prm.load_state_dict(ckpt["prm_state"])
prm.to(DEVICE).eval()

# Load Datasets
problems = []
gsm8k_test = load_dataset("openai/gsm8k", "main")["test"]
small_gsm8k = gsm8k_test.select(range(2,4))
for obj in small_gsm8k:
    problems.append({"q": obj["question"], "gold": obj["answer"]})
print("Finish Loading model and dataset!")

# Evaluation utils
@torch.no_grad()
def generate_solutions(
    backbone,
    tokenizer,
    question: str,
    n: int,
) -> List[str]:
    prompt = (
        f"{SYSTEM_PROMPT_SAMPLE}\n\n"
        f"Problem: {question}\n"
    )
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    gen_cfg = GenerationConfig(
        max_new_tokens=MAX_NEW_TOK,
        do_sample=True, temperature=0.8, top_p=0.9,
        num_return_sequences=n,
        pad_token_id=tokenizer.eos_token_id,
    )
    out = backbone.generate(input_ids.repeat(n, 1), **gen_cfg.to_dict())
    return [
        tokenizer.decode(seq[input_ids.shape[-1]:], skip_special_tokens=True)
        for seq in out
    ]

def parse_steps(text: str) -> List[str]:
    return [ln.strip() for ln in text.splitlines() if ln.strip().lower().startswith("step")]

@torch.no_grad()
def prm_score(
    prm: ProcessRewardModel,
    backbone,
    tokenizer,
    question: str,
    steps: List[str],
) -> float:
    """
    각 prefix(Problem+Step1…i)에 대해 PRM 예측 → 평균 점수
    점수가 클수록 '좋은' reasoning.
    """
    prefix_lines = [f"Problem: {question}"]
    scores = []
    for step in steps:
        prefix_lines.append(step)
        txt = "\n".join(prefix_lines)
        ids = tokenizer(
            txt,
            max_length=384,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        ).input_ids.to(DEVICE)
        feats = backbone(
            input_ids=ids,
            output_hidden_states=True,
            return_dict=True,
        ).hidden_states[-1][:, 0, :]           # CLS
        # score = prm(feats).sigmoid().item()    # 0~1 확률
        score = prm(feats).item()
        scores.append(score)
    print("Scores by PRM:", scores, "\n")
    return sum(scores) / len(scores)

# Evaluation
correct = 0
total   = len(problems)
for idx, item in enumerate(problems, 1):
    q, gold = item["q"], _sanitize(item["gold"])
    # ① N개의 솔루션 생성
    sols = generate_solutions(base, tokenizer, q, N_ROLLOUTS)
    # ② PRM 스코어 계산
    scored: List[Tuple[float, str]] = []
    for sol in sols:
        steps = parse_steps(sol)
        if not steps:
            continue
        s = prm_score(prm, base, tokenizer, q, steps)
        scored.append((s, sol))
        print(f"[{idx}-th solution] Step:", steps, "")

    if not scored:
        pred_answer = "N/A"
    else:
        best_sol = max(scored, key=lambda t: t[0])[1]
        pred_answer = mcr._extract_answer(text=best_sol) or "N/A"

    if _numeric_equiv(pred_answer, gold):
        correct += 1

    print(f"[{idx}/{total}] pred={pred_answer} | gold={gold} | {'✓' if _numeric_equiv(pred_answer, gold) else '✗'}")

accuracy = correct / total * 100
print(f"\n=== GSM8K Test Accuracy (Best-of-{N_ROLLOUTS} w/ PRM) : {accuracy:.2f}% ===")

# Utils

In [None]:
import torch, platform, sys, subprocess, os
print("PyTorch :", torch.__version__)
print("CUDA ver:", torch.version.cuda)
print("is_avail:", torch.cuda.is_available())
print("nvidia-smi output ↓")
subprocess.run(["nvidia-smi"])

In [None]:
import json

with open("/home/leena/ccc_eval/mcts_prm/MC_PRM/samples/math_gsm8k_200.json") as file:
    f1 = json.load(file)

with open("/home/leena/ccc_eval/mcts_prm/MC_PRM/gsm8k_train_0703_200.json") as file:
    f2 = json.load(file)

print(len(f1))
print(len(f2))

merged = f1 + f2
print(len(merged))

with open("/home/leena/ccc_eval/mcts_prm/MC_PRM/samples/math_gsm8k_400.json", "w") as merged_file:
    json.dump(merged, merged_file)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B")
process = StepwisePRMDataset(merged, tokenizer)
print(len(process))

200
200
400


In [11]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_name = "Qwen/QwQ-32B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,            # enable 4-bit (QLoRA-style) weights
    bnb_4bit_quant_type="nf4",    # NF4 gives the best accuracy for most LLMs
    bnb_4bit_use_double_quant=True, # optional: second quantisation pass to save ~0.4 bits/param
    bnb_4bit_compute_dtype=torch.bfloat16  # faster matmuls on recent GPUs; fall back to float16 if needed
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",      # let Accelerate split layers across all visible GPUs
    quantization_config=bnb_config,
    torch_dtype="auto",     # keeps non-linear layers in their original dtype
    trust_remote_code=True  # Qwen models need their custom code
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

ValueError: Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` requires `accelerate`. You can install it with `pip install accelerate`