In [1]:
import pandas as pd
import numpy as np
import json
import random
import math
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import AutoTokenizer
import re
import multiprocessing
# import torch
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # torch.cuda.manual_seed_all(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
seed_everything(0)

def instruction_format(s):
    return f'[INST] {s} [/INST]'

seed_everything(0)

# get the processor count
num_processes = multiprocessing.cpu_count()
print(f"Number of processes: {num_processes}")

  from .autonotebook import tqdm as notebook_tqdm


Number of processes: 48


In [2]:
dataset_path = '/storage/group/renkan/luao/reward_datasets/math-shephered/'
model_path = '/storage/group/renkan/luao/pretrain/deepseek-math-7b-base'
prm_token = '[PRM]'
tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'additional_special_tokens':[prm_token]})
prm_token_id = tokenizer.encode(prm_token, add_special_tokens=False)[-1]
max_length = 512
longer_max_length = 1024

In [None]:
ds = load_dataset("json", data_dir=dataset_path, num_proc=num_processes)['train']
# ds = ds.select(range(1000))

def process_example(example, prm_token, tokenizer, max_length, longer_max_length):
    # 分割输入文本
    input_text = example['input']
    steps = re.split(r'Step \d+:', input_text)
    steps = [s for s in steps if s.strip() != '']
    
    # 跳过只有一个步骤的样本
    if len(steps) <= 1:
        return None
    
    # 提取问题和步骤
    question = steps[0]
    steps = [
        f'Step {i + 1}: ' + step.strip().replace('ки', '').strip()
        for i, step in enumerate(steps[1:])
        if step.strip() != ''
    ]
    
    # 处理标签
    label_steps = re.split(r'Step \d+:', example['label'])
    label_steps = [s.strip() for s in label_steps[1:] if s.strip() != '']
    
    # 验证标签以 '+' 或 '-' 结尾
    try:
        for s in label_steps:
            assert s[-1] in ['+', '-'], f"Invalid label format: {label_steps}"
    except AssertionError:
        return None
    
    # 提取步骤标签
    step_labels = [1 if l[-1] == '+' else 0 for l in label_steps]
    
    # 验证步骤和标签数量匹配
    try:
        assert len(steps) == len(step_labels)
    except AssertionError:
        return None
    
    # 构造查询
    query = {
        "query": instruction_format(question), 
        "answer": f" {prm_token}\n".join(steps) + f" {prm_token}",
        "labels": step_labels,
    }
    
    # 计算编码长度
    encoded = tokenizer.encode(query['query'] + query['answer'])
    query_length = len(encoded)
    query['length'] = query_length
    
    return query

processed_ds = ds.map(
    process_example,
    fn_kwargs={
        "prm_token": prm_token,
        "tokenizer": tokenizer,
        "max_length": max_length,
        "longer_max_length": longer_max_length
    },
    batched=False, 
    num_proc=num_processes,    
    remove_columns=ds.column_names,
    desc="Processing dataset"
)

# 过滤掉无效样本（返回 None 的行）
processed_ds = processed_ds.filter(lambda x: x is not None, num_proc=num_processes)

# 主进程打印调试信息
print(f"Dataset Lengths: Normal={len(processed_ds)}")

# Turn these queries into a list of dictionaries
# queries = [{"query": q["query"], "answer": q["answer"], "labels": q["labels"], "length"} for q in processed_ds]

Processing dataset (num_proc=48): 100%|██████████| 444655/444655 [03:07<00:00, 2373.25 examples/s]
Filter (num_proc=48): 100%|██████████| 444558/444558 [00:01<00:00, 300745.08 examples/s]


Dataset Lengths: Normal=444558


In [6]:
from tqdm import tqdm
# iterate the queries
# calculate the mean and std of the query length
# count how many queries are longer than max_length
# count how many queries are longer than longer_max_length
# count how many last steps are false
longer_than_max = 0
longer_than_longer_max = 0
total_length = 0
final_label_false = 0
continuous_false = [0]
continuous_true = [0]
for i, q in tqdm(enumerate(processed_ds), total=len(processed_ds)):    
    total_length += q['length']
    
    if q['length'] > max_length:
        longer_than_max += 1
        if q['length'] > longer_max_length:
            longer_than_longer_max += 1
    
    if q['labels'][-1] == 0:
        final_label_false += 1
        continuous_false[-1] += 1
        if continuous_true[-1] != 0:
            continuous_true.append(0)
    else:
        continuous_true[-1] += 1
        if continuous_false[-1] != 0:
            continuous_false.append(0)

print(f"Mean Length: {total_length / len(processed_ds)}")
# print(f"Std Length: {math.sqrt(sum([(q['length'] - total_length / len(processed_ds)) ** 2 for q in processed_ds]) / len(processed_ds))}")
print(f"Longer than max length: {longer_than_max} / {len(processed_ds)}")
print(f"Longer than longer max length: {longer_than_longer_max} / {len(processed_ds)}")
print(f"Final label false: {final_label_false} / {len(processed_ds)}")
print(f"Continuous false: {continuous_false}")
print(f"Continuous true: {continuous_true}")

100%|██████████| 444558/444558 [00:40<00:00, 10858.29it/s]

Mean Length: 310.6981743664494
Longer than max length: 51776 / 444558
Longer than longer max length: 2543 / 444558
Final label false: 281959 / 444558
Continuous false: [121526, 68, 32, 19, 29, 8, 171, 5, 75, 9, 220, 121, 11, 30, 42, 32, 104, 22, 18, 13, 2, 1, 23, 93, 5, 36, 45, 7, 48, 8, 17, 3, 17, 6, 3, 64, 16, 19, 41, 19, 40, 37, 77, 70, 12, 52, 28, 3, 2, 28, 31, 17, 24, 108, 10, 82, 9, 14, 1, 67, 17, 5, 26, 55, 1, 105, 57, 48, 62, 6, 22, 37, 1, 34, 34, 31, 44, 8, 92, 11, 36, 3, 15, 10, 27, 9, 62, 21, 7, 15, 37, 13, 2, 2, 13, 12, 12, 3, 8, 4, 10, 21, 10, 20, 29, 14, 20, 89, 2, 5, 12, 26, 4, 14, 43, 2, 16, 47, 15, 4, 23, 24, 49, 57, 20, 36, 49, 1, 46, 25, 36, 23, 68, 121, 52, 72, 75, 3, 56, 97, 19, 31, 16, 151, 32, 23, 6, 2, 14, 37, 10, 40, 3, 6, 101, 7, 6, 8, 24, 75, 19, 122, 15, 41, 22, 32, 34, 24, 123, 49, 3, 14, 32, 2, 32, 41, 1, 5, 12, 64, 25, 7, 13, 48, 67, 31, 15, 91, 3, 31, 1, 38, 50, 56, 16, 33, 1, 49, 46, 12, 22, 3, 51, 7, 16, 1, 4, 23, 17, 12, 12, 38, 28, 11, 120, 20, 52, 2


