In [5]:
from transformers import AutoTokenizer
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING

# 查看所有注册的分词器映射
for model_config, tokenizer_class in TOKENIZER_MAPPING.items():
    print(f"{model_config.__name__}: {tokenizer_class}")

AlbertConfig: (None, <class 'transformers.models.albert.tokenization_albert_fast.AlbertTokenizerFast'>)
AlignConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>)
AriaConfig: (<class 'transformers.utils.dummy_sentencepiece_objects.LlamaTokenizer'>, <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>)
AyaVisionConfig: (None, <class 'transformers.models.cohere.tokenization_cohere_fast.CohereTokenizerFast'>)
BarkConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokenizerFast'>)
BartConfig: (<class 'transformers.models.bart.tokenization_bart.BartTokenizer'>, <class 'transformers.models.bart.tokenization_bart_fast.BartTokenizerFast'>)
BertConfig: (<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>, <class 'transformers.models.bert.tokenization_bert_fast.BertTokeniz

In [3]:
import json
import random
import re
import os
from typing import Literal

def set_global_mod_mode(use_mod: bool):
    """
    直接读取、修改并重写 const/params.py 文件来设置全局模式。
    """
    params_path = os.path.join('const', 'params.py')
    
    try:
        with open(params_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except FileNotFoundError:
        print(f"错误: 无法找到 '{params_path}'。请确保脚本在 iGSM 项目根目录下运行。")
        exit(1)

    found_line = False
    for i, line in enumerate(lines):
        if re.match(r'^\s*USE_MOD\s*=', line):
            lines[i] = f"USE_MOD = {use_mod}\n"
            found_line = True
            break
            
    if not found_line:
        lines.append(f"\nUSE_MOD = {use_mod}\n")

    with open(params_path, 'w', encoding='utf-8') as f:
        f.writelines(lines)


def run_generation(num_problems, difficulty, output_file, seed=42, batch_size=1000):
    """
    封装了导入和生成逻辑的主函数。
    """
    from data_gen.pretrain.id_gen import IdGen
    from tools.tools import tokenizer, fix_seed
    from const import params

    # fix_seed(seed)
    data_buffer, total_generated = [], 0

    print(f"开始生成 {num_problems} 个问题...")
    print(f"难度: {difficulty}")
    print(f"文件名: {output_file}")
    print(f"验证: 当前运行时模式 USE_MOD = {params.USE_MOD}")
    print("-" * 30)
    
    # vvvvvvvv 关键修正点 vvvvvvvv
        # vvvvvvvv 关键修正点 vvvvvvvv
    def get_prob_sol_ans_triple(tpy: Literal["easy", "med", "hard"]):
        assert tpy in ["easy", "med", "hard"], "Invalid type: Choose 'easy', 'med' or 'hard'"
        
        if tpy == "easy":
            max_op, max_edge = (9, 12)
        elif tpy == "med":
            max_op, max_edge = (15, 20)
        else:  # hard
            max_op, max_edge = (21, 28)

        id_gen = IdGen(max_op=max_op, max_edge=max_edge, perm_level=5, detail_level=0)
        
        # 根据当前模式决定传递给 gen_prob 的哈希列表
        if params.USE_MOD:
            valid_hashes = list(range(23))  # 取模模式
        else:
            valid_hashes = [0]  # 非取模模式

        while True:
            try:
                id_gen.gen_prob(valid_hashes, p_format="pq")
                break 
            except Exception:
                continue
        return id_gen
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


    with open(output_file, 'w', encoding='utf-8') as f_out:
        for i in range(num_problems):
            id_gen = get_prob_sol_ans_triple(difficulty)
            problem_text = tokenizer.decode(id_gen.prob_token, skip_special_tokens=True).strip()
            solution_text = tokenizer.decode(id_gen.sol_token, skip_special_tokens=True).strip()
            answer_text = tokenizer.decode(id_gen.ans_token, skip_special_tokens=True).strip()
            
            data_buffer.append({
                "id": f"{difficulty}_{seed}_{i}",
                "problem": problem_text, "solution": solution_text, "answer": answer_text
            })
            
            if len(data_buffer) >= batch_size or (i + 1) == num_problems:
                for item in data_buffer:
                    f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
                total_generated += len(data_buffer)
                print(f"已生成并保存 {total_generated}/{num_problems} 个问题...")
                data_buffer = []
            
    print(f"\n成功完成: {output_file}")


# ===================================================================
#                      主程序入口
# ===================================================================
if __name__ == "__main__":
    
    # --- 所有可调参数都集中在这里 ---
    USE_MOD = False
    NUM_TO_GENERATE = 1000
    DIFFICULTY = "easy"
    OUTPUT_FILE = "GSM_data_easy_1000.jsonl"
    BATCH_SIZE = 1000
    RANDOM_SEED = 42
    
    # --- 执行 ---
    set_global_mod_mode(USE_MOD)
    run_generation(
        num_problems=NUM_TO_GENERATE,
        difficulty=DIFFICULTY,
        output_file=OUTPUT_FILE,
        seed=RANDOM_SEED,
        batch_size=BATCH_SIZE
    )

开始生成 1000 个问题...
难度: easy
文件名: GSM_data_easy_1000.jsonl
验证: 当前运行时模式 USE_MOD = False
------------------------------
已生成并保存 1000/1000 个问题...

成功完成: GSM_data_easy_1000.jsonl


# 😺🤖直出

In [12]:
# -*- coding: utf-8 -*-
import json
import sys
import os
import re
import random
from typing import Literal
from tqdm import tqdm

# ==============================================================
#                   数据生成部分
# ==============================================================

def set_global_mod_mode(use_mod: bool):
    """
    直接读取、修改并重写 const/params.py 文件来设置全局模式。
    """
    params_path = os.path.join('const', 'params.py')
    
    try:
        with open(params_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
    except FileNotFoundError:
        print(f"错误: 无法找到 '{params_path}'。请确保脚本在 iGSM 项目根目录下运行。")
        sys.exit(1)

    found_line = False
    for i, line in enumerate(lines):
        if re.match(r'^\s*USE_MOD\s*=', line):
            lines[i] = f"USE_MOD = {use_mod}\n"
            found_line = True
            break
            
    if not found_line:
        lines.append(f"\nUSE_MOD = {use_mod}\n")

    with open(params_path, 'w', encoding='utf-8') as f:
        f.writelines(lines)


def run_generation(num_problems, difficulty, output_file, seed=42, batch_size=1000):
    """
    封装了导入和生成逻辑的主函数。
    """
    from data_gen.pretrain.id_gen import IdGen
    from tools.tools import tokenizer
    from const import params

    data_buffer, total_generated = [], 0

    print(f"开始生成 {num_problems} 个问题...")
    print(f"难度: {difficulty}")
    print(f"文件名: {output_file}")
    print(f"验证: 当前运行时模式 USE_MOD = {params.USE_MOD}")
    print("-" * 30)
    
    # vvvvvvvv 支持 easy/med/hard vvvvvvvv
    def get_prob_sol_ans_triple(tpy: Literal["easy", "med", "hard"]):
        assert tpy in ["veasy","easy", "med", "hard"], "Invalid type: Choose  'veasy','easy', 'med' or 'hard'"
        
        if tpy == "veasy":
            max_op, max_edge = (5, 8)
        elif tpy == "easy":
            max_op, max_edge = (9, 12)
        elif tpy == "med":
            max_op, max_edge = (15, 20)
        else:  # hard
            max_op, max_edge = (21, 28)

        id_gen = IdGen(max_op=max_op, max_edge=max_edge, perm_level=5, detail_level=0)
        
        # 根据当前模式决定传递给 gen_prob 的哈希列表
        valid_hashes = list(range(23)) if params.USE_MOD else [0]

        while True:
            try:
                id_gen.gen_prob(valid_hashes, p_format="pq")
                break 
            except Exception:
                continue
        return id_gen
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    with open(output_file, 'w', encoding='utf-8') as f_out:
        for i in range(num_problems):
            id_gen = get_prob_sol_ans_triple(difficulty)
            problem_text = tokenizer.decode(id_gen.prob_token, skip_special_tokens=True).strip()
            solution_text = tokenizer.decode(id_gen.sol_token, skip_special_tokens=True).strip()
            answer_text = tokenizer.decode(id_gen.ans_token, skip_special_tokens=True).strip()
            
            data_buffer.append({
                "id": f"{difficulty}_{seed}_{i}",
                "problem": problem_text, "solution": solution_text, "answer": answer_text
            })
            
            if len(data_buffer) >= batch_size or (i + 1) == num_problems:
                for item in data_buffer:
                    f_out.write(json.dumps(item, ensure_ascii=False) + '\n')
                total_generated += len(data_buffer)
                print(f"已生成并保存 {total_generated}/{num_problems} 个问题...")
                data_buffer = []
            
    print(f"\n成功完成: {output_file}")


# ==============================================================
#                   转换为 😺🤖 对话格式
# ==============================================================

def convert_to_conversation(data: dict) -> list | None:
    """
    将单行原始JSON对象转换为目标对话格式 [{"😺": ...}, {"🤖": ...}]。
    """
    problem_text = data.get("problem")
    solution_text = data.get("solution")
    answer_text = data.get("answer")

    if not all([problem_text, solution_text, answer_text]):
        return None

    assistant_content = solution_text.strip() + "\n" + answer_text.strip()
    conversation = [
        {"😺": problem_text.strip()},
        {"🤖": assistant_content}
    ]
    return conversation


def convert_file(input_file: str, output_file: str):
    """
    将生成的数据文件转换为对话格式。
    """
    print(f"[*] 开始转换文件...")
    print(f"    输入: {input_file}")
    print(f"    输出: {output_file}")

    lines_processed = 0
    lines_converted = 0
    lines_skipped = 0
    
    try:
        with open(input_file, 'r', encoding='utf-8') as f_in, \
             open(output_file, 'w', encoding='utf-8') as f_out:
            
            pbar = tqdm(f_in, desc="转换中", unit="行", file=sys.stdout, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')

            for line in pbar:
                lines_processed += 1
                line = line.strip()
                if not line:
                    lines_skipped += 1
                    continue

                try:
                    original_data = json.loads(line)
                    converted_conversation = convert_to_conversation(original_data)

                    if converted_conversation:
                        output_line = json.dumps(
                            converted_conversation, 
                            ensure_ascii=False, 
                            separators=(',', ':')
                        )
                        f_out.write(output_line + '\n')
                        lines_converted += 1
                    else:
                        lines_skipped += 1

                except json.JSONDecodeError:
                    print(f"\n[!] 警告: 第 {lines_processed} 行 JSON 解析失败，已跳过。", file=sys.stderr)
                    lines_skipped += 1
                except Exception as e:
                    print(f"\n[!] 错误: 处理第 {lines_processed} 行时发生未知错误: {e}", file=sys.stderr)
                    lines_skipped += 1

    except FileNotFoundError:
        print(f"\n[X] 错误: 输入文件 '{input_file}' 未找到。", file=sys.stderr)
        sys.exit(1)

    print("\n" + "="*30)
    print(" 转换完成 ")
    print("="*30)
    print(f"[*] 共处理输入文件行数: {lines_processed:,}")
    print(f"[*] 成功转换并写入行数: {lines_converted:,}")
    print(f"[*] 跳过或转换失败行数: {lines_skipped:,}")
    print(f"[*] 结果已保存至: {output_file}")
    print("="*30)


# ==============================================================
#                      主程序入口
# ==============================================================

if __name__ == "__main__":
    # --- 生成参数 ---
    USE_MOD = False
    NUM_TO_GENERATE = 1000000
    DIFFICULTY = "easy"   # 可选: "veasy" | "easy" | "med" | "hard"
    GEN_OUTPUT_FILE = f"GSM_data_{DIFFICULTY}_{NUM_TO_GENERATE}2.jsonl"
    RANDOM_SEED = 42

    # --- 转换参数 ---
    FINAL_OUTPUT_FILE = f"GSM_data_{DIFFICULTY}_{NUM_TO_GENERATE}_😺🤖2.jsonl"

    # --- 执行数据生成 ---
    set_global_mod_mode(USE_MOD)
    run_generation(
        num_problems=NUM_TO_GENERATE,
        difficulty=DIFFICULTY,
        output_file=GEN_OUTPUT_FILE,
        seed=RANDOM_SEED,
        batch_size=10000
    )

    # --- 执行文件转换 ---
    convert_file(GEN_OUTPUT_FILE, FINAL_OUTPUT_FILE)


开始生成 1000000 个问题...
难度: easy
文件名: GSM_data_easy_10000002.jsonl
验证: 当前运行时模式 USE_MOD = False
------------------------------
已生成并保存 10000/1000000 个问题...
已生成并保存 20000/1000000 个问题...
已生成并保存 30000/1000000 个问题...
已生成并保存 40000/1000000 个问题...
已生成并保存 50000/1000000 个问题...
已生成并保存 60000/1000000 个问题...
已生成并保存 70000/1000000 个问题...
已生成并保存 80000/1000000 个问题...
已生成并保存 90000/1000000 个问题...
已生成并保存 100000/1000000 个问题...
已生成并保存 110000/1000000 个问题...
已生成并保存 120000/1000000 个问题...
已生成并保存 130000/1000000 个问题...
已生成并保存 140000/1000000 个问题...
已生成并保存 150000/1000000 个问题...
已生成并保存 160000/1000000 个问题...
已生成并保存 170000/1000000 个问题...
已生成并保存 180000/1000000 个问题...
已生成并保存 190000/1000000 个问题...
已生成并保存 200000/1000000 个问题...
已生成并保存 210000/1000000 个问题...
已生成并保存 220000/1000000 个问题...
已生成并保存 230000/1000000 个问题...
已生成并保存 240000/1000000 个问题...
已生成并保存 250000/1000000 个问题...
已生成并保存 260000/1000000 个问题...
已生成并保存 270000/1000000 个问题...
已生成并保存 280000/1000000 个问题...
已生成并保存 290000/1000000 个问题...
已生成并保存 300000/1000000 个问题...
已生成并保存 310000/100

In [8]:
import os

def avg_line_length(file_path, max_lines=1000):
    total_len = 0
    count = 0
    with open(file_path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i >= max_lines:
                break
            total_len += len(line.strip())
            count += 1
    return total_len / count if count > 0 else 0

def main():
    results = []
    for fname in os.listdir("."):
        if "😺🤖" in fname and fname.endswith(".jsonl"):
            avg_len = avg_line_length(fname, max_lines=1000)
            results.append((fname, avg_len))
    
    # 排序（从短到长）
    results.sort(key=lambda x: x[1])
    
    for fname, avg_len in results:
        print(f"{fname} : {avg_len:.2f}")

if __name__ == "__main__":
    main()


GSM_data_veasy__1000_😺🤖.jsonl : 650.41
GSM_data_easy__1000_😺🤖.jsonl : 902.75
GSM_data_med_500000_😺🤖.jsonl : 1258.97
GSM_data_hard_500000_😺🤖.jsonl : 1556.80
