## 0. 导入依赖

In [24]:
import os
import warnings
import pandas as pd
from pickle import load
from torch.utils.data import random_split
from importlib.resources import files
from configparser import ConfigParser
from argparse import Namespace
from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)
from transformers.trainer_callback import EarlyStoppingCallback
from MiCoGPT.utils.corpus import MiCoGPTCorpus
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, Subset


warnings.filterwarnings("ignore")

## 1. 基础设置

In [25]:
args = Namespace(
    input="../data/try2_withCC/ResMicroDB_90338.pkl",
    output="../models/pretrain_ResMicroDB_90338",
    log="../logs/pretrain_ResMicroDB_90338",                    # 日志和 checkpoint 存放的根目录
)

## 2. 加载 corpus 和 tokenizer

In [26]:
# 选择 Split_Group 为 A 的样本
all_corpus = load(open(args.input, "rb"))
corpus = all_corpus.subset_by_metadata(
    lambda df: df["Split_Group"] == "A"
)
tokenizer = all_corpus.tokenizer

print("Number of samples in all_corpus:", len(all_corpus))
print("Number of samples in corpus:", len(corpus))
all_corpus.metadata["Split_Group"].value_counts()
print(all_corpus.metadata["Split_Group"].value_counts())
print("Tokenizer vocab size:", tokenizer.vocab_size)

Number of samples in all_corpus: 90338
Number of samples in corpus: 74557
Split_Group
A    74557
B    13901
C     1880
Name: count, dtype: int64
Tokenizer vocab size: 1121


## 3. 构建 GPT2Config

In [27]:
cfg = ConfigParser()
cfg.read(files("MiCoGPT")/"resources/config.ini")

gpt2_config_dict = {
    # 模型类别 gpt2
    "model_type":   cfg.get("GPT2", "model_type"),
    # tokenizer 词表大小 1121
    "vocab_size":   tokenizer.vocab_size,
    # 支持的最大序列长度（position embedding 的长度）512
    "n_positions":  cfg.getint("GPT2", "n_positions"),
    # hidden size / embedding 维度 256
    "n_embd":       cfg.getint("GPT2", "n_embd"),
    # Transformer block 8 层
    "n_layer":      cfg.getint("GPT2", "n_layer"),
    # Multi-head Self-Attention 8 头
    "n_head":       cfg.getint("GPT2", "n_head"),
    # bos_token_id: 1119
    "bos_token_id": tokenizer.bos_token_id,
    # eos_token_id: 1120
    "eos_token_id": tokenizer.eos_token_id,
    # pad_token_id: 0
    "pad_token_id": tokenizer.pad_token_id,
}

# 额外参数还有
# attn_pdrop: 0.1, 训练时对 attention 做随机丢弃比例
# embd_pdrop: 0.1, 训练时对 embedding 做随机丢弃比例
# resid_pdrop: 0.1, 训练时对 residual 做随机丢弃比例
# layer_norm_epsilon: 1e-05, 层归一化的 epsilon 超参数，防止除0
# initializer_range: 0.02, 初始化时的范围，用于初始化权重
# activation_function: "gelu_new", 激活函数，使用 GELU 新变体
# scale_attn_weights: true, 是否缩放 attention 权重，默认 true
# scale_attn_by_inverse_layer_idx: false, 是否根据层索引逆比例缩放 attention 权重，默认 false
# reorder_and_upcast_attn: false, 是否在计算 attention 时重新排序并升级为 float32，默认 false
# summary_type: "cls_index", 序列总结类型，使用 cls_token_index 作为总结，默认 "cls_index"
# summary_use_proj: true, 是否对序列总结进行投影，默认 true
# summary_activation: null, 序列总结的激活函数，默认 null
# summary_first_dropout: 0.1, 训练时对序列总结做随机丢弃比例
# summary_proj_to_labels: true, 是否将序列总结投影到标签空间，默认 true
# use_cache: true, 是否使用缓存，默认 true
# transformers_version: "4.33.3"

config = GPT2Config(**gpt2_config_dict)
# config

## 4. 构建 TrainingArguments（从 cfg 的 [pretrain] 段读取）

