## 0. 导入依赖

In [12]:
import os
import warnings
from pickle import load
from argparse import Namespace
from configparser import ConfigParser
from importlib.resources import files

import pandas as pd
from torch.utils.data import  Subset

from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from transformers.trainer_callback import EarlyStoppingCallback

from MiCoGPT.utils.pretrain import attach_gated_prior_to_gpt2, summarize_gate, PriorDiagnosticsCallback
from MiCoGPT.utils.tools import split_train_val_by_project_stratified, ProjectAggregatedEvalCallback


warnings.filterwarnings("ignore")


## 1. 基本参数设置

In [13]:
args = Namespace(
    input="../data/try2_withCC/ResMicroDB_90338.pkl",
    output="../models/pretrain_ResMicroDB_90338_GATED",
    log="../logs/pretrain_ResMicroDB_90338_GATED",
)

G_MIN = 0.10      # 门控先验的最小权重（至少注入 10% 的先验）
INIT_W = 0.50     # 初始先验权重（介于 g_min 和 1 之间）
VAL_RATIO = 0.10  # 验证集比例

## 2. 载入语料库

In [14]:
all_corpus = load(open(args.input, "rb"))

# 选择 Split_Group 为 A 的样本进行训练
corpus = all_corpus.subset_by_metadata(lambda df: df["Split_Group"] == "A")
tokenizer = all_corpus.tokenizer

print("Number of samples in all_corpus:", len(all_corpus))
print("Number of samples in corpus:", len(corpus))
print(all_corpus.metadata["Split_Group"].value_counts())
print("Tokenizer vocab size:", tokenizer.vocab_size)

Number of samples in all_corpus: 90338
Number of samples in corpus: 74557
Split_Group
A    74557
B    13901
C     1880
Name: count, dtype: int64
Tokenizer vocab size: 1121


## 3. 构建 GPT2Config

In [15]:
cfg = ConfigParser()
cfg.read(files("MiCoGPT")/"resources/config.ini")

gpt2_config_dict = {
    "model_type":   cfg.get("GPT2", "model_type"),
    "vocab_size":   tokenizer.vocab_size,
    "n_positions":  cfg.getint("GPT2", "n_positions"),
    "n_embd":       cfg.getint("GPT2", "n_embd"),
    "n_layer":      cfg.getint("GPT2", "n_layer"),
    "n_head":       cfg.getint("GPT2", "n_head"),
    "bos_token_id": tokenizer.bos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}
config = GPT2Config(**gpt2_config_dict)
config

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 2,
  "embd_pdrop": 0.1,
  "eos_token_id": 3,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 256,
  "n_head": 8,
  "n_inner": null,
  "n_layer": 8,
  "n_positions": 512,
  "pad_token_id": 0,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.33.3",
  "use_cache": true,
  "vocab_size": 1121
}

## 4. 构建 TrainingArguments

In [16]:
training_args_dict = {
    # "metric_for_best_model": "eval_proj_sqrt",
    "metric_for_best_model": "eval_proj_worst10_shrink",
    # eval_proj_sqrt / eval_proj_shrink / eval_proj_worst10_shrink
    "greater_is_better": False,

    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "steps",
    "eval_steps": cfg.getint("pretrain", "eval_steps"),
    "save_strategy": "steps",
    "save_steps": cfg.getint("pretrain", "save_steps"),
    "group_by_length": False,
    "length_column_name": "length",
    "disable_tqdm": False,
    "learning_rate": cfg.getfloat("pretrain", "learning_rate"),
    "lr_scheduler_type": "linear",
    "warmup_steps": cfg.getint("pretrain", "warmup_steps"),
    "weight_decay": cfg.getfloat("pretrain", "weight_decay"),
    "per_device_train_batch_size": cfg.getint("pretrain", "per_device_train_batch_size"),
    "num_train_epochs": cfg.getint("pretrain", "num_train_epochs"),
    "logging_steps": cfg.getint("pretrain", "logging_steps"),
    "output_dir": f"{args.log}/pretrain_checkpoints",
    "logging_dir": args.log,
    "load_best_model_at_end": True,
}
training_args = TrainingArguments(**training_args_dict)
training_args

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=500,
evaluation_strategy=steps,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=False,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_mode

## 5. 构建 collator

In [17]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

## 6. 构建模型

In [18]:
model = GPT2LMHeadModel(config)
model.train()
print("Training from scratch.")

Training from scratch.


## 7. 挂载 gated prior

In [19]:
npz_path = files("MiCoGPT")/"resources"/"genus_embeddings_256.npz"
genus_token_ids, missing = attach_gated_prior_to_gpt2(
    model=model,
    tokenizer=tokenizer,
    npz_path=npz_path,
    g_min=G_MIN,
    init_w=INIT_W,
)
print(f"[gated prior] genus_token_ids={len(genus_token_ids)}, missing={len(missing)}")
summarize_gate(model, topk=10)

