# Utils

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

model_name = "Qwen/Qwen2.5-Math-PRM-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, device_map="auto",  torch_dtype=torch.bfloat16, trust_remote_code=True,).eval()

data = {
    "system": "Please reason step by step, and put your final answer within \\boxed{}.",
    "query": "Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard.  Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?",
    "response": [
      "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
      "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
      "On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
      "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30})."
    ]
}


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s]
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i] # seq_len, num_labels [452, 2]
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels [# of steps, label]
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    return all_scores_res


messages = [
    {"role": "system", "content": data['system']},
    {"role": "user", "content": data['query']},
    {"role": "assistant", "content": "<extra_0>".join(data['response']) + "<extra_0>"},
]
conversation_str = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=False
)

input_ids = tokenizer.encode(conversation_str, return_tensors="pt").to(model.device)

outputs = model(input_ids=input_ids)

step_sep_id = tokenizer.encode("<extra_0>")[0]
token_masks = (input_ids == step_sep_id)
step_reward = make_step_rewards(outputs[0], token_masks)
print(step_reward)  # [[1.0, 0.1904296875, 0.9765625, 1.0]]

In [9]:
def check_token_learning():
    test_tokens = ["<extra_0>", "<extra_1>", "<extra_2>", "|STEP|"]
    for token in test_tokens:
        try:
            token_id = tokenizer.encode(token)[0]
            print(f"Token '{token}': ID {token_id}")
            if hasattr(model, 'get_input_embeddings'):
                embedding = model.get_input_embeddings()
                token_embedding = embedding(torch.tensor([token_id]))
                print(f"  Embedding shape: {token_embedding.shape}")
                
        except Exception as e:
            print(f"Token '{token}': Error - {e}")

check_token_learning()

Token '<extra_0>': ID 151651
Token '<extra_0>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '<extra_1>': ID 27
Token '<extra_1>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '<extra_2>': ID 27
Token '<extra_2>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '|STEP|': ID 91
Token '|STEP|': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)


# Incorrect Step Dataset

In [24]:
import json
import random
from typing import List, Dict, Any, Tuple
import copy

def create_incorrect_steps(gold_steps: List[str], incorrect_type: str = "wrong_calculation") -> Tuple[List[str], int]:
    if not gold_steps:
        return gold_steps

    perturbed_steps = gold_steps.copy()
    insert_position = random.randint(1, len(gold_steps))   # 1 ≤ pos ≤ len
    
    if incorrect_type == "wrong_calculation":
        wrong_calculations = [
            f"Step {insert_position + 1}: 5 + 3 = 9",  # 5 + 3 = 8
            f"Step {insert_position + 1}: 10 * 2 = 15",  # 10 * 2 = 20
            f"Step {insert_position + 1}: 20 / 4 = 6",  # 20 / 4 = 5
            f"Step {insert_position + 1}: 7 - 3 = 5",  # 7 - 3 = 4
            f"Step {insert_position + 1}: 2^2 = 5",  # 2^2 = 4
            f"Step {insert_position + 1}: sqrt(16) = 3",  # sqrt(16) = 4
            f"Step {insert_position + 1}: 3 * 7 = 24",  # 3 * 7 = 21
            f"Step {insert_position + 1}: 15 / 3 = 6",  # 15 / 3 = 5
        ]
        insert_step = random.choice(wrong_calculations)
    elif incorrect_type == "logical_error":
        logical_errors = [
            f"Step {insert_position + 1}: Since we need to find the total, we should multiply by 0 instead of adding.",
            f"Step {insert_position + 1}: To solve this, we need to subtract the larger number from the smaller one.",
            f"Step {insert_position + 1}: The answer should be negative because we're dealing with positive numbers.",
            f"Step {insert_position + 1}: We can ignore the units since they don't affect the calculation.",
            f"Step {insert_position + 1}: The order of operations doesn't matter here, so we can do addition first.",
        ]
        insert_step = random.choice(logical_errors)
    elif incorrect_type == "irrelevant":
        irrelevant_steps = [
            f"Step {insert_position + 1}: The weather is nice today.",
            f"Step {insert_position + 1}: I like literatures very much.",
            f"Step {insert_position + 1}: This reminds me of my school days.",
            f"Step {insert_position + 1}: The sky is blue and beautiful.",
            f"Step {insert_position + 1}: I should drink more water.",
            f"Step {insert_position + 1}: Python is the language of the universe.",
        ]
        insert_step = random.choice(irrelevant_steps)
    elif incorrect_type == "repetition":
        if len(gold_steps) > 1:
            repeat_step = gold_steps[insert_position - 1]  # 이전 step
            step_num = insert_position + 1
            step_content = repeat_step.split(":", 1)[1] if ":" in repeat_step else ""
            insert_step = f"Step {step_num}:{step_content}"
        else:
            insert_position = 1
            step_num = insert_position + 1
            repeat_step = gold_steps[0]
            step_content = repeat_step.split(":", 1)[1] if ":" in repeat_step else ""
            insert_step = f"Step {step_num}:{step_content}"
    else:
        insert_step = f"Step {insert_position + 1}: 5 + 3 = 9"
    
    perturbed_steps.insert(insert_position, insert_step)

    for i in range(insert_position + 1, len(perturbed_steps)):
            if perturbed_steps[i].startswith("Step "):
                step_num = i + 1
                step_content = perturbed_steps[i].split(":", 1)[1] if ":" in perturbed_steps[i] else ""
                perturbed_steps[i] = f"Step {step_num}:{step_content}"
    
    return perturbed_steps, insert_position

def add_incorrect_completion(entry: Dict[str, Any], incorrect_type: str = "wrong_calculation", negative_reward: float = -1.0) -> Dict[str, Any]:
    new_entry = copy.deepcopy(entry)
    original_completion = entry.get("completion", [])
    # incorrect completion 생성
    incorrect_completion, insert_position = create_incorrect_steps(original_completion, incorrect_type)
    new_entry["completion"] = incorrect_completion
    # incorrect rewards 생성
    for key, val in entry.items():
        if isinstance(val, list) and len(val) == len(original_completion) and key != "completion":
            new_vec = val.copy()
            if key == "contributions":
                new_vec.insert(insert_position, negative_reward)
            elif key == "mi_filtered":
                new_vec.insert(insert_position, 0.0)
            else:
                new_vec.insert(insert_position, negative_reward)
            new_entry[key] = new_vec
    # meta data 추가
    new_entry["is_incorrect"] = True
    new_entry["incorrect_type"] = incorrect_type

    return new_entry

