In [5]:
from transformers import AutoTokenizer
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING

# 查看所有注册的分词器映射
for model_config, tokenizer_class in TOKENIZER_MAPPING.items():
    print(f"{model_config.__name__}: {tokenizer_class}")

AlbertConfig: (None, <class 'transformers.models.albert.tokenization_albert_fast.AlbertTokenizerFast'>)
AlignConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>)
AriaConfig: (<class 'transformers.utils.dummy_sentencepiece_objects.LlamaTokenizer'>, <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>)
AyaVisionConfig: (None, <class 'transformers.models.cohere.tokenization_cohere_fast.CohereTokenizerFast'>)
BarkConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>)
BartConfig: (<class 'transformers.models.bart.tokenization_bart.BartTokenizer'>, <class 'transformers.models.bart.tokenization_bart_fast.BartTokenizerFast'>)
BertConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokeniz

In [1]:
import json
import random
import re
import os
from typing import Literal

def set_global_mod_mode(use_mod: bool):
    """
    直接读取、修改并重写 const/params.py 文件来设置全局模式。
    """
    params_path = os.path.join('const', 'params.py')
    
    try:
        with open(params_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except FileNotFoundError:
        print(f"错误: 无法找到 '{params_path}'。请确保脚本在 iGSM 项目根目录下运行。")
        exit(1)

    found_line = False
    for i, line in enumerate(lines):
        if re.match(r'^\s*USE_MOD\s*=', line):
            lines[i] = f"USE_MOD = {use_mod}\n"
            found_line = True
            break
            
    if not found_line:
        lines.append(f"\nUSE_MOD = {use_mod}\n")

    with open(params_path, 'w', encoding='utf-8') as f:
        f.writelines(lines)


def run_generation(num_problems, difficulty, output_file, seed=42, batch_size=1000):
    """
    封装了导入和生成逻辑的主函数。
    """
    from data_gen.pretrain.id_gen import IdGen
    from tools.tools import tokenizer, fix_seed
    from const import params

    fix_seed(seed)
    data_buffer, total_generated = [], 0

    print(f"开始生成 {num_problems} 个问题...")
    print(f"难度: {difficulty}")
    print(f"文件名: {output_file}")
    print(f"验证: 当前运行时模式 USE_MOD = {params.USE_MOD}")
    print("-" * 30)
    
    # vvvvvvvv 关键修正点 vvvvvvvv
    def get_prob_sol_ans_triple(tpy: Literal["med", "hard"]):
        assert tpy in ["med", "hard"], "Invalid type: Choose 'med' or 'hard'"
        max_op, max_edge = (15, 20) if tpy == "med" else (21, 28)
        id_gen = IdGen(max_op=max_op, max_edge=max_edge, perm_level=5, detail_level=0)
        
        # 根据当前模式决定传递给 gen_prob 的哈希列表
        # 这是为了满足取模模式下的哈希过滤机制
        if params.USE_MOD:
            # 在取模模式下，提供所有可能的哈希桶，确保任何生成的问题都能通过
            valid_hashes = list(range(23))
        else:
            # 在整数模式下，过滤被禁用，列表内容不重要，但不能为空
            valid_hashes = [0] 

        while True:
            try:
                # 将正确的哈希列表传递给 gen_prob
                id_gen.gen_prob(valid_hashes, p_format="pq")
                break 
            except Exception:
                continue
        return id_gen
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    with open(output_file, 'w', encoding='utf-8') as f_out:
        for i in range(num_problems):
            id_gen = get_prob_sol_ans_triple(difficulty)
            problem_text = tokenizer.decode(id_gen.prob_token, skip_special_tokens=True).strip()
            solution_text = tokenizer.decode(id_gen.sol_token, skip_special_tokens=True).strip()
            answer_text = tokenizer.decode(id_gen.ans_token, skip_special_tokens=True).strip()
            
            data_buffer.append({
                "id": f"{difficulty}_{seed}_{i}",
                "problem": problem_text, "solution": solution_text, "answer": answer_text
            })
            
            if len(data_buffer) >= batch_size or (i + 1) == num_problems:
                for item in data_buffer:
                    f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
                total_generated += len(data_buffer)
                print(f"已生成并保存 {total_generated}/{num_problems} 个问题...")
                data_buffer = []
            
    print(f"\n成功完成: {output_file}")


# ===================================================================
#                      主程序入口
# ===================================================================
if __name__ == "__main__":
    
    # --- 所有可调参数都集中在这里 ---
    USE_MOD = False
    NUM_TO_GENERATE = 10
    DIFFICULTY = "med"
    OUTPUT_FILE = "gsm_dataset.jsonl"
    BATCH_SIZE = 10
    RANDOM_SEED = 42
    
    # --- 执行 ---
    set_global_mod_mode(USE_MOD)
    run_generation(
        num_problems=NUM_TO_GENERATE,
        difficulty=DIFFICULTY,
        output_file=OUTPUT_FILE,
        seed=RANDOM_SEED,
        batch_size=BATCH_SIZE
    )

开始生成 10 个问题...
难度: med
文件名: gsm_dataset.jsonl
验证: 当前运行时模式 USE_MOD = False
------------------------------
DEBUG: Num created with value 'None' and use_mod=False
DEBUG: Num created with value 'None' and use_mod=False
DEBUG: Num created with value 'None' and use_mod=False
DEBUG: Num created with value '0' and use_mod=False
DEBUG: Num created with value '4' and use_mod=False
DEBUG: Num created with value 'None' and use_mod=False
DEBUG: Num created with value '0' and use_mod=False
DEBUG: Num created with value '8' and use_mod=False
DEBUG: Num created with value '8' and use_mod=False
DEBUG: Num created with value '16' and use_mod=False
DEBUG: Num created with value 'None' and use_mod=False
DEBUG: Num created with value '0' and use_mod=False
DEBUG: Num created with value '20' and use_mod=False
DEBUG: Num created with value '20' and use_mod=False
DEBUG: Num created with value '26' and use_mod=False
DEBUG: Num created with value '0' and use_mod=False
DEBUG: Num created with value '16' and use_m