[prior] npz genus 总数: 1117
[prior] 写入 prior 的 unique token_id 数: 1117
[prior] missing/unk genus 数: 0
[norm] base  p10/p50/p90 = 0.3006 / 0.3197 / 0.3378
[norm] prior p10/p50/p90 = 0.8385 / 1.0859 / 1.6267
[norm] suggested prior_scale (p50 align) = 0.2944
[prior] applied global scale s=0.2944 to prior_matrix (p50 align)
[gated prior] genus_token_ids=1117, missing=0
[gate] g_min=0.1, prior_nonzero_tokens=1117
[gate] w mean=0.5000, std=0.0000, min=0.5000, max=0.5000
[gate] top using prior token_ids: [12, 13, 8, 11, 9, 7, 5, 4, 6, 10]
[gate] top weights: [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]


## 8. 划分 train/val

In [20]:
train_set, val_set = split_train_val_by_project_stratified(
    corpus,
    project_col="Project_ID",
    val_ratio=0.10,
    min_project_samples=20,
    min_val_per_project=2,
    random_state=42,
)

[split] total_samples=74557, target_val~7456
[split] eligible_projects=304, eligible_samples=74367
[split] ineligible_projects=16, ineligible_samples=190
[split] actual_val=7456 (target~7456), train=67101


## 9. train

In [21]:
# callbacks = [
#     EarlyStoppingCallback(early_stopping_patience=10),
#     # GateMonitorCallback(every_eval=True, topk=10),
#     # 检查 gate 统计是否正常，并且做一个消融实验
#     PriorDiagnosticsCallback(tokenizer=tokenizer, topk=10, do_ablation=True),
# ]
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=data_collator,
    callbacks=[],   # 这里先留空（或者只放不需要 trainer 引用的 callback）
)


# 先加 ProjectAggregatedEvalCallback（它会写 eval_proj_* 和 tok_sum 建议）
trainer.add_callback(
    ProjectAggregatedEvalCallback(
        trainer=trainer,
        eval_subset=val_set,
        project_col="Project_ID",
        shrink_k=5000.0,   # 你长尾分布建议先从 5000 起
        worst_frac=0.10,
    )
)

# 再加 PriorDiagnosticsCallback（它会写 gate/cancel/ablation）
trainer.add_callback(
    PriorDiagnosticsCallback(
        tokenizer=tokenizer,
        topk=10,
        do_ablation=True,
        trainer=trainer,
    )
)

# 最后加 EarlyStopping（保证它能看到 eval_proj_sqrt 等自定义指标）
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=10))

# 训练开始前手动写入“实验初始信息”到 log_history
trainer.log({
    "init_g_min": float(G_MIN),
    "init_init_w": float(INIT_W),
    "init_prior_token_ids": float(len(genus_token_ids)),
    "init_prior_missing": float(len(missing)),
})



print("Start training...")
trainer.train()

# ========= 保存模型 =========
os.makedirs(args.output, exist_ok=True)
trainer.save_model(args.output)
print("Model saved to:", args.output)

# ========= 保存训练日志 =========
logs = pd.DataFrame(trainer.state.log_history)
os.makedirs(args.log, exist_ok=True)
log_path = os.path.join(args.log, "pretrain_log.csv")
logs.to_csv(log_path, index=False)
print("Logs saved to:", log_path)


Start training...
[PriorDiagnostics] cached initial base embedding for delta stats.


Step,Training Loss,Validation Loss
500,4.5641,4.257278
1000,4.3231,4.079933
1500,4.2477,3.969246
2000,4.1774,3.90911
2500,4.0792,3.870778
3000,4.0827,3.822397
3500,4.0012,3.798882
4000,4.0319,3.768195
4500,3.9362,3.749916
5000,3.9041,3.72419


[tok_sum] min=23, p10=155, p50=620, p90=2979, max=30504
[suggest_k] p10@lam=0.1 -> 1395 | p10@lam=0.2 -> 620 | p50@lam=0.5 -> 620
[ProjEval] projects=304 | k=5000 | sqrt=4.4959 | shrink=4.5055 | worst10(shrink)=4.8068 | micro=4.5019

[PriorDiagnostics] step=500
[ProjEval] projects=304 | k=5000 | sqrt=4.3407 | shrink=4.3383 | worst10(shrink)=4.6544 | micro=4.3330

[PriorDiagnostics] step=1000
[ProjEval] projects=304 | k=5000 | sqrt=4.2291 | shrink=4.2237 | worst10(shrink)=4.5304 | micro=4.2183

[PriorDiagnostics] step=1500
[ProjEval] projects=304 | k=5000 | sqrt=4.1702 | shrink=4.1595 | worst10(shrink)=4.4605 | micro=4.1539

[PriorDiagnostics] step=2000
[ProjEval] projects=304 | k=5000 | sqrt=4.1300 | shrink=4.1207 | worst10(shrink)=4.4240 | micro=4.1154

[PriorDiagnostics] step=2500
[ProjEval] projects=304 | k=5000 | sqrt=4.0829 | shrink=4.0719 | worst10(shrink)=4.3746 | micro=4.0665

[PriorDiagnostics] step=3000
[ProjEval] projects=304 | k=5000 | sqrt=4.0576 | shrink=4.0459 | worst10(