In [1]:
from datasets import Dataset, load_dataset
from tqdm import tqdm
from collections import defaultdict
import random
import math
import itertools
import operator
import pandas as pd
from countdown import evaluate_equation, compute_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_numbers(difficulty):
    """Generate a list of random numbers."""
    return [random.randint(min_number, max_number) for _ in range(difficulty)]

def generate_operations(num_operations):
    """Generate a list of random operations."""
    return [random.choice(operations) for _ in range(num_operations)]

In [3]:
def generate_expression(numbers, operations):
    """Generate a mathematical expression from numbers and operations with parentheses."""
    # Basic expression without parentheses
    expression_parts = []
    for i in range(len(numbers)):
        if i > 0:
            expression_parts.append(operations[i-1])
        expression_parts.append(str(numbers[i]))
    
    # Randomly add parentheses if we have at least 3 numbers
    if len(numbers) >= 3:
        # Decide whether to add parentheses (50% chance)
        if random.random() > 0.5:
            # Choose a random position to add parentheses (around 2 numbers and 1 operation)
            start_pos = random.randint(0, len(numbers) - 2)
            
            # Add opening parenthesis before the number
            expression_parts[start_pos * 2] = "(" + expression_parts[start_pos * 2]
            
            # Add closing parenthesis after the operation and next number
            if start_pos < len(numbers) - 2:
                expression_parts[start_pos * 2 + 3] = expression_parts[start_pos * 2 + 3] + ")"
            else:
                expression_parts[-1] = expression_parts[-1] + ")"
    
    # Join all parts to form the expression
    expression = " ".join(expression_parts)
    return expression

In [4]:
def generate_problem(difficulty):
    """Generate a random problem."""
    numbers = generate_numbers(difficulty)
    num_operations = len(numbers) - 1
    operations = generate_operations(num_operations)
    expression = generate_expression(numbers, operations)
    result = evaluate_equation(expression)
    
    # Check if result is None (error occurred) or not an integer
    # or out of the specified range
    if result is None or result != int(result) or not (min_result <= result <= max_result):
        return generate_problem(difficulty)
    
    return numbers, expression, int(result)

def generate_problems(num_problems, difficulty):
    """Generate a list of random problems."""
    problems = []
    for _ in tqdm(range(num_problems)):
        numbers, expression, result = generate_problem(difficulty)
        problems.append((numbers, expression, result))
    return problems

In [5]:
min_number = 0
max_number = 99

min_result = 0
max_result = 999

number_range = range(5, 7)  # Will generate 3-4 numbers
operations = ['+', '-', '*', '/']
num_problems = 3840 + 320

In [6]:
import sys
new_limit = 10000
sys.setrecursionlimit(new_limit)

In [7]:
difficulty_to_dataset = {}
for difficulty in number_range:
    print(f"Generating problems with difficulty {difficulty}...")
    
    # Generate problems for the current difficulty    
    generated_problems = generate_problems(num_problems, difficulty)

    valid_problems = []
    num_invalid = 0
    for numbers, expression, result in generated_problems:
        ground_truth = {'target': result, 'numbers': numbers}
        if compute_score("Assistant: <answer>" + expression + '</answer>', ground_truth) == 1:
            valid_problems.append((numbers, expression, result))
        else:
            print(f"Invalid problem: {numbers}, {expression}, {result}")
            num_invalid += 1
            continue
    print(f"Invalid problems: {num_invalid}")

    if num_invalid == 0:
        difficulty_to_dataset[difficulty] = valid_problems
    else:
        print("RIP")
        break

Generating problems with difficulty 5...


100%|██████████| 4160/4160 [00:00<00:00, 5254.10it/s]


