In [1]:
from datasets import Dataset, load_dataset
from tqdm import tqdm
from collections import defaultdict
import random
import math
import itertools
import operator
import pandas as pd
from countdown import evaluate_equation, compute_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate_numbers(difficulty):
    """Generate a list of random numbers."""
    return [random.randint(min_number, max_number) for _ in range(difficulty)]

def generate_operations(num_operations):
    """Generate a list of random operations."""
    return [random.choice(operations) for _ in range(num_operations)]

In [3]:
def generate_expression(numbers, operations):
    """Generate a mathematical expression from numbers and operations with parentheses."""
    # Basic expression without parentheses
    expression_parts = []
    for i in range(len(numbers)):
        if i > 0:
            expression_parts.append(operations[i-1])
        expression_parts.append(str(numbers[i]))
    
    # Randomly add parentheses if we have at least 3 numbers
    if len(numbers) >= 3:
        # Decide whether to add parentheses (50% chance)
        if random.random() > 0.5:
            # Choose a random position to add parentheses (around 2 numbers and 1 operation)
            start_pos = random.randint(0, len(numbers) - 2)
            
            # Add opening parenthesis before the number
            expression_parts[start_pos * 2] = "(" + expression_parts[start_pos * 2]
            
            # Add closing parenthesis after the operation and next number
            if start_pos < len(numbers) - 2:
                expression_parts[start_pos * 2 + 3] = expression_parts[start_pos * 2 + 3] + ")"
            else:
                expression_parts[-1] = expression_parts[-1] + ")"
    
    # Join all parts to form the expression
    expression = " ".join(expression_parts)
    return expression

In [4]:
def generate_problem(difficulty):
    """Generate a random problem."""
    numbers = generate_numbers(difficulty)
    num_operations = len(numbers) - 1
    operations = generate_operations(num_operations)
    expression = generate_expression(numbers, operations)
    result = evaluate_equation(expression)
    
    # Check if result is None (error occurred) or not an integer
    # or out of the specified range
    if result is None or result != int(result) or not (min_result <= result <= max_result):
        return generate_problem(difficulty)
    
    return numbers, expression, int(result)

def generate_problems(num_problems, difficulty):
    """Generate a list of random problems."""
    problems = []
    for _ in tqdm(range(num_problems)):
        numbers, expression, result = generate_problem(difficulty)
        problems.append((numbers, expression, result))
    return problems

In [5]:
min_number = 0
max_number = 99

min_result = 0
max_result = 999

number_range = range(3, 11)  # Will generate 3-4 numbers
operations = ['+', '-', '*', '/']
num_problems = 1600 + 160

In [6]:
import sys
new_limit = 10000
sys.setrecursionlimit(new_limit)

In [7]:
difficulty_to_dataset = {}
for difficulty in number_range:
    print(f"Generating problems with difficulty {difficulty}...")
    
    # Generate problems for the current difficulty    
    generated_problems = generate_problems(num_problems, difficulty)

    valid_problems = []
    num_invalid = 0
    for numbers, expression, result in generated_problems:
        ground_truth = {'target': result, 'numbers': numbers}
        if compute_score("Assistant: <answer>" + expression + '</answer>', ground_truth) == 1:
            valid_problems.append((numbers, expression, result))
        else:
            print(f"Invalid problem: {numbers}, {expression}, {result}")
            num_invalid += 1
            continue
    print(f"Invalid problems: {num_invalid}")

    if num_invalid == 0:
        difficulty_to_dataset[difficulty] = valid_problems
    else:
        print("RIP")
        break

Generating problems with difficulty 3...


100%|██████████| 1760/1760 [00:00<00:00, 24425.99it/s]

--------------------------------
Target: 224 | Numbers: [79, 52, 93] | Length: 18
Extracted equation: 79 + 52 + 93
Solution string: Assistant: <answer>79 + 52 + 93</answer>
Correct equation: 79 + 52 + 93 = 224
--------------------------------
Target: 55 | Numbers: [52, 35, 32] | Length: 18
Extracted equation: 52 + 35 - 32
Solution string: Assistant: <answer>52 + 35 - 32</answer>
Correct equation: 52 + 35 - 32 = 55





