# 1. 准备数据

In [1]:
!mkdir -p data

In [2]:
TRAIN_DATA_NUM = 1000
TEST_DATA_NUM = 100

TRAIN_DATA_PATH = "data/train.jsonl"
TRAIN_DATA_WITH_STEPS_PATH = "data/train_with_steps.jsonl"
TEST_DATA_PATH = "data/test.jsonl"

RL_DATA_PATH = "data/rl_data.jsonl"
RL_DATA_NUM = 1000

## 1.1 合成 Countdown 数据集

限制：

1. 仅生成顺序 ops 表达式，无法生成 (A+B)/(C-D) 这种表达式
2. 每个数字均被使用，且每个数字只使用一次。
3. 中间计算结果必须是整数。
4. 必须有解。

In [3]:
import tqdm
import random
import json
from typing import List, Tuple


In [4]:
def convert_solution_to_expression(solution):
    """将计算步骤转换为标准数学表达式"""
    if not solution:
        return ""
    
    # 第一步
    expr = f"({solution[0][0]} {solution[0][1]} {solution[0][2]})"
    
    # 后续步骤
    for step in solution[1:]:
        expr = f"({expr} {step[1]} {step[2]})"
    
    # 去掉最外层的括号
    if expr.startswith('(') and expr.endswith(')'):
        expr = expr[1:-1]
    
    return expr

def gen_dataset(
    num_samples: int,
    num_operands: int = 4,
    max_target: int = 999,
    min_number: int = 1,
    max_number: int = 999,
    operations: List[str] = ['+', '-', '*', '/'],
    op_weights: dict = {'*': 0.2, '/': 0.7, '+': 0.05, '-': 0.05},
    small_number_ratio: float = 0.8,  # 80%的数字选自小范围
    small_range_ratio: float = 0.1,   # 小范围为总范围的前10%
    seed_value: int = 42,
) -> List[Tuple]:
    random.seed(seed_value)
    samples = []
    
    # 计算小范围的上限
    small_range_upper = min_number + int((max_number - min_number) * small_range_ratio)
    
    for _ in tqdm.tqdm(range(num_samples)):
        while True:
            # 生成随机数，80%的数字从小范围中选择
            numbers = []
            for _ in range(num_operands):
                if random.random() < small_number_ratio:
                    # 从小范围选择
                    num = random.randint(min_number, small_range_upper)
                else:
                    # 从全范围选择
                    num = random.randint(min_number, max_number)
                numbers.append(num)
            
            # 尝试生成有效表达式
            solution = []
            nums_left = numbers.copy()
            valid = True
            
            # 取第一个数作为初始值
            n1 = nums_left.pop(0)
            
            # 依次处理剩余数字
            while nums_left and valid:
                n2 = nums_left.pop(0)
                
                # 根据权重选择操作符
                weighted_ops = []
                for op in operations:
                    weight = op_weights.get(op, 1.0 / len(operations))
                    weighted_ops.extend([op] * int(weight * 100))
                
                random.shuffle(weighted_ops)
                valid_op_found = False
                
                # 尝试所有操作符
                tried_ops = set()
                while weighted_ops and not valid_op_found:
                    op = weighted_ops.pop()
                    if op in tried_ops:
                        continue
                    tried_ops.add(op)
                    
                    if op == '+':
                        result = n1 + n2
                        valid_op_found = True
                    elif op == '-':
                        result = n1 - n2
                        valid_op_found = True
                    elif op == '*':
                        result = n1 * n2
                        valid_op_found = True
                    elif op == '/' and n2 != 0 and n1 % n2 == 0:
                        result = n1 // n2
                        valid_op_found = True
                    else:
                        continue
                        
                    if valid_op_found:
                        solution.append((n1, op, n2, result))
                        n1 = result
                        break
                
                # 如果所有操作符都尝试过仍未找到有效操作符
                if not valid_op_found:
                    valid = False
            
            # 如果生成了有效表达式
            if valid:
                target = n1  # 最终结果
                if target > 0 and target <= max_target:  # 确保结果在有效范围内
                    random.shuffle(numbers) # 打乱顺序，避免和答案顺序一致
                    samples.append({
                        "numbers": numbers,
                        "target": target,
                        "ground_truth_solution": convert_solution_to_expression(solution),
                        "solution_steps": solution,
                    })
                    break
    
    # 缺点，这种方法只能生成顺序的表达式，无法生成比如 (A+B)/(C-D) 这种表达式
    return samples

