In [1]:
import os,torch
import numpy as np
import pandas as pd
from configparser import ConfigParser
from argparse import Namespace
from pickle import load as pkl_load
from importlib.resources import files
from joblib import dump
from torch.utils.data import Subset
from transformers import Trainer,TrainingArguments,default_data_collator
from transformers.trainer_callback import EarlyStoppingCallback

from MiCoGPT.utils.finetune_v2 import prepare_labels_for_subset,get_raw_labels_from_subset,load_model_compat,SubsetWithLabels,FinetuneDataset
from MiCoGPT.utils.finetune import split_train_val_by_project_stratified_with_labels

In [2]:
cfg = ConfigParser()
cfg.read("../MiCoGPT/resources/config.ini")

input_corpus_path      = "../data/try2_withCC/ResMicroDB_90338_new.pkl"
pretrained_model_path  = "../models/pretrain_ResMicroDB_90338_GATED_base_wte"
output_model_dir       = "../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite"
log_dir                = "../logs/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite"
val_split              = 0.2

# ========= 新增：与你当前任务强相关的参数 =========
label_col              = "Sample_Site"   # 你以前用的标签列
subset_split_group     = "A"            # 只用 Split_Group == A
drop_na_label          = True           # Is_Healthy is NA 的样本：直接丢弃（不训练也不预测）

# gated-prior 相关（v6/v9 才用得到；普通 GPT2 无影响）
g_min                  = 0.0            # 必须与 pretraining 时一致

# 训练超参（你可以先用默认，后面再调）
batch_size             = 64
grad_accum             = 1
lr                     = 1e-5
epochs                 = 1000
patience               = 5

# （可选）按组划分，防止同一 project 泄漏
use_group_split        = False
group_col              = "Project_ID"

args = Namespace(
    input=input_corpus_path,
    model=pretrained_model_path,
    output=output_model_dir,
    log=log_dir,
    val_split=val_split,

    label_col=label_col,
    split_group=subset_split_group,
    drop_na_label=drop_na_label,

    g_min=g_min,

    batch_size=batch_size,
    grad_accum=grad_accum,
    lr=lr,
    epochs=epochs,
    patience=patience,

    use_group_split=use_group_split,
    group_col=group_col,
)

args


Namespace(input='../data/try2_withCC/ResMicroDB_90338_new.pkl', model='../models/pretrain_ResMicroDB_90338_GATED_base_wte', output='../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite', log='../logs/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite', val_split=0.2, label_col='Sample_Site', split_group='A', drop_na_label=True, g_min=0.0, batch_size=64, grad_accum=1, lr=1e-05, epochs=1000, patience=5, use_group_split=False, group_col='Project_ID')

In [3]:


# 读取你保存的 MiCoGPTCorpus（all_corpus）
with open(args.input, "rb") as f:
    all_corpus = pkl_load(f)

tokenizer = all_corpus.tokenizer

# ====== 可选但强烈建议：确保 tokenizer 有 pad_token（GPT2 常见没有）======
# 你的 corpus 通常已经有固定长度+attention_mask，但某些 collator/Trainer 仍可能关心 pad_token_id
if getattr(tokenizer, "pad_token_id", None) is None:
    # 常见做法：把 eos 当 pad（只要与你构建语料时的策略一致即可）
    tokenizer.pad_token = tokenizer.eos_token

print("[Tokenizer]")
print("  vocab_size:", getattr(tokenizer, "vocab_size", None))
print("  pad_token_id:", getattr(tokenizer, "pad_token_id", None))
print("  eos_token_id:", getattr(tokenizer, "eos_token_id", None))

# 你想作为微调集合的样本（Split_Group == A 且 Is_Healthy 非空）
finetune_subset = all_corpus.subset_by_metadata(
    lambda df: (df["Split_Group"] == args.split_group) & df[args.label_col].notna()
)

print("Number of samples in all_corpus:", len(all_corpus))
print("Number of samples in finetune_subset:", len(finetune_subset))
print(all_corpus.metadata["Split_Group"].value_counts(dropna=False))


[Tokenizer]
  vocab_size: 1121
  pad_token_id: 0
  eos_token_id: 3
Number of samples in all_corpus: 90338
Number of samples in finetune_subset: 74557
Split_Group
A    74557
B    13901
C     1880
Name: count, dtype: int64


In [4]:
labels_tensor, all_labels, le, num_labels = prepare_labels_for_subset(
    all_corpus=all_corpus,
    subset=finetune_subset,
    label_col="Sample_Site",
    verbose=True,
)

labels_tensor[:10], num_labels