--------------------------------
Target: 194 | Numbers: [76, 70, 48] | Length: 18
Extracted equation: 76 + 70 + 48
Solution string: Assistant: <answer>76 + 70 + 48</answer>
Correct equation: 76 + 70 + 48 = 194
--------------------------------
Target: 167 | Numbers: [40, 32, 95] | Length: 18
Extracted equation: 40 + 32 + 95
Solution string: Assistant: <answer>40 + 32 + 95</answer>
Correct equation: 40 + 32 + 95 = 167
--------------------------------
Target: 120 | Numbers: [72, 8, 40] | Length: 17
Extracted equation: 72 + 8 + 40
Solution string: Assistant: <answer>72 + 8 + 40</answer>
Correct equation: 72 + 8 + 40 = 120
--------------------------------
Target: 46 | Numbers: [60, 59, 45] | Length: 18
Extracted equation: 60 - 59 + 45
Solution string: Assistant: <answer>60 - 59 + 45</answer>
Correct equation: 60 - 59 + 45 = 46
--------------------------------
Target: 0 | Numbers: [11, 17, 0] | Length: 17
Extracted equation: 11 * 17 * 0
Solution string: Assistant: <answer>11 * 17 * 0</answer

100%|██████████| 1760/1760 [00:00<00:00, 12826.33it/s]


--------------------------------
Target: 115 | Numbers: [99, 44, 79, 19] | Length: 22
Extracted equation: 99 - 44 + 79 - 19
Solution string: Assistant: <answer>99 - 44 + 79 - 19</answer>
Correct equation: 99 - 44 + 79 - 19 = 115
--------------------------------
Target: 36 | Numbers: [10, 8, 7, 30] | Length: 20
Extracted equation: 10 + 8 * 7 - 30
Solution string: Assistant: <answer>10 + 8 * 7 - 30</answer>
Correct equation: 10 + 8 * 7 - 30 = 36
--------------------------------
Target: 295 | Numbers: [10, 14, 27, 93] | Length: 22
Extracted equation: 10 + 14 * 27 - 93
Solution string: Assistant: <answer>10 + 14 * 27 - 93</answer>
Correct equation: 10 + 14 * 27 - 93 = 295
--------------------------------
Target: 998 | Numbers: [40, 23, 91, 13] | Length: 22
Extracted equation: 40 * 23 + 91 - 13
Solution string: Assistant: <answer>40 * 23 + 91 - 13</answer>
Correct equation: 40 * 23 + 91 - 13 = 998
--------------------------------
Target: 161 | Numbers: [72, 74, 36, 21] | Length: 22
Extracte

100%|██████████| 1760/1760 [00:00<00:00, 6610.39it/s]


--------------------------------
Target: 87 | Numbers: [96, 36, 28, 31, 42] | Length: 26
Extracted equation: 96 + 36 + 28 - 31 - 42
Solution string: Assistant: <answer>96 + 36 + 28 - 31 - 42</answer>
Correct equation: 96 + 36 + 28 - 31 - 42 = 87
--------------------------------
Target: 56 | Numbers: [26, 52, 73, 77, 78] | Length: 26
Extracted equation: 26 - 52 - 73 + 77 + 78
Solution string: Assistant: <answer>26 - 52 - 73 + 77 + 78</answer>
Correct equation: 26 - 52 - 73 + 77 + 78 = 56
--------------------------------
Target: 357 | Numbers: [16, 23, 91, 23, 79] | Length: 26
Extracted equation: 16 * 23 + 91 - (23 + 79)
Solution string: Assistant: <answer>16 * 23 + 91 - (23 + 79)</answer>
Correct equation: 16 * 23 + 91 - (23 + 79) = 357
--------------------------------
Target: 163 | Numbers: [4, 25, 67, 16, 8] | Length: 24
Extracted equation: 4 + 25 + 67 * 16 / 8
Solution string: Assistant: <answer>4 + 25 + 67 * 16 / 8</answer>
Correct equation: 4 + 25 + 67 * 16 / 8 = 163.0
------------

100%|██████████| 1760/1760 [00:00<00:00, 3544.35it/s]


