# Utils

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedTokenizer
from collections import defaultdict, deque
import math
import logging
import json
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm, trange
import wandb
import os
from pathlib import Path
from collections import deque
from torch.utils.data import Dataset
from typing import List, Tuple, Dict, Any, Optional
from collections import Counter
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

model_name = "Qwen/Qwen2.5-Math-PRM-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, device_map="auto",  torch_dtype=torch.bfloat16, trust_remote_code=True,).eval()

data = {
    "system": "Please reason step by step, and put your final answer within \\boxed{}.",
    "query": "Sue lives in a fun neighborhood.  One weekend, the neighbors decided to play a prank on Sue. On Friday morning, the neighbors placed 18 pink plastic flamingos out on Sue's front yard. On Saturday morning, the neighbors took back one third of the flamingos, painted them white, and put these newly painted white flamingos back out on Sue's front yard.  Then, on Sunday morning, they added another 18 pink plastic flamingos to the collection. At noon on Sunday, how many more pink plastic flamingos were out than white plastic flamingos?",
    "response": [
      "To find out how many more pink plastic flamingos were out than white plastic flamingos at noon on Sunday, we can break down the problem into steps. First, on Friday, the neighbors start with 18 pink plastic flamingos.",
      "On Saturday, they take back one third of the flamingos. Since there were 18 flamingos, (1/3 \\times 18 = 6) flamingos are taken back. So, they have (18 - 6 = 12) flamingos left in their possession. Then, they paint these 6 flamingos white and put them back out on Sue's front yard. Now, Sue has the original 12 pink flamingos plus the 6 new white ones. Thus, by the end of Saturday, Sue has (12 + 6 = 18) pink flamingos and 6 white flamingos.",
      "On Sunday, the neighbors add another 18 pink plastic flamingos to Sue's front yard. By the end of Sunday morning, Sue has (18 + 18 = 36) pink flamingos and still 6 white flamingos.",
      "To find the difference, subtract the number of white flamingos from the number of pink flamingos: (36 - 6 = 30). Therefore, at noon on Sunday, there were 30 more pink plastic flamingos out than white plastic flamingos. The answer is (\\boxed{30})."
    ]
}


Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.40it/s]
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels

    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i] # seq_len, num_labels [452, 2]
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels [# of steps, label]
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    return all_scores_res


messages = [
    {"role": "system", "content": data['system']},
    {"role": "user", "content": data['query']},
    {"role": "assistant", "content": "<extra_0>".join(data['response']) + "<extra_0>"},
]
conversation_str = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=False
)

input_ids = tokenizer.encode(conversation_str, return_tensors="pt").to(model.device)

outputs = model(input_ids=input_ids)

step_sep_id = tokenizer.encode("<extra_0>")[0]
token_masks = (input_ids == step_sep_id)
step_reward = make_step_rewards(outputs[0], token_masks)
print(step_reward)  # [[1.0, 0.1904296875, 0.9765625, 1.0]]

In [9]:
def check_token_learning():
    test_tokens = ["<extra_0>", "<extra_1>", "<extra_2>", "|STEP|"]
    for token in test_tokens:
        try:
            token_id = tokenizer.encode(token)[0]
            print(f"Token '{token}': ID {token_id}")
            if hasattr(model, 'get_input_embeddings'):
                embedding = model.get_input_embeddings()
                token_embedding = embedding(torch.tensor([token_id]))
                print(f"  Embedding shape: {token_embedding.shape}")
                
        except Exception as e:
            print(f"Token '{token}': Error - {e}")

check_token_learning()

Token '<extra_0>': ID 151651
Token '<extra_0>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '<extra_1>': ID 27
Token '<extra_1>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '<extra_2>': ID 27
Token '<extra_2>': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
Token '|STEP|': ID 91
Token '|STEP|': Error - Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)


# Incorrect Step Dataset

In [24]:
import json
import random
from typing import List, Dict, Any, Tuple
import copy