[labels] label_col = Sample_Site
[labels] num_labels = 10
[labels] label2id: {'BALF': 0, 'Bronchus': 1, 'Lung Tissue': 2, 'Nasal': 3, 'Nasopharynx': 4, 'Oropharynx': 5, 'Pharynx': 6, 'Sputum': 7, 'Throat': 8, 'Trachea': 9}
[labels] label counts:
Nasopharynx    18976
Nasal          14183
Sputum         12230
Oropharynx      8407
Trachea         6349
BALF            5599
Pharynx         5277
Throat          2025
Lung Tissue      782
Bronchus         729
Name: count, dtype: int64


(tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 10)

In [5]:
npz_path = files("MiCoGPT") / "resources" / "genus_embeddings_256.npz"  # 可留可删
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载器：
#    - 如果 args.model 是普通 GPT2（或者你 base 的本地目录），直接加载 GPT2ForSequenceClassification
#    - 如果 args.model 是 v6/v9 gated checkpoint，本地权重里会有 "transformer.wte.base.weight" 和 "transformer.wte.gate_logits"
#      加载器会自动 patch wte，并识别 v6(1D gate) / v9(2D gate) 然后再把权重灌进去
model = load_model_compat(
    model_name_or_path=args.model,
    num_labels=num_labels,
    g_min=0.0,   # 这里务必与你预训练时 g_min 一致（你之前写的是 0.0 就保持 0.0）
)

model.to(device)
model.train()


model


[compat] missing keys (first 20): ['score.weight']
[compat] unexpected keys (first 20): ['lm_head.weight']


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(1121, 256)
    (wpe): Embedding(512, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=256, out_features=10, bias=False)
)

In [6]:
from transformers import TrainingArguments

training_args_dict = {
    # 学习率等超参从 config.ini 读（保留你的习惯）
    "learning_rate": cfg.getfloat("finetune", "learning_rate"),
    "warmup_steps": cfg.getint("finetune", "warmup_steps"),
    "weight_decay": cfg.getfloat("finetune", "weight_decay"),

    # 训练/评估开关
    "do_train": True,
    "do_eval": True,

    # 你的数据是定长 tensor（MiCoGPTCorpus 直接给 input_ids/attention_mask），不需要按长度分桶
    "group_by_length": False,
    "disable_tqdm": False,

    # scheduler
    "lr_scheduler_type": "linear",

    # batch & epoch
    "per_device_train_batch_size": cfg.getint("finetune", "per_device_train_batch_size"),
    "per_device_eval_batch_size": cfg.getint("finetune", "per_device_train_batch_size"),  # ✅ 建议显式设置，避免默认值不一致
    "num_train_epochs": cfg.getint("finetune", "num_train_epochs"),

    # 保存与评估策略（保留你的 epoch 级别习惯）
    "save_strategy": "epoch",
    "evaluation_strategy": "epoch",
    "logging_steps": cfg.getint("finetune", "logging_steps"),

    # 输出目录：注意 output_dir 是 checkpoint 的保存位置（你原来就这么写）
    "output_dir": f"{args.log}/finetune_checkpoints",
    "logging_dir": args.log,

    # 选最优模型（用 eval_loss 最稳，不依赖自定义 metrics）
    "load_best_model_at_end": True,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
}

training_args = TrainingArguments(**training_args_dict)
training_args


TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=False,
group_by_length=False,
half_precision_backend=auto,
hub_always_push=False,
hub_mod

In [7]:
train_subset, val_subset = split_train_val_by_project_stratified_with_labels(
    finetune_subset,
    label_col="Sample_Site",
    project_col="Project_ID",
    val_ratio=args.val_split,
    min_project_samples=20,
    min_val_per_project=2,
    random_state=42,
    label_balance_strength=0,  # 先用 1.0；想更强拉平就 2.0；不管标签就 0
)

print("train_subset:", len(train_subset))
print("val_subset:", len(val_subset))

[split] total_samples=74557, target_val~14911
[split] eligible_projects=304, eligible_samples=74367
[split] ineligible_projects=16, ineligible_samples=190
[split] label_dist (overall):
Sample_Site
Nasopharynx    18976
Nasal          14183
Sputum         12230
Oropharynx      8407
Trachea         6349
BALF            5599
Pharynx         5277
Throat          2025
Lung Tissue      782
Bronchus         729
Name: count, dtype: int64
[split] actual_val=14911 (target~14911), train=59646
[split] label_dist (val):
Sample_Site
Nasopharynx    3809
Nasal          2831
Sputum         2444
Oropharynx     1687
Trachea        1264
BALF           1116
Pharynx        1062
Throat          393
Bronchus        157
Lung Tissue     148
Name: count, dtype: int64
train_subset: 59646
val_subset: 14911


In [8]:
print("Start training...")

