In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re  # I see re is used in check_answer but not imported yet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset, concatenate_datasets
dataset = load_dataset("Duruo/quant_qa", split="train")

In [17]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# Define the system prompt
system_prompt = \
f"""You are given a statistical reasoning problem.
Keep your reasoning concise and focused between {reasoning_start} and {reasoning_end}.
IMPORTANT: Always end with your final answer between {solution_start} and {solution_end}.
Your answer should be a single number, probability, or short phrase.
Do not exceed 300 words total in your response."""

In [18]:
# Prepare data for GRPO training
def prepare_data(example):
    # Use the problem as input
    question = example['Problem']
    
    # Get the answer
    answer = example['correct']
    
    # Create GRPO-required prompt format
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ],
        "answer": answer
    }

# Apply the transformation
prepared_dataset = dataset.map(prepare_data)

Map: 100%|██████████| 79/79 [00:00<00:00, 5806.74 examples/s]


In [19]:
model_id = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,  # Use bfloat16 for efficiency
    device_map="auto"  # Automatically use available GPUs
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 93.72it/s]


In [20]:
def generate_response(prompt_messages):
    # Format prompts for Llama chat format
    formatted_prompt = tokenizer.apply_chat_template(
        prompt_messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Tokenize the formatted prompt
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # Set pad token id if it's not set
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    # Generate response (explicitly pass attention_mask)
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,  # Explicitly pass attention mask
            max_new_tokens=512,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id  # Explicitly set pad token ID
        )
    
    # Decode the response
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return response

In [21]:
# Function to check the answer
def try_convert_to_float(s):
    # Remove % if present and convert to decimal
    s = s.strip()
    if s.endswith('%'):
        try:
            return float(s.rstrip('%')) / 100
        except ValueError:
            return None
    # Try to convert to float directly
    try:
        return float(s)
    except ValueError:
        return None
    
def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    # Regular expression to extract the answer from the solution section
    match_solution = re.compile(
        rf"{solution_start}(.*?){solution_end}",
        flags = re.MULTILINE | re.DOTALL
    )

    extracted_responses = [
        solution.group(1).strip()
        if (solution := match_solution.search(r)) is not None else None
        for r in responses
    ]

    # Print first example to help debug
    if len(responses) > 0:
        print('*'*20, f"\nQuestion:\n{prompts[0][-1]['content']}", 
              f"\nAnswer:\n{answer[0]}", 
              f"\nResponse:\n{responses[0]}", 
              f"\nExtracted:\n{extracted_responses[0]}")
    correct = False
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            correct = False  # No answer found
            continue
        
        # Split the reference answer by comma to allow multiple correct formats
        alternative_answers = [alt.strip() for alt in true_answer.split(',')]
        
        # Try to convert guess to float
        guess_value = try_convert_to_float(guess)
        
        # Check if any alternative answer matches
        
        for alt in alternative_answers:
            # Check exact text match
            if alt.lower() in guess.lower():
                correct = True
                break
            
            # Check numerical match with tolerance
            alt_value = try_convert_to_float(alt)
            if guess_value is not None and alt_value is not None:
                if abs(guess_value - alt_value) <= 0.01:
                    correct = True
                    break
    
    return correct

In [22]:
# Evaluate a subset of the dataset for testing
num_samples = 20  # Adjust as needed
test_subset = prepared_dataset.select(range(num_samples))

# Collect results
results = []
correct_count = 0

for example in test_subset:
    # Generate response
    response = generate_response(example["prompt"])
    
    # Check if answer is correct
    is_correct = check_answer(
        [example["prompt"]], 
        [[{"content": response}]], 
        [example["answer"]]
    )
    
    if is_correct:
        correct_count += 1
    
    results.append({
        "question": example["prompt"][-1]["content"],
        "true_answer": example["answer"],
        "model_response": response,
        "is_correct": is_correct
    })

# Calculate accuracy
accuracy = correct_count / num_samples
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{num_samples})")

******************** 
Question:
Suppose that two integers a and b are uniformly at random selected from S={-10, -9, ..., 9, 10}. Find the probability that max(0,a) = min(0,b) 
Answer:
121/441 
Response:
To find the probability that max(0,a) = min(0,b), we need to consider the cases where both a and b are non-negative or both are non-positive.

There are 21 integers in the set S = {-10, -9,..., 9, 10}. 

Case 1: Both a and b are non-negative. 
There are 11 non-negative integers in the set S. The number of ways to choose two non-negative integers is 11 * 11 = 121.

Case 2: Both a and b are non-positive. 
There are 10 non-positive integers in the set S (excluding 0). The number of ways to choose two non-positive integers is 10 * 10 = 100.

However, we counted the case where both a and b are 0 twice. We need to subtract this case once. There is only 1 way to choose both a and b as 0.

Total number of favorable outcomes = 121 + 100 - 1 = 220.

Total number of possible outcomes = 21 * 21 = 4

In [None]:
from tqdm.notebook import tqdm  # Import tqdm for Jupyter notebooks

# Evaluate a subset of the dataset for testing
num_samples = 20  # Adjust as needed
test_subset = prepared_dataset.select(range(num_samples))

# Collect results
results = []
correct_count = 0

# Add tqdm progress bar around the loop
for example in tqdm(test_subset, desc="Evaluating examples", total=num_samples):
    # Generate response
    response = generate_response(example["prompt"])
    
    # Check if answer is correct
    is_correct = check_answer(
        [example["prompt"]], 
        [[{"content": response}]], 
        [example["answer"]]
    )
    
    if is_correct:
        correct_count += 1
    
    results.append({
        "question": example["prompt"][-1]["content"],
        "true_answer": example["answer"],
        "model_response": response,
        "is_correct": is_correct
    })

# Calculate accuracy
accuracy = correct_count / num_samples
print(f"Accuracy: {accuracy:.2%} ({correct_count}/{num_samples})")

In [2]:
import numpy as np
np.array([1,2,3]).shape

(3,)

In [7]:
import numpy as np

# replace these with your actual gold labels and preds
gold = np.array([0, 0, 0, 0, 0])        # shape (5,)
preds = np.array([
    [0.6, 0.4],
    [0.65, 0.35],
    [0.6, 0.4],
    [0.75, 0.25],
    [0.75, 0.25],
])                                        # shape (5,2)

# one-hot encode
gold_one_hot = np.eye(2)[gold]
brier = np.mean(np.sum((preds - gold_one_hot)**2, axis=1))
print(brier)  # should match the harness’s 0.227


0.22700000000000004


In [8]:
print(gold_one_hot)

[[1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]]