--------------------------------
Target: 46 | Numbers: [50, 90, 27, 34, 3] | Length: 25
Extracted equation: 50 / 90 * 27 + 34 - 3
Solution string: Assistant: <answer>50 / 90 * 27 + 34 - 3</answer>
Correct equation: 50 / 90 * 27 + 34 - 3 = 46.0
--------------------------------
Target: 204 | Numbers: [94, 53, 54, 70, 41] | Length: 26
Extracted equation: 94 + 53 - 54 + 70 + 41
Solution string: Assistant: <answer>94 + 53 - 54 + 70 + 41</answer>
Correct equation: 94 + 53 - 54 + 70 + 41 = 204
--------------------------------
Target: 47 | Numbers: [22, 25, 62, 85, 23] | Length: 26
Extracted equation: 22 + 25 + 62 - 85 + 23
Solution string: Assistant: <answer>22 + 25 + 62 - 85 + 23</answer>
Correct equation: 22 + 25 + 62 - 85 + 23 = 47
--------------------------------
Target: 86 | Numbers: [16, 95, 16, 99, 11] | Length: 26
Extracted equation: 16 * 95 / 16 - 99 / 11
Solution string: Assistant: <answer>16 * 95 / 16 - 99 / 11</answer>
Correct equation: 16 * 95 / 16 - 99 / 11 = 86.0
--------------

100%|██████████| 4160/4160 [00:01<00:00, 2700.71it/s]


--------------------------------
Target: 520 | Numbers: [30, 21, 91, 24, 26, 17] | Length: 30
Extracted equation: 30 * 21 - 91 + 24 - 26 - 17
Solution string: Assistant: <answer>30 * 21 - 91 + 24 - 26 - 17</answer>
Correct equation: 30 * 21 - 91 + 24 - 26 - 17 = 520
--------------------------------
Target: 341 | Numbers: [87, 2, 4, 31, 27, 33] | Length: 28
Extracted equation: 87 + 2 * 4 * 31 - 27 + 33
Solution string: Assistant: <answer>87 + 2 * 4 * 31 - 27 + 33</answer>
Correct equation: 87 + 2 * 4 * 31 - 27 + 33 = 341
--------------------------------
Target: 61 | Numbers: [63, 36, 3, 64, 85, 73] | Length: 29
Extracted equation: 63 - 36 + 3 * 64 - 85 - 73
Solution string: Assistant: <answer>63 - 36 + 3 * 64 - 85 - 73</answer>
Correct equation: 63 - 36 + 3 * 64 - 85 - 73 = 61
--------------------------------
Target: 297 | Numbers: [91, 4, 23, 60, 90, 74] | Length: 29
Extracted equation: 91 * 4 - 23 - 60 + 90 - 74
Solution string: Assistant: <answer>91 * 4 - 23 - 60 + 90 - 74</answer>
C

In [8]:
def gen(split):
    if split == 'train':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][:3840]:
                yield {'nums': numbers, 'expr': expression, 'target': result}
    elif split == 'test':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][3840:]:
                yield {'nums': numbers, 'expr': expression, 'target': result}

In [9]:
train = Dataset.from_generator(gen, gen_kwargs={"split": 'train'})
test = Dataset.from_generator(gen, gen_kwargs={"split": 'test'})

Generating train split: 7680 examples [00:00, 283468.75 examples/s]
Generating train split: 640 examples [00:00, 212554.80 examples/s]


In [10]:
hub_dataset_name = f"d1shs0ap/countdown-5-6"

train.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='train',
    private=True,
)

test.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='test',
    private=True,
)

Creating parquet from Arrow format: 100%|██████████| 8/8 [00:00<00:00, 2029.30ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1223.19ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.91it/s]


CommitInfo(commit_url='https://huggingface.co/datasets/d1shs0ap/countdown-5-6/commit/a0e6af47ec6676661fc0ce97486de08a3e4183b4', commit_message='Upload dataset', commit_description='', oid='a0e6af47ec6676661fc0ce97486de08a3e4183b4', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/d1shs0ap/countdown-5-6', endpoint='https://huggingface.co', repo_type='dataset', repo_id='d1shs0ap/countdown-5-6'), pr_revision=None, pr_num=None)