In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from math_verify import parse, verify
from tqdm import tqdm

model_name = "a-m-team/AM-Thinking-v1"

tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",        # FP16 or BF16 depending on hardware
    device_map=None,          # distributes across available GPUs/CPU
    do_sample=False
).cuda()

Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [1]:
from datasets import load_dataset

data_files = {"train": "/home/ehab02/datasets/math.jsonl"}   # use pattern for all shards
ds = load_dataset("json", data_files=data_files, split="train")

Loading dataset shards:   0%|          | 0/23 [00:00<?, ?it/s]

In [58]:
print(ds[84]['conversations'])

[{'from': 'human', 'value': '22) Three friends are dividing a basket of apples equally and notice they can do so without cutting the apples. A fourth friend arrives and they decide to redistribute the apples equally again, and they manage to do so without cutting the apples. One of the four goes home and eats 2 apples on the way. Upon arriving home, he realizes he can divide the remaining apples in half with his girlfriend without cutting them. What is the minimum number of apples that could have been in the basket originally?\n(A) 12\n(B) 16\n(C) 24\n(D) 36\n(E) 50.', 'info': {'source': 'NuminaMath', 'category': 'math', 'ground_truth': '24', 'test_case': None, 'instruction_constrain': None, 'think_content': None, 'answer_content': None, 'verify_score': None, 'model_name': None, 'ppl': None}}, {'from': 'assistant', 'value': '<think>Okay, let\'s try to solve this problem step by step. The question is about finding the minimum number of apples in the basket originally. Let\'s break down 

In [55]:
def get_reward(datapoint):

    messages = []

    if datapoint['system'] is None:
        datapoint['system'] = ""
    datapoint['system'] += "The answer should just be your final answer in \\boxed{} without any explanation or reasoning between the <answer> and </answer> tags. Even if you're not sure, only give me the final answer."

    if datapoint['system'] is not None:
        messages.append({'role': 'system', 'content': datapoint['system']})

    messages.append({'role': 'user', 'content': datapoint['conversations'][0]['value']})
    
    CoT = datapoint['conversations'][1]['value']
    CoT = CoT[7:CoT.find('<answer>')]
    partial_CoT = ""
    
    rewards = []
    outputs = []
    
    for idx, CoT_paragraph in enumerate(tqdm(["<think>"]+CoT.split('\n\n'))):

        partial_CoT += CoT_paragraph

        message_CoT = partial_CoT + ("" if "</think>" in partial_CoT else "</think>\n") + "<answer>\\boxed{"
        
        text = tokenizer.apply_chat_template(
            messages+[{'role': 'assistant', 'content': message_CoT}],
            tokenize=False,
            add_generation_prompt=True
        )
        text=text[:text.rfind('<|im_end|>')]
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        
        out = model.generate(
            **model_inputs,
            max_new_tokens = 60,
            do_sample = False
        )
        out = out[:, model_inputs.input_ids.shape[1]:]
        out = tokenizer.decode(out[0])

        # torch.cuda.empty_cache()

        bracket_count = 1
        idx = 0

        while idx < len(out):

            if out[idx] == '{':
                bracket_count += 1
            if out[idx] == '}':
                bracket_count -= 1
                if bracket_count == 0:
                    break

            idx += 1
        
        answer = '\('+out[:idx]+'\)'
        gt = '\('+datapoint['conversations'][0]['info']['ground_truth']+'\)'
        
        rewards.append(verify(parse(answer),parse(gt)))
        # print(answer, rewards[-1])
        outputs.append(out)

        partial_CoT += '\n\n'

        #print({"answer":answer, "reward": rewards[-1]})

    #print(outputs)
    #print(rewards)

    return outputs, rewards

In [57]:
zipped = get_reward(ds[98])

100%|███████████████████████████████████████████████████████████████████████████████████| 66/66 [02:15<00:00,  2.05s/it]


In [None]:
print(zipped)

In [19]:
from copy import deepcopy
import time

def get_reward_KV_cache(datapoint):

    messages = []

    if datapoint['system'] is None:
        datapoint['system'] = ""
    datapoint['system'] += "The answer should just be your final answer in \\boxed{} without any explanation or reasoning between the <answer> and </answer> tags. Even if you're not sure, only give me the final answer."

    if datapoint['system'] is not None:
        messages.append({'role': 'system', 'content': datapoint['system']})

    messages.append({'role': 'user', 'content': datapoint['conversations'][0]['value']})
    
    CoT = datapoint['conversations'][1]['value']
    CoT = CoT[7:CoT.find('</think>')]

    text = tokenizer.apply_chat_template(
        messages+[{'role': 'assistant', 'content': "<think>"}],
        tokenize=False,
        add_generation_prompt=True
    )
    text=text[:text.rfind('<|im_end|>')]
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    with torch.no_grad():
        past_key_values = model(**model_inputs, use_cache = True).past_key_values
    
    rewards = []
    outputs = []

    tot = 0
    
    for idx, CoT_paragraph in enumerate(tqdm(CoT.split('\n\n')+[""])):

        message_CoT = "</think>\n" + "<answer>\\boxed{"
        new_inputs = tokenizer([message_CoT], return_tensors="pt").to(model.device)

        st = time.perf_counter()
        copied_cache = deepcopy(past_key_values)
        tot += time.perf_counter() - st
        
        out = model.generate(
            input_ids = torch.hstack((model_inputs.input_ids,new_inputs.input_ids)),
            attention_mask = torch.ones((1,past_key_values.layers[0].keys.shape[2]+new_inputs.attention_mask.shape[1]), device=model.device),
            past_key_values = copied_cache,
            max_new_tokens = 40,
            use_cache = True,
            do_sample = False,
        )
        out = out[:, model_inputs.input_ids.shape[1]:]
        out = tokenizer.decode(out[0])
        
        answer = '\('+out[24:out[:out.find('</answer>')].rfind('}')]+'\)'
        gt = '\('+datapoint['conversations'][0]['info']['ground_truth']+'\)'
        
        rewards.append(verify(parse(answer),parse(gt)))
        if '\\boxed' in answer:
            rewards[-1] = False
        outputs.append(out)

        new_text = ("\n\n" if idx else "") + CoT_paragraph

        new_inputs = tokenizer([new_text], return_tensors="pt").to(model.device)

        with torch.no_grad():
            past_key_values = model(**new_inputs, past_key_values = past_key_values, use_cache = True).past_key_values

        model_inputs.input_ids = torch.hstack((model_inputs.input_ids, new_inputs.input_ids))
        text += new_text

        #print({"answer":answer, "reward": rewards[-1]})

    #print(outputs)
    #print(rewards)
    print(tot)
    
    return zip(outputs,rewards)

In [20]:
get_reward_KV_cache(ds[1])

100%|█████████████████████████████████████████████████████████████████████████████████| 177/177 [02:55<00:00,  1.01it/s]

0.8691285038366914





<zip at 0x145489b2a940>

In [49]:
def get_reward_batched(datapoint):

    messages = []

    if datapoint['system'] is None:
        datapoint['system'] = ""
    datapoint['system'] += "The answer should just be your final answer in \\boxed{} without any explanation or reasoning between the <answer> and </answer> tags. Even if you're not sure, only give me the final answer."

    if datapoint['system'] is not None:
        messages.append({'role': 'system', 'content': datapoint['system']})

    messages.append({'role': 'user', 'content': datapoint['conversations'][0]['value']})
    
    CoT = datapoint['conversations'][1]['value']
    CoT = CoT[7:CoT.find('<answer>')]
    partial_CoT = ""
    
    rewards = []
    outputs = []

    texts = []
    
    for idx, CoT_paragraph in enumerate(tqdm(["<think>"]+CoT.split('\n\n')+[""])):

        partial_CoT += CoT_paragraph

        message_CoT = partial_CoT + ("" if "</think>" in partial_CoT else "</think>\n") + "<answer>\\boxed{"
        
        text = tokenizer.apply_chat_template(
            messages+[{'role': 'assistant', 'content': message_CoT}],
            tokenize=False,
            add_generation_prompt=True
        )
        text=text[:text.rfind('<|im_end|>')]
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

        if len(texts) == 4 or idx == len(CoT.split('\n\n')) + 1:

            model_inputs = tokenizer(texts, return_tensors="pt", padding = True, padding_side='left').to(model.device)
            
            bout = model.generate(
                **model_inputs,
                max_new_tokens = 60,
                do_sample = False
            )
            bout = bout[:, model_inputs.input_ids.shape[1]:]
    
            for i in range(len(bout)):
                out = tokenizer.decode(bout[i])
        
                bracket_count = 1
                idx = 0
        
                while idx < len(out):
        
                    if out[idx] == '{':
                        bracket_count += 1
                    if out[idx] == '}':
                        bracket_count -= 1
                        if bracket_count == 0:
                            break
        
                    idx += 1
                
                answer = '\('+out[:idx]+'\)'
                gt = '\('+datapoint['conversations'][0]['info']['ground_truth']+'\)'
            
                rewards.append(verify(parse(answer),parse(gt)))
                print(answer, rewards[-1])
                outputs.append(out)

            texts = []

        texts.append(text)

        if idx:
            partial_CoT += '\n\n'

        torch.cuda.empty_cache()

        #print({"answer":answer, "reward": rewards[-1]})

    #print(outputs)
    #print(rewards)

    return rewards

In [50]:
print(get_reward_batched(ds[1]))

  3%|██▎                                                                                | 5/178 [00:00<00:16, 10.55it/s]

\(20\) False
\(25\) True
\(25\) True
\(25\) True


  5%|████▏                                                                              | 9/178 [00:01<00:19,  8.54it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


  7%|█████▉                                                                            | 13/178 [00:01<00:21,  7.58it/s]

\(35\) False
\(35\) False
\(35\) False
\(35\) False


 10%|███████▊                                                                          | 17/178 [00:02<00:23,  6.83it/s]

\(35\) False
\(35\) False
\(25\) True
\(25\) True


 12%|█████████▋                                                                        | 21/178 [00:03<00:26,  5.87it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 14%|███████████▌                                                                      | 25/178 [00:04<00:31,  4.92it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 16%|█████████████▎                                                                    | 29/178 [00:05<00:34,  4.32it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 19%|███████████████▏                                                                  | 33/178 [00:06<00:36,  3.96it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 21%|█████████████████                                                                 | 37/178 [00:07<00:39,  3.59it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 23%|██████████████████▉                                                               | 41/178 [00:09<00:41,  3.29it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 25%|████████████████████▋                                                             | 45/178 [00:11<00:44,  2.98it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 28%|██████████████████████▌                                                           | 49/178 [00:12<00:46,  2.78it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 30%|████████████████████████▍                                                         | 53/178 [00:14<00:47,  2.63it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 32%|██████████████████████████▎                                                       | 57/178 [00:16<00:48,  2.50it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 34%|████████████████████████████                                                      | 61/178 [00:18<00:49,  2.35it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 37%|█████████████████████████████▉                                                    | 65/178 [00:20<00:49,  2.26it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 39%|███████████████████████████████▊                                                  | 69/178 [00:21<00:49,  2.19it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 41%|█████████████████████████████████▋                                                | 73/178 [00:24<00:50,  2.09it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 43%|███████████████████████████████████▍                                              | 77/178 [00:26<00:50,  1.98it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 46%|█████████████████████████████████████▎                                            | 81/178 [00:28<00:50,  1.90it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 48%|███████████████████████████████████████▏                                          | 85/178 [00:31<00:50,  1.83it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 50%|█████████████████████████████████████████                                         | 89/178 [00:33<00:50,  1.75it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 52%|██████████████████████████████████████████▊                                       | 93/178 [00:36<00:50,  1.70it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 54%|████████████████████████████████████████████▋                                     | 97/178 [00:38<00:49,  1.65it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 57%|█████████████████████████████████████████████▉                                   | 101/178 [00:41<00:47,  1.63it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 59%|███████████████████████████████████████████████▊                                 | 105/178 [00:43<00:45,  1.60it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 61%|█████████████████████████████████████████████████▌                               | 109/178 [00:46<00:44,  1.54it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 63%|███████████████████████████████████████████████████▍                             | 113/178 [00:49<00:43,  1.49it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 66%|█████████████████████████████████████████████████████▏                           | 117/178 [00:52<00:41,  1.45it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 68%|███████████████████████████████████████████████████████                          | 121/178 [00:55<00:40,  1.41it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 70%|████████████████████████████████████████████████████████▉                        | 125/178 [00:58<00:39,  1.35it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 72%|██████████████████████████████████████████████████████████▋                      | 129/178 [01:02<00:38,  1.26it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 75%|████████████████████████████████████████████████████████████▌                    | 133/178 [01:06<00:37,  1.20it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 77%|██████████████████████████████████████████████████████████████▎                  | 137/178 [01:10<00:36,  1.13it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 79%|████████████████████████████████████████████████████████████████▏                | 141/178 [01:14<00:34,  1.08it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 81%|█████████████████████████████████████████████████████████████████▉               | 145/178 [01:18<00:32,  1.02it/s]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 84%|███████████████████████████████████████████████████████████████████▊             | 149/178 [01:23<00:29,  1.03s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 86%|█████████████████████████████████████████████████████████████████████▌           | 153/178 [01:28<00:27,  1.09s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 88%|███████████████████████████████████████████████████████████████████████▍         | 157/178 [01:33<00:24,  1.15s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 90%|█████████████████████████████████████████████████████████████████████████▎       | 161/178 [01:38<00:20,  1.19s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 93%|███████████████████████████████████████████████████████████████████████████      | 165/178 [01:43<00:15,  1.23s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 95%|████████████████████████████████████████████████████████████████████████████▉    | 169/178 [01:48<00:11,  1.26s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 97%|██████████████████████████████████████████████████████████████████████████████▋  | 173/178 [01:54<00:06,  1.29s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


 99%|████████████████████████████████████████████████████████████████████████████████▌| 177/178 [02:00<00:01,  1.33s/it]

\(25\) True
\(25\) True
\(25\) True
\(25\) True


100%|█████████████████████████████████████████████████████████████████████████████████| 178/178 [02:01<00:00,  1.47it/s]

\(25\) True
[False, True, True, True, True, True, True, True, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, Tr




In [None]:
inc = 0
for i in tqdm(range(len(ds))):
    get_reward(ds[i])
    #if not get_reward(ds[i]):
    #    inc += 1
    #    print(i)
    #if i%100==99:
    #    print("current ratio:",inc/(i+1))

  0%|                                                                            | 1/558129 [00:02<458:57:08,  2.96s/it]

In [5]:
print("tokenize <answer>:", tokenizer.tokenize("<answer>"))
print("ids for <answer>:", tokenizer.convert_tokens_to_ids(tokenizer.tokenize("<answer>")))
print("single-id? ->", tokenizer.convert_tokens_to_ids("<answer>") if "<answer>" in tokenizer.get_vocab() else None)

tokenize <answer>: ['<', 'answer', '>']
ids for <answer>: [27, 9217, 29]
single-id? -> None


In [70]:
print({"text":get_reward(ds[14])})

{'text': '<|im_start|>system\nYou are a helpful assistant. To answer the user\'s question, you first think about the reasoning process and then provide the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.<|im_end|>\n<|im_start|>user\n5. All three-digit numbers from 100 to 999 are written in a row without spaces. Kostya underlined \\( k \\) consecutive digits in this sequence, and Andrey underlined other \\( k \\) consecutive digits in this sequence. It turned out that the \\( k \\)-digit numbers underlined by the boys are equal. For what largest \\( k \\) could this have happened?<|im_end|>\n<|im_start|>assistant\n<think>Okay, let\'s try to figure out this problem. The question is about three-digit numbers from 100 to 999 written in a row, so the sequence starts with 100101102...999. We need to find the largest k such that t

In [9]:
def compute_lengths(datapoint, stopping_point):

    CoT = datapoint['conversations'][1]['value']
    CoT = CoT[7:CoT.find('</think>')]
    partial_CoT = ""

    stopping_length = None
    total_length = None

    for idx, CoT_paragraph in enumerate(["<think>"]+CoT.split('\n\n')):

        partial_CoT += CoT_paragraph

        if idx == stopping_point:
            stopping_length = tokenizer([partial_CoT + "</think>\n"], return_tensors = "pt").input_ids.shape[1]

        if idx == len(CoT.split('\n\n')):
            total_length = tokenizer([partial_CoT + "</think>\n"], return_tensors = "pt").input_ids.shape[1]

        partial_CoT += '\n\n'

    return stopping_length, total_length

import json

tot_shortened = 0
tot_orig = 0

with open("rewards.jsonl", "r") as f:
    for line in tqdm(f):
        if not line.strip():
            continue
        entry = json.loads(line)
        rewards = entry["rewards"]
        if not rewards[-1]:
            continue
        stopping_length, total_length = compute_lengths(ds[entry["id"]], rewards.index(True))
        tot_shortened += stopping_length
        tot_orig += total_length

print(tot_shortened / tot_orig)

1156it [00:07, 164.29it/s]

0.3407935263992418





In [16]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from math_verify import parse, verify
from tqdm import tqdm
import torch.multiprocessing as mp
from worker_utils import worker_process

mp.set_start_method('spawn', force=True)

num_gpus = torch.cuda.device_count()

task_queue = mp.Queue()
results_queue = mp.Queue()
processes = []

for gpu_id in range(num_gpus):
    p = mp.Process(target=worker_process, args=(gpu_id, task_queue, results_queue))
    processes.append(p)
    p.start()

print(f"Main process: Adding {len(dataset)} tasks to the queue...")
for idx,item in enumerate(ds):
    task_queue.put((idx,item))
    if idx>=4:
        break

for _ in range(num_gpus):
    task_queue.put(None)

for i in range(len(dataset)):
    idx, rewards = results_queue.get()

    print(rewards)

here
Main process: Adding 2 tasks to the queue...
Worker process started for cuda:1
Worker process started for cuda:0


Loading checkpoint shards: 100%|██████████| 14/14 [00:01<00:00, 12.07it/s]
Loading checkpoint shards: 100%|██████████| 14/14 [00:01<00:00, 11.95it/s]
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
100%|██████████| 11/11 [00:12<00:00,  1.11s/it]
  0%|          | 0/9 [00:00<?, ?it/s]  1.04s/it]

(['25}</answer><|im_end|>', '25}</answer><|im_end|>', '25}</answer><|im_end|>', '25}\\end{answer><|im_end|>', '25}</answer><|im_end|>', '25}</answer><|im_end|>', '25}\\</answer><|im_end|>', '25}</answer><|im_end|>', '25}</answer><|im_end|>', '25}</answer><|im_end|>', '25}</answer><|im_end|>'], [True, True, True, True, True, True, True, True, True, True, True])


100%|██████████| 9/9 [00:10<00:00,  1.12s/it]it]


(['18}</answer><|im_end|>', '18}\\</answer><|im_end|>', '16}</answer><|im_end|>', '18}\\</answer><|im_end|>', '20}</answer><|im_end|>', '20}</answer><|im_end|>', '20}</answer><|im_end|>', '20}</answer><|im_end|>', '20}</answer><|im_end|>'], [False, False, False, False, True, True, True, True, True])


 62%|██████▏   | 58/94 [03:05<02:05,  3.49s/it]t]