def create_incorrect_steps(gold_steps: List[str], incorrect_type: str = "wrong_calculation") -> Tuple[List[str], int]:
    if not gold_steps:
        return gold_steps

    perturbed_steps = gold_steps.copy()
    insert_position = random.randint(1, len(gold_steps))   # 1 ≤ pos ≤ len
    
    if incorrect_type == "wrong_calculation":
        wrong_calculations = [
            f"Step {insert_position + 1}: 5 + 3 = 9",  # 5 + 3 = 8
            f"Step {insert_position + 1}: 10 * 2 = 15",  # 10 * 2 = 20
            f"Step {insert_position + 1}: 20 / 4 = 6",  # 20 / 4 = 5
            f"Step {insert_position + 1}: 7 - 3 = 5",  # 7 - 3 = 4
            f"Step {insert_position + 1}: 2^2 = 5",  # 2^2 = 4
            f"Step {insert_position + 1}: sqrt(16) = 3",  # sqrt(16) = 4
            f"Step {insert_position + 1}: 3 * 7 = 24",  # 3 * 7 = 21
            f"Step {insert_position + 1}: 15 / 3 = 6",  # 15 / 3 = 5
        ]
        insert_step = random.choice(wrong_calculations)
    elif incorrect_type == "logical_error":
        logical_errors = [
            f"Step {insert_position + 1}: Since we need to find the total, we should multiply by 0 instead of adding.",
            f"Step {insert_position + 1}: To solve this, we need to subtract the larger number from the smaller one.",
            f"Step {insert_position + 1}: The answer should be negative because we're dealing with positive numbers.",
            f"Step {insert_position + 1}: We can ignore the units since they don't affect the calculation.",
            f"Step {insert_position + 1}: The order of operations doesn't matter here, so we can do addition first.",
        ]
        insert_step = random.choice(logical_errors)
    elif incorrect_type == "irrelevant":
        irrelevant_steps = [
            f"Step {insert_position + 1}: The weather is nice today.",
            f"Step {insert_position + 1}: I like literatures very much.",
            f"Step {insert_position + 1}: This reminds me of my school days.",
            f"Step {insert_position + 1}: The sky is blue and beautiful.",
            f"Step {insert_position + 1}: I should drink more water.",
            f"Step {insert_position + 1}: Python is the language of the universe.",
        ]
        insert_step = random.choice(irrelevant_steps)
    elif incorrect_type == "repetition":
        if len(gold_steps) > 1:
            repeat_step = gold_steps[insert_position - 1]  # 이전 step
            step_num = insert_position + 1
            step_content = repeat_step.split(":", 1)[1] if ":" in repeat_step else ""
            insert_step = f"Step {step_num}:{step_content}"
        else:
            insert_position = 1
            step_num = insert_position + 1
            repeat_step = gold_steps[0]
            step_content = repeat_step.split(":", 1)[1] if ":" in repeat_step else ""
            insert_step = f"Step {step_num}:{step_content}"
    else:
        insert_step = f"Step {insert_position + 1}: 5 + 3 = 9"
    
    perturbed_steps.insert(insert_position, insert_step)

    for i in range(insert_position + 1, len(perturbed_steps)):
            if perturbed_steps[i].startswith("Step "):
                step_num = i + 1
                step_content = perturbed_steps[i].split(":", 1)[1] if ":" in perturbed_steps[i] else ""
                perturbed_steps[i] = f"Step {step_num}:{step_content}"
    
    return perturbed_steps, insert_position

def add_incorrect_completion(entry: Dict[str, Any], incorrect_type: str = "wrong_calculation", negative_reward: float = -1.0) -> Dict[str, Any]:
    new_entry = copy.deepcopy(entry)
    original_completion = entry.get("completion", [])
    # incorrect completion 생성
    incorrect_completion, insert_position = create_incorrect_steps(original_completion, incorrect_type)
    new_entry["completion"] = incorrect_completion
    # incorrect rewards 생성
    for key, val in entry.items():
        if isinstance(val, list) and len(val) == len(original_completion) and key != "completion":
            new_vec = val.copy()
            if key == "contributions":
                new_vec.insert(insert_position, negative_reward)
            elif key == "mi_filtered":
                new_vec.insert(insert_position, 0.0)
            else:
                new_vec.insert(insert_position, negative_reward)
            new_entry[key] = new_vec
    # meta data 추가
    new_entry["is_incorrect"] = True
    new_entry["incorrect_type"] = incorrect_type

    return new_entry

