In [1]:
from datasets import Dataset, load_dataset
from tqdm import tqdm
from collections import defaultdict
import random
import math
import itertools
import operator
import pandas as pd
from countdown import evaluate_equation, compute_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_numbers(difficulty):
    """Generate a list of random numbers."""
    return [random.randint(min_number, max_number) for _ in range(difficulty)]

def generate_operations(num_operations):
    """Generate a list of random operations."""
    return [random.choice(operations) for _ in range(num_operations)]

In [3]:
def generate_expression(numbers, operations):
    """Generate a mathematical expression from numbers and operations with parentheses."""
    # Basic expression without parentheses
    expression_parts = []
    for i in range(len(numbers)):
        if i > 0:
            expression_parts.append(operations[i-1])
        expression_parts.append(str(numbers[i]))
    
    # Randomly add parentheses if we have at least 3 numbers
    if len(numbers) >= 3:
        # Decide whether to add parentheses (50% chance)
        if random.random() > 0.5:
            # Choose a random position to add parentheses (around 2 numbers and 1 operation)
            start_pos = random.randint(0, len(numbers) - 2)
            
            # Add opening parenthesis before the number
            expression_parts[start_pos * 2] = "(" + expression_parts[start_pos * 2]
            
            # Add closing parenthesis after the operation and next number
            if start_pos < len(numbers) - 2:
                expression_parts[start_pos * 2 + 3] = expression_parts[start_pos * 2 + 3] + ")"
            else:
                expression_parts[-1] = expression_parts[-1] + ")"
    
    # Join all parts to form the expression
    expression = " ".join(expression_parts)
    return expression

In [4]:
def generate_problem(difficulty):
    """Generate a random problem."""
    numbers = generate_numbers(difficulty)
    num_operations = len(numbers) - 1
    operations = generate_operations(num_operations)
    expression = generate_expression(numbers, operations)
    result = evaluate_equation(expression)
    
    # Check if result is None (error occurred) or not an integer
    # or out of the specified range
    if result is None or result != int(result) or not (min_result <= result <= max_result):
        return generate_problem(difficulty)
    
    return numbers, expression, int(result)

def generate_problems(num_problems, difficulty):
    """Generate a list of random problems."""
    problems = []
    for _ in tqdm(range(num_problems)):
        numbers, expression, result = generate_problem(difficulty)
        problems.append((numbers, expression, result))
    return problems

In [5]:
min_number = 0
max_number = 99

min_result = 0
max_result = 999

number_range = [3, 6, 9]  # Will generate 3-4 numbers
operations = ['+', '-', '*', '/']
num_problems = 1280 + 128

In [6]:
import sys
new_limit = 10000
sys.setrecursionlimit(new_limit)

In [7]:
difficulty_to_dataset = {}
for difficulty in number_range:
    print(f"Generating problems with difficulty {difficulty}...")
    
    # Generate problems for the current difficulty    
    generated_problems = generate_problems(num_problems, difficulty)

    valid_problems = []
    num_invalid = 0
    for numbers, expression, result in generated_problems:
        ground_truth = {'target': result, 'numbers': numbers}
        if compute_score("Assistant: <answer>" + expression + '</answer>', ground_truth) == 1:
            valid_problems.append((numbers, expression, result))
        else:
            print(f"Invalid problem: {numbers}, {expression}, {result}")
            num_invalid += 1
            continue
    print(f"Invalid problems: {num_invalid}")

    if num_invalid == 0:
        difficulty_to_dataset[difficulty] = valid_problems
    else:
        print("RIP")
        break

Generating problems with difficulty 3...


100%|██████████| 1408/1408 [00:00<00:00, 16870.77it/s]


--------------------------------
Target: 85 | Numbers: [53, 31, 1] | Length: 17
Extracted equation: 53 + 31 + 1
Solution string: Assistant: <answer>53 + 31 + 1</answer>
Correct equation: 53 + 31 + 1 = 85
--------------------------------
Target: 98 | Numbers: [99, 78, 77] | Length: 18
Extracted equation: 99 - 78 + 77
Solution string: Assistant: <answer>99 - 78 + 77</answer>
Correct equation: 99 - 78 + 77 = 98
--------------------------------
Target: 86 | Numbers: [86, 0, 79] | Length: 17
Extracted equation: 86 - 0 * 79
Solution string: Assistant: <answer>86 - 0 * 79</answer>
Correct equation: 86 - 0 * 79 = 86
--------------------------------
Target: 571 | Numbers: [14, 42, 17] | Length: 18
Extracted equation: 14 * 42 - 17
Solution string: Assistant: <answer>14 * 42 - 17</answer>
Correct equation: 14 * 42 - 17 = 571
--------------------------------
Target: 827 | Numbers: [23, 33, 68] | Length: 18
Extracted equation: 23 * 33 + 68
Solution string: Assistant: <answer>23 * 33 + 68</answer>
C

