In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
import concurrent.futures
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


# Config

In [57]:
class PRMConfig:
    """Configuration class for PRM hyperparameters and settings"""
    # MC config
    max_new_tokens: int = 386
    num_rollouts: int = 5
    reward_threshold: float = 0.2
    samples_per_question: int = 1
    use_llm: bool = True
    use_contri: bool = True
    # PRMTrainer config
    batch_size: int = 32
    learning_rate: float = 5e-4
    hidden_size: int = 256
    num_workers: int = 4
    epochs: int = 2
    # Misc config
    use_wandb: bool = True
    wandb_project: str = "mc_prm"
    run_name: str = "test_gsm8k_0623"
    checkpoint_dir: str = "checkpoints"
    seed: int = 42

# MC Rewards

In [3]:
################################################################################
#                        UTILITY: ANSWER NORMALISATION                         #
################################################################################
import sympy as sp

def _strip_markup(ans: str) -> str:
    """Remove common LaTeX/markup & variable tags."""
    # # Remove LaTeX inline math wrappers \( … \) or \[ … \]
    # ans = re.sub(r"\\[\[(](.*?)[\\\])]", r"\1", ans)
    # # Remove \boxed{…}
    # ans = re.sub(r"\\boxed\{([^}]*)\}", r"\1", ans)
    ans = re.sub(r"\\\[.*?\\\]", "", ans)
    ans = re.sub(r"\$\$.*?\$\$", "", ans)
    # Remove inline LaTeX: \( ... \) and $...$
    ans = re.sub(r"\\\((.*?)\\\)", r"\1", ans)
    ans = re.sub(r"\$(.*?)\$", r"\1", ans)
    # Remove \boxed{...}
    ans = re.sub(r"\\boxed\s*{([^}]*)}", r"\1", ans)
    # Remove LaTeX commands like \text{...}, \frac{...}, etc.
    ans = re.sub(r"\\[a-zA-Z]+\s*(\{[^{}]*\})?", "", ans)
    # Remove variable assignments like "y =" or "x=" at start
    ans = re.sub(r"^[a-zA-Z]\s*=\s*", "", ans)
    # Trim outer $ … $ if present
    ans = ans.strip()
    if ans.startswith("$") and ans.endswith("$"):
        ans = ans[1:-1]
    return ans.strip()

def _sanitize(text: str) -> str:
    """Normalise a candidate answer string for comparison."""
    text = _strip_markup(text)
    text = text.strip()
    text = re.sub(r"[\s\.;:,]+$", "", text)     # trailing punctuation
    text = re.sub(r"\s+", " ", text)              # collapse spaces
    return text

def _to_float(expr: str) -> Optional[float]:
    try:
        return float(eval(expr.replace("^", "**")))
    except Exception:
        return None

def _numeric_equiv(a: str, b: str) -> bool:
    """Return True if `a` and `b` are numerically equivalent or exact match."""
    a_clean, b_clean = map(_sanitize, (a, b))
    if a_clean == b_clean:
        return True

    # Attempt simple numeric evaluation
    a_val, b_val = _to_float(a_clean), _to_float(b_clean)
    if a_val is not None and b_val is not None:
        return math.isclose(a_val, b_val, rel_tol=1e-6)

    if sp is not None:
        try:
            a_expr = sp.sympify(a_clean.replace("^", "**"))
            b_expr = sp.sympify(b_clean.replace("^", "**"))
            return sp.simplify(a_expr - b_expr) == 0
        except Exception:
            pass
    return False

def system_prompt(type):
    prompt = ""
    if type == "sample":
        prompt = """You are a math-problem expert. Your task is to complete the step-by-step solution for the problem provided. Write each reasoning step on its own line in the exact form \"Step k: [your reasoning step]\n\", numbering start from Step 1. When the final answer is obtained, write exactly one final line, \"Answer: [Final answer]\". Do NOT add explanations, extra steps, or any text after the "Answer:" line.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Follow the FORMAT GUIDE structure exactly. Generate rationales step-by-step, not directly to the final answer. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    if type == "rollout":
        prompt = """You are a math problem-solving expert. Continue solving the given problem step by step, strictly following the required format. Each new step must begin with \"Step k+1: ...\", \"Step k+2:...\", and so on, continuing from the last given step number. When the final answer is reached, write only one final line starting with: \"Answer: [Final Answer]\". Do not add any explanations, extra commentary, or additional text after the "Answer:" line. Your output must follow this exact step-by-step format with no deviations.