def extend_dataset_with_incorrect_steps(dataset: List[Dict[str, Any]], incorrect_types: List[str] = None,negative_rewards: List[float] = None, ratio: float = 0.5) -> List[Dict[str, Any]]:
    if incorrect_types is None:
        incorrect_types = ["wrong_calculation", "logical_error", "irrelevant_step", "repetition"]
    if negative_rewards is None:
        negative_rewards = [-1.0] * len(incorrect_types)
    
    # negative_rewards를 incorrect_types와 매칭
    if len(negative_rewards) != len(incorrect_types):
        negative_rewards = negative_rewards * (len(incorrect_types) // len(negative_rewards) + 1)
        negative_rewards = negative_rewards[:len(incorrect_types)]
    
    extended_dataset = []
    for entry in dataset:
        extended_dataset.append(entry)
        if random.random() < ratio:
            incorrect_type = random.choice(incorrect_types) # 랜덤하게 incorrect type 선택
            type_idx = incorrect_types.index(incorrect_type)
            negative_reward = negative_rewards[type_idx]
            incorrect_entry = add_incorrect_completion(entry, incorrect_type, negative_reward)
            extended_dataset.append(incorrect_entry)
    
    return extended_dataset



In [25]:
input_path = "/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral.json"
output_path = "/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral_incorrect.json"

with open(input_path, "r") as file:
    dataset = json.load(file)
print(f"Original dataset size: {len(dataset)}")

extended_dataset = extend_dataset_with_incorrect_steps(dataset=dataset, ratio=0.5)
print(f"Extended dataset size: {len(extended_dataset)}")
print(f"Added {len(extended_dataset) - len(dataset)} incorrect samples")

# 확장된 dataset 저장
with open(output_path, "w") as file2:
    json.dump(extended_dataset, file2, indent=2)

Original dataset size: 7473
Extended dataset size: 11141
Added 3668 incorrect samples


In [26]:
# 통계 출력
correct_count = sum(1 for entry in extended_dataset if not entry.get("is_incorrect", False))
incorrect_count = sum(1 for entry in extended_dataset if entry.get("is_incorrect", False))

print(f"Correct samples: {correct_count}")
print(f"Incorrect samples: {incorrect_count}")

# 타입별 통계
type_counts = {}
for entry in extended_dataset:
    if entry.get("is_incorrect", False):
        incorrect_type = entry.get("incorrect_type", "unknown")
        type_counts[incorrect_type] = type_counts.get(incorrect_type, 0) + 1

print("\nIncorrect types distribution:")
for incorrect_type, count in type_counts.items():
    print(f"  {incorrect_type}: {count}")


Correct samples: 7473
Incorrect samples: 3668

Incorrect types distribution:
  repetition: 911
  wrong_calculation: 928
  irrelevant_step: 923
  logical_error: 906


# Trainer Fintuning

In [1]:
import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModel,
    get_linear_schedule_with_warmup,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import random
import string
import re

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


## Load pretrained model

In [2]:
class FTPRM(nn.Module):
    def __init__(self, base_model_name: str, lora_rank: int = 16, lora_alpha: int = 32):
        super().__init__()
        
        self.backbone = AutoModel.from_pretrained(
            base_model_name,
            device_map="auto",
            trust_remote_code=True,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            ),
        )

        if hasattr(self.backbone, "score"):
            # For Qwen2.5-Math-PRM-7B
            in_feat = self.backbone.score[0].in_features
            self.backbone.score = nn.Sequential(
                nn.Linear(in_feat, in_feat),
                nn.ReLU(),
                nn.Linear(in_feat, 1, bias=True)  # 2 → 1
            )
            self.reg_head = None
        else:
            # Other AutoModel(Causal LM)
            hidden = self.backbone.config.hidden_size
            self.reg_head = nn.Sequential(
                nn.Linear(hidden, hidden // 4),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden // 4, 1)
            )

        # Add Lora Adapter
        self.backbone = prepare_model_for_kbit_training(self.backbone)
        lora_cfg = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
        )
        self.backbone = get_peft_model(self.backbone, lora_cfg)
        self._activate_head_params()   

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels=None):
        out = self.backbone(input_ids=input_ids,
                            attention_mask=attention_mask,
                            output_hidden_states=True,
                            return_dict=True)
        hidden = out.hidden_states[-1]                     # (B, L, H)

        # Last token vector
        if attention_mask is None:   
            rep = hidden[:, -1, :]
        else:
            seq_len = attention_mask.sum(1) - 1           # (B,)
            rep = hidden[torch.arange(hidden.size(0), device=hidden.device), seq_len, :]

        # head 통과
        if self.reg_head is None:
            pred = self.backbone.score(rep).squeeze(-1)
        else:
            pred = self.reg_head(rep).squeeze(-1)

        if labels is not None:              # training / eval
            loss = F.mse_loss(pred, labels.float())
            return loss, pred   
        else:                               # pure inference
            return pred

    def _activate_head_params(self):
        if self.reg_head is not None:
            for p in self.reg_head.parameters():
                p.requires_grad_(True)
        else:
            for p in self.backbone.score.parameters():
                p.requires_grad_(True)

    def get_trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]
    
    def get_parameter_stats(self):
        trainable_params = 0
        all_param = 0
        module_stats = {}
        
        for name, param in self.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
                
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['trainable'] += param.numel()
                module_stats[module_name]['total'] += param.numel()
            else:
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['total'] += param.numel()
        
        return {
            'total_params': all_param,
            'trainable_params': trainable_params,
            'trainable_ratio': trainable_params / all_param * 100,
            'module_stats': module_stats
        }