In [5]:
samples = gen_dataset(TRAIN_DATA_NUM + TEST_DATA_NUM + RL_DATA_NUM)
samples[:3]

100%|███████████████████████████████████████████████████████████████████████████████| 2100/2100 [00:03<00:00, 622.10it/s]


[{'numbers': [29, 11, 27, 10],
  'target': 524,
  'ground_truth_solution': '((29 - 10) * 27) + 11',
  'solution_steps': [(29, '-', 10, 19),
   (19, '*', 27, 513),
   (513, '+', 11, 524)]},
 {'numbers': [34, 35, 5, 56],
  'target': 272,
  'ground_truth_solution': '((56 * 5) / 35) * 34',
  'solution_steps': [(56, '*', 5, 280), (280, '/', 35, 8), (8, '*', 34, 272)]},
 {'numbers': [34, 5, 8, 15],
  'target': 816,
  'ground_truth_solution': '((8 * 15) * 34) / 5',
  'solution_steps': [(8, '*', 15, 120),
   (120, '*', 34, 4080),
   (4080, '/', 5, 816)]}]

In [6]:
with open(TRAIN_DATA_WITH_STEPS_PATH, "w") as f:
    for sample in samples[:TRAIN_DATA_NUM]:
        f.write(json.dumps(sample) + "\n")

with open(TRAIN_DATA_PATH, "w") as f:
    for sample in samples[:TRAIN_DATA_NUM]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

with open(TEST_DATA_PATH, "w") as f:
    for sample in samples[TRAIN_DATA_NUM:TRAIN_DATA_NUM + TEST_DATA_NUM]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

with open(RL_DATA_PATH, "w") as f:
    for sample in samples[TRAIN_DATA_NUM + TEST_DATA_NUM:]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

### 小作业

目前只能生成顺序的表达式，无法生成比如 (A+B)/(C-D) 这种表达式，请修改为可以生成这种表达式。

## 2. 生成简单一些的数据集

In [7]:
samples_simple = gen_dataset(RL_DATA_NUM + TEST_DATA_NUM, num_operands=3)
samples_simple[:3]

100%|██████████████████████████████████████████████████████████████████████████████| 1100/1100 [00:00<00:00, 2091.09it/s]


[{'numbers': [77, 7, 77],
  'target': 7,
  'ground_truth_solution': '(77 / 77) * 7',
  'solution_steps': [(77, '/', 77, 1), (1, '*', 7, 7)]},
 {'numbers': [57, 63, 11],
  'target': 131,
  'ground_truth_solution': '(63 + 57) + 11',
  'solution_steps': [(63, '+', 57, 120), (120, '+', 11, 131)]},
 {'numbers': [11, 15, 22],
  'target': 17,
  'ground_truth_solution': '(22 / 11) + 15',
  'solution_steps': [(22, '/', 11, 2), (2, '+', 15, 17)]}]

In [8]:
with open("data/rl_data_simple.jsonl", "w") as f:
    for sample in samples_simple[:RL_DATA_NUM]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

with open("data/test_simple.jsonl", "w") as f:
    for sample in samples_simple[RL_DATA_NUM:]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

In [11]:
samples_simple_10k = gen_dataset(10000, num_operands=3, seed_value=888)
with open("data/rl_data_simple_10k.jsonl", "w") as f:
    for sample in samples_simple_10k[:10000]:
        del sample["solution_steps"]
        f.write(json.dumps(sample) + "\n")

100%|████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:05<00:00, 1859.08it/s]


## 3. （可选）使用现成的数据集

从 Jiayi-Pan/Countdown-Tasks-3to4 或者类似的数据集中抽取，这些数据集确保结果可解，所以无需计算过程。

此处略