def extend_dataset_with_incorrect_steps(dataset: List[Dict[str, Any]], incorrect_types: List[str] = None,negative_rewards: List[float] = None, ratio: float = 0.5) -> List[Dict[str, Any]]:
    if incorrect_types is None:
        incorrect_types = ["wrong_calculation", "logical_error", "irrelevant_step", "repetition"]
    if negative_rewards is None:
        negative_rewards = [-1.0] * len(incorrect_types)
    
    # negative_rewards를 incorrect_types와 매칭
    if len(negative_rewards) != len(incorrect_types):
        negative_rewards = negative_rewards * (len(incorrect_types) // len(negative_rewards) + 1)
        negative_rewards = negative_rewards[:len(incorrect_types)]
    
    extended_dataset = []
    for entry in dataset:
        extended_dataset.append(entry)
        if random.random() < ratio:
            incorrect_type = random.choice(incorrect_types) # 랜덤하게 incorrect type 선택
            type_idx = incorrect_types.index(incorrect_type)
            negative_reward = negative_rewards[type_idx]
            incorrect_entry = add_incorrect_completion(entry, incorrect_type, negative_reward)
            extended_dataset.append(incorrect_entry)
    
    return extended_dataset



In [25]:
input_path = "/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral.json"
output_path = "/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral_incorrect.json"

with open(input_path, "r") as file:
    dataset = json.load(file)
print(f"Original dataset size: {len(dataset)}")

extended_dataset = extend_dataset_with_incorrect_steps(dataset=dataset, ratio=0.5)
print(f"Extended dataset size: {len(extended_dataset)}")
print(f"Added {len(extended_dataset) - len(dataset)} incorrect samples")

# 확장된 dataset 저장
with open(output_path, "w") as file2:
    json.dump(extended_dataset, file2, indent=2)

Original dataset size: 7473
Extended dataset size: 11141
Added 3668 incorrect samples


In [26]:
# 통계 출력
correct_count = sum(1 for entry in extended_dataset if not entry.get("is_incorrect", False))
incorrect_count = sum(1 for entry in extended_dataset if entry.get("is_incorrect", False))

print(f"Correct samples: {correct_count}")
print(f"Incorrect samples: {incorrect_count}")

# 타입별 통계
type_counts = {}
for entry in extended_dataset:
    if entry.get("is_incorrect", False):
        incorrect_type = entry.get("incorrect_type", "unknown")
        type_counts[incorrect_type] = type_counts.get(incorrect_type, 0) + 1

print("\nIncorrect types distribution:")
for incorrect_type, count in type_counts.items():
    print(f"  {incorrect_type}: {count}")


Correct samples: 7473
Incorrect samples: 3668

Incorrect types distribution:
  repetition: 911
  wrong_calculation: 928
  irrelevant_step: 923
  logical_error: 906


## Add incorrect MI

In [32]:
import json

def convert_jsonl_to_json(jsonl_file_path, json_file_path):
    json_objects = []
    
    try:
        with open(jsonl_file_path, 'r', encoding='utf-8') as infile:
            for line in infile:
                line = line.strip()
                if line:
                    json_objects.append(json.loads(line))
        
        with open(json_file_path, 'w', encoding='utf-8') as outfile:
            json.dump(json_objects, outfile, ensure_ascii=False, indent=4)
            
        print(f"성공: '{jsonl_file_path}' 파일이 '{json_file_path}' 파일로 성공적으로 변환되었습니다.")
    except FileNotFoundError:
        print(f"오류: 입력 파일 '{jsonl_file_path}'을(를) 찾을 수 없습니다.")
    except json.JSONDecodeError as e:
        print(f"오류: '{jsonl_file_path}' 파일 처리 중 JSON 파싱 오류가 발생했습니다. 오류 내용: {e}")
    except Exception as e:
        print(f"알 수 없는 오류가 발생했습니다: {e}")

input_file_name = '/home/leena/ccc_eval/mcts_prm/cmi_samples/math_mi_mistral_full.jsonl'
output_file_name = '/home/leena/ccc_eval/mcts_prm/cmi_samples/test_json.json'
convert_jsonl_to_json(input_file_name, output_file_name)


성공: '/home/leena/ccc_eval/mcts_prm/cmi_samples/math_mi_mistral_full.jsonl' 파일이 '/home/leena/ccc_eval/mcts_prm/cmi_samples/test_json.json' 파일로 성공적으로 변환되었습니다.


In [33]:
import json, copy, random, re
from typing import List, Dict, Any, Optional
import numpy as np
from collections import Counter

# ======================= Expanded Banks =======================
SELF_REFLECTION_BANK = [
    "Let me think about this problem carefully.",
    "I need to check my calculations.",
    "This step seems important for the solution.",
    "Let me verify the previous step.",
    "I should double-check my work.",
    "This is a crucial part of the solution.",
    "Let me organize my thoughts.",
    "I need to be careful with the math.",
    "I should confirm the units are consistent.",
    "Let me restate the given conditions precisely.",
    "I might have made an algebraic slip; let me re-derive.",
    "I should check edge cases and constraints.",
    "Let me simplify the expression before substituting values.",
    "I should verify if I used the correct formula.",
    "Let me check whether I applied the operation in the right order.",
    "I should compute a quick sanity check with approximations.",
    "Let me recompute using an alternative method to confirm.",
    "I should cross-check with the final requirement of the question.",
    "Let me verify that each transformation is logically valid.",
    "I should test with a small example to validate the pattern.",
]

WRONG_STEP_BANK = [
    # 산술/대수
    "5 + 3 = 9",
    "10 * 2 = 15",
    "20 / 4 = 6",
    "x/2 = 6, so x = 10",
    "2^2 = 5",
    "\\sqrt(36) = 5",
    "Perimeter of square side 5 = 15",
    # 추가: 분수/부호/분배/지수/로그/근사
    "1/3 + 1/6 = 1/9",
    "(-2)^2 = -4",
    "2(x + 3) = 2x + 3",
    "x^a * x^b = x^{a-b}",
    "\\log(ab) = \\log a - \\log b",
    "\\sqrt{a+b} = \\sqrt a + \\sqrt b",
    "1/0 = 0",
    "0^0 = 0",
    "7/10 ≈ 0.9",
    # 기하
    "Area of triangle = base + height",
    "Circumference of a circle with r=3 is 3r",
    "Area of a circle with r=3 is 2\\pi r^2",
    "Pythagorean theorem: a + b = c",
    # 확률/통계
    "P(A \\cap B) = P(A) + P(B)",
    "Variance of cX is Var(X) + c",
    "Mean of [2,4,9] is 6",
    # 미적분
    "d/dx (x^2) = 2",
    "∫ x dx = x^2 + C (missing 1/2)",
    "Derivative of sin x is -cos x",
    "Product rule: (fg)' = f' + g'",
    # 방정식 처리
    "From 2x = 6, x = 2 (dividing by 2 and subtracting 2)",
    "If x/y = 2/3, then x = y (cross-multiplication error)",
]

IRRELEVANT_BANK = [
    "The weather is nice today.",
    "I like mathematics very much.",
    "This reminds me of my school days.",
    "The sky is blue and beautiful.",
    "I should drink more water.",
    "Patterns exist in everything.",
    "I should make a grocery list.",
    "I wonder what to cook for dinner.",
    "This pencil needs sharpening.",
    "I should clean my desk later.",
    "The soundtrack from that movie is stuck in my head.",
    "My cat was very energetic this morning.",
    "I might take a walk after finishing this.",
    "I should reply to that email soon.",
    "I forgot to water the plants yesterday.",
    "I wonder if it's going to rain tomorrow.",
    "I should back up my files.",
    "This coffee tastes a bit strong.",
    "I need to charge my phone.",
]

In [34]:
def renumber_steps(steps: List[str], start_idx: int = 0) -> List[str]:
    new_steps: List[str] = []
    for i in range(len(steps)):
        s = steps[i]
        content = s.split(":", 1)[1] if ":" in s else s
        new_steps.append(f"Step {start_idx + i + 1}:{content}")
    return new_steps

def _is_numeric_list(v: Any, expected_len: int) -> bool:
    if not isinstance(v, list) or len(v) != expected_len:
        return False
    try:
        _ = [float(x) for x in v]
        return True
    except Exception:
        return False

def _sample_perturbation_type(perturbation_probs: Optional[Dict[str, float]] = None, rng: Optional[random.Random] = None,) -> str:
    R = rng or random
    types = ["wrong_step", "irrelevant", "self_reflection"]
    if perturbation_probs:
        # normalize
        items = [(t, max(0.0, float(perturbation_probs.get(t, 0.0)))) for t in types]
        total = sum(w for _, w in items)
        if total <= 0:
            weights = [1/3, 1/3, 1/3]
        else:
            weights = [w/total for _, w in items]
    else:
        weights = [1/3, 1/3, 1/3]
    return random.choices(types, weights=weights, k=1)[0]

def create_perturbed_steps(steps: List[str], typ: str, insert_pos: int, rng: Optional[random.Random] = None) -> (List[str], int):
    R = rng or random
    assert 0 <= insert_pos <= len(steps)
    new_steps = steps.copy()
    if typ == "wrong_step":
        ins = f"Step {insert_pos + 1}: {random.choice(WRONG_STEP_BANK)}"
    elif typ == "irrelevant":
        ins = f"Step {insert_pos + 1}: {random.choice(IRRELEVANT_BANK)}"
    else:
        ins = f"Step {insert_pos + 1}: {random.choice(SELF_REFLECTION_BANK)}"
    new_steps.insert(insert_pos, ins)
    new_steps = renumber_steps(new_steps)
    return new_steps, insert_pos

def _robust_z(x: np.ndarray, clip_z: float = 3.0) -> np.ndarray:
    x = np.asarray(x, dtype=float)
    med = np.median(x)
    mad = np.median(np.abs(x - med)) + 1e-8
    z = (x - med) / (1.4826 * mad)
    if clip_z is not None:
        z = np.clip(z, -clip_z, clip_z)
    return z

def _normalize_signed(x: np.ndarray, tau=1.5, clip_z=3.0, deadzone=0.2):
    z = _robust_z(x, clip_z=clip_z)
    s = np.tanh(z / max(tau, 1e-8))   # [-1,1]
    if deadzone and deadzone > 0.0:
        mag = np.maximum(0.0, np.abs(s) - deadzone) / (1.0 - deadzone)
        s = np.sign(s) * mag
    return s

def recompute_mi_norm_with_ignore(raw: List[float], incorrect_mask: List[int],*, mode: str = "unit",         # "signed" | "unit" | "minmax" | "relu"
    tau: float = 1.5, clip_z: float = 3.0, deadzone: float = 0.2, q_low: float = 5.0, q_high: float = 95.0, round_to: int = 4) -> List[float]:
    x = np.asarray(raw, dtype=float)
    mask = np.asarray(incorrect_mask, dtype=int)
    keep = (mask == 0)

    x_keep = x[keep] if np.any(keep) else x
    if mode == "relu":
        y = np.maximum(x_keep, 0.0)
        out = np.zeros_like(x, dtype=float)
        out[keep] = y
        out[~keep] = 0.0
    elif mode == "signed":
        y = _normalize_signed(x_keep, tau=tau, clip_z=clip_z, deadzone=deadzone)
        out = np.zeros_like(x, dtype=float)
        out[keep] = y
        out[~keep] = -1.0
    elif mode == "unit":
        y = _normalize_signed(x_keep, tau=tau, clip_z=clip_z, deadzone=deadzone)
        y = 0.5 * (y + 1.0)
        out = np.zeros_like(x, dtype=float)
        out[keep] = y
        out[~keep] = 0.0
    elif mode == "minmax":
        lo = np.percentile(x_keep, q_low) if len(x_keep) else np.min(x)
        hi = np.percentile(x_keep, q_high) if len(x_keep) else np.max(x)
        if hi <= lo:
            scale = np.zeros_like(x_keep)
        else:
            scale = np.clip((x_keep - lo) / (hi - lo), 0.0, 1.0)
        out = np.zeros_like(x, dtype=float)
        out[keep] = scale
        out[~keep] = 0.0
    else:
        raise ValueError(f"Unknown mode: {mode}")

    if round_to is not None:
        out = np.round(out.astype(float), round_to)
    return out.tolist()



In [35]:
def inject_incorrect_step( entry: Dict[str, Any],*, min_raw_value: float = -1e6, norm_mode_for_mi_norm: str = "signed", norm_kwargs: Optional[Dict[str, Any]] = None,
    perturbation_probs: Optional[Dict[str, float]] = None, rng: Optional[random.Random] = None,) -> Dict[str, Any]:
    
    R = rng or random
    e = copy.deepcopy(entry)
    steps = e.get("completion", [])
    if not isinstance(steps, list) or len(steps) == 0:
        return e

    L = len(steps)
    insert_pos = random.randint(0, L)
    ptype = _sample_perturbation_type(perturbation_probs)

    # 1) 텍스트 삽입
    new_steps, ins = create_perturbed_steps(steps, ptype, insert_pos)
    e["completion"] = new_steps
    e["perturbation"] = ptype
    e["perturbation_pos"] = ins
    e["incorrect_mask"] = [0] * len(new_steps)
    e["incorrect_mask"][ins] = 1
    if perturbation_probs:
        e["perturbation_probs_used"] = perturbation_probs

    # 2) 모든 수치형 벡터에 min_raw_value 삽입 (mi_norm 제외)
    numeric_keys = []
    for k, v in list(e.items()):
        if k == "mi_norm":
            continue
        if _is_numeric_list(v, L):
            numeric_keys.append(k)

    for k in numeric_keys:
        old = [float(x) for x in e[k]]
        e[k] = old[:ins] + [float(min_raw_value)] + old[ins:]

    # 3) mi_norm 재계산 (주입 스텝 무시하고 통계 산출 → 주입 스텝은 바닥값 강제)
    base_priority = ["mi_loo", "mi_shapley", "mi_margin"]  # "ori_rewards", "contributions", "cmi"
    base_key = next((k for k in base_priority if k in e and _is_numeric_list(e[k], len(new_steps))), None)

    if base_key is not None:
        nk = norm_kwargs or {}
        e["mi_norm"] = recompute_mi_norm_with_ignore(
            e[base_key],
            e["incorrect_mask"],
            mode=norm_mode_for_mi_norm,
            tau=nk.get("tau", 1.5),
            clip_z=nk.get("clip_z", 3.0),
            deadzone=nk.get("deadzone", 0.2),
            q_low=nk.get("q_low", 5.0),
            q_high=nk.get("q_high", 95.0),
            round_to=nk.get("round_to", 4),
        )
        e["norm_mode"] = norm_mode_for_mi_norm
        e["norm_kwargs"] = nk
    return e

def process_dataset_copy(data: List[Dict[str, Any]],*,min_raw_value: float = -1e6, norm_mode_for_mi_norm: str = "signed",
    norm_kwargs: Optional[Dict[str, Any]] = None, perturbation_probs: Optional[Dict[str, float]] = None, seed: int = 42,) -> (List[Dict[str, Any]], Dict[str, Any]):
    
    random.seed(seed)
    out = []
    stats = {"total": 0, "types": Counter(), "positions": Counter()}
    for e in data:
        ne = inject_incorrect_step( e,min_raw_value=min_raw_value, norm_mode_for_mi_norm=norm_mode_for_mi_norm, norm_kwargs=norm_kwargs, perturbation_probs=perturbation_probs,)
        out.append(ne)
        stats["total"] += 1
        stats["types"][ne["perturbation"]] += 1
        stats["positions"][ne["perturbation_pos"]] += 1
    # Counter -> dict
    stats["types"] = dict(stats["types"])
    stats["positions"] = dict(stats["positions"])
    return out, stats


In [36]:
# ======================= File I/O helpers =======================
def print_stats(stats: Dict[str, Any]):
    print(f"Processed entries: {stats.get('total', 0)}")
    print("Type distribution:")
    for k, v in sorted(stats.get("types", {}).items()):
        print(f"  - {k}: {v}")
    print("Insert position distribution (0-based after renumbering):")
    for k, v in sorted(stats.get("positions", {}).items(), key=lambda x: int(x[0])):
        print(f"  - pos {k}: {v}")

def process_file_copy(input_path: str, output_path: str, **kwargs):
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    assert isinstance(data, list), "Top-level JSON must be a list"
    new_data, stats = process_dataset_copy(data, **kwargs)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(new_data, f, ensure_ascii=False, indent=2)
    print_stats(stats)

In [31]:
process_file_copy(
    input_path="/home/leena/ccc_eval/mcts_prm/cmi_samples/test_json.json",
    output_path="/home/leena/ccc_eval/mcts_prm/cmi_samples/test_incorr.json",
    min_raw_value=-1e6,                          # -5 대신 더 강하게도 OK. (정규화 통계에서 무시하므로 안전)
    norm_mode_for_mi_norm="unit",                # "signed" | "unit" | "minmax" | "relu"
    norm_kwargs={"tau":1.5, "clip_z":3.0, "deadzone":0.2, "q_low":5, "q_high":95, "round_to":4},
    perturbation_probs={"wrong_step":0.4, "irrelevant":0.4, "self_reflection":0.2},
)

Processed entries: 2685
Type distribution:
  - irrelevant: 1116
  - self_reflection: 500
  - wrong_step: 1069
Insert position distribution (0-based after renumbering):
  - pos 0: 642
  - pos 1: 642
  - pos 2: 666
  - pos 3: 396
  - pos 4: 194
  - pos 5: 97
  - pos 6: 29
  - pos 7: 15
  - pos 8: 2
  - pos 9: 2


In [3]:
import os
import sys, json
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
)

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from prm_training.dataset import PRMDataset, PRMPackCollator
from prm_training.wrapper import PRMRewardWrapper