In [3]:
model_name = "Qwen/Qwen2.5-Math-PRM-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = FTPRM(base_model_name=model_name)
model.to(device)
# model.get_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 4/4 [00:22<00:00,  5.72s/it]
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: ['lm_head.weight']
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


FTPRM(
  (backbone): PeftModel(
    (base_model): LoraModel(
      (model): Qwen2ForProcessRewardModel(
        (model): Qwen2Model(
          (embed_tokens): Embedding(152064, 3584)
          (layers): ModuleList(
            (0-27): 28 x Qwen2DecoderLayer(
              (self_attn): Qwen2SdpaAttention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=3584, out_features=3584, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=3584, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=16, out_features=3584, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                  (lora_magnitude_

In [6]:
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# PRM Custom Class
class FTLM(nn.Module):
    def __init__(self, base_model_name: str, lora_rank: int = 16, lora_alpha: int = 32, mlp_ratio: int = 4, value_head_prefix: str = "value_head", 
                 normalize_reward: bool = False):
        super().__init__()
        
        # Use AutoModelForCausalLM for pure CausalLM fine-tuning
        self.backbone = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            device_map="auto",
            trust_remote_code=True,
            quantization_config=BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_use_double_quant=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            ),
        )

        # Add Lora Adapter for efficient fine-tuning
        self.backbone = prepare_model_for_kbit_training(self.backbone)
        lora_cfg = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
            lora_dropout=0.05,
            bias="none",
        )
        self.backbone = get_peft_model(self.backbone, lora_cfg)

        # Value head for reward prediction
        hidden = self.backbone.config.hidden_size
        mlp_hidden = hidden // mlp_ratio
        head = nn.Sequential(
            nn.Linear(hidden, mlp_hidden, bias=False),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(mlp_hidden, 1, bias=False),
        )
        self.value_head_prefix = value_head_prefix
        setattr(self, value_head_prefix, head)

        # head 가중치 학습 가능하도록 보장
        for p in head.parameters():
            p.requires_grad_(True)
        
        self.normalize_reward = normalize_reward
        self.register_buffer("mean", torch.zeros(1), persistent=False)
        self.register_buffer("std",  torch.ones(1),  persistent=False)

        # 캐시 비활성 (gradient checkpointing 호환)
        self.backbone.config.use_cache = False

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels=None, return_hidden: bool = False):
        if attention_mask is None:
            attention_mask = (input_ids != self.backbone.config.pad_token_id).long()

        # position_ids = cumulative mask
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 0)

        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            use_cache=False,
        )

        last_hidden = outputs.hidden_states[-1]        # (B, T, H)
        # Index of last non-pad token  → (B, 1)
        eos_idx = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(-1, keepdim=True)

        values = getattr(self, self.value_head_prefix)(last_hidden).squeeze(-1)   # (B, T)
        reward = values.gather(1, eos_idx).squeeze(1)                             # (B,)

        if labels is not None:
            loss = F.mse_loss(reward, labels.float())
            return loss, reward
        else:
            # if (not self.training) and self.normalize_reward:
            #     reward = (reward - self.mean) / (self.std + 1e-8)
            return (reward, last_hidden) if return_hidden else reward

    def get_trainable_parameters(self):
        return [p for p in self.parameters() if p.requires_grad]
    
    def get_parameter_stats(self):
        trainable_params = 0
        all_param = 0
        module_stats = {}
        
        for name, param in self.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
                
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['trainable'] += param.numel()
                module_stats[module_name]['total'] += param.numel()
            else:
                module_name = name.split('.')[0]
                if module_name not in module_stats:
                    module_stats[module_name] = {'trainable': 0, 'total': 0}
                module_stats[module_name]['total'] += param.numel()
        
        return {
            'total_params': all_param,
            'trainable_params': trainable_params,
            'trainable_ratio': trainable_params / all_param * 100,
            'module_stats': module_stats
        }


In [7]:
model_name = "Qwen/Qwen2.5-Math-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = FTLM(base_model_name=model_name, value_head_prefix="value_head")
model.to(device)

stats = model.get_parameter_stats()
print(f"Total parameters: {stats['total_params']:,}")
print(f"Trainable parameters: {stats['trainable_params']:,}")
print(f"Trainable ratio: {stats['trainable_ratio']:.2f}%")
print(model)

Loading checkpoint shards: 100%|██████████| 4/4 [00:15<00:00,  3.93s/it]


Total parameters: 4,366,276,992
Trainable parameters: 13,304,704
Trainable ratio: 0.30%
FTLM(
  (backbone): PeftModel(
    (base_model): LoraModel(
      (model): Qwen2ForCausalLM(
        (model): Qwen2Model(
          (embed_tokens): Embedding(152064, 3584)
          (layers): ModuleList(
            (0-27): 28 x Qwen2DecoderLayer(
              (self_attn): Qwen2Attention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=3584, out_features=3584, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=3584, out_features=16, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=16, out_features=3584, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
               

## Prepare dataset

In [5]:
@dataclass
class PRMCollator:
    tokenizer: AutoTokenizer
    pad_to_multiple_of: Optional[int] = 8

    def __call__(self, batch):
        input_ids, rewards = zip(*batch)
        lengths = [len(ids) for ids in input_ids]
        max_len = max(lengths)
        if self.pad_to_multiple_of:
            max_len = int(math.ceil(max_len / self.pad_to_multiple_of) * self.pad_to_multiple_of)

        padded = [
            torch.cat([ids, ids.new_full((max_len - len(ids),), self.tokenizer.pad_token_id)])
            for ids in input_ids
        ]
        input_ids = torch.stack(padded)
        attention_mask = (input_ids != self.tokenizer.pad_token_id).long()
        rewards = torch.tensor(rewards, dtype=torch.float)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": rewards,
        }

In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from prm_dataset import StepwisePRMDataset
from config import PRMConfig

with open("/home/leena/ccc_eval/mcts_prm/cmi_samples/total_gsm8k_merge_mistral.json", "r") as file:
    gsm8k_raw = json.load(file)

cfg = PRMConfig()
full_ds = StepwisePRMDataset(gsm8k_raw, tokenizer, cfg.max_new_tokens, reward_type="cmi")
print(f"Full dataset size: {len(full_ds)}") 

indices = list(range(len(full_ds)))
split_idx = int(0.9 * len(full_ds)) if len(full_ds) > 1 else 1
train_indices = indices[:split_idx]
val_indices = indices[split_idx:] if len(full_ds) > 1 else indices[:1]

train_ds = Subset(full_ds, train_indices)
valid_ds = Subset(full_ds, val_indices)

collate = PRMCollator(tokenizer)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate)
valid_loader = DataLoader(valid_ds, batch_size=16, shuffle=False, collate_fn=collate)
print("Finish Loading Dataset")

output_dir = "/home/leena/ccc_eval/mcts_prm/prm_training/checkpoints/pt_prm"
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=50,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,       # ← 변수명 교정
    data_collator=collate        # ← 반드시 추가
)
print("Finish Loading Trainer")

with torch.no_grad():
    b = collate([train_ds[i] for i in range(4)])
    out = model(**{k: v.to(device) for k, v in b.items()})
    print("loss:", out[0].item(), "pred shape:", out[1].shape)

model.get_parameter_stats()

Reward type: cmi
Full dataset size: 26720
Finish Loading Dataset


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Finish Loading Trainer
loss: 0.6680973768234253 pred shape: torch.Size([4])


{'total_params': 3830919681,
 'trainable_params': 22944769,
 'trainable_ratio': 0.598936310614861,
 'module_stats': {'backbone': {'trainable': 22944769, 'total': 3830919681}}}

In [None]:
def predict_step_rewards(model, tokenizer, problem, steps):
    """
    problem: str
    steps  : List[str]  # ["Step 1: ...", "Step 2: ...", ...]
    """
    model.eval()
    prefix = [f"Problem: {problem}"]
    rewards = []
    for s in steps:
        prefix.append(s)
        txt = "\n".join(prefix)
        enc = tokenizer(txt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            reward = model(**enc).item()      # forward returns pred
        rewards.append(reward)
    return rewards

# Utils

In [None]:

# --------------------------- Training utils ------------------------------- #
def train(model: RewardRegressionModel, loader: DataLoader, optimizer, scheduler, device, loss_fn):
    model.train()
    total_loss = 0.0
    for batch in tqdm(loader, desc="train"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        preds = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(preds, labels)

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad(set_to_none=True)

        total_loss += loss.item() * input_ids.size(0)
    return total_loss / len(loader.dataset)


def evaluate(model: RewardRegressionModel, loader: DataLoader, device, loss_fn):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, desc="eval"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            preds = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = loss_fn(preds, labels)
            total_loss += loss.item() * input_ids.size(0)
    return total_loss / len(loader.dataset)


# Model ---------------------------------------------------------------
model = RewardRegressionModel(
    base_model_name=args.base_model,
    lora_rank=args.lora_rank,
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# Optimiser & scheduler ----------------------------------------------
no_decay = ["bias", "LayerNorm.weight"]
optim_groups = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
optimizer = torch.optim.AdamW(optim_groups, lr=args.lr, betas=(0.9, 0.999), eps=1e-8)

total_steps = len(train_loader) * args.epochs // args.gradient_accumulation
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

# Loss function -------------------------------------------------------
if args.loss == "mse":
    loss_fn = nn.MSELoss()
elif args.loss == "huber":
    loss_fn = nn.SmoothL1Loss()
else:  # logcosh
    loss_fn = lambda pred, tgt: torch.mean(torch.log(torch.cosh(pred - tgt)))

# Gradient accumulation ----------------------------------------------
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision in ("bf16", "fp16") else None

best_val = float("inf")
for epoch in range(1, args.epochs + 1):
    print(f"\nEpoch {epoch}/{args.epochs}")

    train_loss = 0.0
    model.train()
    pbar = tqdm(train_loader, desc="train")
    for step, batch in enumerate(pbar, 1):
        with torch.cuda.amp.autocast(enabled=args.mixed_precision in ("bf16", "fp16")):
            preds = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
            )
            loss = loss_fn(preds, batch["labels"].to(device)) / args.gradient_accumulation

        if scaler is not None:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if step % args.gradient_accumulation == 0:
            if scaler is not None:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)

        train_loss += loss.item() * args.gradient_accumulation * batch["input_ids"].size(0)
        pbar.set_postfix(loss=loss.item() * args.gradient_accumulation)

    train_loss /= len(train_ds)
    val_loss = evaluate(model, valid_loader, device, loss_fn)
    print(f"Epoch {epoch} | train MSE {train_loss:.4f} | valid MSE {val_loss:.4f}")

    if val_loss < best_val:
        best_val = val_loss
        ckpt_path = os.path.join(args.output_dir, f"best_epoch{epoch}_loss{val_loss:.4f}.pt")
        torch.save({
            "model_state": model.state_dict(),
            "tokenizer": tokenizer.__dict__,
            "args": vars(args),
        }, ckpt_path)
        print(f"Saved new best checkpoint → {ckpt_path}")

print("Training finished. Best validation MSE:", best_val)


In [None]:
import argparse
import json
import logging
from pathlib import Path
from typing import List, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Project imports
from prm_trainer_mse import PRMTrainerMSE
from config import PRMConfig
from data_generation.contri_reward import ContriRewardvLLM

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def load_or_generate_data(config: PRMConfig, dataset_name: str = "olympiad") -> List[Dict]:
    """
    Load or generate training data
    """
    if dataset_name == "olympiad":
        # Generate OlympiadBench data
        logger.info("Generating OlympiadBench dataset...")
        reward_generator = ContriRewardvLLM(config)
        
        # Collect data from different splits
        train_entries = []
        val_entries = []
        
        # Generate training data
        for entry in reward_generator.olympiad_reward_dataset_vllm(
            split="train", 
            start=0, 
            take=config.dataset_size if config.dataset_size > 0 else None
        ):
            train_entries.append(entry)
        
        # Generate validation data (smaller subset)
        val_size = min(100, len(train_entries) // 5)  # 20% or max 100
        for entry in reward_generator.olympiad_reward_dataset_vllm(
            split="validation", 
            start=0, 
            take=val_size
        ):
            val_entries.append(entry)
        
        logger.info(f"Generated {len(train_entries)} training samples and {len(val_entries)} validation samples")
        
        return train_entries, val_entries
    
    elif dataset_name == "gsm8k":
        # Generate GSM8K data
        logger.info("Generating GSM8K dataset...")
        reward_generator = ContriRewardvLLM(config)
        
        train_entries = []
        val_entries = []
        
        for entry in reward_generator.gsm8k_reward_dataset_vllm(
            split="train",
            start=0,
            take=config.dataset_size if config.dataset_size > 0 else None
        ):
            train_entries.append(entry)
        
        val_size = min(100, len(train_entries) // 5)
        for entry in reward_generator.gsm8k_reward_dataset_vllm(
            split="test",
            start=0,
            take=val_size
        ):
            val_entries.append(entry)
        
        logger.info(f"Generated {len(train_entries)} training samples and {len(val_entries)} validation samples")
        
        return train_entries, val_entries
    
    elif dataset_name == "math":
        # Generate MATH data
        logger.info("Generating MATH dataset...")
        reward_generator = ContriRewardvLLM(config)
        
        train_entries = []
        val_entries = []
        
        for entry in reward_generator.math_reward_dataset_vllm(
            split="train",
            start=0,
            take=config.dataset_size if config.dataset_size > 0 else None
        ):
            train_entries.append(entry)
        
        val_size = min(100, len(train_entries) // 5)
        for entry in reward_generator.math_reward_dataset_vllm(
            split="validation",
            start=0,
            take=val_size
        ):
            val_entries.append(entry)
        
        logger.info(f"Generated {len(train_entries)} training samples and {len(val_entries)} validation samples")
        
        return train_entries, val_entries
    
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

def load_pretrained_model(model_name: str):
    """
    Load pretrained model and tokenizer
    """
    logger.info(f"Loading model: {model_name}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto" if torch.cuda.is_available() else None,
        trust_remote_code=True
    )
    
    return model, tokenizer

def main():
    parser = argparse.ArgumentParser(description="Train PRM with MSE loss")
    parser.add_argument("--config", type=str, default="config.py", help="Path to config file")
    parser.add_argument("--dataset", type=str, default="olympiad", 
                       choices=["olympiad", "gsm8k", "math"], help="Dataset to use")
    parser.add_argument("--from-scratch", action="store_true", help="Train from scratch")
    parser.add_argument("--checkpoint", type=str, help="Path to checkpoint to resume from")
    parser.add_argument("--output-dir", type=str, default="./checkpoints", help="Output directory")
    
    args = parser.parse_args()
    
    # Load configuration
    config = PRMConfig()
    
    # Override config with command line args
    if args.output_dir:
        config.checkpoint_dir = args.output_dir
    
    # Set dataset size if not specified
    if config.dataset_size == 0:
        config.dataset_size = 1000  # Default dataset size
    
    # Load model and tokenizer
    model, tokenizer = load_pretrained_model(config.model_name)
    
    # Initialize trainer
    trainer = PRMTrainerMSE(
        cfg=config,
        model=model,
        tokenizer=tokenizer,
        from_scratch=args.from_scratch
    )
    
    # Load checkpoint if specified
    if args.checkpoint:
        trainer.load_checkpoint(args.checkpoint)
        logger.info(f"Resumed training from checkpoint: {args.checkpoint}")
    
    # Load or generate data
    train_entries, val_entries = load_or_generate_data(config, args.dataset)
    
    # Start training
    logger.info("Starting PRM training with MSE loss...")
    trainer.train(train_entries, val_entries)
    
    logger.info("Training completed successfully!")

if __name__ == "__main__":
    main() 