In [28]:
training_args_dict = {

    # 会运行训练 loop（trainer.train() 时真正训练）。
    "do_train": True,
    # 在训练过程中会按 evaluation_strategy 去跑验证集。
    "do_eval": True,
    # 按步数做 eval
    "evaluation_strategy": "steps",
    # 每训练 500 个 step，跑一次验证
    "eval_steps": cfg.getint("pretrain", "eval_steps"),
    # 按步数做 save
    "save_strategy": "steps",
    # 每训练 500 个 step，保存一次模型
    "save_steps": cfg.getint("pretrain", "save_steps"),

    # 让 DataLoader 按样本长度把数据分成“长度相近”的 batch。
    "group_by_length": True,
    # 数据集中表示句子长度的那一列的名字叫 "length"
    "length_column_name": "length",

    # 显示进度条
    "disable_tqdm": False,

    # 学习率 1e-3
    "learning_rate": cfg.getfloat("pretrain", "learning_rate"), 
    # 学习率调度器类型，线性
    "lr_scheduler_type": "linear",
    # 预热步数 1000，前 1000 步 0 → lr
    "warmup_steps": cfg.getint("pretrain", "warmup_steps"),
    # 权重衰减系数 0.001，防止过拟合
    "weight_decay": cfg.getfloat("pretrain", "weight_decay"),

    # 每个 GPU（或 CPU）上的训练 batch size = 32
    "per_device_train_batch_size": cfg.getint(
        "pretrain", "per_device_train_batch_size"
    ),

    # 在全量训练集上跑 50 个 epoch
    "num_train_epochs": cfg.getint("pretrain", "num_train_epochs"),

    # 每 100 step 打一次 log（loss、学习率等）
    "logging_steps": cfg.getint("pretrain", "logging_steps"),
    # 训练日志和模型 checkpoint 保存到 args.log 目录
    "output_dir": f"{args.log}/pretrain_checkpoints",
    "logging_dir": args.log,
    # 训练结束后，加载验证集上表现最好的模型
    "load_best_model_at_end": True,
}

# 额外参数还有
# adam_beta1=0.9,  Adam 优化器的 beta1 超参数，默认 0.9
# adam_beta2=0.999,  Adam 优化器的 beta2 超参数，默认 0.999
# adam_epsilon=1e-08,  Adam 优化器的 epsilon 超参数，默认 1e-08
# optim=adamw_torch,  AdamW 优化器，默认 adamw_torch
# max_grad_norm=1.0,  最大梯度范数，默认 1.0    
# fp16=False,  是否使用 fp16 混合精度训练，默认 False
# bf16=False,  是否使用 bf16 混合精度训练，默认 False
# no_cuda=False,  是否禁用 CUDA，默认 False
# use_cpu=False,  是否使用 CPU 训练，默认 False
# use_mps_device=False,  是否使用 MPS 设备（Apple Silicon）训练，默认 False
# per_device_eval_batch_size=8,  每个 GPU（或 CPU）上的评估 batch size，默认 8
# gradient_accumulation_steps=1,  梯度累加步数，默认 1
# dataloader_num_workers=0,  数据加载器的工作线程数，默认 0
# dataloader_pin_memory=True,  是否将数据加载到 pinned memory，默认 True
# dataloader_drop_last=False,  是否丢弃最后一个不完整的 batch，默认 False
# seed=42,  随机种子，默认 42
# skip_memory_metrics=True,  是否跳过内存指标计算，默认 True
# ddp_backend=None, 多卡大规模训练,DDP 后端，默认 None
# fsdp=[],  多卡大规模训练，FSDP 配置，默认 []
# deepspeed=None,  多卡大规模训练，DeepSpeed 配置，默认 None
# sharded_ddp=[],  多卡大规模训练，Sharded DDP 配置，默认 []
# push_to_hub=False,  是否将模型上传到 Hugging Face Hub，默认 False
# hub_strategy=every_save,  上传模型的策略，默认 every_save
# report_to=[],  报告指标到的服务，默认 []


training_args = TrainingArguments(**training_args_dict)
# training_args

## 5. 构建数据 collator + 初始化模型

In [29]:
print("Start training...")

# causal LM 的 collator，和原 pretrain 一样关闭 MLM
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False, # mlm 指 Masked Language Modeling (BERT)
)

model = GPT2LMHeadModel(config)
print("Training from scratch.")

# 切换为训练模式
model.train()
# model

Start training...
Training from scratch.


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(1121, 256)
    (wpe): Embedding(512, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=256, out_features=1121, bias=False)
)

## 5.1 加入 DNABERT-S 的 embedding 后 PCA 的结果

In [23]:
import numpy as np
import torch

