In [1]:
# Load the converted QuantQA dataset
from datasets import load_dataset, Dataset
import pandas as pd
import re
import json

# Load the dataset we created in the previous step
# Option 1: Load from disk if you saved it as a Hugging Face dataset
dataset = Dataset.load_from_disk("quant_qa_dataset")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Print basic information
print(f"Dataset size: {len(dataset)} problems")
print("\nExample structure:")
print(dataset[0])


Dataset size: 78 problems

Example structure:
{'id': '3 Die Minimum', 'Problem': 'Suppose you roll three fair 100-sided die. What is the expected value of the lowest roll?', 'Rationale': "Let's call the value of the minimum of the three die X. Using expectation by summation of survival we get:\n\nE[X] = ∑(k=1 to 100) P(X ≥ k)\n\nThe probability that X is at least k is ((100-k+1)^3)/(100^3) so,\n\nE[X] = (100^3)/(100^3) + (99^3)/(100^3) + (98^3)/(100^3) + ... + (1^3)/(100^3)\n     = (1)/(100^3) · (100^3 + 99^3 + 98^3 + ... + 1^3)\n\nThe sum of cubes from 1 to n can be rewritten to the square of the sum from 1 to n thus,\n\nE[X] = ((1 + 2 + 3 + ... + 100)^2)/(100^3)\n    ⟹ ∑(k=1 to 100) k = 100 · (1 + 100)/2 = 5050\n    ⟹ E[X] = (5050^2)/(100^3) = 25.5025", 'correct': '25.5025', 'annotated_formula': '', 'linear_formula': '', 'category': 'probability'}


In [5]:
# Define the special tokens for reasoning and solution sections
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

# Define the system prompt
system_prompt = \
f"""You are given a statistical reasoning problem.
Think about the problem and provide your working out step by step.
Place your reasoning between {reasoning_start} and {reasoning_end}.
Then, provide your final answer between {solution_start} and {solution_end}.
Make sure your final answer is clear and concise."""

# Prepare data for GRPO training
def prepare_data(example):
    # Use the problem as input
    question = example['Problem']
    
    # Get the answer
    answer = example['correct']
    
    # Create GRPO-required prompt format
    return {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": question}
        ],
        "answer": answer
    }

# Apply the transformation
prepared_dataset = dataset.map(prepare_data)

# Print an example of the prepared data
print("\nExample of prepared data:")
print(prepared_dataset[0])

# Define the regex pattern for format checking
match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)


Map: 100%|██████████| 78/78 [00:00<00:00, 8456.26 examples/s]


Example of prepared data:
{'id': '3 Die Minimum', 'Problem': 'Suppose you roll three fair 100-sided die. What is the expected value of the lowest roll?', 'Rationale': "Let's call the value of the minimum of the three die X. Using expectation by summation of survival we get:\n\nE[X] = ∑(k=1 to 100) P(X ≥ k)\n\nThe probability that X is at least k is ((100-k+1)^3)/(100^3) so,\n\nE[X] = (100^3)/(100^3) + (99^3)/(100^3) + (98^3)/(100^3) + ... + (1^3)/(100^3)\n     = (1)/(100^3) · (100^3 + 99^3 + 98^3 + ... + 1^3)\n\nThe sum of cubes from 1 to n can be rewritten to the square of the sum from 1 to n thus,\n\nE[X] = ((1 + 2 + 3 + ... + 100)^2)/(100^3)\n    ⟹ ∑(k=1 to 100) k = 100 · (1 + 100)/2 = 5050\n    ⟹ E[X] = (5050^2)/(100^3) = 25.5025", 'correct': '25.5025', 'annotated_formula': '', 'linear_formula': '', 'category': 'probability', 'prompt': [{'content': 'You are given a statistical reasoning problem.\nThink about the problem and provide your working out step by step.\nPlace your reason




In [7]:
# Define reward functions for GRPO

# Function to check if the format is exactly matched
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

# Function to check if the format is approximately matched
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        scores.append(score)
    return scores


Saving the dataset (1/1 shards): 100%|██████████| 78/78 [00:00<00:00, 18757.85 examples/s]


Prepared dataset saved to quantqa_grpo_dataset/

Testing format regex:
Match found: True
Extracted answer: 42





In [1]:
# Function to check the answer
def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    # Regular expression to extract the answer from the solution section
    match_solution = re.compile(
        rf"{solution_start}(.*?){solution_end}",
        flags = re.MULTILINE | re.DOTALL
    )

    extracted_responses = [
        solution.group(1).strip()
        if (solution := match_solution.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    # Print first example to help debug
    if len(responses) > 0:
        print('*'*20, f"\nQuestion:\n{prompts[0][-1]['content']}", 
              f"\nAnswer:\n{answer[0]}", 
              f"\nResponse:\n{responses[0]}", 
              f"\nExtracted:\n{extracted_responses[0]}")
    
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(-1.0)  # No answer found
            continue
        
        # Split the reference answer by comma to allow multiple correct formats
        alternative_answers = [alt.strip() for alt in true_answer.split(',')]
        
        # Try to convert to number for comparison with tolerance
        def try_convert_to_float(s):
            # Remove % if present and convert to decimal
            s = s.strip()
            if s.endswith('%'):
                try:
                    return float(s.rstrip('%')) / 100
                except ValueError:
                    return None
            # Try to convert to float directly
            try:
                return float(s)
            except ValueError:
                return None
        
        # Try to convert guess to float
        guess_value = try_convert_to_float(guess)
        
        # Check if any alternative answer matches
        correct = False
        for alt in alternative_answers:
            # Check exact text match
            if alt.lower() in guess.lower():
                correct = True
                break
            
            # Check numerical match with tolerance
            alt_value = try_convert_to_float(alt)
            if guess_value is not None and alt_value is not None:
                if abs(guess_value - alt_value) <= 0.01:
                    correct = True
                    break
        
        scores.append(3.0 if correct else -1.0)
    
    return scores

In [8]:

# Test the format regex with a sample response
test_response = f"{reasoning_start}Let me calculate this step by step...{reasoning_end}\n{solution_start}42{solution_end}"
print("\nTesting format regex:")
match = match_format.search(test_response)
print(f"Match found: {match is not None}")
if match:
    print(f"Extracted answer: {match.group(1)}")


Testing format regex:
Match found: True
Extracted answer: 42
