# iGSM 生成中文数据的版本

In [None]:
import json
from data_gen.pretrain.id_gen import IdGen
from tools.tools import tokenizer, fix_seed
from typing import Literal
import math

def get_prob_sol_ans_triple(tpy: Literal["med", "hard"]):
    """
    生成一个包含问题、解题过程和答案的 IdGen 对象。
    """
    assert tpy in ["med", "hard"], "Invalid type: Choose 'med' or 'hard'"
    
    # 根据难度设置参数
    max_op = 15 if tpy == "med" else 21
    max_edge = 20 if tpy == "med" else 28

    id_gen = IdGen(
        max_op=max_op,
        max_edge=max_edge,
        perm_level=5,
        detail_level=0
    )

    # 循环尝试生成，直到成功为止
    while True:
        try:
            id_gen.gen_prob([0], p_format="pq")
            break
        except Exception as e:
            # print(f"Generation failed with error: {e}. Retrying...") # 调试时可以取消注释
            continue
            
    return id_gen

def generate_and_save_to_jsonl(num_problems: int, difficulty: Literal["med", "hard"], output_file: str, seed: int = 42, batch_size: int = 1000):
    """
    生成指定数量的题目，并以 JSONL 格式批量保存。

    :param num_problems: 要生成的题目总数。
    :param difficulty: 题目难度 ('med' 或 'hard')。
    :param output_file: 输出的 .jsonl 文件名。
    :param seed: 随机种子，确保结果可复现。
    :param batch_size: 每批生成和保存的题目数量（缓存大小）。
    """
    fix_seed(seed)
    
    data_buffer = []  # 用于缓存生成的题目
    total_generated = 0

    # 使用 'w' 模式打开文件，确保从一个空文件开始
    with open(output_file, 'w', encoding='utf-8') as f:
        # 使用 'a' (append) 模式在内部循环中写入，防止每次都覆盖
        with open(output_file, 'a', encoding='utf-8') as f_append:
            for i in range(num_problems):
                # 1. 生成一个问题实例
                id_gen = get_prob_sol_ans_triple(difficulty)
                
                # 2. 将 token 解码为字符串
                problem_text = tokenizer.decode(id_gen.prob_token, skip_special_tokens=True).strip()
                solution_text = tokenizer.decode(id_gen.sol_token, skip_special_tokens=True).strip()
                answer_text = tokenizer.decode(id_gen.ans_token, skip_special_tokens=True).strip()
                
                # 3. 创建键值对并添加到缓存中
                problem_data = {
                    "id": f"{difficulty}_{i}",
                    "problem": problem_text,
                    "solution": solution_text,
                    "answer": answer_text
                }
                data_buffer.append(problem_data)
                
                # 4. 检查是否达到批处理大小或者已是最后一题
                if len(data_buffer) == batch_size or (i + 1) == num_problems:
                    # 批量写入文件
                    for item in data_buffer:
                        f_append.write(json.dumps(item, ensure_ascii=False) + '\n')
                    
                    total_generated += len(data_buffer)
                    
                    # 打印进度
                    print(f"Generated and saved {total_generated}/{num_problems} problems...")
                    
                    # 清空缓存
                    data_buffer = []
            
    print(f"\nSuccessfully generated {num_problems} problems and saved to '{output_file}'.")


# ==============================
#           主程序入口
# ==============================
if __name__ == "__main__":
    # --- 参数设置 ---
    NUM_TO_GENERATE = 5000       # 你想生成的题目总数
    DIFFICULTY_LEVEL = "med"     # 'med' 或 'hard'
    OUTPUT_FILENAME = "gsm_dataset_large.jsonl" # 输出文件名
    BATCH_SIZE = 1000            # 每1000题打印一次并保存

    # --- 开始执行 ---
    generate_and_save_to_jsonl(
        num_problems=NUM_TO_GENERATE,
        difficulty=DIFFICULTY_LEVEL,
        output_file=OUTPUT_FILENAME,
        batch_size=BATCH_SIZE
    )