In [1]:
from datasets import load_dataset, Dataset
from collections import defaultdict
from tqdm import tqdm

ds = load_dataset("nvidia/OpenMathInstruct-1", split='train')

print("Grouping examples by dataset and question...")
groups = {'gsm8k': defaultdict(list), 'math': defaultdict(list)}

for i, ex in enumerate(tqdm(ds, desc='Iterating dataset')):
    # Chỉ xét các sample có is_correct == True
    if not ex.get('is_correct', False):
        continue
        
    dataset_name = ex.get('dataset')
    if dataset_name in ('gsm8k', 'math'):
        q = ex.get('question')
        groups[dataset_name][q].append(i)

  from .autonotebook import tqdm as notebook_tqdm


Grouping examples by dataset and question...


Iterating dataset: 100%|██████████| 7321344/7321344 [02:29<00:00, 49084.93it/s]


In [73]:
import re
from typing import List
import torch
import sys
sys.path.append('/home/guest/AdvancedLLMReasoning/utils')
from prompt import PROMPT_V2

def extract_answer(text: str) -> str:
    """Extract answer from \\boxed{}"""
    if "\\boxed{" in text:
        idx = text.rfind("\\boxed{")
        content = ""
        count = 0
        started = False
        for char in text[idx:]:
            if char == "{":
                count += 1
                started = True
                if count == 1: 
                    continue
            elif char == "}":
                count -= 1
            if started:
                if count == 0: 
                    break
                content += char
        return content.strip()
    return None

def parse_solution_into_steps(solution: str) -> List[str]:
    """
    Parse solution into reasoning steps
    - Keep <llm-code> blocks as single steps (don't split by lines)
    - Keep <llm-code-output> blocks as single steps
    - Split plain text by sentences (dấu chấm)
    """
    steps = []
    
    block_pattern = r'(<llm-code>.*?</llm-code>|<llm-code-output>.*?</llm-code-output>)'
    parts = re.split(block_pattern, solution, flags=re.DOTALL)
    
    for part in parts:
        part = part.strip()
        if not part:
            continue
        
        # code block or code-output block, keep as one step
        if part.startswith('<llm-code>') or part.startswith('<llm-code-output>'):
            steps.append(part)
        
        # Plain text - split by sentences (dấu chấm)
        else:
            sentences = re.split(r'(?<=[.!?])\s+', part)
            for sent in sentences:
                sent = sent.strip()
                if sent:
                    steps.append(sent)
    
    return steps

def check_sample(sample, model, tokenizer, N=8):
    question = sample.get('question')
    expected_answer = sample.get('expected_answer')
    
    if not question or not expected_answer:
        return False, None
    
    correct_count = 0
    
    prompt = f"### Question:\n{question}\n\n### Instruction:\nSolve the problem step by step. You can use Python code if needed.\nIf you write code, wrap it inside <llm-code> ... </llm-code>.\nOutput ONLY the final number inside \\boxed{{}}. Example: \\boxed{{42}}.\n\n### Solution:\n"
    
    for _ in range(N):
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.6,
                top_p=0.95,
                pad_token_id=tokenizer.eos_token_id
            )
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        solution = generated_text[len(prompt):].strip()
        
        predicted_answer = extract_answer(solution)
        
        if predicted_answer and normalize_answer(predicted_answer) == normalize_answer(expected_answer):
            correct_count += 1
    
    # Check if difficulty is appropriate (1 <= correct < N)
    is_good = 1 <= correct_count < N
    
    answer_solution = sample.get('generated_solution', '')
    steps = parse_solution_into_steps(answer_solution)
    
    sample_dict = {
        'question': question,
        'steps': steps,
        'expected_answer': expected_answer,
        'correct_count': correct_count,
        'total_rollouts': N,
        'difficulty_score': correct_count / N
    }
    
    if is_good:
        return True, sample_dict
    else:
        return False, sample_dict

def normalize_answer(answer: str) -> str:
    """Normalize answer for comparison"""
    if not answer:
        return ""
    # Remove whitespace, convert to lowercase
    normalized = answer.strip().lower()
    # Remove common LaTeX formatting
    normalized = normalized.replace('\\', '').replace('{', '').replace('}', '')
    normalized = normalized.replace(',', '').replace(' ', '')
    return normalized


In [57]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

