In [1]:
import re
import os
from datasets import load_dataset, concatenate_datasets

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def make_prefix(dp, template_type):
    target = dp['target']
    numbers = dp['nums']
    # NOTE: also need to change reward_score/countdown.py
    if template_type == 'base':
        """This works for any base model"""
        prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
User: Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.
Assistant: Let me solve this step by step.
"""
    elif template_type == 'qwen-instruct':
        """This works for Qwen Instruct Models"""
        prefix = f"""<|im_start|>system\nYou are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.<|im_end|>\n<|im_start|>user\n Using the numbers {numbers}, create an equation that equals {target}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.<|im_end|>\n<|im_start|>assistant\nLet me solve this step by step.\n<think>"""
    return prefix

In [3]:
dataset = load_dataset('d1shs0ap/countdown-final', split='train')

def make_map_fn(split):
    def process_fn(example, idx):
        question = make_prefix(example, template_type='base')
        solution = {
            "target": example['target'],
            "numbers": example['nums']
        }
        data = {
            "data_source": 'countdown',
            "prompt": [{
                "role": "user",
                "content": question,
            }],
            "ability": "math",
            "reward_model": {
                "style": "rule",
                "ground_truth": solution
            },
            "extra_info": {
                'split': split,
                'index': idx,
            }
        }
        return data
    return process_fn

dataset = dataset.map(function=make_map_fn('train'), with_indices=True)

In [4]:
test_dataset = dataset.map(function=make_map_fn('test'), with_indices=True).filter(lambda x: len(x['nums']) in [3, 4, 7])
test_dataset.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'test-3-4-7.parquet'))

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 5/5 [00:00<00:00, 18.24ba/s]


3591577

In [5]:
threes = dataset.filter(lambda x: len(x['nums']) == 3)
fours = dataset.filter(lambda x: len(x['nums']) == 4)
sixes = dataset.filter(lambda x: len(x['nums']) == 6)
sevens = dataset.filter(lambda x: len(x['nums']) == 7)
nines = dataset.filter(lambda x: len(x['nums']) == 9)

In [6]:
first = concatenate_datasets([threes.select(range(640)), fours.select(range(640))]).shuffle(seed=42)
first_three = threes.select(range(1280)).shuffle(seed=42)
first_four = fours.select(range(1280)).shuffle(seed=42)
second = concatenate_datasets([threes.select(range(640, 640 + 64)), fours.select(range(640, 640 + 64)), sevens.select(range(640 - 128))]).shuffle(seed=42)

In [9]:
real_second = concatenate_datasets([second, second]).shuffle(seed=42)
all = concatenate_datasets([first, real_second]).shuffle(seed=42)
real_all = concatenate_datasets([all, all.select(range(1280))]).shuffle(seed=42)

In [11]:
first_three.filter(lambda x: len(x['nums']) == 6)

Filter: 100%|██████████| 1280/1280 [00:00<00:00, 16731.62 examples/s]


Dataset({
    features: ['nums', 'expr', 'target', 'data_source', 'prompt', 'ability', 'reward_model', 'extra_info'],
    num_rows: 0
})

In [18]:
first.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'train-3-4.parquet'))
first_three.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'train-3.parquet'))
first_four.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'train-4.parquet'))
real_second.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'train-7.parquet'))
real_all.to_parquet(os.path.join('/home/myang4/countdown-curriculum/data/countdown', 'train-3-4-7.parquet'))

Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 28.86ba/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 29.20ba/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 29.71ba/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 29.54ba/s]
Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 18.62ba/s]


2899534

### 3, 6 curriculum

In [4]:
three_six = dataset.filter(lambda x: len(x['nums']) == 3 or len(x['nums']) == 6).shuffle(seed=42)

In [5]:
three_six.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '3-and-6.parquet'))

Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 24.09ba/s]


2384525

In [6]:
three_six_times_two_and_a_half = concatenate_datasets([three_six, three_six, three_six]).select(range(0, 7680))

In [7]:
three_six_times_two_and_a_half.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '3-and-6-x2.5.parquet'))

Creating parquet from Arrow format: 100%|██████████| 8/8 [00:00<00:00, 20.59ba/s]


5724120

In [None]:
zero.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy.parquet'))

In [4]:
easy = dataset.filter(lambda example: 3 <= len(example['nums']) <= 4)
medium = dataset.filter(lambda example: 5 <= len(example['nums']) <= 6)
hard = dataset.filter(lambda example: 7 <= len(example['nums']) <= 8)

In [5]:
half_size = len(easy) // 2
easy_first_half = easy.select(range(half_size))
easy_second_half = easy.select(range(half_size, len(easy)))

medium_first_half = medium.select(range(half_size))
medium_second_half = medium.select(range(half_size, len(medium)))

### 1 easy, 2 hard | 1 easy -> 2 hard

In [6]:
zero = easy.shuffle(seed=42)
zero.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy.parquet'))

Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 24.43ba/s]


2305746

In [10]:
one = concatenate_datasets([easy, hard, hard]).shuffle(seed=42)
one.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy-2-hard.parquet'))

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 10/10 [00:00<00:00, 21.09ba/s]


7547388

In [9]:
one = concatenate_datasets([zero, concatenate_datasets([easy, hard, hard, hard, hard]).shuffle(seed=42)])
one.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy-4-hard.parquet'))

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 20/20 [00:00<00:00, 20.91ba/s]


15094776

### 1 easy, 2 medium | 1 easy, 1 medium, 3 hard

In [7]:
zero = easy.shuffle(seed=42)
zero.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy.parquet'))

Creating parquet from Arrow format: 100%|██████████| 4/4 [00:00<00:00, 24.56ba/s]


2305746

In [13]:
one = concatenate_datasets([easy, concatenate_datasets([medium, medium, easy]).shuffle(seed=42)])
one.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy-2-medium.parquet'))


Creating parquet from Arrow format:   0%|          | 0/13 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 13/13 [00:00<00:00, 21.76ba/s]


9537894

In [15]:
for i in range(len(easy)):
    assert one[i] == easy[i], f"Mismatch at index {i}: {one[i]} != {easy[i]}"
print("All checks passed. The datasets are identical.")

All checks passed. The datasets are identical.


In [16]:
two = concatenate_datasets([one, concatenate_datasets([hard, hard, hard, medium, easy]).shuffle(seed=42)])
two.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '1-easy-1-medium-3-hard.parquet'))

Creating parquet from Arrow format:   0%|          | 0/29 [00:00<?, ?ba/s]

Creating parquet from Arrow format: 100%|██████████| 29/29 [00:01<00:00, 21.34ba/s]


22169304

In [17]:
for i in range(len(one)):
    assert one[i] == two[i], f"Mismatch at index {i}: {one[i]} != {two[i]}"
print("All checks passed. The datasets are identical.")

All checks passed. The datasets are identical.


### 0.5 easy, 2.5 medium | 0.5 easy, 0.5 medium, 3 hard

In [41]:
tmp = concatenate_datasets([easy_first_half, medium, medium, medium_first_half])
tmp.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '0.5-easy-2.5-medium.parquet'))

Creating parquet from Arrow format: 100%|██████████| 10/10 [00:00<00:00, 22.38ba/s]


7310821

In [42]:
tmp = concatenate_datasets([easy_second_half, medium_second_half, hard, hard, hard])
tmp.to_parquet(os.path.join('/home/cmu/countdown-curriculum/data/countdown', '0.5-easy-0.5-medium-3-hard.parquet'))

Creating parquet from Arrow format: 100%|██████████| 13/13 [00:00<00:00, 19.77ba/s]


10246991