# ========= 1) 解决：train_subset / val_subset 可能是“嵌套 Subset”，需要解析到最底层 base dataset =========
def resolve_subset_to_base_and_indices(ds):
    """
    把 Dataset/Subset 递归展开，返回：
      base_dataset: 最底层数据集（通常应是 all_corpus）
      base_indices: ds 在 base_dataset 上的绝对索引
    """
    if not isinstance(ds, Subset):
        return ds, np.arange(len(ds), dtype=int)

    base, base_idx = resolve_subset_to_base_and_indices(ds.dataset)
    cur_idx = np.asarray(ds.indices, dtype=int)
    return base, base_idx[cur_idx]

# 解析 train/val 到最底层 dataset + 绝对索引
base_train, train_idx = resolve_subset_to_base_and_indices(train_subset)
base_val,   val_idx   = resolve_subset_to_base_and_indices(val_subset)

# 一般情况下两者都会指向同一个 base（通常是 all_corpus）
assert base_train is base_val, "train/val 的 base dataset 不一致，这通常不应该发生。"
base_corpus = base_train

print("train_idx range sample:", train_idx[:10])
print("val_idx range sample:", val_idx[:10])

# ========= 2) 取原始标签（从 metadata 里拿），再用同一个 le 编码成 0..C-1 =========
train_raw = base_corpus.metadata.iloc[train_idx][args.label_col]
val_raw   = base_corpus.metadata.iloc[val_idx][args.label_col]

# 你已经在 finetune_subset 里过滤过 notna，这里再兜底检查一次
assert train_raw.notna().all(), "train 中仍存在 NA 标签，请检查你的 subset 条件。"
assert val_raw.notna().all(),   "val 中仍存在 NA 标签，请检查你的 subset 条件。"

# 用同一个 label encoder（le）做 transform，保证 train/val 编码一致
train_labels = le.transform(train_raw.tolist())
val_labels   = le.transform(val_raw.tolist())

# 你之前的断言（现在标签不再是 -1，而是 0..C-1）
assert (train_labels >= 0).all()
assert (val_labels >= 0).all()

# ========= 3) 构建 Trainer 可用的 Dataset（沿用你原来的 FinetuneDataset 风格） =========
train_dataset = FinetuneDataset(base_corpus, train_idx, train_labels)
val_dataset   = FinetuneDataset(base_corpus, val_idx,   val_labels)

# ========= 4) callbacks + Trainer =========
callbacks = [EarlyStoppingCallback(early_stopping_patience=10)]

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=callbacks,

    # ⚠️ 关键：不要传 tokenizer（否则可能触发 tokenizer.save_pretrained -> NotImplementedError）
    tokenizer=None,

    # 你的样本本身就是 tensor + 定长，默认 collator 最稳
    data_collator=default_data_collator,
)

trainer.train()

# ========= 5) 保存模型 =========
os.makedirs(args.output, exist_ok=True)
trainer.save_model(args.output)

# 保存 label encoder（你原来怎么做就怎么做）
dump(le, open(os.path.join(args.output, "label_encoder.pkl"), "wb"))
print(f"Model and label encoder saved to: {args.output}")

# ========= 6) 保存日志 =========
logs = trainer.state.log_history
logs_df = pd.DataFrame(logs)

os.makedirs(args.log, exist_ok=True)
log_path = os.path.join(args.log, "finetune_log.csv")
logs_df.to_csv(log_path, index=False)

print(f"Training logs saved to: {log_path}")
logs_df.tail()


Start training...
train_idx range sample: [ 0  1  3  4  5  6  7  8  9 10]
val_idx range sample: [ 2 18 24 26 27 46 50 53 57 74]


  return {'input_ids': torch.tensor(tokens),


Epoch,Training Loss,Validation Loss
1,1.1661,1.110276
2,0.8381,0.787602
3,0.6828,0.658984
4,0.6643,0.581294
5,0.567,0.52604
6,0.5419,0.483082
7,0.4366,0.45343
8,0.434,0.428294
9,0.4251,0.407238
10,0.4568,0.390238


  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {'input_ids': torch.tensor(tokens),
  return {

Model and label encoder saved to: ../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite
Training logs saved to: ../logs/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite/finetune_log.csv


Unnamed: 0,loss,learning_rate,epoch,step,eval_loss,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
8991,0.1287,1e-05,47.99,44725,,,,,,,,,
8992,0.0854,1e-05,47.99,44730,,,,,,,,,
8993,0.102,1e-05,48.0,44735,,,,,,,,,
8994,,,48.0,44736,0.281545,22.4564,663.996,10.376,,,,,
8995,,,48.0,44736,,,,,14153.7224,4214.156,65.848,5.559555e+16,0.30096