In [4]:
def debug_dataset(dataset: PRMDataset, tokenizer, num_samples=3):
    print("--- Dataset Debug Start ---")
    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]
        print(f"\n--- Sample {i} ---")
        
        decoded_text = tokenizer.decode(sample['input_ids'], skip_special_tokens=False)
        print(f"Decoded Text:\n{decoded_text}")
        
        rw_positions = sample['rw_positions']
        targets = sample['targets']
        
        print(f"\nRW Positions: {rw_positions}")
        print(f"Targets: {targets}")
        
        # 각 RW 토큰 위치와 보상 값을 확인
        tokens = tokenizer.convert_ids_to_tokens(sample['input_ids'])
        for pos, target in zip(rw_positions, targets):
            print(f"  - Pos {pos}, Token '{tokens[pos]}', Target Reward: {target:.4f}")
    print("\n--- Dataset Debug End ---")

def debug_schema(ds, nprint=3):
        n = len(ds)
        bad = []
        for i in range(n):
            s = ds[i]
            if "rw_positions" not in s or "targets" not in s:
                bad.append(i)
        print(f"[schema] total={n}, pack_ok={n-len(bad)}, missing_pack_keys={len(bad)}")
        if bad[:nprint]:
            print("examples of bad idx:", bad[:nprint])
            for j in bad[:nprint]:
                print(ds[j].keys())

tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B", trust_remote_code=True, use_fast=False)
data_path = "/home/leena/ccc_eval/mcts_prm/cmi_samples/test_incorr.json"
with open(data_path, "r", encoding="utf-8") as f:
    entries = json.load(f)

full_ds = PRMDataset(entries, tok, reward_key="mi_loo", max_length=512)
debug_dataset(full_ds, tok, num_samples=3)

--- Dataset Debug Start ---

--- Sample 0 ---
Decoded Text:
You are a math expert. Solve the problem step by step.

Problem: Let \[f(x) = \left\{
\begin{array}{cl} ax+3, &\text{ if }x>2, \\
x-5 &\text{ if } -2 \le x \le 2, \\
2x-b &\text{ if } x <-2.
\end{array}
\right.\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).
Step 1: For the piecewise function to be continuous, the cases must "meet" at $2$ and $-2$
<RW>
Step 2: For example, $ax+3$ and $x-5$ must be equal when $x=2$
<RW>
Step 3: This implies $a(2)+3=2-5$, which we solve to get $2a=-6 \Rightarrow a=-3$
<RW>
Step 4: Similarly, $x-5$ and $2x-b$ must be equal when $x=-2$
<RW>
Step 5: Substituting, we get $-2-5=2(-2)-b$, which implies $b=3$
<RW>
Step 6: d/dx (x^2) = 2
<RW>
Step 7: So $a+b=-3+3=$
<RW><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end

In [5]:
full_ds[0]

{'input_ids': tensor([  2610,    525,    264,   6888,   6203,     13,  63284,    279,   3491,
           3019,    553,   3019,    382,  31198,     25,   6771,   1124,     58,
             69,   2075,      8,    284,   1124,   2359,     59,    515,     59,
           7265,     90,   1653,  15170,    564,     92,   3859,     10,     18,
             11,    609,     59,   1318,     90,    421,    335,     87,     29,
             17,     11,  90155,     87,     12,     20,    609,     59,   1318,
             90,    421,    335,    481,     17,   1124,    273,    856,   1124,
            273,    220,     17,     11,  90155,     17,     87,   1455,    609,
             59,   1318,     90,    421,    335,    856,   9119,     17,    624,
             59,    408,     90,   1653,    532,     59,   1291,   7110,     60,
           9885,    400,     64,  35093,      3,    421,    279,   6573,   4482,
            729,    374,  19259,    320,   8206,   3363,    429,   1181,   4771,
            646

In [None]:
from transformers import AutoTokenizer
import torch

model_dir = "/path/to/final_model"
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=False)

prm = PRMRewardWrapper.from_pretrained(
    model_dir,
    tokenizer=tok,                 # 있으면 rw_token_id 자동 설정
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# 예시 입력(이미 각 step 뒤에 <RW>가 삽입된 텍스트라고 가정)
text = "Problem: ...\nStep 1: ...\n<RW>\nStep 2: ...\n<RW>\n"
enc = tok(text, return_tensors="pt")
for k in enc:
    enc[k] = enc[k].to(prm.backbone.device)

rewards = prm.predict_rewards_at_rw(enc["input_ids"], enc["attention_mask"])  # list[Tensor]
print([r.detach().float().cpu().tolist() for r in rewards])  # 각 <RW> 위치의 스칼라 보상들

# Trainer Fintuning

In [1]:
import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModel,
    get_linear_schedule_with_warmup,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Load pretrained model

In [2]:
class FTPRM(nn.Module):
    def __init__(self, base_model_name: str, lora_rank: int = 16, lora_alpha: int = 32):
        super().__init__()
        
        self.backbone = AutoModel.from_pretrained(
            base_model_name,
            device_map="auto",
            trust_remote_code=True,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            ),
        )

        if hasattr(self.backbone, "score"):
            # For Qwen2.5-Math-PRM-7B
            in_feat = self.backbone.score[0].in_features
            self.backbone.score = nn.Sequential(
                nn.Linear(in_feat, in_feat),
                nn.ReLU(),
                nn.Linear(in_feat, 1, bias=True)  # 2 → 1
            )
            self.reg_head = None
        else:
            # Other AutoModel(Causal LM)
            hidden = self.backbone.config.hidden_size
            self.reg_head = nn.Sequential(
                nn.Linear(hidden, hidden // 4),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden // 4, 1)
            )

        # Add Lora Adapter
        self.backbone = prepare_model_for_kbit_training(self.backbone)
        lora_cfg = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
        )
        self.backbone = get_peft_model(self.backbone, lora_cfg)
        self._activate_head_params()   

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels=None):
        out = self.backbone(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=True,
                            return_dict=True)
        hidden = out.hidden_states[-1]                     # (B, L, H)

        # Last token vector
        if attention_mask is None:   
            rep = hidden[:, -1, :]
        else:
            seq_len = attention_mask.sum(1) - 1           # (B,)
            rep = hidden[torch.arange(hidden.size(0), device=hidden.device), seq_len, :]

        # head 통과
        if self.reg_head is None:
            pred = self.backbone.score(rep).squeeze(-1)
        else:
            pred = self.reg_head(rep).squeeze(-1)

        if labels is not None:              # training / eval
            loss = F.mse_loss(pred, labels.float())
            return loss, pred   
        else:                               # pure inference
            return pred

    def _activate_head_params(self):
        if self.reg_head is not None:
            for p in self.reg_head.parameters():
                p.requires_grad_(True)
        else:
            for p in self.backbone.score.parameters():
                p.requires_grad_(True)

    def get_trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]
    
    def get_parameter_stats(self):
        trainable_params = 0
        all_param = 0
        module_stats = {}
        
        for name, param in self.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
                
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['trainable'] += param.numel()
                module_stats[module_name]['total'] += param.numel()
            else:
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['total'] += param.numel()
        
        return {
            'total_params': all_param,
            'trainable_params': trainable_params,
            'trainable_ratio': trainable_params / all_param * 100,
            'module_stats': module_stats
        }


In [3]:
model_name = "Qwen/Qwen2.5-Math-PRM-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = FTPRM(base_model_name=model_name)
model.to(device)
# model.get_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 4/4 [00:22<00:00,  5.72s/it]
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


FTPRM(
  (backbone): PeftModel(
    (base_model): LoraModel(
      (model): Qwen2ForProcessRewardModel(
        (model): Qwen2Model(
          (embed_tokens): Embedding(152064, 3584)
          (layers): ModuleList(
            (0-27): 28 x Qwen2DecoderLayer(
              (self_attn): Qwen2SdpaAttention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=3584, out_features=3584, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=3584, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=16, out_features=3584, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                  (lora_magnitude_

In [6]:
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# PRM Custom Class
class FTLM(nn.Module):
    def __init__(self, base_model_name: str, lora_rank: int = 16, lora_alpha: int = 32, mlp_ratio: int = 4, value_head_prefix: str = "value_head", 
                 normalize_reward: bool = False):
        super().__init__()
        
        # Use AutoModelForCausalLM for pure CausalLM fine-tuning
        self.backbone = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            device_map="auto",
            trust_remote_code=True,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            ),
        )

        # Add Lora Adapter for efficient fine-tuning
        self.backbone = prepare_model_for_kbit_training(self.backbone)
        lora_cfg = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
        )
        self.backbone = get_peft_model(self.backbone, lora_cfg)

        # Value head for reward prediction
        hidden = self.backbone.config.hidden_size
        mlp_hidden = hidden // mlp_ratio
        head = nn.Sequential(
            nn.Linear(hidden, mlp_hidden, bias=False),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(mlp_hidden, 1, bias=False),
        )
        self.value_head_prefix = value_head_prefix
        setattr(self, value_head_prefix, head)

        # head 가중치 학습 가능하도록 보장
        for p in head.parameters():
            p.requires_grad_(True)
        
        self.normalize_reward = normalize_reward
        self.register_buffer("mean", torch.zeros(1), persistent=False)
        self.register_buffer("std",  torch.ones(1),  persistent=False)

        # 캐시 비활성 (gradient checkpointing 호환)
        self.backbone.config.use_cache = False

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels=None, return_hidden: bool = False):
        if attention_mask is None:
            attention_mask = (input_ids != self.backbone.config.pad_token_id).long()

        # position_ids = cumulative mask
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 0)

        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            use_cache=False,
        )

        last_hidden = outputs.hidden_states[-1]        # (B, T, H)
        # Index of last non-pad token  → (B, 1)
        eos_idx = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(-1, keepdim=True)

        values = getattr(self, self.value_head_prefix)(last_hidden).squeeze(-1)   # (B, T)
        reward = values.gather(1, eos_idx).squeeze(1)                             # (B,)

        if labels is not None:
            loss = F.mse_loss(reward, labels.float())
            return loss, reward
        else:
            # if (not self.training) and self.normalize_reward:
            #     reward = (reward - self.mean) / (self.std + 1e-8)
            return (reward, last_hidden) if return_hidden else reward

    def get_trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]
    
    def get_parameter_stats(self):
        trainable_params = 0
        all_param = 0
        module_stats = {}
        
        for name, param in self.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
                
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['trainable'] += param.numel()
                module_stats[module_name]['total'] += param.numel()
            else:
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['total'] += param.numel()
        
        return {
            'total_params': all_param,
            'trainable_params': trainable_params,
            'trainable_ratio': trainable_params / all_param * 100,
            'module_stats': module_stats
        }


In [7]:
model_name = "Qwen/Qwen2.5-Math-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = FTLM(base_model_name=model_name, value_head_prefix="value_head")
model.to(device)

stats = model.get_parameter_stats()
print(f"Total parameters: {stats['total_params']:,}")
print(f"Trainable parameters: {stats['trainable_params']:,}")
print(f"Trainable ratio: {stats['trainable_ratio']:.2f}%")
print(model)

Loading checkpoint shards: 100%|██████████| 4/4 [00:15<00:00,  3.93s/it]


Total parameters: 4,366,276,992
Trainable parameters: 13,304,704
Trainable ratio: 0.30%
FTLM(
  (backbone): PeftModel(
    (base_model): LoraModel(
      (model): Qwen2ForCausalLM(
        (model): Qwen2Model(
          (embed_tokens): Embedding(152064, 3584)
          (layers): ModuleList(
            (0-27): 28 x Qwen2DecoderLayer(
              (self_attn): Qwen2Attention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=3584, out_features=3584, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=3584, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=16, out_features=3584, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
               

## Prepare dataset

In [5]:
@dataclass
class PRMCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: Optional[int] = 8

    def __call__(self, batch):
        input_ids, rewards = zip(*batch)
        lengths = [len(ids) for ids in input_ids]
        max_len = max(lengths)
        if self.pad_to_multiple_of:
            max_len = int(math.ceil(max_len / self.pad_to_multiple_of) * self.pad_to_multiple_of)

        padded = [
            torch.cat([ids, ids.new_full((max_len - len(ids),), self.tokenizer.pad_token_id)])
            for ids in input_ids
        ]
        input_ids = torch.stack(padded)
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
        rewards = torch.tensor(rewards, dtype=torch.float)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": rewards,
        }

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from prm_dataset import StepwisePRMDataset
from config import PRMConfig

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral.json", "r") as file:
    gsm8k_raw = json.load(file)

cfg = PRMConfig()
full_ds = StepwisePRMDataset(gsm8k_raw, tokenizer, cfg.max_new_tokens, reward_type="cmi")
print(f"Full dataset size: {len(full_ds)}") 

indices = list(range(len(full_ds)))
split_idx = int(0.9 * len(full_ds)) if len(full_ds) > 1 else 1
train_indices = indices[:split_idx]
val_indices = indices[split_idx:] if len(full_ds) > 1 else indices[:1]

train_ds = Subset(full_ds, train_indices)
valid_ds = Subset(full_ds, val_indices)

collate = PRMCollator(tokenizer)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate)
valid_loader = DataLoader(valid_ds, batch_size=16, shuffle=False, collate_fn=collate)
print("Finish Loading Dataset")

output_dir = "/home/leena/ccc_eval/mcts_prm/prm_training/checkpoints/pt_prm"
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=50,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,       # ← 변수명 교정
    data_collator=collate        # ← 반드시 추가
)
print("Finish Loading Trainer")

with torch.no_grad():
    b = collate([train_ds[i] for i in range(4)])
    out = model(**{k: v.to(device) for k, v in b.items()})
    print("loss:", out[0].item(), "pred shape:", out[1].shape)

model.get_parameter_stats()

Reward type: cmi
Full dataset size: 26720
Finish Loading Dataset


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Finish Loading Trainer
loss: 0.6680973768234253 pred shape: torch.Size([4])


{'total_params': 3830919681,
 'trainable_params': 22944769,
 'trainable_ratio': 0.598936310614861,
 'module_stats': {'backbone': {'trainable': 22944769, 'total': 3830919681}}}

In [None]:
def predict_step_rewards(model, tokenizer, problem, steps):
    """
    problem: str
    steps  : List[str]  # ["Step 1: ...", "Step 2: ...", ...]
    """
    model.eval()
    prefix = [f"Problem: {problem}"]
    rewards = []
    for s in steps:
        prefix.append(s)
        txt = "\n".join(prefix)
        enc = tokenizer(txt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            reward = model(**enc).item()      # forward returns pred
        rewards.append(reward)
    return rewards

# QwQ prompting

In [2]:
import math
import re, sys
from typing import List, Optional, Tuple, Dict
import torch
from datasets import load_dataset
from tqdm import tqdm
from vllm import LLM, SamplingParams
from pathlib import Path

# Project-level helpers
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from config import PRMConfig
from utils import _sanitize_enhanced, _numeric_equiv_enhanced, _extract_boxed_answer, system_prompt
from inference.answer_extractor import AnswerExtractor
from inference.answer_matcher import MathAnswerScorer

def _format_system_prompt() -> str:
    return (
        "You are an **expert mathematical-reasoning assistant**.\n\n"
        "## Format rules\n"
        "1. Begin *every* reasoning line with the exact prefix `Step k:` where `k = 1, 2, …`. No other prefix is allowed.\n"
        "2. Show *all* intermediate calculations using standard symbols (×, ÷, ±, √).\n"
        "3. Put your final answer within `Answer: \boxed{}`. and **stop immediately** — no extra text after the answer.\n"
        "4. Each step must be concise *yet mathematically rigorous*.\n"
        "5. Do not generate any text or reflection if you reach the final answer.\n\n"
        "Follow these rules exactly — evaluations are case- and format‑sensitive.\n"
        "Respond *only* in the specified format."
    )

def build_chat_messages_qwq(*, question: str, tokenizer, dataset: str, shots: Optional[List[Tuple[str, str, str]]] = None, prefix_context: Optional[str] = None, next_label: Optional[str] = None,) -> str:
    system_prompt = _format_system_prompt()
    default_shots: List[Tuple[str, str, str]] = [
        (
            "gsm8k, math, olympiad, omni",
            "Problem: What is the next number in the sequence 2, 4, 8, 16?",
            "Step 1: Identify the pattern; each term is multiplied by 2.\n"
            "Step 2: 16 × 2 = 32\n"
            "Answer: 32",
        ),
        (
            "gsm8k, math",
            "Problem: Solve for x: 3x + 7 = 22",
            "Step 1: Subtract 7 from both sides: 3x = 15\n"
            "Step 2: Divide by 3: x = 5\n"
            "Answer: 5",
        ),
        (
            "olympiad, omni",
            "Problem: Determine whether v₁ = [1,2] and v₂ = [3,6] are linearly independent.",
            "Step 1: Observe v₂ = 3 · v₁, so v₂ is a scalar multiple of v₁.\n"
            "Step 2: Therefore the vectors are linearly dependent.\n"
            "Answer: Dependent",
        ),
    ]

    if shots is None:
        shots = default_shots
    user_lines: List[str] = []
    
    if prefix_context:
        user_lines.append(prefix_context.rstrip())
    
    user_lines.append(f"Problem: {question}".rstrip())
    if next_label:
        user_lines.append(next_label.rstrip())
    user_content = "\n".join([ln for ln in user_lines if ln])

    messages: List[Dict[str, str]] = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    for tag, q, a in shots:
        if dataset.lower() in tag.lower():
            messages.append({"role": "user", "content": q})
            messages.append({"role": "assistant", "content": a})

    messages.append({"role": "user", "content": user_content})
    return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

def build_masking_chat_messages_qwq(tokenizer, sentence: str) -> str:
    masking_system = (
        "In the sentence below, mask any word or expression that seems crucial "
        "(such as a variable, a number or an operator, etc.) for solving the math problem "
        "by replacing it with '[MASKED]'."
    )
    user_content = f"Sentence: \"{sentence}\"\nRewritten:"
    messages = [
        {"role": "system", "content": masking_system},
        {"role": "user", "content": user_content},
    ]
    return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

class ContriRewardvLLM:
    def __init__(self, config: "PRMConfig", model_name: str = "mistralai/Mathstral-7B-v0.1"):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.llm = LLM(
            model=model_name,
            trust_remote_code=True,
            dtype="bfloat16",
            gpu_memory_utilization=0.8,
            max_model_len=4096,
            quantization="bitsandbytes",
        )
        self.tokenizer = self.llm.get_tokenizer()

        self.scorer = MathAnswerScorer()
        self.extractor = AnswerExtractor()
        
        self.rollout_params = SamplingParams(
            temperature=0.6,
            top_p=0.95,
            max_tokens=self.config.max_new_tokens,
            n=self.config.num_rollouts,
            repetition_penalty=1.1,
        )
        self.masking_params = SamplingParams(
            temperature=0.6,
            top_p=0.95,
            max_tokens=self.config.max_new_tokens,
            n=self.config.num_rollouts,
            repetition_penalty=1.1,
        )
        print(f"vLLM model loaded: {model_name}")

    def _batched_generate(self, prompts: List[str], params: SamplingParams):
        return self.llm.generate(prompts, params)

    def _score_batch(self, outputs, gold_answer: str) -> List[float]:
        rewards = []
        for result in outputs:
            correct = 0
            for comp in result.outputs:
                text = comp.text or ""
                tail = text.rsplit("Answer:", 1)[-1] if "Answer:" in text else text
                pred = self.extractor.extract_pred_answer(tail)
                print("Prediction Answer:", pred)
                if self.scorer.answers_match(pred, gold_answer):
                    correct += 1
            rewards.append(correct / float(self.config.num_rollouts))
        return rewards

    def _make_prompt(self, *, question: str, staged_steps: List[str], next_label: str, dataset: str) -> str:
        prefix_context = "\n".join(staged_steps)
        return build_chat_messages_qwq(question=question, tokenizer=self.tokenizer, dataset=dataset, prefix_context=prefix_context, next_label=next_label)
    
    def compute_step_rewards_batch(self, question: str, dataset: str, steps: List[str], gold_answer: str) -> List[float]:
        prompts: List[str] = []
        for i in range(len(steps)):
            next_label = f"Step {i + 2}:" if i < len(steps) - 1 else "Answer:"
            staged_steps = steps[: i + 1]
            prompts.append(self._make_prompt(question=question, staged_steps=staged_steps, next_label=next_label, dataset=dataset))
        outputs = self._batched_generate(prompts, self.rollout_params)
        print("Generated rollout outputs:")
        for i, output in enumerate(outputs):
            print(f"Output {i}: {output.outputs[0].text}")
        return self._score_batch(outputs, gold_answer)
        
    def model_masking_batch(self, texts: List[str]) -> List[str]:
        mask_prompts = [build_masking_chat_messages_qwq(self.tokenizer, t) for t in texts]
        outputs = self._batched_generate(mask_prompts, self.masking_params)
        print("Mask Generation:")
        for i, output in enumerate(outputs):
            print(f"Output {i}: {output.outputs[0].text}")
        return [out.outputs[0].text.strip() for out in outputs]

    def perturb_step_rewards_batch(self, question: str, dataset: str, steps: List[str], gold_answer: str, use_llm: bool = True) -> List[float]:
        bodies = []
        prefixes = []
        for step in steps:
            m = re.match(r"^[\s>#*\-]*Step\s*\d+\s*[:.\-]\s*", step, flags=re.I)
            prefixes.append(m.group(0) if m else "")
            bodies.append(step[len(prefixes[-1]):])

        if use_llm:
            masked_bodies = self.model_masking_batch(bodies)
        else:
            masked_bodies = [self._MASK_PATTERN.sub("[MASKED]", b) for b in bodies]
        print("Masked Body:", masked_bodies)   
        prompts = []
        for i in range(len(steps)):
            masked_step = prefixes[i] + masked_bodies[i]
            staged_steps = steps[:i] + [masked_step]
            next_label = f"Step {i + 2}:" if i < len(steps) - 1 else "Answer:"
            prompts.append(self._make_prompt(question=question, staged_steps=staged_steps, next_label=next_label, dataset=dataset))
        
        outputs = self._batched_generate(prompts, self.rollout_params)
        return self._score_batch(outputs, gold_answer)

    def gsm8k_reward_dataset_vllm(self, *, split: str = "train", start: int = 0, take: Optional[int] = None):
        ds = load_dataset("openai/gsm8k", "main", split=split)
        ds = ds.select(range(start, start + take)) if take else ds
        # ds = ds.select(range(start, len(ds)))
        # print("Generated dataset size: ", len(ds))

        for sample in tqdm(ds, desc="Building GSM8K contri reward-dataset"):
            q_txt, g_sol = sample["question"], sample["answer"]
            lines, gold_ans = [], None
            
            gold_ans = self.extractor.extract_gold_answer(g_sol, "gsm8k")
            if gold_ans is None:
                raise ValueError("gold answer not found for sample")
            
            lines = [ln.strip() for ln in g_sol.splitlines() if ln.strip()]
            steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines)]
            # steps = [f"Step {i+1}: {t}" for i, t in enumerate(lines) if not t.lower().startswith("answer")] 
            print("Steps Split:", steps)

            ori = self.compute_step_rewards_batch(q_txt, "gsm8k", steps, gold_ans)
            ptb = self.perturb_step_rewards_batch(q_txt, "gsm8k", steps, gold_ans, self.config.use_llm)
            print("Original Rewards:", ori)
            print("Masked Rewards:", ptb)
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            entry = {
                "question": q_txt,
                "completion": steps,
                "ori_rewards": ori,
                "ptb_rewards": ptb,
                "contributions": contrib,
                "gold_answer": gold_ans,
            }
            yield entry

    def math_reward_dataset_vllm(self, *, split: str = "train", start: int = 0, take: Optional[int] = None):
        sent_split = re.compile(r'\.(?!\d)(?=\s|$)')
        ds = load_dataset("HuggingFaceTB/MATH", "all", split=split)
        ds = ds.select(range(start, start + take)) if take else ds
        # ds = ds.select(range(start, len(ds)))
        # print("Generated dataset size: ", len(ds))
        
        for sample in tqdm(ds, desc="Building MATH contri reward-dataset"):
            full_sol = sample["solution"]
            
            gold_ans = self.extractor.extract_gold_answer(full_sol, "math")
            if gold_ans is None:
                lines = [line.strip() for line in full_sol.splitlines() if line.strip()]
                for line in reversed(lines):
                    if re.search(r'[\d\-+*/()=]', line):
                        gold_ans = _sanitize_enhanced(line)
                        break
            
            sol_wo_box = re.sub(r'\\boxed\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', '', full_sol)
            raw_steps = [s.strip() for s in sent_split.split(sol_wo_box) if s.strip()]
            steps = [f"Step {i+1}: {s}" for i, s in enumerate(raw_steps)]
            print("Steps Split:", steps)

            ori = self.compute_step_rewards_batch(sample["problem"],"math", steps, gold_ans)
            ptb = self.perturb_step_rewards_batch(sample["problem"], "math", steps, gold_ans, self.config.use_llm)
            print("Original Rewards:", ori)
            print("Masked Rewards:", ptb)
            contrib = [round(o - p, 4) for o, p in zip(ori, ptb)]

            entry = {
                "question": sample["problem"],
                "completion": steps,
                "ori_rewards": ori,
                "ptb_rewards": ptb,
                "contributions": contrib,
                "gold_answer": gold_ans,
            }
            yield entry


NameError: name '__file__' is not defined

# BCE training

In [None]:
import torch, torch.nn as nn
from dataclasses import dataclass
from typing import Any
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
# (옵션) LoRA를 쓰려면 peft import 후 주석 해제
# from peft import get_peft_model, LoraConfig

class PRMTokenClassifier(nn.Module):
    """
    LM 마지막 히든(토큰 단위)에 1-logit 헤드를 붙여,
    <RW> 위치들만 선택해 BCEWithLogitsLoss(logit, r) 최적화.
    """
    def __init__(self, base_model_name: str, tokenizer: AutoTokenizer,
                 use_lora: bool = False, freeze_backbone: bool = False):
        super().__init__()
        self.tok = tokenizer
        self.lm  = AutoModelForCausalLM.from_pretrained(
            base_model_name, torch_dtype=torch.bfloat16,
            device_map="auto", trust_remote_code=True
        )
        # 스페셜 토큰 추가 후 임베딩 리사이즈
        self.lm.resize_token_embeddings(len(self.tok))

        # (옵션) LoRA
        # if use_lora:
        #     lora = LoraConfig(
        #         r=8, lora_alpha=16, lora_dropout=0.05,
        #         target_modules=["q_proj","k_proj","v_proj","o_proj"]
        #     )
        #     self.lm = get_peft_model(self.lm, lora)

        if freeze_backbone:
            for p in self.lm.parameters(): p.requires_grad = False

        h = self.lm.config.hidden_size
        self.head = nn.Linear(h, 1)  # 1-logit

        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, input_ids=None, attention_mask=None,
                rw_mask=None, rewards=None):
        out = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        H = out.hidden_states[-1]       # (bs, seq, h)
        logit = self.head(H).squeeze(-1)  # (bs, seq)

        loss = None
        if rewards is not None and rw_mask is not None:
            m = rw_mask.bool()
            if m.ndim == 2:   # (bs, seq)
                sel_logit = logit[m]
                sel_label = rewards[m].float()
            else:             # (seq,) 단일 배치 보호
                sel_logit = logit[m]
                sel_label = rewards[m].float()
            if sel_logit.numel() > 0:
                loss = self.loss_fn(sel_logit, sel_label)
            else:
                # <RW>가 잘려서 없는 샘플은 로스 없음(=0으로 무시)
                loss = torch.zeros([], dtype=logit.dtype, device=logit.device)

        return {"loss": loss, "logits_prm": logit, "hidden_states": out.hidden_states}

@dataclass
class CollatedBatch:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    rw_mask: torch.Tensor
    rewards: torch.Tensor

def default_collate(batch: list[dict]) -> CollatedBatch:
    # 모두 동일 길이(padding="max_length")라면 stack만 해도 됨
    keys = ["input_ids","attention_mask","rw_mask","rewards"]
    out: Dict[str, Any] = {}
    for k in keys:
        out[k] = torch.stack([b[k] for b in batch], dim=0)
    return CollatedBatch(**out)


In [None]:
# 1) 토크나이저/엔트리 준비
base = "mistral-community/Mistral-7B-v0.2"   # 또는 Qwen/Mistral base
tok  = AutoTokenizer.from_pretrained(base, trust_remote_code=True)

# (중요) pad_token 보정
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# (예시) entries 골격: 각 entry는 step list와 같은 개수의 보상 리스트 포함
entries = [
    {
        "question": "Janet’s ducks lay 16 eggs per day ...",
        "completion": [
            "Step 1: ...",
            "Step 2: ...",
            "Step 3: ...",
            "Step 4: ...",
        ],
        "mi_loo": [0.9, 0.8, 0.95, 0.1],
        "contributions": [0.7, 0.6, 0.8, 0.2],
        "mi_filtered": [0.1, 0.1, 0.1, 0.0],
        "ori_rewards": [1,1,1,0],
        "gold_answer": "18",
    },
    # ...
]

# 2) 데이터셋 생성
train_ds = StepwisePRMWithRW(
    entries, tokenizer=tok,
    max_length=2048,
    reward_type="mi_loo",    # 네 파이프라인에 맞게 선택
    preprocess=True,
    apply_norm=True, norm_mode="unit",
    norm_kwargs={"tau":1.5, "clip_z":3.0, "deadzone":0.2, "q_low":5, "q_high":95, "round_to":4},
)

# 3) 모델/트레이너
model = PRMTokenClassifier(base, tokenizer=tok, use_lora=False, freeze_backbone=False)

args = TrainingArguments(
    output_dir="prm_rw_out",
    learning_rate=1e-5,                 # 7B 기준 5e-6 ~ 2e-5 탐색 추천
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=1,
    bf16=True, fp16=False,
    logging_steps=10,
    save_steps=1000,
    report_to=[],
)

class PRMTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        out = model(**inputs.__dict__)
        loss = out["loss"]
        return (loss, out) if return_outputs else loss

trainer = PRMTrainer(
    model=model, args=args,
    train_dataset=train_ds,
    data_collator=default_collate,
)
trainer.train()


In [None]:
import torch

def score_sample(model: PRMTokenClassifier, tok: AutoTokenizer, text: str, max_length=2048):
    enc = tok(text + f"\n{RW_TOKEN}", return_tensors="pt", truncation=True, max_length=max_length)
    input_ids = enc["input_ids"].to(model.lm.device)
    attn      = enc["attention_mask"].to(model.lm.device)
    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attn)
        logit = out["logits_prm"]  # (1, seq)
    rw_id = tok.convert_tokens_to_ids(RW_TOKEN)
    rw_mask = (input_ids[0] == rw_id)
    s = logit[0][rw_mask]
    return torch.sigmoid(s).cpu().tolist()

# 예시
# prob = score_sample(model, tok, "System...\nProblem: ...\nStep 1 ...\nStep 2 ...")
# print(prob)  # [0.83]  (해당 샘플의 스텝 보상)