--------------------------------
Target: 101 | Numbers: [12, 34, 23, 13, 51, 96] | Length: 30
Extracted equation: 12 + 34 + 23 - 13 - 51 + 96
Solution string: Assistant: <answer>12 + 34 + 23 - 13 - 51 + 96</answer>
Correct equation: 12 + 34 + 23 - 13 - 51 + 96 = 101
--------------------------------
Target: 229 | Numbers: [80, 87, 82, 90, 60, 96] | Length: 30
Extracted equation: 80 + 87 - 82 + 90 / 60 * 96
Solution string: Assistant: <answer>80 + 87 - 82 + 90 / 60 * 96</answer>
Correct equation: 80 + 87 - 82 + 90 / 60 * 96 = 229.0
--------------------------------
Target: 59 | Numbers: [15, 69, 55, 41, 90, 51] | Length: 30
Extracted equation: 15 + 69 + 55 - 41 - 90 + 51
Solution string: Assistant: <answer>15 + 69 + 55 - 41 - 90 + 51</answer>
Correct equation: 15 + 69 + 55 - 41 - 90 + 51 = 59
--------------------------------
Target: 83 | Numbers: [38, 24, 14, 38, 85, 88] | Length: 30
Extracted equation: 38 + 24 - 14 + 38 + 85 - 88
Solution string: Assistant: <answer>38 + 24 - 14 + 38 + 85

100%|██████████| 1760/1760 [00:00<00:00, 1832.61it/s]


--------------------------------
Target: 6 | Numbers: [60, 15, 35, 25, 48, 13, 38] | Length: 34
Extracted equation: 60 - 15 * 35 + 25 - 48 + 13 * 38
Solution string: Assistant: <answer>60 - 15 * 35 + 25 - 48 + 13 * 38</answer>
Correct equation: 60 - 15 * 35 + 25 - 48 + 13 * 38 = 6
--------------------------------
Target: 211 | Numbers: [57, 56, 3, 96, 1, 99, 94] | Length: 32
Extracted equation: 57 + 56 - 3 + 96 * 1 + 99 - 94
Solution string: Assistant: <answer>57 + 56 - 3 + 96 * 1 + 99 - 94</answer>
Correct equation: 57 + 56 - 3 + 96 * 1 + 99 - 94 = 211
--------------------------------
Target: 584 | Numbers: [89, 12, 27, 37, 24, 79, 79] | Length: 34
Extracted equation: 89 + 12 * 27 + 37 - 24 + 79 + 79
Solution string: Assistant: <answer>89 + 12 * 27 + 37 - 24 + 79 + 79</answer>
Correct equation: 89 + 12 * 27 + 37 - 24 + 79 + 79 = 584
--------------------------------
Target: 221 | Numbers: [71, 17, 71, 2, 84, 78, 28] | Length: 33
Extracted equation: 71 * 17 / 71 / 2 / 84 * 78 * 28
Solut

100%|██████████| 1760/1760 [00:01<00:00, 1066.53it/s]


--------------------------------
Target: 269 | Numbers: [4, 16, 30, 64, 92, 59, 7, 3] | Length: 35
Extracted equation: 4 + 16 + 30 + 64 + 92 + 59 + 7 - 3
Solution string: Assistant: <answer>4 + 16 + 30 + 64 + 92 + 59 + 7 - 3</answer>
Correct equation: 4 + 16 + 30 + 64 + 92 + 59 + 7 - 3 = 269
--------------------------------
Target: 83 | Numbers: [72, 52, 36, 0, 8, 91, 0, 95] | Length: 35
Extracted equation: 72 - 52 - 36 + 0 + 8 + 91 - 0 / 95
Solution string: Assistant: <answer>72 - 52 - 36 + 0 + 8 + 91 - 0 / 95</answer>
Correct equation: 72 - 52 - 36 + 0 + 8 + 91 - 0 / 95 = 83.0
--------------------------------
Target: 75 | Numbers: [24, 64, 5, 35, 28, 80, 24, 60] | Length: 37
Extracted equation: 24 / 64 * 5 * 35 / 28 * 80 * 24 / 60
Solution string: Assistant: <answer>24 / 64 * 5 * 35 / 28 * 80 * 24 / 60</answer>
Correct equation: 24 / 64 * 5 * 35 / 28 * 80 * 24 / 60 = 75.0
--------------------------------
Target: 47 | Numbers: [79, 0, 9, 29, 23, 4, 41, 17] | Length: 35
Extracted equat

100%|██████████| 1760/1760 [00:02<00:00, 618.66it/s]