def load_genus_embeddings(npz_path: str):
    data = np.load(npz_path, allow_pickle=True)
    genus = data["genus"]
    emb = data["embeddings"]
    genus = np.array(genus, dtype=str)
    print(f"[load_genus_embeddings] genus 数量: {genus.shape[0]}")
    print(f"[load_genus_embeddings] embeddings 形状: {emb.shape}")
    return genus, emb

def build_genus_to_id(genus_array, tokenizer):
    # 对每个 genus 调用 tokenizer，建立 genus -> token_id 的映射，不猜顺序
    genus_to_id = {}
    missing = []
    for g_str in genus_array:
        try:
            token_id = tokenizer.convert_tokens_to_ids(g_str)
            genus_to_id[g_str] = token_id
        except KeyError:
            missing.append(g_str)
    print(f"[build_genus_to_id] 成功找到 {len(genus_to_id)} 个 genus 的 token_id")
    if missing:
        print(f"[build_genus_to_id] 有 {len(missing)} 个 genus 在 tokenizer 中找不到，例如: {missing[:5]} ...")
    else:
        print("[build_genus_to_id] 所有 genus 都能在 tokenizer 中找到 id")
    return genus_to_id, missing

def init_token_embeddings_from_genus(model, tokenizer, npz_path: str):
    """用 DNABERT+PCA 的 genus embedding 初始化 GPT2 的 token embedding。"""
    print(">>> 开始用 genus_embeddings_256.npz 初始化 GPT2 embedding")

    genus_array, emb = load_genus_embeddings(npz_path)
    genus_to_id, missing = build_genus_to_id(genus_array, tokenizer)

    # 一些 sanity check
    wte = model.transformer.wte.weight  # (vocab_size, n_embd)
    n_genus, emb_dim = emb.shape
    vocab_size, model_dim = wte.shape

    assert emb_dim == model_dim, (
        f"embedding 维度不匹配: embeddings 是 {emb_dim}，"
        f"但 model 的 n_embd 是 {model_dim}"
    )

    print(f"[init] 模型 vocab_size = {vocab_size}, n_embd = {model_dim}")
    print(f"[init] 将对齐写入 {len(genus_to_id)} 个 genus 的向量")

    # 实际写入
    device = wte.device
    dtype = wte.dtype
    written = 0

    with torch.no_grad():
        for i, g_str in enumerate(genus_array):
            if g_str not in genus_to_id:
                continue
            token_id = genus_to_id[g_str]
            vec = torch.from_numpy(emb[i]).to(device=device, dtype=dtype)
            wte[token_id].copy_(vec)
            written += 1

    print(f"[init] 实际写入 {written} 个 token 的 embedding")
    if missing:
        print(f"[init] {len(missing)} 个 genus 没有对应 token，保持原始随机初始化")
    print(">>> GPT2 embedding 初始化完成（特殊 token 如 <pad>/<bos>/<eos> 也保持原始初始化）")

# ==== 真正调用初始化的地方 ====
# 你之前用的是 files("MiCoGPT")/"resources"/"genus_embeddings_256.npz"，继续沿用就行
npz_path = files("MiCoGPT")/"resources"/"genus_embeddings_256.npz"
init_token_embeddings_from_genus(model, tokenizer, npz_path)


>>> 开始用 genus_embeddings_256.npz 初始化 GPT2 embedding
[load_genus_embeddings] genus 数量: 1117
[load_genus_embeddings] embeddings 形状: (1117, 256)
[build_genus_to_id] 成功找到 1117 个 genus 的 token_id
[build_genus_to_id] 所有 genus 都能在 tokenizer 中找到 id
[init] 模型 vocab_size = 1121, n_embd = 256
[init] 将对齐写入 1117 个 genus 的向量
[init] 实际写入 1117 个 token 的 embedding
>>> GPT2 embedding 初始化完成（特殊 token 如 <pad>/<bos>/<eos> 也保持原始初始化）


## 6. 划分 train/val、构建 Trainer 并训练 + 保存模型和日志

In [30]:
import numpy as np
from torch.utils.data import Subset

