# 剔除重复数据->数据分类

In [1]:
"""
剔除重复数据->数据分类
"""

import json
import re
import random
import os

random.seed(42)

file_path = "/home/ma-user/work/train.json"
data_list = []
with open(file_path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        data_list.append(problem + "#分隔符#" + solution)

data_list = set(data_list)  # 剔除重复数据
print(len(data_list))  # 548942
dataset = []
for data in data_list:
    problem, solution = data.split("#分隔符#")
    dataset.append({"problem": problem, "solution": solution})

en_data = []  # en
calculate_data = []  # cn_计算
calculate_add_data = []
calculate_sub_data = []
calculate_mul_data = []
calculate_div_data = []
calculate_sqrt_data = []
calculate_pow_data = []
equation_data = []  # cn_解方程
product_data = []  # cn_商品
average_data = []  # cn_求平均值
quality_data = []  # cn_计算物体质量
area_data = []  # cn_计算面积
sales_data = []  # cn_计算销售额
simplify_data = []  # cn_进行简化
function_data = [] # cn_求函数

# 匹配加法、减法、乘法、除法、平方根和乘方
add_pattern = re.compile(r'(\d+\.?\d*)\+(-?\d+\.?\d*)')
sub_pattern = re.compile(r'(\d+\.?\d*)-(-?\d+\.?\d*)')
mul_pattern = re.compile(r'(\d+\.?\d*)\*(-?\d+\.?\d*)')
div_pattern = re.compile(r'(\d+\.?\d*)/(-?\d+\.?\d*)')
sqrt_pattern = re.compile(r'(\d+\.?\d*)的平方根')
pow_pattern = re.compile(r'(\d+\.?\d*)的(\d+)次方')

for item in dataset:
    question = item['problem']
    if not re.match(r'^[\u4e00-\u9fff]', question):
        en_data.append(item)
    elif question.startswith('计算'):
        calculate_data.append(item)
        new_item = item["problem"].replace(' ', '')
        if match := add_pattern.search(new_item):
            calculate_add_data.append(item)
        if match := sub_pattern.search(new_item):
            calculate_sub_data.append(item)
        if match := mul_pattern.search(new_item):
            calculate_mul_data.append(item)
        if match := div_pattern.search(new_item):
            calculate_div_data.append(item)
        if match := sqrt_pattern.search(new_item):
            calculate_sqrt_data.append(item)
        if match := pow_pattern.search(new_item):
            calculate_pow_data.append(item)

    elif question.startswith('解方程'):
        equation_data.append(item)
    elif question.startswith('商品'):
        product_data.append(item)
    elif question.startswith('求以下数据的平均值'):
        average_data.append(item)
    elif '请计算该物体的质量。' in question:
        quality_data.append(item)
    elif '请计算其面积。' in question:
        area_data.append(item)
    elif '请计算今年的销售额。' in question:
        sales_data.append(item)
    elif '进行简化。' in question:
        simplify_data.append(item)
    elif re.match(r'^当.*时，求函数.*', question):
        function_data.append(item)

print(len(en_data))  # 9993
print(len(calculate_data))  # 437781
print(len(calculate_add_data))  # 90113
print(len(calculate_sub_data))  # 89366
print(len(calculate_mul_data))  # 89902
print(len(calculate_div_data))  # 89966
print(len(calculate_sqrt_data))  # 39227
print(len(calculate_pow_data))  # 39207
print(len(equation_data))  # 40080
print(len(product_data))  # 6275
print(len(average_data))  # 19709
print(len(quality_data))  # 100
print(len(area_data))  # 8665
print(len(sales_data))  # 6266
print(len(simplify_data))  # 45
print(len(function_data))  # 20028

random.shuffle(en_data)
random.shuffle(calculate_data)
random.shuffle(calculate_add_data)
random.shuffle(calculate_sub_data)
random.shuffle(calculate_mul_data)
random.shuffle(calculate_div_data)
random.shuffle(calculate_sqrt_data)
random.shuffle(calculate_pow_data)
random.shuffle(equation_data)
random.shuffle(product_data)
random.shuffle(average_data)
random.shuffle(quality_data)
random.shuffle(area_data)
random.shuffle(sales_data)
random.shuffle(simplify_data)
random.shuffle(function_data)

os.makedirs('/home/ma-user/work/dataset/', exist_ok=True)

with open('/home/ma-user/work/dataset/en.json', 'w', encoding='utf-8') as file:
    for item in en_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate.json', 'w', encoding='utf-8') as file:
    for item in calculate_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_add.json', 'w', encoding='utf-8') as file:
    for item in calculate_add_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_sub.json', 'w', encoding='utf-8') as file:
    for item in calculate_sub_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_mul.json', 'w', encoding='utf-8') as file:
    for item in calculate_mul_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_div.json', 'w', encoding='utf-8') as file:
    for item in calculate_div_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_sqrt.json', 'w', encoding='utf-8') as file:
    for item in calculate_sqrt_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/calculate_pow.json', 'w', encoding='utf-8') as file:
    for item in calculate_pow_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/equation.json', 'w', encoding='utf-8') as file:
    for item in equation_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/product.json', 'w', encoding='utf-8') as file:
    for item in product_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/average.json', 'w', encoding='utf-8') as file:
    for item in average_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/quality.json', 'w', encoding='utf-8') as file:
    for item in quality_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/area.json', 'w', encoding='utf-8') as file:
    for item in area_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/sales.json', 'w', encoding='utf-8') as file:
    for item in sales_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/simplify.json', 'w', encoding='utf-8') as file:
    for item in simplify_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open('/home/ma-user/work/dataset/function.json', 'w', encoding='utf-8') as file:
    for item in function_data:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')


548942
9993
437781
90113
89366
89902
89966
39227
39207
40080
6275
19709
100
8665
6266
45
20028


# 分类构建 CoT 数据集

In [2]:
"""
area_dataset
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers


path = "/home/ma-user/work/dataset/area.json"
output_path = "/home/ma-user/work/dataset/cot_area.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        problem_numbers_text, problem_numbers = nums_extract(problem)

        solution_number_text, solution_number = nums_extract(solution)
        
        multiply_text = " * ".join(problem_numbers_text) + " = "
        multiply_result = problem_numbers[0] * problem_numbers[1]

        
        if multiply_result != solution_number[0]:
            print(problem)

        cot_solution = "面积 = 长 * 宽 = " + multiply_text + solution_number_text[0] + "\n因此，" + solution
        # print(cot_solution)

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')



In [3]:
"""
average_dataset
结果保留两位小数
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers


path = "/home/ma-user/work/dataset/average.json"
output_path = "/home/ma-user/work/dataset/cot_average.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        problem_numbers_text, problem_numbers = nums_extract(problem)
        problem_numbers_len = len(problem_numbers)

        solution_number_text = solution.split("平均值为 ")[1]
        solution_number = round(float(solution_number_text), 2)
        
        sum_text = " + ".join(problem_numbers_text) + " = "
        sum_result = str(sum(problem_numbers))

        if round(sum(problem_numbers) / problem_numbers_len, 2) != round(solution_number, 2):
            print(problem)

        cot_solution = "平均值 = 数据之和 / 数据个数\n首先，计算所有数据的总和：" + sum_text + sum_result + "\n" + "一共有 " + str(problem_numbers_len) + " 个数据\n" + "数据之和 / 数据个数：" + sum_result + " / " + str(problem_numbers_len) + " = " + str(solution_number) + "\n因此，平均值为 " + str(solution_number)
        # print(cot_solution)

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')


In [4]:
"""
equation_dataset
结果保留两位小数
"""

path = "/home/ma-user/work/dataset/equation.json"
output_path = "/home/ma-user/work/dataset/cot_equation.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        s1 = int(problem.split("解方程 ")[1].split("x")[0])
        s2 = int(problem.split("x + ")[1].split(" = ")[0])
        result = round(float(solution.split("方程的解为：")[1]), 2)

        if round(-s2 / s1, 2) != round(result, 2):
            print(problem)

        cot_solution = problem.split("解方程 ")[1] + "\n" + str(s1) + "x = " + str(-s2) + "\nx = " + str(-s2) + " / " + str(s1) + " = " + str(result) + "\n因此，方程的解为：" + str(result)
        # print(cot_solution)

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [5]:
"""
function_dataset
"""

path = "/home/ma-user/work/dataset/function.json"
output_path = "/home/ma-user/work/dataset/cot_function.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        x = float(problem.split("当 x = ")[1].split(" 时")[0])
        s1 = int(problem.split("求函数 y = ")[1].split("x^")[0])
        s2 = int(problem.split("x^")[1].split(" 的值")[0])
        result = float(solution.split("函数的值为：")[1])


        cot_solution = problem.split("，")[0] + "\n" + problem.split("求函数 ")[1].split(" 的值")[0] + " = " + str(s1) + " * " + str(x) + "^" + str(s2) + " = " + solution.split("函数的值为：")[1] + "\n因此，" + solution

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [6]:
"""
product_dataset
保留两位小数
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers

path = "/home/ma-user/work/dataset/product.json"
output_path = "/home/ma-user/work/dataset/cot_product.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution), 2)

        problem_numbers_text, problem_numbers = nums_extract(problem)
        dec_result = problem_numbers[0] - problem_numbers[1]

        if round(dec_result / problem_numbers[0] * 100, 2) != round(float(solution), 2):
            print(problem)

        cot_solution = "折扣比例 = (1 - 价格 / 原价) * 100% = (1 - " + problem_numbers_text[1] + " / " + problem_numbers_text[0] + ") * 100% = " + str(result) + "%\n因此，折扣比例为 " + str(result)

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [7]:
"""
quality_dataset
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers

path = "/home/ma-user/work/dataset/quality.json"
output_path = "/home/ma-user/work/dataset/cot_quality.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        problem_numbers_text, problem_numbers = nums_extract(problem)
        solution_numbers_text, solution_numbers = nums_extract(solution)

        if problem_numbers[0] * problem_numbers[1] != solution_numbers[0]:
            print(problem)

        cot_solution = "质量 = 密度 * 体积 = " + problem_numbers_text[0] + " * " + problem_numbers_text[1] + " = " + solution_numbers_text[0] + "\n因此，该物体的质量为 " + solution

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [8]:
"""
sales_dataset
保留两位小数
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers

path = "/home/ma-user/work/dataset/sales.json"
output_path = "/home/ma-user/work/dataset/cot_sales.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution), 2)

        problem_numbers_text, problem_numbers = nums_extract(problem)
        mid_result = round(float(1 + problem_numbers[1] / 100), 2)

        if round(problem_numbers[0] * mid_result, 2) != round(float(solution), 2):
            print(problem)

        cot_solution = "今年销售额 = 去年销售额 * (1 + 增加比例) = " + problem_numbers_text[0] + " * (1 + " + problem_numbers_text[1] + "%) = " + problem_numbers_text[0] + " * " + str(mid_result) + " = " + str(result) + "\n因此，今年的销售额为 " + str(result) + " 万元"

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [9]:
"""
simplify_dataset
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    numbers_text = re.findall(r'\d+', text)
    # 转换为整数列表
    numbers = [int(num) for num in numbers_text]
    return numbers_text, numbers

path = "/home/ma-user/work/dataset/simplify.json"
output_path = "/home/ma-user/work/dataset/cot_simplify.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']

        problem_text = problem.split("将分数 ")[1].split(" 进行简化。")[0]
        result = solution.split("最简化的形式为：")[1]

        cot_solution = problem_text + " = " + problem_text + "，" + solution

        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

# 后续手动改几个化简结果

In [10]:
"""
calculate_add_dataset
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    pattern = r'-?\d+\.?\d*'
    numbers_text = re.findall(pattern, text)
    numbers = [float(num) for num in numbers_text]
    return numbers

path = "/home/ma-user/work/dataset/calculate_add.json"
output_path = "/home/ma-user/work/dataset/calculate_add.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = solution.split('= ')[1]

        s1, s2 = nums_extract(problem)
        if s1 >= 0 and s2 >= 0:
            cot_data = {"problem": problem, "solution": solution}
        elif s1 >= 0 and s2 < 0:
            cot_solution = str(s1) + " + " + str(s2) + " = " + str(s1) + " - " + str(-s2) + " = " + result
            cot_data = {"problem": problem, "solution": cot_solution}
        elif s1 < 0 and s2 >= 0:
            cot_solution = str(s1) + " + " + str(s2) + " = " + str(s2) + " - " + str(-s1) + " = " + result
            cot_data = {"problem": problem, "solution": cot_solution}
        elif s1 < 0 and s2 < 0:
            cot_solution = str(s1) + " + " + str(s2) + " = -(" + str(-s1) + " + " + str(-s2) + ") = " + result
            cot_data = {"problem": problem, "solution": cot_solution}

        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [11]:
"""
calculate_sub_dataset
"""

def nums_extract(text):
    # 使用正则表达式提取所有数字
    pattern = r'-?\d+\.?\d*'
    numbers_text = re.findall(pattern, text)
    numbers = [float(num) for num in numbers_text]
    return numbers

path = "/home/ma-user/work/dataset/calculate_sub.json"
output_path = "/home/ma-user/work/dataset/calculate_sub.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = solution.split('= ')[1]

        s1, s2 = nums_extract(problem)
        if s1 >= 0 and s2 >= 0:
            cot_data = {"problem": problem, "solution": solution}
        elif s1 >= 0 and s2 < 0:
            cot_solution = str(s1) + " - " + str(s2) + " = " + str(s1) + " + " + str(-s2) + " = " + result
            cot_data = {"problem": problem, "solution": cot_solution}
        elif s1 < 0 and s2 >= 0:
            cot_solution = str(s1) + " - " + str(s2) + " = -(" + str(-s1) + " + " + str(s2) + ") = " + result
            cot_data = {"problem": problem, "solution": cot_solution}
        elif s1 < 0 and s2 < 0:
            cot_solution = str(s1) + " - " + str(s2) + " = " + str(-s2) + " - " + str(-s1) + " = " + result
            cot_data = {"problem": problem, "solution": cot_solution}

        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [12]:
"""
calculate_mul_dataset
保留两位小数
"""

path = "/home/ma-user/work/dataset/calculate_mul.json"
output_path = "/home/ma-user/work/dataset/calculate_mul.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution.split('= ')[1]), 2)

        cot_solution = solution.split('= ')[0] + "= " + str(result)
        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [13]:
"""
calculate_div_dataset
保留两位小数
"""

path = "/home/ma-user/work/dataset/calculate_div.json"
output_path = "/home/ma-user/work/dataset/calculate_div.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution.split('= ')[1]), 2)

        cot_solution = solution.split('= ')[0] + "= " + str(result)
        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [14]:
"""
calculate_sqrt_dataset
保留两位小数
"""

path = "/home/ma-user/work/dataset/calculate_sqrt.json"
output_path = "/home/ma-user/work/dataset/calculate_sqrt.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution.split('= ')[1]), 2)

        cot_solution = solution.split('= ')[0] + "= " + str(result)
        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

In [15]:
"""
calculate_pow_dataset
保留两位小数
"""

path = "/home/ma-user/work/dataset/calculate_pow.json"
output_path = "/home/ma-user/work/dataset/calculate_pow.json"
data_list = []
with open(path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        problem = data['problem']
        solution = data['solution']
        result = round(float(solution.split('= ')[1]), 2)

        cot_solution = solution.split('= ')[0] + "= " + str(result)
        cot_data = {"problem": problem, "solution": cot_solution}
        data_list.append(cot_data)

with open(output_path, 'w', encoding='utf-8') as file:
    for item in data_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

# 采样构建训练集&验证集

In [17]:
import json
import re
import random

random.seed(42)

train_path = "/home/ma-user/work/train_dataset.json"
val_path = "/home/ma-user/work/val_dataset.json"
area_path = "/home/ma-user/work/dataset/cot_area.json"
average_path = "/home/ma-user/work/dataset/cot_average.json"
calculate_add_path = "/home/ma-user/work/dataset/cot_calculate_add.json"
calculate_sub_path = "/home/ma-user/work/dataset/cot_calculate_sub.json"
calculate_mul_path = "/home/ma-user/work/dataset/cot_calculate_mul.json"
calculate_div_path = "/home/ma-user/work/dataset/cot_calculate_div.json"
calculate_sqrt_path = "/home/ma-user/work/dataset/cot_calculate_sqrt.json"
calculate_pow_path = "/home/ma-user/work/dataset/cot_calculate_pow.json"
en_path = "/home/ma-user/work/dataset/en.json"
equation_path = "/home/ma-user/work/dataset/cot_equation.json"
function_path = "/home/ma-user/work/dataset/cot_function.json"
product_path = "/home/ma-user/work/dataset/cot_product.json"
quality_path = "/home/ma-user/work/dataset/cot_quality.json"
sales_path = "/home/ma-user/work/dataset/cot_sales.json"
simplify_path = "/home/ma-user/work/dataset/cot_simplify.json"

dataset_path_list_1 = [area_path, calculate_add_path, calculate_sub_path, calculate_mul_path, calculate_div_path, calculate_sqrt_path, calculate_pow_path, en_path, product_path, sales_path]

dataset_path_list_2 = [average_path, equation_path, function_path]

train_list = []
val_list = []
for file_path in dataset_path_list_1:
    with open(file_path, 'r', encoding='utf-8') as file:
        i = 0
        for line in file:
            data = json.loads(line.strip())
            i += 1
            if i <= 6000:
                train_list.append(data)
            elif i > 6000 and i <= 6100:
                val_list.append(data)

for file_path in dataset_path_list_2:
    with open(file_path, 'r', encoding='utf-8') as file:
        i = 0
        for line in file:
            data = json.loads(line.strip())
            i += 1
            if i <= 12000:
                train_list.append(data)
            elif i > 12000 and i <= 12100:
                val_list.append(data)

with open(quality_path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        for i in range(10):
            train_list.append(data)
        val_list.append(data)

with open(simplify_path, 'r', encoding='utf-8') as file:
    for line in file:
        data = json.loads(line.strip())
        for i in range(10):
            train_list.append(data)
        val_list.append(data)

val_list = val_list[:1440]
random.shuffle(train_list)
# random.shuffle(val_list)

with open(train_path, 'w', encoding='utf-8') as file:
    for item in train_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')

with open(val_path, 'w', encoding='utf-8') as file:
    for item in val_list:
        json.dump(item, file, ensure_ascii=False)
        file.write('\n')