--------------------------------
Target: 59 | Numbers: [75, 28, 12, 0, 75, 68, 22, 18, 21] | Length: 41
Extracted equation: 75 - 28 + 12 + 0 * 75 / 68 / 22 / (18 / 21)
Solution string: Assistant: <answer>75 - 28 + 12 + 0 * 75 / 68 / 22 / (18 / 21)</answer>
Correct equation: 75 - 28 + 12 + 0 * 75 / 68 / 22 / (18 / 21) = 59.0
--------------------------------
Target: 79 | Numbers: [79, 56, 73, 5, 69, 9, 0, 44, 49] | Length: 39
Extracted equation: 79 - 56 * 73 * 5 / 69 / 9 * 0 / (44 + 49)
Solution string: Assistant: <answer>79 - 56 * 73 * 5 / 69 / 9 * 0 / (44 + 49)</answer>
Correct equation: 79 - 56 * 73 * 5 / 69 / 9 * 0 / (44 + 49) = 79.0
--------------------------------
Target: 10 | Numbers: [0, 61, 82, 32, 63, 82, 40, 71, 21] | Length: 41
Extracted equation: 0 / 61 * 82 / 32 / 63 * 82 - 40 + 71 - 21
Solution string: Assistant: <answer>0 / 61 * 82 / 32 / 63 * 82 - 40 + 71 - 21</answer>
Correct equation: 0 / 61 * 82 / 32 / 63 * 82 - 40 + 71 - 21 = 10.0
--------------------------------
Tar

100%|██████████| 1760/1760 [00:04<00:00, 372.24it/s]

--------------------------------
Target: 424 | Numbers: [51, 11, 43, 34, 62, 69, 81, 22, 0, 77] | Length: 45
Extracted equation: 51 * 11 + 43 - 34 * 62 + 69 + 81 * 22 + 0 + 77
Solution string: Assistant: <answer>51 * 11 + 43 - 34 * 62 + 69 + 81 * 22 + 0 + 77</answer>
Correct equation: 51 * 11 + 43 - 34 * 62 + 69 + 81 * 22 + 0 + 77 = 424
--------------------------------
Target: 96 | Numbers: [45, 1, 51, 61, 51, 48, 58, 59, 9, 0] | Length: 43
Extracted equation: 45 - 1 * 51 + 61 + 51 + 48 - 58 + 59 / 9 * 0
Solution string: Assistant: <answer>45 - 1 * 51 + 61 + 51 + 48 - 58 + 59 / 9 * 0</answer>
Correct equation: 45 - 1 * 51 + 61 + 51 + 48 - 58 + 59 / 9 * 0 = 96.0
--------------------------------
Target: 609 | Numbers: [9, 87, 99, 1, 29, 69, 8, 35, 23, 50] | Length: 43
Extracted equation: 9 + 87 + 99 - 1 - 29 + 69 * 8 - 35 - (23 + 50)
Solution string: Assistant: <answer>9 + 87 + 99 - 1 - 29 + 69 * 8 - 35 - (23 + 50)</answer>
Correct equation: 9 + 87 + 99 - 1 - 29 + 69 * 8 - 35 - (23 + 50)




In [8]:
def gen(split):
    if split == 'train':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][:1600]:
                yield {'nums': numbers, 'expr': expression, 'target': result}
    elif split == 'test':
        for difficulty in number_range:
            for numbers, expression, result in difficulty_to_dataset[difficulty][1600:]:
                yield {'nums': numbers, 'expr': expression, 'target': result}

In [9]:
train = Dataset.from_generator(gen, gen_kwargs={"split": 'train'})
test = Dataset.from_generator(gen, gen_kwargs={"split": 'test'})

Generating train split: 12800 examples [00:00, 72931.34 examples/s]
Generating train split: 1280 examples [00:00, 230535.43 examples/s]


In [10]:
hub_dataset_name = f"d1shs0ap/countdown-final"

train.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='train',
    private=True,
)

test.push_to_hub(
    hub_dataset_name,
    revision='main',
    split='test',
    private=True,
)

Creating parquet from Arrow format: 100%|██████████| 13/13 [00:00<00:00, 2096.67ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.31it/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 2170.96ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:09<00:00,  9.67s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/d1shs0ap/countdown-final/commit/cd47ff733c8a84f1fce7ed4e8892be97304cf1a6', commit_message='Upload dataset', commit_description='', oid='cd47ff733c8a84f1fce7ed4e8892be97304cf1a6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/d1shs0ap/countdown-final', endpoint='https://huggingface.co', repo_type='dataset', repo_id='d1shs0ap/countdown-final'), pr_revision=None, pr_num=None)