100%|██████████| 1408/1408 [00:00<00:00, 1992.55it/s]


--------------------------------
Target: 99 | Numbers: [98, 0, 22, 7, 60, 10] | Length: 28
Extracted equation: 98 - 0 / 22 + 7 - (60 / 10)
Solution string: Assistant: <answer>98 - 0 / 22 + 7 - (60 / 10)</answer>
Correct equation: 98 - 0 / 22 + 7 - (60 / 10) = 99.0
--------------------------------
Target: 543 | Numbers: [78, 11, 45, 6, 35, 10] | Length: 29
Extracted equation: 78 * 11 - 45 - 6 * (35 + 10)
Solution string: Assistant: <answer>78 * 11 - 45 - 6 * (35 + 10)</answer>
Correct equation: 78 * 11 - 45 - 6 * (35 + 10) = 543
--------------------------------
Target: 120 | Numbers: [89, 45, 19, 80, 64, 4] | Length: 29
Extracted equation: 89 + 45 - 19 + 80 / 64 * 4
Solution string: Assistant: <answer>89 + 45 - 19 + 80 / 64 * 4</answer>
Correct equation: 89 + 45 - 19 + 80 / 64 * 4 = 120.0
--------------------------------
Target: 125 | Numbers: [64, 4, 68, 49, 16, 76] | Length: 29
Extracted equation: 64 + 4 + 68 + 49 + 16 - 76
Solution string: Assistant: <answer>64 + 4 + 68 + 49 + 16 - 7

100%|██████████| 1408/1408 [00:03<00:00, 385.33it/s]


--------------------------------
Target: 15 | Numbers: [37, 46, 60, 44, 39, 4, 45, 41, 33] | Length: 41
Extracted equation: 37 + 46 + 60 - 44 + 39 - 4 - 45 - 41 - 33
Solution string: Assistant: <answer>37 + 46 + 60 - 44 + 39 - 4 - 45 - 41 - 33</answer>
Correct equation: 37 + 46 + 60 - 44 + 39 - 4 - 45 - 41 - 33 = 15
--------------------------------
Target: 138 | Numbers: [87, 2, 65, 87, 0, 10, 10, 49, 31] | Length: 40
Extracted equation: 87 + 2 * 65 - 87 - 0 * 10 - 10 + 49 - 31
Solution string: Assistant: <answer>87 + 2 * 65 - 87 - 0 * 10 - 10 + 49 - 31</answer>
Correct equation: 87 + 2 * 65 - 87 - 0 * 10 - 10 + 49 - 31 = 138
--------------------------------
Target: 598 | Numbers: [10, 52, 40, 74, 15, 56, 14, 30, 5] | Length: 41
Extracted equation: 10 * 52 + 40 + 74 - 15 + 56 / 14 - 30 + 5
Solution string: Assistant: <answer>10 * 52 + 40 + 74 - 15 + 56 / 14 - 30 + 5</answer>
Correct equation: 10 * 52 + 40 + 74 - 15 + 56 / 14 - 30 + 5 = 598.0
--------------------------------
Target: 594

In [8]:
def gen(split):
    if split == 'train':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][:1280]:
                yield {'nums': numbers, 'expr': expression, 'target': result}
    elif split == 'test':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][1280:]:
                yield {'nums': numbers, 'expr': expression, 'target': result}

In [9]:
train = Dataset.from_generator(gen, gen_kwargs={"split": 'train'})
test = Dataset.from_generator(gen, gen_kwargs={"split": 'test'})

Generating train split: 3840 examples [00:00, 26168.87 examples/s]
Generating train split: 384 examples [00:00, 45301.74 examples/s]


In [11]:
hub_dataset_name = f"d1shs0ap/countdown-3-6-9"

train.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='train',
    private=True,
)

test.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='test',
    private=True,
)

Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 517.02ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.98it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 1313.18ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/d1shs0ap/countdown-3-6-9/commit/d936f12bf5d072627cae35d76a9a241ba59a8358', commit_message='Upload dataset', commit_description='', oid='d936f12bf5d072627cae35d76a9a241ba59a8358', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/d1shs0ap/countdown-3-6-9', endpoint='https://huggingface.co', repo_type='dataset', repo_id='d1shs0ap/countdown-3-6-9'), pr_revision=None, pr_num=None)