# MCTS Implements

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
import concurrent.futures
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Config

In [7]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('omega_prm.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

@dataclass
class OmegaPRMConfig:
    """Configuration class for OmegaPRM hyperparameters and settings"""
    # MCTS config
    model_name: str = "Qwen/Qwen2.5-Math-7B"   # "meta-llama/Meta-Llama-3-8B" "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" "Qwen/Qwen2.5-Math-7B-Instruct"
    search_limit: int = 5 # 100
    max_rollout_tokens: int = 384
    rollout_width: int = 5
    alpha: float = 0.5
    beta: float = 0.9
    L: int = 300 # 500
    cpuct: float = 0.125
    use_mc_reward: bool = True
    reward_threshold: float = 0.7
    # PRM config
    batch_size: int = 32
    learning_rate: float = 0.001
    hidden_size: int = 256
    # max_length: int = 256
    num_workers: int = 4
    use_wandb: bool = False
    checkpoint_dir: str = "checkpoints"
    

## PRM model

In [8]:
class ProcessRewardModel(nn.Module):
    """Enhanced Process Reward Model with dropout and layer normalization"""
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        dropout: float = 0.1,
        num_layers: Optional[int] = None
    ):
        """ 
        Args:
            input_size (int): Size of input features
            hidden_size (int): Size of hidden layers
            output_size (int): Size of output
            dropout (float): Dropout rate
            num_layers (Optional[int]): Number of hidden layers
        """
        super(ProcessRewardModel, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_rate = dropout
        self.num_layers = num_layers or 2
        
        # Input layer
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        
        # Hidden layers
        hidden_layers = []
        for i in range(self.num_layers - 1):
            in_features = hidden_size if i == 0 else hidden_size // (2 ** i)
            out_features = hidden_size // (2 ** (i + 1))
            hidden_layers.extend([
                nn.Linear(in_features, out_features),
                nn.LayerNorm(out_features),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
        self.hidden_layers = nn.Sequential(*hidden_layers)
        
        # Output layer
        last_hidden_size = hidden_size // (2 ** (self.num_layers - 1))
        self.fc_out = nn.Linear(last_hidden_size, output_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Output predictions
        """
        # Input layer
        x = self.dropout(torch.relu(self.ln1(self.fc1(x))))
        
        # Hidden layers
        x = self.hidden_layers(x)
        
        # Output layer
        x = torch.sigmoid(self.fc_out(x))
        return x
    
    def get_complexity(self) -> int:
        """
        Returns:
            int: Total number of parameters
        """
        return sum(p.numel() for p in self.parameters())


## PRM Datasets

In [9]:
class PRMDataset(Dataset):
    """
    • 최소 설정(토큰화만)으로도 바로 학습 가능  
    • 옵션으로 whitespace-정규화 / 소문자화 / 간단한 텍스트 증강 / 인코딩 캐시 지원
    """
    def __init__(
        self,
        solutions: List[str],
        rewards  : List[float],
        tokenizer: PreTrainedTokenizer,
        max_length: int,
        *,
        preprocess: bool = True,
        augment: bool = False,
        augment_prob: float = 0.1,
        cache_encodings: bool = True,
    ):
        assert len(solutions) == len(rewards)
        self.sol, self.r = solutions, rewards                    # 원본 보존
        self.tok  = tokenizer
        self.max  = max_length
        self.preprocess = preprocess
        self.augment    = augment
        self.augment_prob = augment_prob
        self.cache = {} if cache_encodings else None

        # ─ preprocessing ────────────────────────────────────────────────
        self.proc = [self._clean(s) if preprocess else s for s in solutions]

        # ─ augmentation vocabulary (길이별 단어 집합) ────────────────────
        self.vocab_by_len = self._build_vocab() if augment else {}

    # ------------------------------------------------------------------ core
    def __len__(self): return len(self.sol)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        txt = self.proc[idx]
        if self.augment and random.random() < self.augment_prob:
            txt = self._augment(txt)

        if self.cache is not None and txt in self.cache:
            ids = self.cache[txt]
        else:
            ids = self.tok(
                txt,
                max_length=self.max,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).input_ids.squeeze(0)
            if self.cache is not None:
                self.cache[txt] = ids

        return ids, torch.tensor(self.r[idx], dtype=torch.float32)

    # ---------------------------------------------------------------- utils
    @staticmethod
    def _clean(t: str) -> str:
        return re.sub(r"\s+", " ", t.strip().lower())

    def _build_vocab(self) -> Dict[int, List[str]]:
        cnt = Counter(w for s in self.proc for w in s.split())
        v = {}
        for w in cnt:
            v.setdefault(len(w), []).append(w)
        return v

    # － augmentation (세 가지만 간단히) －
    def _augment(self, txt: str) -> str:
        return random.choice([self._swap, self._delete_char, self._insert_char])(txt)

    def _swap(self, t: str) -> str:
        w = t.split(); n = len(w)
        if n < 2: return t
        i = random.randint(0, n-2)
        w[i], w[i+1] = w[i+1], w[i]
        return " ".join(w)

    def _delete_char(self, t: str) -> str:
        if len(t) == 0: return t
        i = random.randint(0, len(t)-1)
        return t[:i] + t[i+1:]

    def _insert_char(self, t: str) -> str:
        i = random.randint(0, len(t))
        c = random.choice(string.ascii_lowercase)
        return t[:i] + c + t[i:]

    # optional helpers --------------------------------------------------
    def get_statistics(self) -> Dict[str, Any]:
        rl = np.array(self.r)
        return {
            "n": len(self),
            "avg_r": rl.mean(),
            "std_r": rl.std(),
            "min_r": rl.min(),
            "max_r": rl.max(),
            "avg_len": np.mean([len(s) for s in self.proc]),
        }


## MCTS

In [10]:
@dataclass
class Node:
    """A node stores the partial solution string and MCTS statistics."""
    state: str                     # concatenated steps so far (can be empty)
    parent: Optional["Node"]
    prior: float                   # optional prior from policy – not used here
    children: Dict[str, "Node"]    # action (next step) → Node

    n_visits: int = 0
    q_value: float = 0.0           # mean rollout success ratio
    correct_rollouts: int = 0       # cumulative successes
    total_rollouts: int = 0         # cumulative rollouts (denominator of MC)

    def ucb(self, cpuct: float, alpha: float, total_parent_visits: int) -> float:
        if self.n_visits == 0:
            return float("inf")  # force unseen nodes to be explored once
        exploration = cpuct * math.sqrt(math.log(total_parent_visits + 1) / (self.n_visits))
        return self.q_value + alpha * exploration
    

In [13]:
class MCTS:
    """MCTS driver for LLM‑based step‑wise mathematical reasoning."""
    STEP_PATTERN = re.compile(r"Step\s+\d+:")
    ANSWER_PATTERN = re.compile(r"Answer\s*:\s*(.+?)\s*(?:$|\n)")

    def __init__(self, config: "OmegaPRMConfig", golden_answers: Dict[str, str]):
        self.config = config
        self.golden_answers = golden_answers
        # Device & model ----------------------------------------------------
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForCausalLM.from_pretrained(config.model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        # Generation configs ------------------------------------------------
        # Expansion: one step → we only need ~64 tokens max, sample top-k 8
        self.gen_cfg_expand = GenerationConfig(
            max_new_tokens=128,
            # top_k=5,
            do_sample=True,
            temperature=0.8,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        # Rollout: can be longer; keep top-k large to encourage diversity
        self.gen_cfg_rollout = GenerationConfig( 
            max_new_tokens=config.max_rollout_tokens, 
            # top_k=10,
            do_sample=True,
            temperature=0.8,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
        )
        # Root placeholder (empty state).
        self.root = Node(state="", parent=None, prior=0.0, children={})
    
    # ---------------------------------------------------------------------
    # 1‑A. Low‑level helpers
    # ---------------------------------------------------------------------
    def _prompt_expand(self, question: str, partial_solution: str) -> str:
        if partial_solution:
            next_idx = self._next_step_idx(partial_solution)
            system = """You are a math‑problem expert. Generate **exactly one** next step following the numbered format \"Step k: ...\". Do NOT write more than one step or output the final answer directly. Never skip the step number formatand you MUST follow the given format.
            
            ## Example 1 ##
            Problem: Determine the next number in the sequence 2, 4, 8, 16.
            Step 1: Notice each term is obtained by multiplying the previous term by 2.\n
            Expand => "Step 2: Multiply 16 by 2 to get 32."

            ## Example 2 ##
            Problem: Solve for x: 2(x-1)+3=11.
            Step 1: Subtract 3 from both sides to get 2x = 8.
            Step 2: Divide both sides by 2 to find x-1=4.\n
            Expand => "Step 3: Add 1 to both sides to get x = 5.
            """
            return (
                f"{system}\nProblem: {question}\n{partial_solution}\n"
                f"Step {next_idx}:"
            )
        # root
        return (
            """You are a math‑problem expert. Generate **exactly one** first step in the format \"Step 1: ...\". You must follow the format. Do NOT write more than one step or output the final answer directly. Never skip the step number format and you MUST follow the given format.

            ## Example 1 ##
            Problem: Determine the next number in the sequence 2, 4, 8, 16.\n
            Step 1: Notice each term is obtained by multiplying the previous term by 2.

            ## Example 2 ##
            Problem: Solve for x: 2(x-1)+3=11.\n
            Step 1: Subtract 3 from both sides to get 2x = 8.
            """
            f"Problem: {question}\nStep 1:"
        )

    def _prompt_rollout(self, question: str, partial_solution: str) -> str:
        intro = """You are a math‑problem expert. Continue the reasoning from the current step‑by‑step solution. You may write multiple additional steps \"Step k+1: ..., Step k+2:... \" with this format as needed to solve the problem. When the solution is complete, write a **single final line** beginning with \"Answer: \" followed by only the final answer. Do NOT add explanations, extra steps, or any trailing text after you reach the \"Answer: \". Strictly follow the given generation format during step-by-step reasoning.
        
        ## Example 1 ##
        Current solution:
        Problem: Find the sum of the first 8 positive even integers.
        Step 1: The first 8 even integers are 2, 4, 6, 8, 10, 12, 14, 16.\n
        Solution Continuation:
        Step 2: Use the formula for an arithmetic series: S = n·(first + last)/2.
        Step 3: Substitute n=8, first=2, last=16 to get S = 8·(2+16)/2 = 8·9 = 72.
        Answer: 72

        ## Example 2 ##
        Current solution:
        Problem: Solve for x:3x^2-12=0, x>0.
        Step 1: Step 1: Add 12 to both sides to get 3x^2 = 12.
        Step 2: Divide both sides by 3 to get x^2 = 4.\n
        Solution Continuation:
        Step 3: Take the square root of both sides to get x = ±2.
        Step 4: Take a positive value x>0, x=2.
        Answer: 2
        """
        if partial_solution:
            next_idx = self._next_step_idx(partial_solution)
            return (
                f"{intro}\nProblem: {question}\n{partial_solution}\nStep {next_idx}:"
            )
        else:
            return (
                f"{intro}\nProblem: {question}\nStep 1:"
            )
    
    @staticmethod
    def _next_step_idx(solution: str) -> int:
        """Return index of the next step number."""
        matches = list(MCTS.STEP_PATTERN.finditer(solution))
        return len(matches) + 1

    def _extract_answer(self, text: str) -> Optional[str]:
        m = self.ANSWER_PATTERN.search(text)
        return m.group(1).strip() if m else None
    
    # ---------------------------------------------------------------------
    # 1‑B. Tree policy – selection & expansion
    # ---------------------------------------------------------------------
    def _select(self, node: Node) -> Node:
        """Traverse the tree until we hit a leaf (node without children)."""
        while node.children:
            # Choose child with maximal UCB score
            total = max(1, node.n_visits)
            best_action, node = max(
                node.children.items(),
                key=lambda kv: kv[1].ucb(self.config.cpuct, self.config.alpha, total),
            )
        return node
    
    def _split_steps(self, text: str) -> List[str]:
        """Turn 'Step i:' stream into a list of distinct step strings."""
        parts = self.STEP_PATTERN.split(text)
        headers = self.STEP_PATTERN.findall(text)
        steps = [h + p.strip() for h, p in zip(headers, parts[1:])]
        # Ensure each step ends with a newline for readability
        return [s if s.endswith("\n") else s + "\n" for s in steps if s]

    def _expand(self, node: Node, question: str):
        """Generate *top_k* candidate next steps from the language model."""
        if node is None:
            return
        prompt = self._prompt_expand(question, node.state)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(input_ids=input_ids, **self.gen_cfg_expand.to_dict())
        new_text = self.tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)

        # Split into candidate steps (may generate multiple Step k: blocks)
        steps = self._split_steps(new_text)
        if not steps:
            next_idx = self._next_step_idx(node.state)
            steps = [f"Step {next_idx}: {new_text.strip()}"]

        first_new_child = None
        for step in steps[: self.config.search_limit]:
            if step not in node.children:                 # 새로 본 step
                child_state = f"{node.state}{step}\n"
                child = Node(state=child_state,
                            parent=node,
                            prior=0.0,
                            children={})
                node.children[step] = child
                if first_new_child is None:               # 첫 신규 child 기억
                    first_new_child = child
        
        # print(f"[Expand] Node(partial_solution=\"{node.state}\") -> Generated steps: {steps}")
        # print(f"Children count: {len(node.children)}")
        return first_new_child

    # ---------------------------------------------------------------------
    # 1‑C. Simulation (rollout)
    # ---------------------------------------------------------------------
    def _rollout_from(self, node: Node, question: str) -> float:
        """Perform *rollout_width* simulations and compute average Q‑score."""
        lengths: List[int] = []
        successes = 0
        for _ in range(self.config.rollout_width):
            prompt = self._prompt_rollout(question, node.state)
            ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
            with torch.no_grad():
                out = self.model.generate(input_ids=ids, **self.gen_cfg_rollout.to_dict())
            gen_ids = out[0][ids.shape[-1]:]
            lengths.append(len(gen_ids))
            generated = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
            full_solution = node.state + generated
            ans = self._extract_answer(full_solution)
            print(f"[Rollout] Solution:\n{full_solution}\n=> Extracted answer: {ans}")
            gold = self.golden_answers.get(question)
            print("Extracted rollout_answer:", ans)
            if ans is not None and gold is not None and self._compare_answers(ans, gold):
                successes += 1

        # update cumulative MC statistics ----------------------------------
        node.correct_rollouts += successes
        node.total_rollouts += self.config.rollout_width
        mc = node.correct_rollouts / max(1, node.total_rollouts)

        # compute Q‑scores for each rollout length --------------------------
        q_scores = [
            (self.config.alpha ** (1 - mc)) * (self.config.beta ** (l / self.config.L))
            for l in lengths
        ]
        value = sum(q_scores) / len(q_scores)
        print(f"[Rollout] successes: {successes}/{self.config.rollout_width}, mc={mc:.2f}, Q={value:.3f}")
        return value

    @staticmethod
    def _compare_answers(pred: str, gold: str) -> bool: # Loose numeric match – can be improved to exact or symbolic comparison
        try:
            return float(pred) == float(gold)
        except ValueError:
            return pred.strip() == gold.strip()

    # ---------------------------------------------------------------------
    # 1‑D. Back‑propagation
    # ---------------------------------------------------------------------
    def _backprop(self, node: Node, outcome: float):
        while node is not None:
            node.n_visits += 1
            node.q_value += (outcome - node.q_value) / node.n_visits    # Incremental mean update
            # start from leaf
            # temp = node  
            # print(f"[Backprop] Node(partial=\"{temp.state}\") visits={temp.n_visits}, Q={temp.q_value:.3f}")
            node = node.parent

    # ---------------------------------------------------------------------
    # 1‑E. Public interface – run one search
    # ---------------------------------------------------------------------
    def solve(self, question: str, iterations: int = 2) -> Tuple[str, Optional[str], float]:
        """Run MCTS for *iterations* simulations and return best solution."""
        self.root = Node(state="", parent=None, prior=0.0, children={})
        for _ in trange(iterations, desc="MCTS"):
            # 1. Selection
            leaf = self._select(self.root)
            # 2. Expansion
            # self._expand(leaf, question)
            new_child = self._expand(leaf, question)
            # 3. Simulation
            # value = self._rollout_from(leaf, question)
            sim_node = new_child if new_child is not None else leaf
            value = self._rollout_from(sim_node, question)
            # 4. Back‑propagation
            self._backprop(leaf, value)

        # Choose the most visited child of root as final solution path
        if not self.root.children:
            return "", None, 0.0
        best_step, best_child = max(self.root.children.items(), key=lambda kv: kv[1].n_visits)
        
        # Optionally run one deterministic rollout from best_child to get a full solution
        prompt = self._prompt_rollout(question, best_child.state)
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        with torch.no_grad():
            out = self.model.generate(input_ids=input_ids, **self.gen_cfg_rollout.to_dict())
        gen = self.tokenizer.decode(out[0][input_ids.shape[-1]:], skip_special_tokens=True)
        full_solution = best_child.state + gen
        answer = self._extract_answer(full_solution)
        return full_solution, answer, best_child    # best_child.q_value

    # ---------------------------------------------------------------------
    # 2. Interface with main function
    # ---------------------------------------------------------------------
    def mcts_for_prm(self, q: str, samples: int = 1) -> Dict[str, List[Dict]]:
        """
        Runs MCTS to collect high-quality solution paths for the given question.
        Returns a dict with the question as key and a list of solution entries (with steps and rewards) as value.
        """
        results = []  # will collect solution entries for this question
        for _ in range(samples):
            s, a, node = self.solve(q)

            # Filter out solutions with incorrect final answers (if golden answer is provided)
            gold_answer = self.golden_answers.get(q, None)
            gold_answer = self._extract_answer(gold_answer)
            if gold_answer is not None:
                if a is None or not (a==gold_answer):
                    continue  # skip this path since final answer is wrong

            # Determine final solution reward based on configuration
            if hasattr(self.config, "use_mc_reward") and not self.config.use_mc_reward:
                # Use value estimate (q_val) as reward if available
                score_value = getattr(node, "q_value", None)
                if score_value is None:
                    score_value = node.correct_rollouts / max(1, node.total_rollouts)
            else:
                # Default: use Monte Carlo success rate
                score_value = node.correct_rollouts / max(1, node.total_rollouts)

            # If no gold answer, apply a quality threshold on the reward (e.g., require high success rate)
            if gold_answer is None and hasattr(self.config, "reward_threshold"):
                if score_value < self.config.reward_threshold:
                    continue  # skip low-quality path

            # Reconstruct the sequence of steps from the root to this final node
            path_nodes = []
            curr = node
            while curr.parent is not None:           # traverse back to root (excluding the root itself)
                path_nodes.append(curr)
                curr = curr.parent
            path_nodes.reverse()                    # now from first step to last step node

            # Split the solution text `s` into individual steps.
            # (Assumes each reasoning step is separated by a newline in `s`.)
            if "\n" in s:
                steps_text = [line.strip() for line in s.splitlines() if line.strip()]
                # If the first line of s was the question/prompt, remove it
                if len(steps_text) > len(path_nodes):
                    steps_text = steps_text[-len(path_nodes):]
            else:
                steps_text = [s]  # if no explicit step separation, treat the whole solution as one step

            # Collect reward for each step node (MC success or q_val as configured)
            step_rewards = []
            for nd in path_nodes:
                if hasattr(self.config, "use_mc_reward") and not self.config.use_mc_reward:
                    step_val = getattr(nd, "q_value", None)
                    if step_val is None:
                        step_val = nd.correct_rollouts / max(1, nd.total_rollouts)
                else:
                    step_val = nd.correct_rollouts / max(1, nd.total_rollouts)
                step_rewards.append(step_val)

            # Ensure the number of steps matches the number of rewards (trim if necessary)
            if len(steps_text) != len(step_rewards):
                min_len = min(len(steps_text), len(step_rewards))
                steps_text = steps_text[:min_len]
                step_rewards = step_rewards[:min_len]

            # Save this solution path entry
            results.append({
                "question": q,
                "completion": steps_text,
                "rewards": step_rewards,
                "answer": gold_answer if gold_answer is not None else a
            })

            print("MCTS for PRM data format", results)

        # Merge duplicate solution paths (average their rewards if seen multiple times)
        merged = {}
        for entry in results:
            # Use tuple of steps as a key for identity of solution path
            key = tuple(entry["completion"])
            if key in merged:
                # Already have this path: average the step-wise rewards
                old = merged[key]
                avg_rewards = [
                    (r_old + r_new) / 2.0 
                    for r_old, r_new in zip(old["rewards"], entry["rewards"])
                ]
                old["rewards"] = avg_rewards
                merged[key] = old
            else:
                merged[key] = entry

        # Return a dictionary with question as key and list of solution entries as value
        return {q: list(merged.values())}

    # -- metrics / export ----------------------------------------------
    def _collect_nodes(self): 
        stack = [self.root]
        while stack:
            n = stack.pop(); yield n; stack.extend(n.children.values())

    def get_metrics(self) -> Dict[str, float]:
        leaves = sum(len(n.children) == 0 for n in self._collect_nodes())
        return {"total_nodes": len(list(self._collect_nodes())), "leaf_nodes": leaves}

    def export_results(self, path: str):
        with open(path, "w") as f: 
            json.dump(self.get_metrics(), f, indent=2)

    def print_tree(self, node: Node, depth: int = 0):
        prefix = "    " * depth
        state_preview = node.state.replace("\n", " / ")  # 줄바꿈을 슬래시로 치환하여 한 줄로 표시
        if len(state_preview) > 60:  # 너무 길면 자르기
            state_preview = state_preview[:57] + "..."
        print(f"{prefix}- Node(depth={depth}, visits={node.n_visits}, Q={node.q_value:.2f}): {state_preview}")
        for child_step, child_node in node.children.items():
            self.print_tree(child_node, depth + 1)


## PRMTrain

In [14]:
class PRMTrainer:
    """
    ①  MCTS 로부터 (solution, reward) 쌍을 수집하고
    ②  Process-Reward Model(PRM)을 학습한다.
    """
    def __init__(self, mcts: MCTS, config: OmegaPRMConfig):
        self.mcts   = mcts
        self.cfg    = config
        self.device = mcts.device
        self.tok    = mcts.tokenizer

        # PRM 자체 초기화
        feat_dim = mcts.model.config.hidden_size        # LLM hidden size
        self.prm = ProcessRewardModel(feat_dim, self.cfg.hidden_size, output_size=1).to(self.device)
        self.opt = optim.AdamW(self.prm.parameters(), lr=self.cfg.learning_rate, weight_decay=0.01)
        # self.crit = nn.BCELoss()
        self.crit = nn.MSELoss()

        # Initialize wandb if enabled
        if self.cfg.use_wandb:
            wandb.init(project="omega-prm", name="prm-train", config=vars(self.cfg))
        # Create checkpoint directory
        Path(self.cfg.checkpoint_dir).mkdir(exist_ok=True)

    # ------------------------------------------------------------------ utils
    @torch.no_grad()
    def _encode_features(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        LLM 의 hidden state → [CLS-pooling] 식 임베딩.
        input_ids : [B, T]
        return    : [B, feat_dim]
        """
        # emb_layer = self.model.get_input_embeddings()          # shared embedding matrix
        # emb = (emb_layer(token_ids)).mean(dim=1) 
        out = self.mcts.model(input_ids=input_ids,
                              output_hidden_states=True,
                              return_dict=True)
        # 마지막 hidden-state의 0-번 토큰(CLS) 임베딩 사용
        return out.hidden_states[-1][:, 0, :]

    # ------------------------------------------------------ data preparation
    def build_dataset(
        self,
        questions: List[str],
        samples_per_q: int = 1,
        add_question: bool = True          # 프롬프트에 문제문 포함 여부
    ) -> Tuple[Dataset, List[Dict]]:
        """
        Step-wise 데이터셋을 만든다.
        반환: (torch Dataset, [entry…])  entry 는 질문 하나에 대한 원본 구조
        """
        texts, lbls = [], []
        structured  = []

        for q in tqdm(questions, desc="Collecting MCTS data"):
            paths = self.mcts.mcts_for_prm(q, samples=samples_per_q)[q]

            for path in paths:         # path = {"question", "completion", "rewards", …}
                steps   = path["completion"]
                rewards = path["rewards"]

                assert len(steps) == len(rewards)

                # step-wise 분해
                prefix_lines = [f"Problem: {q}"] if add_question else []
                for i in range(len(steps)):
                    prefix_lines.append(steps[i])
                    txt  = "\n".join(prefix_lines)          # 문제+현재까지 스텝
                    score = rewards[i]
                    texts.append(txt)
                    lbls.append(score)

                structured.append(path)     # 진단용

        if len(texts)==0:
            dummy = [{'question': 'What is (5+7)/2 - 3?', 'completion': ['Calculate the contents inside the parenthesis, 5+7 = 12.', 'Divide by 2, which is 12/2=6.', 'Subtract 3 from 6, 6-3=3.'], 'rewards': [0.7, 0.6,0.75], 'answer': '3.'}]
            return dummy, structured
        
        ds = PRMDataset(texts, lbls,
                        tokenizer=self.tok,
                        max_length=self.cfg.max_length)
        
        print("PRMDataset(size={}, avg_len={:.1f})".format(len(ds), sum(len(t.split()) for t in texts)/len(texts)))
        return ds, structured

    # ---------------------------------------------------------- train / valid
    def _run_epoch(self, loader: DataLoader, train: bool) -> float:
        self.prm.train() if train else self.prm.eval()
        tot = 0.0
        for step, (ids, r) in enumerate(loader):
            ids, r = ids.to(self.device), r.to(self.device)
            with torch.set_grad_enabled(train):
                feats = self._encode_features(ids)
                out   = self.prm(feats)
                loss  = self.crit(out, r)
                if train:
                    self.opt.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.prm.parameters(), 1.0)
                    self.opt.step()
            tot += loss.item()

            if self.cfg.use_wandb and train:
                wandb.log({
                    "batch_loss": loss.item(),
                    # "epoch"     : self.cur_epoch,   # train_prm 에서 설정
                    "step"      : step
                })
        return tot / len(loader)

    def train_prm(
        self,
        train_questions: List[str],
        val_questions  : List[str],
        num_epochs: int = 5,
    ) -> Dict[str, List[float]]:
        # 1) 데이터 수집
        train_ds, _ = self.build_dataset(train_questions)
        val_ds,   _ = self.build_dataset(val_questions)
        print("train ds:", train_ds)

        train_loader = DataLoader(train_ds, batch_size=self.cfg.batch_size, shuffle=True, num_workers=self.cfg.num_workers)
        val_loader = DataLoader(val_ds, batch_size=self.cfg.batch_size,shuffle=False, num_workers=self.cfg.num_workers)

        # 2) 학습 loop
        hist = {"train": [], "val": []}
        best = float("inf")
        for ep in range(num_epochs):
            tr = self._run_epoch(train_loader, train=True)
            vl = self._run_epoch(val_loader,   train=False)
            hist["train"].append(tr)
            hist["val"].append(vl)
            print(f"[EP {ep}] train {tr:.4f} | val {vl:.4f}")

            if self.cfg.use_wandb:
                wandb.log({
                    "epoch"     : ep,
                    "train_loss": tr,
                    "val_loss"  : vl,
                })
            if vl < best:
                best = vl
                self.save_checkpoint(ep, vl)
                # torch.save(self.prm.state_dict(), Path(self.cfg.checkpoint_dir) / "best_prm.pt")
        return hist

    # -------------------------------------------------------------- metrics
    def get_metrics(self) -> Dict[str, float]:
        return {
            "params": sum(p.numel() for p in self.prm.parameters()),
        }
    
    # -------------------------------------------------------------- save checkpoints
    def save_checkpoint(self, epoch: int, validation_loss: float):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.prm.state_dict(),
            'optimizer_state_dict': self.opt.state_dict(),
            'validation_loss': validation_loss,
            'config': self.cfg.__dict__
        }
        path = Path(self.cfg.checkpoint_dir) / f"checkpoint_epoch_{epoch}.pt"
        torch.save(checkpoint, path)
        logger.info(f"Saved checkpoint to {path}")

    def load_checkpoint(self, path: str):
        """Load model checkpoint"""
        try:
            checkpoint = torch.load(path, map_location=self.device)
            self.prm.load_state_dict(checkpoint['model_state_dict'])
            self.opt.load_state_dict(checkpoint['optimizer_state_dict'])
            logger.info(f"Loaded checkpoint from {path}")
            return checkpoint['epoch'], checkpoint['validation_loss']
        except Exception as e:
            logger.error(f"Failed to load checkpoint: {str(e)}")
            raise


## Main

In [15]:
cfg = OmegaPRMConfig(
        use_wandb=False,         # 예시이므로 off
        batch_size=8,
        num_workers=4,
    )
golden = {
    "What is (5+7)/2 - 3?": "3",
    "What is 2 + 2?": "4",
    "A box contains 8 red and 12 blue marbles. If John removes 5 red marbles and then 4 blue marbles, how many marbles remain in the box?": "11",
    "Solve for y: 2y - 7 = 3(y - 4).": "5",

}

mcts = MCTS(cfg, golden)
trainer = PRMTrainer(mcts, cfg)

train_q = ["A box contains 8 red and 12 blue marbles. If John removes 5 red marbles and then 4 blue marbles, how many marbles remain in the box?"]
val_q = ["Solve for y: 2y - 7 = 3(y - 4)."]
tr_ds, tr_st = trainer.build_dataset(train_q)
print(tr_ds)
print("MCTS train print tree")
mcts.print_tree(mcts.root)

val_ds, val_st = trainer.build_dataset(val_q)
print(val_ds)
print("MCTS val print tree")
mcts.print_tree(mcts.root)

train_loader = DataLoader(tr_ds, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=True, num_workers=4)

tr = trainer._run_epoch(train_loader, train=True)
vl = trainer._run_epoch(val_loader,   train=False)
print(f"[EP {1}] train {tr:.4f} | val {vl:.4f}")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.12it/s]
Collecting MCTS data:   0%|          | 0/1 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `

[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
 Calculate the remaining red and blue marbles separately.

Step 3: Sum the remaining red and blue marbles to get the total number of marbles left in the box.

Answer: The total number of marbles remaining in the box is 11.
```python
# Initial number of red and blue marbles
red_marbles = 8
blue_marbles = 12

# Marbles removed by John
red_marbles_removed = 5
blue_marbles_removed = 4

# Remaining red and blue marbles
remaining_red_marbles = red_marbles - red_marbles_removed
remaining_blue_marbles = blue_marbles - blue_marbles_removed

# Total remaining marbles
total_remaining_marbles = remaining_red_marbles + remaining_blue_marbles
print(total_remaining_marbles)
```
```output
11
```
The total number of marbles remaining in the box is \(\boxed{11}\).
=> Extracted answer: The total number of marbles remaining in the box is 11.
Extracted rollout_answer: The total number 

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
 Calculate the number of red marbles left after 5 are removed.

Step 3: Calculate the number of blue marbles left after 4 are removed.

Step 4: Calculate the total number of marbles left in the box.
To solve the problem, we can follow these steps:

1. Calculate the total number of marbles initially in the box.
2. Subtract the number of red marbles removed.
3. Subtract the number of blue marbles removed.
4. Calculate the total number of marbles left in the box.

Let's write the Python code to solve this problem step-by-step.
```python
# Initial number of red and blue marbles
red_marbles = 8
blue_marbles = 12

# Marbles removed by John
red_marbles_removed = 5
blue_marbles_removed = 4

# Calculate the number of red and blue marbles left
red_marbles_left = red_marbles - red_marbles_removed
blue_marbles_left = blue_marbles - blue_marbles_removed

# Calculate the total n

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
 Subtract the number of red marbles removed by John from the initial number of red marbles.

Step 3: Subtract the number of blue marbles removed by John from the initial number of blue marbles.

Step 4: Sum the remaining red and blue marbles to get the total number of marbles left in the box.

**Answer:** The final answer is the total number of marbles remaining in the box.
```python
# Initial number of red and blue marbles
initial_red_marbles = 8
initial_blue_marbles = 12

# Marbles removed by John
removed_red_marbles = 5
removed_blue_marbles = 4

# Remaining red and blue marbles after removal
remaining_red_marbles = initial_red_marbles - removed_red_marbles
remaining_blue_marbles = initial_blue_marbles - removed_blue_marbles

# Total remaining marbles
total_remaining_marbles = remaining_red_marbles + remaining_blue_marbles
print(total_remaining_marbles)
```
```ou

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
 The initial number of marbles is \(8 + 12 = 20\). After removing 5 red marbles and 4 blue marbles, the remaining number of marbles is \(20 - (5 + 4) = 20 - 9 = 11\).

Answer: 11
=> Extracted answer: 11
Extracted rollout_answer: 11


Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
 Find the remaining number of red marbles and blue marbles separately.

Step 3: Sum up the remaining red and blue marbles to get the total number of marbles left in the box.

Answer: 11 Let's start by writing the code to solve the problem step by step.

1. Calculate the initial number of marbles.
2. Subtract the marbles removed by John.
3. Sum the remaining marbles.
```python
# Step 1: Initial number of marbles
initial_red_marbles = 8
initial_blue_marbles = 12

# Step 2: Marbles removed by John
removed_red_marbles = 5
removed_blue_marbles = 4

# Step 3: Remaining marbles
remaining_red_marbles = initial_red_marbles - removed_red_marbles
remaining_blue_marbles = initial_blue_marbles - removed_blue_marbles

# Total remaining marbles
total_remaining_marbles = remaining_red_marbles + remaining_blue_marbles
print(total_remaining_marbles)
```
```output
11
```
The total nu

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
Step 2: Calculate the remaining number of each color marble after removal. Let's break down the problem and solve it using Python to ensure the accuracy of the result.

1. **Initial Setup**:
   - Total red marbles: 8
   - Total blue marbles: 12
   - Total marbles initially: 8 + 12 = 20

2. **Remove Marbles**:
   - Remove 5 red marbles, so remaining red marbles = 8 - 5 = 3
   - Remove 4 blue marbles, so remaining blue marbles = 12 -
 Add the remaining red and blue marbles to get the total number of marbles left in the box.

Here is the Python code to calculate this:
```python
# Initial number of marbles
initial_red_marbles = 8
initial_blue_marbles = 12

# Marbles removed by John
removed_red_marbles = 5
removed_blue_marbles = 4

# Remaining marbles
remaining_red_marbles = initial_red_marbles - removed_red_marbles
remaining_blue_marbles = initial_blue_marbles - remove

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
Step 2: Calculate the remaining number of each color marble after removal. Let's break down the problem and solve it using Python to ensure the accuracy of the result.

1. **Initial Setup**:
   - Total red marbles: 8
   - Total blue marbles: 12
   - Total marbles initially: 8 + 12 = 20

2. **Remove Marbles**:
   - Remove 5 red marbles, so remaining red marbles = 8 - 5 = 3
   - Remove 4 blue marbles, so remaining blue marbles = 12 -
 Calculate the total number of marbles remaining in the box.
- Remaining red marbles: 3
- Remaining blue marbles: 12 - 4 = 8
- Total remaining marbles = 3 + 8 = 11

Let's implement this in Python to confirm the result.
```python
# Initial number of red and blue marbles
initial_red_marbles = 8
initial_blue_marbles = 12

# Marbles removed by John
red_marbles_removed = 5
blue_marbles_removed = 4

# Calculate remaining marbles
remaining_red_

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
Step 2: Calculate the remaining number of each color marble after removal. Let's break down the problem and solve it using Python to ensure the accuracy of the result.

1. **Initial Setup**:
   - Total red marbles: 8
   - Total blue marbles: 12
   - Total marbles initially: 8 + 12 = 20

2. **Remove Marbles**:
   - Remove 5 red marbles, so remaining red marbles = 8 - 5 = 3
   - Remove 4 blue marbles, so remaining blue marbles = 12 -
 Calculate the total number of marbles remaining in the box.
```python
# Initial number of red and blue marbles
initial_red = 8
initial_blue = 12

# Marbles removed by John
removed_red = 5
removed_blue = 4

# Remaining marbles after removal
remaining_red = initial_red - removed_red
remaining_blue = initial_blue - removed_blue

# Total remaining marbles
total_remaining_marbles = remaining_red + remaining_blue
print(total_remaining_marbles

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
Step 2: Calculate the remaining number of each color marble after removal. Let's break down the problem and solve it using Python to ensure the accuracy of the result.

1. **Initial Setup**:
   - Total red marbles: 8
   - Total blue marbles: 12
   - Total marbles initially: 8 + 12 = 20

2. **Remove Marbles**:
   - Remove 5 red marbles, so remaining red marbles = 8 - 5 = 3
   - Remove 4 blue marbles, so remaining blue marbles = 12 -
 Calculate the total number of marbles remaining in the box after John removes the marbles.
\[
\text{Total marbles remaining} = \text{Remaining red marbles} + \text{Remaining blue marbles} = 3 + 8 = 11
\]

Let's confirm this by implementing it in Python.
```python
# Initial number of marbles
initial_red_marbles = 8
initial_blue_marbles = 12

# Marbles removed by John
removed_red_marbles = 5
removed_blue_marbles = 4

# Calculate remaining

MCTS: 100%|██████████| 2/2 [01:48<00:00, 54.27s/it]
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Calculate the total number of marbles initially and then subtract the marbles removed by John.
Step 2: Calculate the remaining number of each color marble after removal. Let's break down the problem and solve it using Python to ensure the accuracy of the result.

1. **Initial Setup**:
   - Total red marbles: 8
   - Total blue marbles: 12
   - Total marbles initially: 8 + 12 = 20

2. **Remove Marbles**:
   - Remove 5 red marbles, so remaining red marbles = 8 - 5 = 3
   - Remove 4 blue marbles, so remaining blue marbles = 12 -
 Calculate the remaining number of marbles in the box.
- Remaining marbles = Remaining red marbles + Remaining blue marbles = 3 + 8 = 11

Now let's confirm this with Python code.
```python
# Initial number of red and blue marbles
red_marbles = 8
blue_marbles = 12

# Marbles removed by John
red_removed = 5
blue_removed = 4

# Remaining marbles after removal
remaining_red_marbles = red_marbles - red_removed
remaining_blue_marbles = blue_ma

Collecting MCTS data: 100%|██████████| 1/1 [02:00<00:00, 120.90s/it]


[{'question': 'What is (5+7)/2 - 3?', 'completion': ['Calculate the contents inside the parenthesis, 5+7 = 12.', 'Divide by 2, which is 12/2=6.', 'Subtract 3 from 6, 6-3=3.'], 'rewards': [0.7, 0.6, 0.75], 'answer': '3.'}]
MCTS train print tree
- Node(depth=0, visits=2, Q=0.50): 
    - Node(depth=1, visits=1, Q=0.47): Step 1: Calculate the total number of marbles initially a...
        - Node(depth=2, visits=0, Q=0.00): Step 1: Calculate the total number of marbles initially a...


Collecting MCTS data:   0%|          | 0/1 [00:00<?, ?it/s]Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
 Simplify the equation from Step 1 to solve for y.

Answer: 

Let's solve the problem step-by-step using Python code to ensure accuracy.
```python
from sympy import symbols, Eq, solve

# Define the variable
y = symbols('y')

# Define the equation
equation = Eq(2 * y - 7, 3 * (y - 4))

# Solve the equation
solution = solve(equation, y)
print(solution)
```
```output
[5]
```
The solution to the equation \(2y - 7 = 3(y - 4)\) is \(y = 5\).

So the final answer is \(\boxed{5}\).
=> Extracted answer: Let's solve the problem step-by-step using Python code to ensure accuracy.
Extracted rollout_answer: Let's solve the problem step-by-step using Python code to ensure accuracy.


Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
 Subtract 2y from both sides to get -7 = y - 12. Then, add 12 to both sides to isolate y.

Answer: 5

Let's solve the problem step by step. The problem is to solve for \( y \) in the equation \( 2y - 7 = 3(y - 4) \).

### Step-by-Step Solution

1. **Initial Equation:** \( 2y - 7 = 3(y - 4) \)
2. **Expand the Right-Hand Side:** \( 2y - 7 = 3y - 12 \)
3. **Rearrange the Equation:** Subtract \( 2y \) from both sides to get \( -7 = y - 12 \). Then, add 12 to both sides to isolate \( y \).

Let's perform these steps in Python to ensure accuracy.
```python
from sympy import symbols, Eq, solve

# Define the variable
y = symbols('y')

# Define the equation
equation = Eq(2*y - 7, 3*(y - 4))

# Solve the equation
solution = solve(equation, y)
print(solution)
```
```output
[5]
```
The solution to the equation \(2y - 7 = 3(y - 4)\) is \( y = 5 \)

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
 Combine like terms to get a simpler equation.

Step 3: Solve for y.
To solve the equation \(2y - 7 = 3(y - 4)\), we will follow the steps to isolate the variable \(y\).

Step 1: Distribute the 3 on the right-hand side of the equation.
\[2y - 7 = 3y - 12\]

Step 2: Subtract \(2y\) from both sides to get all the \(y\) terms on one side.
\[-7 = y - 12\]

Step 3: Add 12 to both sides to isolate \(y\).
\[5 = y\]

So, the solution to the equation is \(y = 5\).

Let's verify this by substituting \(y = 5\) back into the original equation to ensure it satisfies the equation.

\[2(5) - 7 = 3(5 - 4)\]
\[10 - 7 = 3(1)\]
\[3 = 3\]

The left-hand side equals the right-hand side, confirming that \(y = 5\) is the correct solution. The final answer is:

\(\boxed{5}\)

Let's write the Python code to verify this step-by-step solution.
```python
from sy

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
 Subtract 2y from both sides to get -7 = y - 12.
Step 3: Add 12 to both sides to get 5 = y.
Final Answer: y = 5
```python
# Given problem: Solve for y: 2y - 7 = 3(y - 4)

# Let's break down the problem step-by-step using Python to verify the solution.

# Given equation: 2y - 7 = 3(y - 4)

# Step 1: Expand the right-hand side
# 2y - 7 = 3y - 12

# Step 2: Rearrange the equation to isolate y
# Subtract 2y from both sides
# -7 = y - 12

# Step 3: Add 12 to both sides
# 5 = y

# Final answer
print(5)
```
```output
5
```
The final answer is \(\boxed{5}\).
=> Extracted answer: y = 5
Extracted rollout_answer: y = 5


Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
 Let's solve the equation step-by-step:

Starting with the equation:
\[ 2y - 7 = 3(y - 4) \]

First, we expand the right-hand side:
\[ 2y - 7 = 3y - 12 \]

Next, we isolate y by moving all terms involving y to one side and the constant terms to the other side:
\[ 2y - 3y = -12 + 7 \]
\[ -y = -5 \]

Then, we solve for y by dividing both sides by -1:
\[ y = 5 \]

So the solution to the equation is:
\[ \boxed{5} \]

Let's confirm this by solving it with Python code.
```python
from sympy import symbols, Eq, solve

# Define the variable
y = symbols('y')

# Define the equation
equation = Eq(2*y - 7, 3*(y - 4))

# Solve the equation
solution = solve(equation, y)
print(solution)
```
=> Extracted answer: None
Extracted rollout_answer: None
[Rollout] successes: 1/5, mc=0.20, Q=0.527


Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
Step 1:Distribute the 2 on the right-hand side to get 5z + 3 = 2z + 8. Then, subtract 2z from both sides to combine like terms.

 Subtract 3 from both sides
Step 4: Divide both sides by 3 to get
Answer: 6
Let's solve the problem step-by-step using Python and SymPy.

1. **Set up the equation:** The equation we need to solve is \(2y - 7 = 3y - 12\).
2. **Rearrange the equation to isolate \(y\):** Subtract \(2y\) from both sides to get \(-7 = y - 12\). Then, add 12 to both sides to get \(y = 5\).

Here's the Python code to solve the equation:

```python
import sympy as sp

# Define the variable
y = sp.symbols('y')

# Define the equation
equation = sp.Eq(2*y - 7, 3*y - 12)

# Solve the equation for y
solution = sp.solve(equation, y)

# Print the solution
print(solution[0])
```
```output
5
```
The solution to the equation \(2y - 7 = 3y - 1

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
Step 1:Distribute the 2 on the right-hand side to get 5z + 3 = 2z + 8. Then, subtract 2z from both sides to combine like terms.

 Subtract 5z from both sides to get z + 3 = 8. Then, subtract 3 from both sides to isolate z.

Step 4: Simplify to get z = 5. Answer: 5
```python
from sympy import symbols, Eq, solve

# Define the variable
y = symbols('y')

# Define the equation 2y - 7 = 3(y - 4)
equation = Eq(2*y - 7, 3*(y - 4))

# Solve the equation
solution = solve(equation, y)
print(solution[0])
```
```output
5
```
The solution to the equation \(2y - 7 = 3(y - 4)\) is \(\boxed{5}\).
=> Extracted answer: 5
Extracted rollout_answer: 5


Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
Step 1:Distribute the 2 on the right-hand side to get 5z + 3 = 2z + 8. Then, subtract 2z from both sides to combine like terms.

 Subtract 5z from both sides to combine like terms.
Step 4: Subtract 8 from both sides.
Answer: To solve for \( y \) in the equation \( 2y - 7 = 3(y - 4) \), we can follow these steps:

1. Distribute the 3 on the right-hand side.
2. Combine like terms.
3. Isolate \( y \).

Let's solve it step-by-step with Python to ensure accuracy.
```python
from sympy import symbols, Eq, solve

# Define the variable
y = symbols('y')

# Define the equation
equation = Eq(2*y - 7, 3*(y - 4))

# Solve the equation
solution = solve(equation, y)
print(solution)
```
```output
[5]
```
The solution to the equation \( 2y - 7 = 3(y - 4) \) is \( y = 5 \).

Therefore, the final answer is:

\(\boxed{5}\)
=> Extracted answer: To solve fo

Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
Step 1:Distribute the 2 on the right-hand side to get 5z + 3 = 2z + 8. Then, subtract 2z from both sides to combine like terms.

 Combine like terms to get 2y = 3y - 5. Subtract 3y from both sides to isolate the variable y.

Here is the code to solve the problem step-by-step using Python and Sympy:

```python
import sympy as sp

# Define the variable
y = sp.symbols('y')

# Define the equation
equation = 2*y - 7 - 3*(y - 4)

# Solve the equation
solution = sp.solve(equation, y)

# Print the solution
print(solution)
```
```output
[5]
```
The solution to the equation \(2y - 7 = 3(y - 4)\) is \(y = 5\).

Thus, the answer is:

\(\boxed{5}\)
=> Extracted answer: None
Extracted rollout_answer: None


MCTS: 100%|██████████| 2/2 [02:06<00:00, 63.44s/it]
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


[Rollout] Solution:
Step 1: Subtract 7 from both sides and expand the right-hand side to get 2y - 7 = 3y - 12. Then, rearrange the equation to isolate y.
Step 1:Distribute the 2 on the right-hand side to get 5z + 3 = 2z + 8. Then, subtract 2z from both sides to combine like terms.

 Subtract 3 from both sides to get 3z = 5. Then, divide both sides by 3 to find z.
Step 4: To verify the correctness of the solution, substitute the value of z back into the original equation and check if both sides are equal.

## Example 5 ##
Current solution:
Problem: Solve for z: 5z + 3 = 2z + 8.
Step 1: Subtract 2z from both sides to get 3z + 3 = 8. Then, subtract 3 from both sides to isolate the term with z.
Step 2: Subtract 3 from both sides to get 3z = 5. Then, divide both sides by 3 to isolate z.
To solve the problem, we'll follow the steps provided and verify the solution. 

1. Start with the equation: 2y - 7 = 3(y - 4)
2. Distribute the 3 on the right-hand side: 2y - 7 = 3y - 12
3. Subtract 2y from

Collecting MCTS data: 100%|██████████| 1/1 [02:18<00:00, 138.24s/it]

[{'question': 'What is (5+7)/2 - 3?', 'completion': ['Calculate the contents inside the parenthesis, 5+7 = 12.', 'Divide by 2, which is 12/2=6.', 'Subtract 3 from 6, 6-3=3.'], 'rewards': [0.7, 0.6, 0.75], 'answer': '3.'}]
MCTS val print tree
- Node(depth=0, visits=2, Q=0.53): 
    - Node(depth=1, visits=1, Q=0.53): Step 1: Subtract 7 from both sides and expand the right-h...
        - Node(depth=2, visits=0, Q=0.00): Step 1: Subtract 7 from both sides and expand the right-h...
        - Node(depth=2, visits=0, Q=0.00): Step 1: Subtract 7 from both sides and expand the right-h...



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- A

ValueError: too many values to unpack (expected 2)

In [9]:
def main():
    cfg = OmegaPRMConfig(
        use_wandb=True,         # 예시이므로 off
        batch_size=12,
        num_workers=4,
    )

    golden = {
        "What is 2 + 2?": "4",
        "What is (5+7)/2 - 3?": "3",
    }

    # 1) MCTS 초기화
    mcts = MCTS(cfg, golden)

    # 2) PRM trainer
    trainer = PRMTrainer(mcts, cfg)

    # 3) 학습 파이프라인
    train_q = ["What is 2 + 2?"]
    val_q   = ["What is (5+7)/2 - 3?"]
    trainer.train_prm(train_q, val_q, num_epochs=1)

    print("MCTS print tree")
    mcts.print_tree(mcts.root)

    # (선택) MCTS 메트릭·결과 저장
    mcts.export_results("results.json")
    print("MCTS metrics:", mcts.get_metrics())
    print("PRM metrics:",  trainer.get_metrics())

if __name__ == "__main__":
    main()

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.97it/s]
[34m[1mwandb[0m: Currently logged in as: [33mleena12[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Collecting MCTS data:   0%|          | 0/1 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/mai

[Rollout] successes: 3/5, mc=0.60, Q=0.745


Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both

[Rollout] successes: 1/5, mc=0.20, Q=0.572


Collecting MCTS data: 100%|██████████| 1/1 [00:19<00:00, 19.07s/it]
Collecting MCTS data:   0%|          | 0/1 [00:00<?, ?it/s]Both `max_new_tokens` (=128) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer 

[Rollout] successes: 2/5, mc=0.40, Q=0.602


Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both `max_new_tokens` (=384) and `max_length`(=20) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
Both

[Rollout] successes: 0/5, mc=0.00, Q=0.452


Collecting MCTS data: 100%|██████████| 1/1 [02:34<00:00, 154.71s/it]

train ds: No dataset is collected.
train loader: <torch.utils.data.dataloader.DataLoader object at 0x7de21c615010>



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- A

ValueError: too many values to unpack (expected 2)

# GSM8K

In [None]:
# ─── main_gsm8k.py ──────────────────────────────────────────────────────────
from datasets import load_dataset
# from omega_cfg import OmegaPRMConfig
# from mcts_module import MCTS                # 이미 작성된 클래스 import
# from trainer_prm import PRMTrainer          # 개선된 trainer import
# from dataset_prm import PRMDataset
import random, json

def build_golden_from_gsm(split="train", n_samples=2000):
    ds = load_dataset("gsm8k", "main", split=split)
    rows = random.sample(list(ds), n_samples)
    # GSM8K 레코드: {"question": "...", "answer": "#### 42"} (42 가 정답)
    golden = {}
    for r in rows:
        # answer 형식 "#### 42" → "42" 만 추출
        ans = r["answer"].split("####")[-1].strip()
        golden[r["question"]] = ans
    print(len(golden))
    return golden

def main():
    cfg = OmegaPRMConfig(
        use_wandb=True,
        batch_size=16,
        num_workers=4,
        max_length=256,
        rollout_width=10,            # GSM8K는 계산이 복잡하므로 살짝 축소
    )

    golden = build_golden_from_gsm(split="train", n_samples=500)
    questions = list(golden.keys())
    random.shuffle(questions)
    train_q, val_q = questions[:400], questions[400:450]

    mcts = MCTS(cfg, golden)
    trainer = PRMTrainer(mcts, cfg)

    trainer.train_prm(train_q, val_q, num_epochs=3)

    # 저장
    mcts.export_results("gsm_metrics.json")
    with open("gsm_prm_stats.json", "w") as f:
        json.dump(trainer.get_metrics(), f, indent=2)

if __name__ == "__main__":
    main()
