In [1]:
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import display, Markdown, Latex
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm.notebook import tqdm

from util import reconstruct_prompt

device = "cuda"
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x795138924830>

In [2]:
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B").to(device)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")

dataset = load_dataset("immindich/reasoning-steps-labeled")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


In [3]:
import random

def get_step_data(key):
    sample_size = 80
    step_prompts = []
    steps = []
    step_lengths = []

    for problem in dataset["train"].shuffle(seed=42):
        if problem[key]:
            step = random.choice(problem[key])
            prompt = tokenizer.encode(reconstruct_prompt(tokenizer, problem, step), return_tensors="pt", add_special_tokens=False).to(device)
            step_prompts.append(prompt)
            step_lengths.append(step)
            steps.append(tokenizer.encode(problem["steps"][step]))

        if len(step_prompts) >= sample_size:
            break

    return step_prompts, steps, step_lengths

conclusion_step_prompts, conclusion_steps, conclusion_step_lengths = get_step_data("conclusion_steps")
verification_step_prompts, verification_steps, verification_step_lengths = get_step_data("verification_steps")

print(len(conclusion_step_prompts))
print(len(verification_step_prompts))

80
80


In [6]:
def generate_completions(prompts, batch_size=2, num_batches=5):
    completions = []
    for prompt in tqdm(prompts):
        # Generate 15 sequences in 3 batches of 5
        max_length = 10

        prompt_completions = []
        
        for batch in range(num_batches):
            outputs = model.generate(
                input_ids=prompt,
                max_new_tokens=max_length,
                num_return_sequences=batch_size,
                pad_token_id=tokenizer.eos_token_id
            )
            
            for output in outputs:
                generated_text = tokenizer.decode(output[prompt.shape[1]:], skip_special_tokens=True)
                prompt_completions.append(generated_text)

        completions.append(prompt_completions)

    return completions


In [7]:
verification_completions = generate_completions(verification_step_prompts)
conclusion_completions = generate_completions(conclusion_step_prompts, batch_size=1, num_batches=10)

  0%|          | 0/80 [00:00<?, ?it/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


  0%|          | 0/80 [00:00<?, ?it/s]

In [10]:
print(conclusion_completions)

[['Hmm, that seems correct. Let me try another', "Yep, that seems correct. I don't think", "I think that's correct. It's always good", "Yes, that seems correct. I don't think", "It all seems to add up. I don't", "Yes, that all makes sense. I don't", "Yeah, that seems correct. I don't think", "I think that's correct. It's always good", "So, I think that's correct. It's", "Yes, that all seems correct. I don't"], ["I think I'm confident that the slope is \\(", "I think I'm confident with this answer now.\n\n", "I think that's thorough. I can't find", 'Just to be thorough, let me plug in another', "Therefore, I'm confident that the ordered pair is", "I think I'm confident that the slope-intercept", 'Wait, but let me just confirm the arithmetic one', "I think I've verified it multiple ways, so", "I think I'm confident now that the slope is", 'Just to be thorough, let me plug in another'], ['**Final Answer**\nThe number is \\boxed{', '**Final Answer**\nThe number is \\boxed{', '**Final Answe

In [11]:
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request
def create_batch_requests(prompt_completions):
    # create batch of requests
    # Create a batch of requests for all problems in the dataset
    batch_requests = []
    for i, completions in enumerate(prompt_completions):
        for j, completion in enumerate(completions):            
            user_prompt = f"""
            Attached to this prompt is the output of a language model which reasons step by step by step. The output given to you is the beginning of one of the reasoning steps. Classify the output into one of the following categories.

            1. The model is concluding that it has figured out the final answer. In this case, the output should start with something like "I think I'm confident" or "Therefore the answer is" or "I don't see any mistakes here".

            2. The model is verifying its work. In this case, the output should start with something like "Wait, but let me make sure," or "Let me double-check," or "Wait, to make sure".

            3. The output does not seem to fit into the first two categories, or it is not clear what the model is doing.
            
            Do not respond with anything other than the number of the category. Note that the output may not be a complete sentence, because the model was only asked to generate the first 10 tokens.
            
            Here is the output of the language model:
            
            {completion}
            """
            
            batch_requests.append(Request(
                custom_id = f"{i}_{j}",
                params=MessageCreateParamsNonStreaming(
                    model="claude-3-7-sonnet-20250219",
                    max_tokens=100,
                    temperature=1,
                    system="You are a helpful assistant labeling data generated by a language model",
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": user_prompt
                                }
                            ]
                        }
                    ]
                )
            ))

    return batch_requests

requests_verification = create_batch_requests(verification_completions)
requests_conclusion = create_batch_requests(conclusion_completions)



In [12]:
import dotenv
import anthropic

dotenv.load_dotenv()
import os

client = anthropic.Anthropic(
    # defaults to os.environ.get("ANTHROPIC_API_KEY")
    api_key=os.environ.get("ANTHROPIC_API_KEY"),
)

In [13]:
batch_verification = client.messages.batches.create(requests=requests_verification)
batch_conclusion = client.messages.batches.create(requests=requests_conclusion)

print(batch_verification)
print(batch_conclusion)

MessageBatch(id='msgbatch_013iB5mmeyicWvqu1pSzEqRs', archived_at=None, cancel_initiated_at=None, created_at=datetime.datetime(2025, 3, 9, 5, 26, 9, 164344, tzinfo=datetime.timezone.utc), ended_at=None, expires_at=datetime.datetime(2025, 3, 10, 5, 26, 9, 164344, tzinfo=datetime.timezone.utc), processing_status='in_progress', request_counts=MessageBatchRequestCounts(canceled=0, errored=0, expired=0, processing=800, succeeded=0), results_url=None, type='message_batch')
MessageBatch(id='msgbatch_01493aRvkd99cxA6WgmyKCxj', archived_at=None, cancel_initiated_at=None, created_at=datetime.datetime(2025, 3, 9, 5, 26, 9, 874414, tzinfo=datetime.timezone.utc), ended_at=None, expires_at=datetime.datetime(2025, 3, 10, 5, 26, 9, 874414, tzinfo=datetime.timezone.utc), processing_status='in_progress', request_counts=MessageBatchRequestCounts(canceled=0, errored=0, expired=0, processing=800, succeeded=0), results_url=None, type='message_batch')


In [14]:
results_conclusion = list(client.messages.batches.results(
    batch_conclusion.id,
))

results_verification = list(client.messages.batches.results(
    batch_verification.id,
))


In [18]:
conclusion_numbers = list(map(lambda x: int(x.result.message.content[0].text), results_conclusion))
verification_numbers = list(map(lambda x: int(x.result.message.content[0].text), results_verification))

def proportion_equal(numbers, value):
    # return the proportion of numbers that are equal to value
    return sum(1 for number in numbers if number == value) / len(numbers)

print(proportion_equal(conclusion_numbers, 1))
print(proportion_equal(verification_numbers, 1))

0.5475
0.15125
