In [1]:
import json

In [2]:
def load_results(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    new_data = []
    for key, details in data.items():
        outputs = [item['content'] for item in details['responses']['0.6']]
        outputs = [
            "<extra_0>".join(item.split("\n\n")) + "<extra_0>"
            for item in outputs
        ]
        correctness = [item['correctness'] for item in details['responses']['0.6']]
        new_data.append({
            'query': key,
            'system': "Please reason step by step, and put your final answer within \\boxed{}.",
            'outputs': outputs,
            'correctness': correctness
        })
    return new_data
    

In [3]:
import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F


def make_step_rewards(logits, token_masks):
    probabilities = F.softmax(logits, dim=-1)
    probabilities = probabilities * token_masks.unsqueeze(-1)  # bs, seq_len, num_labels
    probabilities = probabilities.detach()  # Detach to free graph memory
    
    all_scores_res = []
    for i in range(probabilities.size(0)):
        sample = probabilities[i]  # seq_len, num_labels
        positive_probs = sample[sample != 0].view(-1, 2)[:, 1]  # valid_tokens, num_labels
        non_zero_elements_list = positive_probs.cpu().tolist()
        all_scores_res.append(non_zero_elements_list)
    
    return all_scores_res


model_name = "Qwen/Qwen2.5-Math-PRM-7B"
device = "auto"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
    model_name, 
    device_map=device, 
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).eval()


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.59it/s]
Some weights of the model checkpoint at Qwen/Qwen2.5-Math-PRM-7B were not used when initializing Qwen2ForProcessRewardModel: {'lm_head.weight'}
- This IS expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2ForProcessRewardModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [12]:
torch.cuda.empty_cache()

import gc
gc.collect()

0

: 

In [24]:
import numpy as np

def calculate_reward(model, tokenizer, input, outputs, correctness):
    max_min_reward = 0
    Answer_correct = False
    all_rewards = []
    for output, correct in zip(outputs, correctness):
        messages = [
            {"role": "system", "content": input['system']},
            {"role": "user", "content": input['query']},
            {"role": "assistant", "content": output},
        ]
        conversation_str = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=False
        )
        input_ids = tokenizer.encode(
            conversation_str, 
            return_tensors="pt", 
        ).to(model.device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids)

        step_sep_id = tokenizer.encode("<extra_0>")[0]
        token_masks = (input_ids == step_sep_id)
        step_reward = make_step_rewards(outputs[0], token_masks)
        all_rewards.append(step_reward[0])
        # print(step_reward)
        # print("Minimum reward: ", min(step_reward[0]))
        # print("Correctness: ", correct)
        if np.min(step_reward[0]) > max_min_reward:
            max_min_reward = np.min(step_reward[0])
            Answer_correct = correct
        torch.cuda.empty_cache()
    return max_min_reward, Answer_correct, all_rewards

In [25]:
data = load_results("results/pass_at_64/no_budget_forcing/aime24/no_thinking_r1/temp_0.6/Pass_at_64/DeepSeek-R1-Distill-Qwen-32B_s0_e-1.json")

all_rewards = []

for item in data:
    a, b, c = calculate_reward(model, tokenizer, item, item['outputs'], item['correctness'])
    print("Maximum minimum reward: ", a)
    print("Answer correctness: ", b)
    all_rewards.append(c)

Maximum minimum reward:  0.9921875
Answer correctness:  True
Maximum minimum reward:  0.43359375
Answer correctness:  False
Maximum minimum reward:  0.546875
Answer correctness:  False
Maximum minimum reward:  0.5625
Answer correctness:  False
Maximum minimum reward:  0.416015625
Answer correctness:  False
Maximum minimum reward:  0.42578125
Answer correctness:  False
Maximum minimum reward:  0.88671875
Answer correctness:  False
Maximum minimum reward:  1.0
Answer correctness:  True
Maximum minimum reward:  0.7578125
Answer correctness:  True
Maximum minimum reward:  0.9921875
Answer correctness:  True
Maximum minimum reward:  0.5625
Answer correctness:  False
Maximum minimum reward:  0.75
Answer correctness:  False
Maximum minimum reward:  0.9921875
Answer correctness:  True
Maximum minimum reward:  0.384765625
Answer correctness:  False
Maximum minimum reward:  0.75390625
Answer correctness:  False
Maximum minimum reward:  0.94921875
Answer correctness:  True
Maximum minimum reward:

In [38]:
total_correct = 0
# Remove the reward under 0.1
# new_all_rewards = []
# for rewards in all_rewards:
#     new_rewards = []
#     for reward in rewards:
#         new_rewards.append([r for r in reward if r > 0.05])
#     new_all_rewards.append(new_rewards)

for index, item in enumerate(data):
    print("Logging reward for item: ", index)
    max_reward = 0
    correctness = False
    for i, correct in enumerate(item['correctness']):
        if all_rewards[index][i] == []:
            continue
        # assert(np.min(all_rewards[index][i]) > 0.05)
        if np.min(all_rewards[index][i]) > max_reward:
            max_reward = np.min(all_rewards[index][i])
            correctness = correct
    print("Maximum reward: ", max_reward)
    print("Correctness: ", correctness)
    total_correct += correctness
    print("--------------------------------\n")

print("Total correct: ", total_correct)

Logging reward for item:  0
Maximum reward:  0.9921875
Correctness:  True
--------------------------------

Logging reward for item:  1
Maximum reward:  0.43359375
Correctness:  False
--------------------------------

Logging reward for item:  2
Maximum reward:  0.546875
Correctness:  False
--------------------------------

Logging reward for item:  3
Maximum reward:  0.5625
Correctness:  False
--------------------------------

Logging reward for item:  4
Maximum reward:  0.416015625
Correctness:  False
--------------------------------

Logging reward for item:  5
Maximum reward:  0.42578125
Correctness:  False
--------------------------------

Logging reward for item:  6
Maximum reward:  0.88671875
Correctness:  False
--------------------------------

Logging reward for item:  7
Maximum reward:  1.0
Correctness:  True
--------------------------------

Logging reward for item:  8
Maximum reward:  0.7578125
Correctness:  True
--------------------------------

Logging reward for item:  9

: 