print("Loading SFT model...")
ADAPTER_PATH = "/home/guest/AdvancedLLMReasoning/math_tutor_model/math_sft_adapter/v2/final_checkpoint" 
BASE_MODEL_ID = "meta-llama/Llama-3.2-1B"

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
sft_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
sft_model.eval()

Loading SFT model...


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 2048)
        (layers): ModuleList(
          (0-15): 16 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora

In [None]:
parse_solution_into_steps(ds[0]['generated_solution'])

["Let's solve this problem using Python code.",
 '<llm-code>\namount_of_lost_crayons = 18 / 2\namount_of_new_crayons = 20\ntotal_amount = amount_of_lost_crayons + amount_of_new_crayons\ntotal_amount\n</llm-code>',
 '<llm-code-output>\n29.0\n</llm-code-output>',
 'Thus Martha has \\boxed{29} crayons in total.']

In [78]:
from collections import deque
import random
seed = 42

prm_dataset = []
ran = random.Random(seed)

gsm8k_questions = list(groups['gsm8k'].keys())
math_questions = list(groups['math'].keys())

ran.shuffle(gsm8k_questions)
ran.shuffle(math_questions)

result = []
q_deques = defaultdict(dict)

for q in gsm8k_questions:
    ls_indices = groups['gsm8k'][q][:]
    ran.shuffle(ls_indices)
    q_deques['gsm8k'][q] = deque(ls_indices)

for q in math_questions:
    ls_indices = groups['math'][q][:]
    code_indices = []
    text_indices = []
    for idx in ls_indices:
        em = ds[idx].get('error_message')
        code_used = (em != '<not_executed>')
        if code_used:
            code_indices.append(idx)
        else:
            text_indices.append(idx)
    if code_indices:
        ran.shuffle(code_indices)
        q_deques['math'][q] = deque(code_indices)
    else:
        ran.shuffle(text_indices)
        q_deques['math'][q] = deque(text_indices)

q_cycles = {
    'gsm8k': deque(gsm8k_questions),
    'math': deque(math_questions)
}
g_cycle = deque(['gsm8k', 'math'])

In [79]:
results = []
TARGET_SIZE = 50

total_steps = 0
total_questions = 0
samples_tested = 0 

with tqdm(total=TARGET_SIZE, desc="Steps collected") as pbar:
    while total_steps < TARGET_SIZE:
        g = g_cycle.popleft()
        q = q_cycles[g].popleft()
        dq = q_deques[g][q]
        
        while dq:
            indices = dq.popleft()
            samples_tested += 1
            
            # Test với SFT model
            is_good, sample_data = check_sample(ds[indices], sft_model, tokenizer, N=8)
            
            if is_good:
                num_steps = len(sample_data['steps'])
                results.append(sample_data)
                total_questions += 1
                total_steps += num_steps
                
                pbar.update(num_steps)
                pbar.set_postfix({
                    'questions': total_questions,
                    'tested': samples_tested,
                    'skip_rate': f"{(samples_tested-total_questions)/samples_tested*100:.1f}%",
                    'difficulty': f"{sample_data['correct_count']}/8"
                })
                break
            else:
                print(f"Reason for skipping sample: {'too easy' if sample_data and sample_data['correct_count'] == 8 else 'too hard' if sample_data and sample_data['correct_count'] == 0 else 'missing data'}")
                
        
        if dq:
            q_cycles[g].append(q)
        if dq or q_cycles[g]:
            g_cycle.append(g)

Steps collected:   8%|▊         | 4/50 [00:16<03:13,  4.20s/it, questions=1, tested=1, skip_rate=0.0%, difficulty=3/8]

Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard
Reason for skipping sample: too hard


Steps collected:   8%|▊         | 4/50 [03:07<35:52, 46.78s/it, questions=1, tested=1, skip_rate=0.0%, difficulty=3/8]


KeyboardInterrupt: 

=> Nhận định: phương pháp lựa chọn mẫu chất lượng cho PRM được cho là mang lại hiệu suất tốt hơn. Tuy nhiên, trong bối cảnh mô hình 1B sau được train SFT chỉ được học cách sinh chuỗi, khả năng suy luận còn rất yếu, vì vậy việc lựa  chọn mẫu trong khoảng đúng từ 1 -> N-1 trong N lần test thì bỏ qua hầu hết dữ liệu.
=> Vì vậy: chúng tôi quyết định sẽ không thực hiện bước question filtering trong pipeline