**Format Guide**: (You MUST write "Step " before numbering the step.)
Step 1: [Step 1 reasoning]\n
Step 2: [Step 2 reasoning]\n
...
Step k: [Step k reasoning]\n
Continue and finish the solution:
Step k+1: [Step k+1 reasoning]\n
...
Answer: [Final answer]

Format Guide with Examples:
<Example 1>
Current solution steps:
Problem: Find the sum of the first 8 positive even integers.
Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.
Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
Continue and finish the solution:
Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
Answer: 72

<Example 2>
Current solution steps:
Problem: Determine the next number in the sequence 2, 4, 8, 16.
Step 1: Notice each term is obtained by multiplying the previous term by 2.
Continue and finish the solution:
Step 2: Multiply 16 by 2, 16 * 2 = 32.
Answer: 32

Keep the reasoning steps precise and factual and complete the solution. Follow the FORMAT GUIDE structure exactly. **Do NOT** write anything after the final 'Answer:' line. Always start stepwise reasoning with "Step {i-th}: " form."""
    return prompt


In [63]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
# STEP_PATTERN = re.compile(r"Step\s+\d+:")
# ANSWER_PATTERN = re.compile(r"Answer\s*:\s*(.+?)\s*(?:$|\n)")

class MCReward:
    STEP_PATTERN = re.compile(
    r"""^[\s>#*\-]*          # optional markdown/bullet symbols
        Step\s*              # word 'Step' (case-insensitive)
        (\d+)                # capture step number
        \s*[:.\-]            # separator (: . or -)
    """,
    re.IGNORECASE | re.VERBOSE,
    )
    ANSWER_PATTERN = re.compile(
        r"""^[\s>#*\-]*          # optional markdown/bullet symbols
            Answer               # word 'Answer'
            \s*[:.\-]\s*         # separator
            (.+?)\s*$            # capture everything after
        """,
        re.IGNORECASE | re.MULTILINE | re.VERBOSE,
    )
    ## Masked rewards ##
    OP_TOKENS = ["add", "plus", "sum", "subtract", "minus",
             "multiply", "times", "product", "divide", "quotient"]
    _MASK_PATTERN = re.compile(
        r"""
        (?:
        # {ops_pattern}|                # operator patterns
            \b\d+(?:\.\d+)?\b         # integers / decimals
          | \b\d+/\d+\b                 # simple fractions
        #   | \b[a-zA-Z]\b                 # single‑letter variables
        )
        """,
        re.VERBOSE,
    )

    def __init__(self, config: "PRMConfig", model, tokenizer):
        self.config = config
        self.model = model
        self.tokenizer = tokenizer
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Function to generate one or more step-by-step solutions for a given question.
    def generate_solutions(self, question: str, sys_prompt: str, num_solutions: int):
        prompt = f"{sys_prompt}\n\n{question}\n"  # Prompt the model to start the step-by-step solution
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
        # Generate multiple solutions via sampling
        outputs = self.model.generate(
            input_ids,
            max_new_tokens=self.config.max_new_tokens,
            do_sample=True,
            num_return_sequences=num_solutions,
            temperature=0.8,         # sampling temperature for diversity (adjust as needed)
            top_p=0.8,               # top-p sampling for diversity
            pad_token_id=self.tokenizer.eos_token_id  # pad token ID to avoid warning for some models
        )
        solutions = []
        prompt_len = input_ids.shape[-1]
        for i in range(num_solutions):
            # Each output is the concatenation of the prompt and the generated completion.
            generated_ids = outputs[i]
            # Extract only the newly generated tokens (skip the prompt tokens).
            gen_ids = generated_ids[prompt_len:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            solutions.append(text)
            print(f"{i}-th Sampled Solutions:",text)
        return solutions
    
    def gsm8k_solutions(self, question: str, gold_solution: str):
        # 1. Split lines *before* the final answer marker (#### …)
        lines: List[str] = []
        gold_answer: str = ""
        _ANSWER_RE = re.compile(r"####\s*(.+?)\s*$")

        for raw_ln in gold_solution.splitlines():
            ln = raw_ln.strip()
            if not ln:
                continue  # skip empty
            ans_match = _ANSWER_RE.match(ln)
            if ans_match:
                gold_answer = ans_match.group(1).strip()
                break  # everything after #### is ignored
            lines.append(ln)

        if not gold_answer:
            raise ValueError("Could not find final answer marker '#### <answer>' in gold_solution.")

        # 2. Prefix each explanatory line with "Step i:"
        solution_steps = [f"Step {i + 1}: {txt}" for i, txt in enumerate(lines)]
        return {
            "question": question,
            "solution": solution_steps,
            "gold_answer": gold_answer,
        }

    # Function to parse a solution text into steps and final answer.
    def _extract_answer(self, text: str) -> Optional[str]:
        """Try multiple heuristics / regexes to pull out an answer string."""
        # Primary regex (robust to Answer:, Answer ‑, etc.)
        match = self.ANSWER_PATTERN.search(text)
        if match:
            return _sanitize(match.group(1))
        
        # Fallback 1: last non‑empty line if it looks simple / numeric
        lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
        if lines:
            candidate = lines[-1]
            if re.search(r"\d", candidate):  # contains digit
                return _sanitize(candidate)

        # Fallback 2: look for last line that starts with 'Answer'
        for line in reversed(text.splitlines()):
            if line.strip().lower().startswith("answer"):
                return _sanitize(line.split("Answer", 1)[-1])
        
        return None

    def parse_solution(self, solution_text: str):
        """Split each step to start with 'Step X:' and the answer to start with 'Answer:'."""
        steps = []
        # Split by lines to identify steps and answer
        for line in solution_text.splitlines():
            line = line.strip()
            if not line:
                continue
            # if line.startswith("Step"):
            #     steps.append(line)
            if self.STEP_PATTERN.match(line):
                cleaned = re.sub(r'^[\s>#*\-]+', '', line)
                steps.append(cleaned)
            answer = self._extract_answer(solution_text)
        return steps, answer
    
    # Function to estimate intermediate rewards for each step via rollouts.
    def compute_step_rewards(self, question, sys_prompt, steps, gold_answer):
        """
        For each prefix ending at a given step in 'steps', generate rollouts and compute the reward 
        (fraction of rollouts ending in the correct answer). Returns a list of reward values corresponding to each step.
        """
        rewards = []
        total_steps = len(steps)

        # Pre‑encode static prefix (sys_prompt + question) once for efficiency
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            prefix_tokens = self.tokenizer.encode("\n".join(steps[: i + 1]) + "\n", return_tensors="pt").to(self.device) # steps up to current step i (0-indexed)
            # Decide how to prompt the next part:
            if i < total_steps - 1:
                next_label = f"Step {i + 2}:"
            else:
                next_label = "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            # Build full prefix ids (avoid Python concat inefficiency by cat)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)
            # prefix_ids = self.tokenizer.encode(prefix_text, return_tensors='pt').to(self.model.device)
            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id
            )
            new_token_start = prefix_ids.shape[-1] 
            # Check each rollout's final answer against the gold answer
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"[{i+1}-th Step, {idx}-th Original Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            reward = correct_count / float(self.config.num_rollouts)
            rewards.append(reward)
        return rewards
    
    # Masked solution paths
    def model_masking(self, text: str, *, max_new_tokens: int = 64) -> str:
        prompt = "In the sentence below, mask any word or expression that seems crucial for solving the math step. This may include key numbers, variables, or action words (like operations), but you should decide what matters. Replace each important item with '[MASKED]'. Keep everything else unchanged. Return ONE line.\n\nSentence: \"{sent}\"\nRewritten:".format(sent=text)
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        out_ids   = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.0, top_p=0.0,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        return self.tokenizer.decode(out_ids[0][input_ids.shape[-1]:],
                                     skip_special_tokens=True).strip()

    def perturbed_step_rewards(self, question: str, sys_prompt: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        """Compute MC correctness rates *after masking* the current step.
        Each step `i` is replaced with a *perturbed* version where important
        tokens (numbers, fractions, single‑letter variables) are substituted by
        the literal string ``[MASKED]``. All preceding steps remain intact.
        """
        ptb_rewards: List[float] = []
        total_steps = len(steps)
        base_prompt = f"{sys_prompt}\n\nProblem: {question}\n"
        base_ids = self.tokenizer.encode(base_prompt, return_tensors="pt").to(self.device)

        for i in range(total_steps):
            # 1. Perturb *only* step i
            orig_step = steps[i] 
            step_match = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", orig_step, flags=re.I)
            prefix = step_match.group(0) if step_match else ""
            # ② 나머지 부분(body)만 마스킹
            body   = steps[i][len(prefix):]                       # 접두사 뒷부분
            if use_llm:
                masked_body = self.model_masking(body)
            else:
                masked_body = self._MASK_PATTERN.sub("[MASKED]", body)
            # ③ 접두사 + 마스킹된 body
            masked_step = prefix + masked_body    
            ptb_prefix_steps = steps[:i] + [masked_step]
            print("perturbed step:", ptb_prefix_steps)

            prefix_tokens = self.tokenizer.encode("\n".join(ptb_prefix_steps) + "\n", return_tensors="pt").to(self.device)
            next_label = f"Step {i + 2}:" if i < total_steps - 1 else "Answer:"
            cont_ids = self.tokenizer.encode(next_label, return_tensors="pt").to(self.device)
            prefix_ids = torch.cat([base_ids, prefix_tokens, cont_ids], dim=-1)

            rollout_outputs = self.model.generate(
                prefix_ids,
                max_new_tokens=self.config.max_new_tokens,
                do_sample=True,
                num_return_sequences=self.config.num_rollouts,
                temperature=0.8,
                top_p=0.8,
                pad_token_id=self.tokenizer.eos_token_id,
            )
            new_token_start = prefix_ids.shape[-1]
            correct_count = 0
            for idx, seq in enumerate(rollout_outputs):
                completion = self.tokenizer.decode(seq[new_token_start:], skip_special_tokens=True)
                pred_answer = self._extract_answer(completion)
                # print(f"[{i+1}-th Step, {idx}-th Rollout]", completion, "Pred Answer", pred_answer)
                if pred_answer is not None and _numeric_equiv(pred_answer, gold_answer):
                    correct_count += 1
            ptb_rewards.append(correct_count / float(self.config.num_rollouts))
        return ptb_rewards

    # Build datasets based on input datas
    def build_datasets(self, problems: List):
        dataset = []  # will hold the output list of dicts
        for problem in problems:
            question = problem["question"]
            # gold_answer = problem["gold_answer"]
            gold_answer = _sanitize(problem["gold_answer"])
            # Generate one or more solutions for this question
            sample_prompt = system_prompt("sample")
            rollout_prompt = system_prompt("rollout")
            solutions = self.generate_solutions(question, sys_prompt=sample_prompt, num_solutions=self.config.samples_per_question)
            
            for sol_text in solutions:
                steps, answer = self.parse_solution(sol_text)
                # print("Parsed solution:", steps, answer)
                if answer is None: # If no answer was found in the solution (edge case), skip this solution
                    continue
                # 2. Compute *original* & *perturbed* per‑step rewards
                # ----------------------------------------------------------
                ori_rewards = self.compute_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                ptb_rewards = self.perturbed_step_rewards(
                    question=question,
                    sys_prompt=rollout_prompt,
                    steps=steps,
                    gold_answer=gold_answer,
                )
                # Align lengths (robustness)
                if len(ptb_rewards) != len(ori_rewards):
                    ptb_rewards = ptb_rewards[: len(ori_rewards)]
                contributions = [o - p for o, p in zip(ori_rewards, ptb_rewards)]
                # print(steps, "\n", rewards)
                # Prepare the output entry
                entry = {
                    "question": question,
                    "completion": steps,          # list[str] (Step i: ...)
                    "ori_rewards": ori_rewards,    # list[float]
                    "ptb_rewards": ptb_rewards,    # list[float]
                    "contributions": contributions,  # ori − ptb
                    "answer": answer,
                    "gold_answer": gold_answer,
                }
                dataset.append(entry)
        return dataset
    
    # Build datasets based on input datas
    def build_datasets_gsm8k(self, split: Optional[str] = None):
        dataset = []  # will hold the output list of dicts
        rollout_prompt = system_prompt("rollout")
        ds_full = load_dataset("openai/gsm8k", "main")
        # problems  = ds_full[split]
        problems  = ds_full[split].select(range(2))

        for problem in tqdm(problems):
            parsed = self.gsm8k_solutions(problem["question"], problem["answer"])
            question = parsed["question"]
            steps = parsed["solution"]
            gold_answer = _sanitize(parsed["gold_answer"])
            # print("Parsed:", question, "\n", steps, "\nGold:", gold_answer)
            
            ori_rewards = self.compute_step_rewards(question, rollout_prompt, steps, gold_answer)
            ptb_rewards = self.perturbed_step_rewards(question, rollout_prompt, steps, gold_answer, self.config.use_llm)
            print("original rewards:", ori_rewards)
            print("perturbed rewards:", ptb_rewards)
            # Align lengths (robustness)
            if len(ptb_rewards) != len(ori_rewards):
                ptb_rewards = ptb_rewards[: len(ori_rewards)]
            contributions = [round(o - p, 4) for o, p in zip(ori_rewards, ptb_rewards)]
            print("contributions:", contributions)

            entry = {
                "question": question,
                "completion": steps,          # list[str] (Step i: ...)
                "ori_rewards": ori_rewards,    # list[float]
                "ptb_rewards": ptb_rewards,    # list[float]
                "contributions": contributions,  # ori − ptb
                "answer": gold_answer,
                "gold_answer": gold_answer,
            }
            dataset.append(entry)
        return dataset


# PRMDataset

In [64]:
import random, re
from pathlib import Path
from typing import List, Tuple
import torch
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

class StepwisePRMDataset(Dataset):
    """
    build_datasets() 가 반환한 entries(list[dict])를
    (input_ids, scalar_reward) 샘플들로 변환한다.

    한 entry = {question, completion[steps], rewards[float], …}
    →  (Problem + Step1,   r1)
        (Problem + Step1 \nStep2,   r2) …
    """
    def __init__(
        self,
        entries: List[dict],
        tokenizer: PreTrainedTokenizer,
        max_length: int = 512,
        use_contr: bool = True,
        *,
        cache_encodings: bool = True,
        preprocess: bool = True,
    ):
        self.tokenizer   = tokenizer
        self.max_length  = max_length
        self.use_contri    = use_contr
        self.cache       = {} if cache_encodings else None
        self.samples: List[Tuple[str, float]] = []

        for e in entries:
            q_txt   = e["question"]
            steps   = e["completion"]
            o_rewards = e["ori_rewards"]
            contri = e["contributions"]
            assert len(steps) == len(o_rewards)

            if self.use_contri:
                rewards = contri
            else:
                rewards = o_rewards

            prefix_lines = [f"Problem: {q_txt}"]
            for step_txt, r in zip(steps, rewards):
                prefix_lines.append(step_txt)
                full_txt = "\n".join(prefix_lines)
                if preprocess:
                    full_txt = self._clean(full_txt)
                self.samples.append((full_txt, float(r)))   # (text, reward)

    # --------------------------------------------------------------------- utils
    @staticmethod
    def _clean(txt: str) -> str:
        """whitespace normalize + 소문자화(선택적) 등 간단 전처리"""
        txt = re.sub(r"\s+", " ", txt).strip()
        return txt

    # --------------------------------------------------------------------- dunder
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        text, reward = self.samples[idx]

        if self.cache is not None and text in self.cache:
            ids = self.cache[text]
        else:
            ids = self.tokenizer(
                text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            ).input_ids.squeeze(0)
            if self.cache is not None:
                self.cache[text] = ids

        return ids, torch.tensor(reward, dtype=torch.float32)
    

# PRM Model

In [65]:
import torch
import torch.nn as nn
from typing import Optional

class ProcessRewardModel(nn.Module):
    """Enhanced Process Reward Model with dropout and layer normalization"""
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int = 1,
        dropout: float = 0.1,
        num_layers: Optional[int] = None
    ):
        """ 
        Args:
            input_size (int): Size of input features
            hidden_size (int): Size of hidden layers
            output_size (int): Size of output
            dropout (float): Dropout rate
            num_layers (Optional[int]): Number of hidden layers
        """
        super(ProcessRewardModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_rate = dropout
        self.num_layers = num_layers or 2
        
        # Input layer
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        
        # Hidden layers
        hidden_layers = []
        for i in range(self.num_layers - 1):
            in_features = hidden_size if i == 0 else hidden_size // (2 ** i)
            out_features = hidden_size // (2 ** (i + 1))
            hidden_layers.extend([
                nn.Linear(in_features, out_features),
                nn.LayerNorm(out_features),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
        self.hidden_layers = nn.Sequential(*hidden_layers)
        
        # Output layer
        last_hidden_size = hidden_size // (2 ** (self.num_layers - 1))
        self.fc_out = nn.Linear(last_hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Input layer
        x = self.dropout(torch.relu(self.ln1(self.fc1(x))))
        # Hidden layers
        x = self.hidden_layers(x)
        # Output layer
        x = torch.sigmoid(self.fc_out(x))
        return x
    
    def get_complexity(self) -> int:
        return sum(p.numel() for p in self.parameters())
    


# PRMTrainer

In [66]:
from pathlib import Path
from typing import List, Dict
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('omega_prm.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class PRMTrainer:
    """
    (1) entries(list[dict]) → StepwisePRMDataset
    (2) LLM encoder + PRM head fine-tuning
    """
    def __init__(self, cfg: PRMConfig, model, tokenizer):
        self.cfg = cfg
        torch.manual_seed(cfg.seed)

        # ----------------------------- Backbone model LLM (frozen or fine-tuned)
        self.tokenizer = tokenizer
        self.model  = model
        self.model.eval()       # LLM은 feature extractor로 freeze
        for p in self.model.parameters():
            p.requires_grad_(False)

        feat_dim = self.model.config.hidden_size
        self.prm = ProcessRewardModel(feat_dim, hidden_size=cfg.hidden_size, output_size=1)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.prm.to(self.device)

        self.opt  = optim.AdamW(self.prm.parameters(), lr=cfg.learning_rate)
        # self.crit = nn.MSELoss()
        self.crit = nn.BCELoss()

        self.ckpt_dir = Path(cfg.checkpoint_dir)
        self.ckpt_dir.mkdir(exist_ok=True, parents=True)

        self.wandb_run = None
        if cfg.use_wandb:                                  # <-- config에 플래그
            self.wandb_run = wandb.init(
                project=cfg.wandb_project,                 # e.g. "omega-prm"
                name=cfg.run_name,                         # e.g. "qwen7b-prm"
                config=vars(cfg),                          # 모든 하이퍼파라미터 로깅
                # reinit=True,
            )

    # ----------------------------------------------------------------- features
    @torch.no_grad()
    def _encode(self, ids: torch.Tensor) -> torch.Tensor:
        """
        input_ids [B,T] → [B, feat_dim] using 마지막 hidden state의 CLS-like 첫 토큰
        """
        out = self.model(
            input_ids=ids,
            return_dict=True,
            output_hidden_states=True,
        )
        return out.hidden_states[-1][:, 0, :]     # CLS embedding

    # ----------------------------------------------------------------- loop util
    def _run_epoch(self, loader: DataLoader, train: bool, epoch_idx: int) -> float:
        self.prm.train(train)
        total = 0.0
        for step, (ids, reward) in enumerate(loader):
            ids, reward = ids.to(self.device), reward.to(self.device)

            with torch.set_grad_enabled(train):
                feats  = self._encode(ids)
                pred   = self.prm(feats).squeeze(-1)
                loss   = self.crit(pred, reward)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.prm.parameters(), 1.0)
                    self.opt.step()

            total += loss.item()

            # -------- minibatch logging --------
            if self.wandb_run and train:
                wandb.log({
                    "batch_loss": loss.item(),
                    "epoch": epoch_idx + step / len(loader),
                    "lr": self.opt.param_groups[0]["lr"],
                    "grad_norm": sum(p.grad.data.norm(2).item()
                                     for p in self.prm.parameters()
                                     if p.grad is not None),
                })

        return total / len(loader)

    # ----------------------------------------------------------------- public
    def fit(self, train_entries: List[dict], val_entries: List[dict]) -> Dict[str, List[float]]:
        train_ds = StepwisePRMDataset(train_entries, self.tokenizer, self.cfg.max_new_tokens, self.cfg.use_contri)
        val_ds   = StepwisePRMDataset(val_entries,   self.tokenizer, self.cfg.max_new_tokens, self.cfg.use_contri)

        train_loader = DataLoader(
            train_ds, batch_size=self.cfg.batch_size, shuffle=True,
            num_workers=self.cfg.num_workers, pin_memory=True,
        )
        val_loader = DataLoader(
            val_ds, batch_size=self.cfg.batch_size, shuffle=False,
            num_workers=self.cfg.num_workers, pin_memory=True,
        )

        history = {"train": [], "val": []}
        best_val = float("inf")

        for ep in range(self.cfg.epochs):
            tr_loss = self._run_epoch(train_loader, train=True,  epoch_idx=ep)
            vl_loss = self._run_epoch(val_loader,   train=False, epoch_idx=ep)

            # -------- epoch logging --------
            if self.wandb_run:
                wandb.log({"train_loss": tr_loss,"val_loss": vl_loss,"epoch": ep})

            history["train"].append(tr_loss)
            history["val"].append(vl_loss)
            print(f"[Epoch {ep+1}/{self.cfg.epochs}] train={tr_loss:.4f}  val={vl_loss:.4f}")

            # 체크포인트 저장
            if vl_loss < best_val:
                best_val = vl_loss
                self._save_checkpoint("best_prm.pt", epoch=ep, val_loss=vl_loss)
        
        self._save_checkpoint("last_prm.pt", epoch=self.cfg.epochs - 1, val_loss=vl_loss)
        return history
    
    # ------------------------------------------------------------------
    # Checkpoint helpers
    def _save_checkpoint(self, filename: str, *, epoch: int, val_loss: float) -> None:
        path = self.ckpt_dir / filename
        save_dict = {
            "epoch": epoch,
            "val_loss": val_loss,
            "prm_state": self.prm.state_dict(),
            "optimizer_state": self.opt.state_dict(),
            "config": vars(self.cfg),              # hyper‑params for reproducibility
            "model_name_or_path": getattr(self.model, "name_or_path", None),
            "tokenizer_config": self.tokenizer.__dict__.get("init_kwargs", {}),
        }
        torch.save(save_dict, path)
        print(f"[CKPT] Saved ⇒ {path}")
    
    # ------------------------------------------------------------------
    # @classmethod
    # def load_for_inference(
    #     cls,
    #     checkpoint_path: str | Path,
    #     *,
    #     device: Optional[torch.device] = None,
    #     base_model: Optional[str | "AutoModelForCausalLM"] = None,
    #     tokenizer: Optional["AutoTokenizer"] = None,
    # ) -> "PRMTrainer":
    #     """Instantiate **frozen** backbone + PRM head from a checkpoint for
    #     *inference* (no optimiser).
    #     """
    #     ckpt = torch.load(checkpoint_path, map_location="cpu")
    #     cfg_dict = ckpt.get("config", {})

    #     # rebuild cfg object (simple Namespace‑style fallback)
    #     from types import SimpleNamespace

    #     cfg = SimpleNamespace(**cfg_dict)

    #     # load / reuse backbone + tokenizer
    #     if isinstance(base_model, str) or base_model is None:
    #         base_model_name = base_model or ckpt.get("model_name_or_path")
    #         backbone = AutoModelForCausalLM.from_pretrained(base_model_name)
    #     else:
    #         backbone = base_model

    #     if tokenizer is None:
    #         tokenizer = AutoTokenizer.from_pretrained(backbone.name_or_path)

    #     trainer = cls(cfg, backbone, tokenizer, device=device)
    #     trainer.prm.load_state_dict(ckpt["prm_state"])
    #     trainer.prm.eval()
    #     print(f"[CKPT] Loaded PRM weights from {checkpoint_path}")
    #     return trainer

    # ------------------------------------------------------------------
    # Simple inference helper
    @torch.no_grad()
    def predict_reward(self, text: str) -> float:
        ids = self.tokenizer.encode(text, return_tensors="pt").to(self.device)
        feat = self._encode(ids)
        return float(torch.sigmoid(self.prm(feat)).item())


# Main

In [8]:
model_name = "Qwen/Qwen2.5-Math-7B"      # "Qwen/Qwen2.5-Math-7B", "Qwen/Qwen2.5-Math-7B-Instruct" , "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
cfg = PRMConfig()
print("Finish load model and config!")

# problems = [
#     {"question": "Each notebook costs $5. Sarah buys 4 notebooks and pays with a $50 bill. How much change does she get?", "gold_answer": "30"},
#     {"question": "Solve for y: 2y - 7 = 3(y - 4).", "gold_answer": "5"},
#     # Add more problems as needed...
# ]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.19it/s]


Finish load model and config!


In [68]:
cfg = PRMConfig()
mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
reward_ds_small = mcr.build_datasets_gsm8k(split="train")

# Print or inspect the dataset
for entry in reward_ds_small:
    print(entry)
    print("-" * 80)


  0%|          | 0/2 [00:00<?, ?it/s]

perturbed step: ['Step 1: Natalia sold 48/2 = [MASKED] clips in May.']
perturbed step: ['Step 1: Natalia sold 48/2 = <<48/2=24>>24 clips in May.', 'Step 2: Natalia sold [MASKED] clips altogether in April and May.']


 50%|█████     | 1/2 [00:35<00:35, 35.32s/it]

original rewards: [1.0, 1.0]
perturbed rewards: [0.0, 0.4]
contributions: [1.0, 0.6]
perturbed step: ['Step 1: "Weng earns [MASKED] = $[MASKED] per minute."']
perturbed step: ['Step 1: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.', 'Step 2: "Working 50 minutes, she earned [MASKED] x [MASKED] = [MASKED]."']


100%|██████████| 2/2 [01:12<00:00, 36.06s/it]

original rewards: [0.2, 0.2]
perturbed rewards: [0.0, 0.0]
contributions: [0.2, 0.2]
{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'completion': ['Step 1: Natalia sold 48/2 = <<48/2=24>>24 clips in May.', 'Step 2: Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.'], 'ori_rewards': [1.0, 1.0], 'ptb_rewards': [0.0, 0.4], 'contributions': [1.0, 0.6], 'answer': '72', 'gold_answer': '72'}
--------------------------------------------------------------------------------
{'question': 'Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?', 'completion': ['Step 1: Weng earns 12/60 = $<<12/60=0.2>>0.2 per minute.', 'Step 2: Working 50 minutes, she earned 0.2 x 50 = $<<0.2*50=10>>10.'], 'ori_rewards': [0.2, 0.2], 'ptb_rewards': [0.0, 0.0], 'contributions': [0.2, 0.2], 'answer': '10', 'gold_ans




In [None]:

# mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
# entries_raw = mcr.build_datasets(problems)

# random.shuffle(entries_raw)
# split_idx       = int(0.9 * len(entries_raw)) if len(entries_raw) > 1 else 1
split_idx = 1
train_entries   = reward_ds_small[:split_idx]
val_entries     = reward_ds_small[split_idx:] or reward_ds_small[:1]   # 최소 1개 확보

trainer = PRMTrainer(cfg, model=model, tokenizer=tokenizer)
history = trainer.fit(train_entries, val_entries)
print("Training complete. Loss history:", history)

In [None]:
import json, random
from pathlib import Path

def main():
    model_name =  "Qwen/Qwen2.5-Math-7B" # "Qwen/Qwen2.5-Math-7B-Instruct"  #"Qwen/Qwen2.5-Math-7B"
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    cfg = PRMConfig()
    print("Finish load model and config!")

    problems = [
        {"question": "A train travels 120 km in 2 hours and then 180 km in 3 hours. What is the average speed of the train?", "gold_answer": "60"},
        {"question": "Solve for y: 2y - 7 = 3(y - 4).", "gold_answer": "5"},
        # Add more problems as needed...
    ]
    
    mcr = MCReward(config=cfg , model=model, tokenizer=tokenizer)
    entries_raw = mcr.build_datasets(problems)

    # Print or inspect the dataset
    for entry in entries_raw:
        print(entry)
        print("-" * 80)

    random.shuffle(entries_raw)
    # split_idx       = int(0.9 * len(entries_raw)) if len(entries_raw) > 1 else 1
    split_idx = 1
    train_entries   = entries_raw[:split_idx]
    val_entries     = entries_raw[split_idx:] or entries_raw[:1]   # 최소 1개 확보

    trainer = PRMTrainer(cfg, model=model, tokenizer=tokenizer)
    history = trainer.fit(train_entries, val_entries)
    print("Training complete. Loss history:", history)

if __name__ == "__main__":
    main()