## 0. 导入依赖

In [2]:
import os
import torch
import warnings
from pickle import load
from argparse import Namespace
from configparser import ConfigParser
from importlib.resources import files

import pandas as pd
from torch.utils.data import  Subset

from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from transformers.trainer_callback import EarlyStoppingCallback

from MiCoGPT.utils.pretrain import attach_gated_prior_to_gpt2,attach_gated_prior_lm_head
from MiCoGPT.utils.tools import split_train_val_by_project_stratified
from MiCoGPT.utils.callback import PriorGateStatsOnEvalLogCallback
from MiCoGPT.utils.freeze import UnfreezeWteBaseAtStepCallback, freeze_wte_base, build_optimizer_no_filter, compute_num_training_steps,get_scheduler

warnings.filterwarnings("ignore")


## 1. 基本参数设置

In [3]:
args = Namespace(
    input="../data/try2_withCC/ResMicroDB_90338.pkl",
    output="../models/pretrain_ResMicroDB_90338_GATED_v6_freeze",
    log="../logs/pretrain_ResMicroDB_90338_GATED_v6_freeze",
)

G_MIN = 0.00      # 门控先验的最小权重
INIT_W = 0.5     # 初始先验权重（介于 g_min 和 1 之间）
VAL_RATIO = 0.10  # 验证集比例

## 2. 载入语料库

In [4]:
all_corpus = load(open(args.input, "rb"))

# 选择 Split_Group 为 A 的样本进行训练
corpus = all_corpus.subset_by_metadata(lambda df: df["Split_Group"] == "A")
tokenizer = all_corpus.tokenizer

print("Number of samples in all_corpus:", len(all_corpus))
print("Number of samples in corpus:", len(corpus))
print(all_corpus.metadata["Split_Group"].value_counts())
print("Tokenizer vocab size:", tokenizer.vocab_size)

Number of samples in all_corpus: 90338
Number of samples in corpus: 74557
Split_Group
A    74557
B    13901
C     1880
Name: count, dtype: int64
Tokenizer vocab size: 1121


## 3. 构建 GPT2Config

In [5]:
cfg = ConfigParser()
cfg.read(files("MiCoGPT")/"resources/config.ini")

gpt2_config_dict = {
    "model_type":   cfg.get("GPT2", "model_type"),
    "vocab_size":   tokenizer.vocab_size,
    "n_positions":  cfg.getint("GPT2", "n_positions"),
    "n_embd":       cfg.getint("GPT2", "n_embd"),
    "n_layer":      cfg.getint("GPT2", "n_layer"),
    "n_head":       cfg.getint("GPT2", "n_head"),
    "bos_token_id": tokenizer.bos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}
config = GPT2Config(**gpt2_config_dict)
config

GPT2Config {
  "activation_function": "gelu_new",
  "attn_pdrop": 0.1,
  "bos_token_id": 2,
  "embd_pdrop": 0.1,
  "eos_token_id": 3,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 256,
  "n_head": 8,
  "n_inner": null,
  "n_layer": 8,
  "n_positions": 512,
  "pad_token_id": 0,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "transformers_version": "4.33.3",
  "use_cache": true,
  "vocab_size": 1121
}

## 4. 构建 TrainingArguments

In [6]:
training_args_dict = {
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "do_train": True,
    "do_eval": True,
    "evaluation_strategy": "steps",
    "eval_steps": cfg.getint("pretrain", "eval_steps"),
    "save_strategy": "steps",
    "save_steps": cfg.getint("pretrain", "save_steps"),
    "group_by_length": False,
    "length_column_name": "length",
    "disable_tqdm": False,
    "learning_rate": cfg.getfloat("pretrain", "learning_rate"),
    "lr_scheduler_type": "linear",
    "warmup_steps": cfg.getint("pretrain", "warmup_steps"),
    "weight_decay": cfg.getfloat("pretrain", "weight_decay"),
    "per_device_train_batch_size": cfg.getint("pretrain", "per_device_train_batch_size"),
    "num_train_epochs": cfg.getint("pretrain", "num_train_epochs"),
    "logging_steps": cfg.getint("pretrain", "logging_steps"),
    "output_dir": f"{args.log}/pretrain_checkpoints",
    "logging_dir": args.log,
    "load_best_model_at_end": True,
}
training_args = TrainingArguments(**training_args_dict)
training_args

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=500,
evaluation_strategy=steps,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=False,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_mode

## 5. 构建 collator

In [7]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

## 6. 构建模型

In [8]:
model = GPT2LMHeadModel(config)
model.train()
print("Training from scratch.")

Training from scratch.


## 7. 挂载 gated prior