def split_train_val_by_project(dataset, val_ratio=0.1, project_col="Project_ID", random_state=42):
    """
    按 project 划分 train / val：
    - 支持传入 MiCoGPTCorpus 或它的 Subset
    - 在“当前 dataset 所包含的样本集合”上，按 project 划分
    - 选出若干个 project 作为验证集
    - 这些 project 的样本总数 ≈ val_ratio * 当前 dataset 的样本数
    - 同一个 project 只会出现在 train 或 val 其中之一
    """

    # 1. 识别当前传入的是 corpus 本体还是 Subset
    if isinstance(dataset, Subset):
        base_corpus = dataset.dataset                    # 真正的 MiCoGPTCorpus
        base_indices = np.array(dataset.indices)         # 当前子集对应的“在 base_corpus 中的行号”
    else:
        base_corpus = dataset
        base_indices = np.arange(len(dataset))           # 整个 corpus 的所有行号

    # 2. 在 base_corpus.metadata 中取出“当前子集部分”的 metadata
    meta_full = base_corpus.metadata
    if meta_full is None:
        raise ValueError("base_corpus.metadata 为空，无法按 Project_ID 划分。")

    if project_col not in meta_full.columns:
        raise ValueError(f"metadata 中没有列 '{project_col}'，请检查列名。")

    # 只看当前子集的 metadata
    meta = meta_full.iloc[base_indices].copy()
    n_samples = meta.shape[0]
    target_val = int(n_samples * val_ratio)

    # 3. 取出当前子集中所有 project_id（去掉缺失值）
    project_ids = meta[project_col].to_numpy()
    # 去除 NaN（如果有的话）
    mask_not_nan = pd.notna(project_ids)
    project_ids_nonan = project_ids[mask_not_nan]

    unique_projects = np.unique(project_ids_nonan)

    # 4. 打乱 project 顺序
    rng = np.random.default_rng(random_state)
    rng.shuffle(unique_projects)

    # 5. 建立一个样本级别的布尔数组：is_val[i] 表示第 i 个样本是否进验证集
    is_val = np.zeros(n_samples, dtype=bool)
    val_projects = []
    val_count = 0

    for pid in unique_projects:
        if val_count >= target_val:
            break

        # 当前 project 对应的样本（在“当前子集中的局部索引”）
        proj_mask = (project_ids == pid)
        # 这个 project 在当前子集中有多少样本
        proj_size = proj_mask.sum()
        if proj_size == 0:
            continue  # 理论上不会，但防御一下

        # 把这个 project 全部丢进验证集
        is_val |= proj_mask
        val_projects.append(pid)
        val_count += proj_size

    # 6. 根据 is_val，映射回 base_corpus 的索引
    val_base_indices = base_indices[is_val]
    train_base_indices = base_indices[~is_val]

    # 7. 构造最终子集
    train_set = Subset(base_corpus, train_base_indices.tolist())
    val_set   = Subset(base_corpus, val_base_indices.tolist())

    # 8. 打印真正的数量（一定和 len(val_set) 一致）
    print(
        f"[split_by_project] 选中 {len(val_projects)} 个 project 作为验证集，"
        f"验证样本数 {len(val_set)}，目标约 {target_val}，当前 dataset 样本数 {n_samples}"
    )
    print(f"[split_by_project] Train samples: {len(train_set)}, Val samples: {len(val_set)}")

    return train_set, val_set


In [31]:
# 按比例划分训练 / 验证集
train_set, val_set = split_train_val_by_project(
    corpus,
    val_ratio=0.1,
    project_col="Project_ID",
    random_state=42,
)

print(f"Train samples: {len(train_set)}, Val samples: {len(val_set)}")

# 提前停止回调
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]

# 构建 Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set,
    data_collator=data_collator,
    callbacks=callbacks,
)

# 开始训练
trainer.train()

# 保存最终模型
os.makedirs(args.output, exist_ok=True)
trainer.save_model(args.output)
print("Model saved to:", args.output)

# 保存训练日志\
logs = trainer.state.log_history
logs = pd.DataFrame(logs)
os.makedirs(args.log, exist_ok=True)
log_path = os.path.join(args.log, "pretrain_log.csv")
logs.to_csv(log_path, index=False)
print("Logs saved to:", log_path)

[split_by_project] 选中 43 个 project 作为验证集，验证样本数 7555，目标约 7455，当前 dataset 样本数 74557
[split_by_project] Train samples: 67002, Val samples: 7555
Train samples: 67002, Val samples: 7555


Step,Training Loss,Validation Loss
500,4.4191,4.581467
1000,4.2735,4.421101
1500,4.1807,4.337917
2000,4.0935,4.293813
2500,4.0351,4.274549
3000,3.9834,4.238392
3500,3.9896,4.220387
4000,3.8945,4.20572
4500,3.9261,4.194244
5000,3.8949,4.165286


Model saved to: ../models/pretrain_ResMicroDB_90338
Logs saved to: ../logs/pretrain_log.csv
