# PRM Statistics

In [None]:
from vllm import LLM
import os
import pickle
import re
from tqdm import tqdm
import pandas as pd
from collections import Counter
import random
import json
from transformers import AutoTokenizer
import torch.nn.functional as F
from datasets import load_dataset

save_path = "NuminaMath-CoT"
dataset = load_dataset(save_path)
filtered_train = dataset["train"]
file_path = 'bridge_results.pkl'
with open(file_path, 'rb') as file:
    loaded_results = pickle.load(file)


In [None]:
def extract_missing_steps(text):
    pattern = re.compile(
        r'Missing Step (\d+)：\s*'
        r'The missing step should be placed between Step (\d+) and Step (\d+)\.\s*'
        r'The missing step is:\s*(.*?)'
        r'(?=\s*Missing Step \d+：|\Z)',
        re.DOTALL
    )
    matches = pattern.findall(text)
    results = []
    for match in matches:
        a, x, y, z = match
        results.append({
            'a': a.strip(),
            'x': x.strip(),
            'y': y.strip(),
            'z': z.strip()
        })
    return results


def filter_results(results):
    filtered_results = []
    for result in results:
        x = int(result['x']) 
        y = int(result['y']) 
        z = result['z']
        if x + 1 == y and not z.startswith('####'):
            filtered_results.append(result)
    return filtered_results


def sort_results_by_x(results):
    sorted_results = sorted(results, key=lambda result: int(result['x']))
    return sorted_results


def process_text(text):
    results = extract_missing_steps(text)
    filtered_results = filter_results(results)
    sorted_results = sort_results_by_x(filtered_results)
    return sorted_results

def process(i):
    query = filtered_train[i]['problem']
    response = filtered_train[i]['solution']
    data = response.split("\n\n")
    result = "\n".join([f"step{i + 1}:\n{item}" for i, item in enumerate(data)])
    temp = process_text(loaded_results[i])
    insert_pos = []
    for i in range(len(temp)):
        x = temp[i]['x']
        missing_step = temp[i]['z']
        data.insert(int(x) + i, missing_step)
        insert_pos.append(int(x) + i)
    insert_result = "\n".join([f"step{i + 1}:\n{item}" for i, item in enumerate(data)])
    output = "\n".join([item for item in data])
    sy = "You are a math problem solver. You should think step by step."
    return query, data, insert_pos

In [None]:
query, m, insert_pos = [], [], []
for i in tqdm(range(len(filtered_train))):
    t1, t2, t3 = process(i)
    query.append(t1)
    m.append(t2)
    insert_pos.append(t3)

In [None]:
with open("prm_result.pkl", 'rb') as file:
    results = pickle.load(file)

In [None]:
temp = 0
def get_values_by_indices(values, indices):
    global temp
    try:
        return [values[i] for i in indices]
    except:
        temp += len(indices)
        return "error data"
insert_score = []
for i in range(len(results)):
    insert_score.append(get_values_by_indices(results[i], insert_pos[i]))
print(temp)

In [None]:
distribution = [0] * 10
insert_context = [[] for _ in range(10)]
total_sum = 0.0 
count = 0   
min_val = 1.0 
max_val = 0.0 

for i in range(len(insert_score)):
    if insert_score[i] == "error data":
        continue
    else:
        for j, score in enumerate(insert_score[i]):
            idx = int(score * 10)
            idx = min(idx, 9)  
            distribution[idx] += 1
            insert_context[idx].append(m[i][insert_pos[i][j]]) 
            total_sum += score
            count += 1
            if score < min_val:
                min_val = score
            if score > max_val:
                max_val = score
distribution[0] += temp
mean = total_sum / count if count > 0 else 0.0
print(f"error: {temp}")
print(f"distribution: {distribution}")
print(f"avg: {mean:.4f}")
print(f"min: {min_val:.4f}, max: {max_val:.4f}")
for i in range(len(distribution)):
    print(distribution[i]/sum(distribution))

# construct denoising data (remove bridged steps of low prm score)

In [None]:
def process_prm(i):
    query = filtered_train[i]['problem']
    response = filtered_train[i]['solution']
    sy = "You are a math problem solver. You should think step by step."
    if insert_score[i] == "error data":
        mess = {
            "messages": [
                {
                    "role": "system",
                    "content": sy
                },
                {
                    "role": "user",
                    "content": query
                },
                {
                    "role": "assistant",
                    "content": response
                }
            ]
        }
        return mess
    else:
        delete_pos = []
        for j, score in enumerate(insert_score[i]):
            if score < 0.1:
                delete_pos.append(insert_pos[i][j])
        temp = []
        for j in range(len(m[i])):
            if j not in delete_pos:
                temp.append(m[i][j])
        output = "\n".join([item for item in temp])
        mess = {
            "messages": [
                {
                    "role": "system",
                    "content": sy
                },
                {
                    "role": "user",
                    "content": query
                },
                {
                    "role": "assistant",
                    "content": output
                }
            ]
        }
        return mess

In [None]:
process_data = []
for i in tqdm(range(len(insert_score))):
    mess = process_prm(i)
    process_data.append(mess)
with open('numina-math-multi-fill-prm-0.1.json', 'w') as f:
    json.dump(process_data, f, ensure_ascii=False, indent=4)

print(len(process_data))
print("process_data and process_idx have been saved to JSON files.")