In [9]:
npz_path = files("MiCoGPT")/"resources"/"genus_embeddings_256.npz"
genus_token_ids, missing = attach_gated_prior_to_gpt2(
    model=model,
    tokenizer=tokenizer,
    npz_path=npz_path,
    g_min=G_MIN,
    init_w=INIT_W,
    shuffle_prior=False,      # 是否打乱先验
    shuffle_seed=42,         # 固定随机种子，方便复现实验
    prior_scale=2,         # 先验缩放因子
)
print(f"[gated prior] genus_token_ids={len(genus_token_ids)}, missing={len(missing)}")

attach_gated_prior_lm_head(model, prior_logits_scale=1.0)

[prior] npz genus: 1117
[prior] prior unique token_id: 1117
[prior] missing genus: 0
[prior] applied MANUAL scale s=2.0000 to prior_matrix
[gated prior] genus_token_ids=1117, missing=0


## 8. 划分 train/val

In [10]:
train_set, val_set = split_train_val_by_project_stratified(
    corpus,
    project_col="Project_ID",
    val_ratio=0.10,
    min_project_samples=20,
    min_val_per_project=2,
    random_state=42,
)

[split] total_samples=74557, target_val~7456
[split] eligible_projects=304, eligible_samples=74367
[split] ineligible_projects=16, ineligible_samples=190
[split] actual_val=7456 (target~7456), train=67101


## 9. train

In [11]:
FREEZE_STEPS = 5000

freeze_wte_base(model, freeze=True)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)

num_training_steps = compute_num_training_steps(trainer)

trainer.optimizer = build_optimizer_no_filter(trainer.model, training_args)
trainer.lr_scheduler = get_scheduler(
    name=training_args.lr_scheduler_type,
    optimizer=trainer.optimizer,
    num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
    num_training_steps=num_training_steps,
)

trainer.add_callback(UnfreezeWteBaseAtStepCallback(unfreeze_step=FREEZE_STEPS))
trainer.add_callback(PriorGateStatsOnEvalLogCallback(token_ids=genus_token_ids, prefix="gp"))

print("Start training...")
print(type(model.transformer.wte))
print(type(model.lm_head))

trainer.train()

# ========= 保存模型 =========
os.makedirs(args.output, exist_ok=True)
trainer.save_model(args.output)
print("Model saved to:", args.output)

# ========= 保存训练日志 =========
logs = pd.DataFrame(trainer.state.log_history)
os.makedirs(args.log, exist_ok=True)
log_path = os.path.join(args.log, "pretrain_log.csv")
logs.to_csv(log_path, index=False)
print("Logs saved to:", log_path)


Start training...
<class 'MiCoGPT.utils.pretrain.GatedPriorEmbedding'>
<class 'MiCoGPT.utils.pretrain.GatedPriorLMHead'>


Step,Training Loss,Validation Loss,Base Req,Bdelta50,Bdmax,W50,W90,Wm,Bn50,Pn50,Gn50,R50,Cos50
500,5.3094,4.984934,0.0,0.0,0.0,0.491218,0.521055,0.492452,0.319097,2.171735,1.072374,3.366964,0.000983
1000,4.8307,4.549277,0.0,0.0,0.0,0.4607,0.571678,0.46938,0.319097,2.171735,1.027142,3.216477,0.000983
1500,4.6388,4.32366,0.0,0.0,0.0,0.437875,0.609882,0.449466,0.319097,2.171735,0.975529,3.042498,0.000983
2000,4.5075,4.202103,0.0,0.0,0.0,0.428341,0.637755,0.43675,0.319097,2.171735,0.945052,2.94251,0.000983
2500,4.4201,4.12562,0.0,0.0,0.0,0.41979,0.660049,0.428162,0.319097,2.171735,0.927126,2.903378,0.000983
3000,4.3816,4.062508,0.0,0.0,0.0,0.413169,0.673465,0.421823,0.319097,2.171735,0.915294,2.843853,0.000983
3500,4.2913,4.020043,0.0,0.0,0.0,0.405929,0.684989,0.417097,0.319097,2.171735,0.907108,2.810335,0.000983
4000,4.2148,3.980726,0.0,0.0,0.0,0.40097,0.694897,0.413525,0.319097,2.171735,0.891733,2.781589,0.000983
4500,4.2225,3.94893,0.0,0.0,0.0,0.397141,0.702833,0.410795,0.319097,2.171735,0.886614,2.759792,0.000983
5000,4.1746,3.921289,0.0,0.0,0.0,0.393191,0.708501,0.408754,0.319097,2.171735,0.874168,2.761177,0.000983



[freeze] Unfroze wte.base.weight at global_step=5000

Model saved to: ../models/pretrain_ResMicroDB_90338_GATED_v6_freeze
Logs saved to: ../logs/pretrain_ResMicroDB_90338_GATED_v6_freeze/pretrain_log.csv
