In [1]:
import os
import sys
import pandas as pd
import numpy as np
import torch
from pickle import load as pkl_load

from transformers import Trainer, AutoConfig, GPT2ForSequenceClassification


from MiCoGPT.utils.mgm_utils import eval_and_save
from MiCoGPT.utils.corpus import SequenceClassificationDataset


# ====== label encoder 读取：兼容 joblib/pickle ======
def load_label_encoder(path):
    try:
        from joblib import load as joblib_load
        return joblib_load(path)
    except Exception:
        with open(path, "rb") as f:
            return pkl_load(f)

def get_label_names_and_num(le):
    """
    兼容两种：
      - LabelEncoder: le.classes_
      - OneHotEncoder: le.categories_[0]
    """
    if hasattr(le, "classes_"):
        names = [str(x) for x in le.classes_]
        return names, len(names)
    if hasattr(le, "categories_"):
        names = [str(x) for x in le.categories_[0]]
        return names, len(names)
    raise TypeError("无法识别 label encoder 类型：既没有 classes_ 也没有 categories_")

def encode_labels_to_ids(le, labels_series: pd.Series) -> np.ndarray:
    """
    把原始标签（Series）编码成 class id（0..C-1），兼容 LabelEncoder / OneHotEncoder。
    """
    if hasattr(le, "classes_"):
        return le.transform(labels_series.astype(str).tolist())
    if hasattr(le, "categories_"):
        # OneHotEncoder: transform -> one-hot（稀疏或 dense），再 argmax 得到 id
        X = le.transform(labels_series.values.reshape(-1, 1))
        if hasattr(X, "toarray"):
            X = X.toarray()
        return np.asarray(X).argmax(axis=1)
    raise TypeError("无法识别 label encoder 类型：既没有 classes_ 也没有 categories_")

# ====== v6/v9 gated 模型兼容加载器（关键） ======
import torch.nn as nn

class GatedPriorEmbeddingCompat(nn.Module):
    """
    兼容两类 gate：
      - v6: gate_logits: [V]
      - v9: gate_logits: [V, D]
    """
    def __init__(self, base: nn.Embedding, vocab_size: int, n_embd: int, gate_rank: int, g_min: float = 0.0):
        super().__init__()
        self.base = base
        self.g_min = float(g_min)
        self.register_buffer("prior_matrix", torch.zeros(vocab_size, n_embd, dtype=base.weight.dtype))

        if gate_rank == 1:
            self.gate_logits = nn.Parameter(torch.zeros(vocab_size, dtype=base.weight.dtype))
        elif gate_rank == 2:
            self.gate_logits = nn.Parameter(torch.zeros(vocab_size, n_embd, dtype=base.weight.dtype))
        else:
            raise ValueError(f"gate_rank must be 1 or 2, got {gate_rank}")

    def forward(self, input_ids: torch.LongTensor):
        base_emb = self.base(input_ids)                         # [B,T,D]
        prior_emb = self.prior_matrix[input_ids].to(base_emb.dtype)

        if self.gate_logits.dim() == 1:
            w = self.g_min + (1.0 - self.g_min) * torch.sigmoid(self.gate_logits[input_ids])  # [B,T]
            return base_emb + w.unsqueeze(-1) * prior_emb
        else:
            w = self.g_min + (1.0 - self.g_min) * torch.sigmoid(self.gate_logits[input_ids])  # [B,T,D]
            return base_emb + w * prior_emb

def _find_checkpoint_files(model_dir: str):
    bin_path = os.path.join(model_dir, "pytorch_model.bin")
    bin_index = os.path.join(model_dir, "pytorch_model.bin.index.json")
    st_path = os.path.join(model_dir, "model.safetensors")
    st_index = os.path.join(model_dir, "model.safetensors.index.json")

    if os.path.isfile(bin_index):
        return ("bin_sharded", bin_index)
    if os.path.isfile(st_index):
        return ("st_sharded", st_index)
    if os.path.isfile(st_path):
        return ("st_single", st_path)
    if os.path.isfile(bin_path):
        return ("bin_single", bin_path)
    return (None, None)

def _read_index_json(index_path: str):
    import json
    with open(index_path, "r", encoding="utf-8") as f:
        return json.load(f)

def _list_checkpoint_keys(model_dir: str) -> set:
    ckpt_type, ckpt_ref = _find_checkpoint_files(model_dir)
    if ckpt_type is None:
        return set()

    if ckpt_type == "bin_single":
        sd = torch.load(ckpt_ref, map_location="cpu")
        return set(sd.keys())

    if ckpt_type == "st_single":
        from safetensors.torch import load_file
        sd = load_file(ckpt_ref, device="cpu")
        return set(sd.keys())

    index = _read_index_json(ckpt_ref)
    return set(index.get("weight_map", {}).keys())

def _load_tensor_shape_from_checkpoint(model_dir: str, key: str):
    ckpt_type, ckpt_ref = _find_checkpoint_files(model_dir)
    if ckpt_type is None:
        return None

    if ckpt_type == "bin_single":
        sd = torch.load(ckpt_ref, map_location="cpu")
        return tuple(sd[key].shape) if key in sd else None

    if ckpt_type == "st_single":
        from safetensors.torch import load_file
        sd = load_file(ckpt_ref, device="cpu")
        return tuple(sd[key].shape) if key in sd else None

    index = _read_index_json(ckpt_ref)
    weight_map = index.get("weight_map", {})
    if key not in weight_map:
        return None

    shard_path = os.path.join(model_dir, weight_map[key])

    if ckpt_type == "bin_sharded":
        shard_sd = torch.load(shard_path, map_location="cpu")
        return tuple(shard_sd[key].shape) if key in shard_sd else None

    if ckpt_type == "st_sharded":
        from safetensors.torch import load_file
        shard_sd = load_file(shard_path, device="cpu")
        return tuple(shard_sd[key].shape) if key in shard_sd else None

    return None

def _coerce_state_dict_dtypes(sd: dict, model: nn.Module):
    model_sd = model.state_dict()
    for k, v in list(sd.items()):
        if isinstance(v, torch.Tensor) and k in model_sd and model_sd[k].dtype != v.dtype:
            sd[k] = v.to(dtype=model_sd[k].dtype)
    return sd

def _load_weights_into_model(model: nn.Module, model_dir: str):
    ckpt_type, ckpt_ref = _find_checkpoint_files(model_dir)
    if ckpt_type is None:
        raise FileNotFoundError(f"Cannot find checkpoint in: {model_dir}")

    missing_all, unexpected_all = [], []

    def _apply(sd):
        sd = _coerce_state_dict_dtypes(sd, model)
        missing, unexpected = model.load_state_dict(sd, strict=False)
        missing_all.extend(missing or [])
        unexpected_all.extend(unexpected or [])

    if ckpt_type == "bin_single":
        _apply(torch.load(ckpt_ref, map_location="cpu"))
        return missing_all, unexpected_all

    if ckpt_type == "st_single":
        from safetensors.torch import load_file
        _apply(load_file(ckpt_ref, device="cpu"))
        return missing_all, unexpected_all

    index = _read_index_json(ckpt_ref)
    shard_files = sorted(set(index.get("weight_map", {}).values()))

    if ckpt_type == "bin_sharded":
        for fn in shard_files:
            _apply(torch.load(os.path.join(model_dir, fn), map_location="cpu"))
        return missing_all, unexpected_all

    if ckpt_type == "st_sharded":
        from safetensors.torch import load_file
        for fn in shard_files:
            _apply(load_file(os.path.join(model_dir, fn), device="cpu"))
        return missing_all, unexpected_all

    return missing_all, unexpected_all

def load_model_compat(model_dir: str, num_labels: int, g_min: float = 0.0):
    """
    兼容加载：
      - 普通 GPT2 分类模型目录
      - v6/v9 gated 分类模型目录（wte 需要 patch）
    说明：预测阶段 args.model 一般是你 finetune 的输出目录（trainer.save_model 生成）
    """
    # 本地目录：用 config 初始化，再按 checkpoint keys 决定是否 patch wte，然后 strict=False 加载权重
    config = AutoConfig.from_pretrained(model_dir)
    config.num_labels = num_labels
    model = GPT2ForSequenceClassification(config)

    keys = _list_checkpoint_keys(model_dir)
    gated = ("transformer.wte.base.weight" in keys) and ("transformer.wte.gate_logits" in keys)

    if gated:
        gate_shape = _load_tensor_shape_from_checkpoint(model_dir, "transformer.wte.gate_logits")
        gate_rank = 1 if (gate_shape is None or len(gate_shape) == 1) else 2

        base = model.transformer.wte
        model.transformer.wte = GatedPriorEmbeddingCompat(
            base=base,
            vocab_size=model.config.vocab_size,
            n_embd=model.config.n_embd,
            gate_rank=gate_rank,
            g_min=g_min,
        )
        print(f"[predict] Detected gated model. gate_rank={gate_rank} (v6=1, v9=2)")

    _load_weights_into_model(model, model_dir)
    return model


In [10]:
from argparse import Namespace
from configparser import ConfigParser

cfg = ConfigParser()
cfg.read("config.ini")

# 手动构造一个等价于命令行的 args 对象
# args = Namespace(
#     input="../data/try2_withCC/ResMicroDB_90338.pkl",
#     # model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte",
#     # model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6",
#     # model="../models/finetuned_v4_pretrain_ResMicroDB_90338_GATED_v9",
#     model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_v9_5000",
#     # output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte",
#     # output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6",
#     # output="../outputs/predict_finetuned_v4_pretrain_ResMicroDB_90338_GATED_v9",
#     output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v9_5000",
#     evaluate=True,
#     # 预测时要选哪个 Split_Group（你原来是 B）
#     split_group="B",
#     # 标签列名（评估时用）
#     label_col="Is_Healthy",
#     # gated-prior 的 g_min（如果你训练时不是 0，需要一致）
#     g_min=0.0,
# )


args = Namespace(
    input="../data/try2_withCC/ResMicroDB_90338.pkl",
    # model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite",
    model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6_sampleSite",
    # model="../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_v9_5000_sampleSite",
    # output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_base_wte_sampleSite",
    output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6_sampleSite",
    # output="../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v9_5000_sampleSite",
    evaluate=True,
    # 预测时要选哪个 Split_Group（你原来是 B）
    split_group="B",
    # 标签列名（评估时用）
    label_col="Sample_Site",
    # gated-prior 的 g_min（如果你训练时不是 0，需要一致）
    g_min=0.0,
)


args


Namespace(input='../data/try2_withCC/ResMicroDB_90338.pkl', model='../models/finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6_sampleSite', output='../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6_sampleSite', evaluate=True, split_group='B', label_col='Sample_Site', g_min=0.0)

In [11]:
# 1. 载入 corpus.pkl
corpus = pkl_load(open(args.input, "rb"))
tokenizer = corpus.tokenizer

print("样本数量（整个 corpus）:", len(corpus))
display(corpus.data.head())

meta = corpus.metadata

# GPT2 tokenizer 有时没有 pad_token_id，这里做个兜底
if getattr(tokenizer, "pad_token_id", None) is None:
    tokenizer.pad_token = tokenizer.eos_token
pad_id = tokenizer.pad_token_id

# 如果你的这个 pkl 已经是纯 B 组，其实可以去掉 Split_Group == "B" 这一条
if "Split_Group" in meta.columns:
    group_mask = (meta["Split_Group"] == args.split_group)
else:
    group_mask = pd.Series(True, index=meta.index)

if args.evaluate:
    # 评估模式：只使用有标签的样本（Is_Healthy 非 NA）
    mask = group_mask & meta[args.label_col].notna()
    print("用于评估的样本数:", int(mask.sum()))
else:
    # 仅预测：对整个 B 组做预测（包括 Is_Healthy 为 NA 的）
    mask = group_mask
    print("用于预测的样本数:", int(mask.sum()))

# 2. 根据 mask 取出样本的行索引（在 corpus.tokens 中的位置）
idx = np.where(mask.to_numpy())[0]

# 对应样本 ID（用于 y_score.csv 的 index）
sample_ids = corpus.data.index[mask]
print("前几个样本 ID:", sample_ids[:5].tolist())

# 3. 准备 input_ids 和 attention_mask
#    注意：你的 corpus.tokens 可能是 torch.Tensor，也可能是 numpy；两种都兼容
input_ids = corpus.tokens[idx]         # [N_subset, max_len]
attention_mask = (torch.as_tensor(input_ids) != pad_id).long()

# 4. 载入 label encoder（无论是否 evaluate，我们都需要它来确定类别顺序）
le = load_label_encoder(os.path.join(args.model, "label_encoder.pkl"))
label_names, num_labels = get_label_names_and_num(le)
print("类别数:", num_labels)
print("类别名称:", label_names)

# 5. 构造 labels_tensor（✅替换成稳健版：不依赖 le.transform().argmax()）
if args.evaluate:
    labels_series = meta.loc[mask, args.label_col]

    # 统一成字符串（解决 True/False 可能混类型的问题）
    labels_norm = labels_series.astype(str).str.strip()

    # 显式映射：以 label_names 的顺序作为 class_id
    label2id = {str(name): i for i, name in enumerate(label_names)}

    # 未知标签检查（避免 silent 编码错误）
    unknown = labels_norm[~labels_norm.isin(label2id.keys())]
    if len(unknown) > 0:
        raise ValueError(f"发现未知标签值: {unknown.unique()[:10]} ; 训练类别: {label_names}")

    y_ids = labels_norm.map(label2id).astype(int).to_numpy()
    labels_tensor = torch.tensor(y_ids, dtype=torch.long)

    print("[y_true after map] counts:", np.unique(y_ids, return_counts=True))
else:
    labels_tensor = torch.zeros(len(idx), dtype=torch.long)

# 6. 构建 SequenceClassificationDataset（保持你原来的 style）
dataset = SequenceClassificationDataset(
    input_ids,
    attention_mask,
    labels_tensor,
)

print("Dataset 大小:", len(dataset))

# ====== 诊断：看看 B 组 + 有标签 后，真实标签分布到底是什么 ======
labels_raw = meta.loc[mask, args.label_col]
print("[raw] dtype =", labels_raw.dtype)
print("[raw] value_counts:\n", labels_raw.value_counts(dropna=False))

# ====== 看看 encoder 认为有哪些类别 ======
label_names, num_labels = get_label_names_and_num(le)
print("[encoder] label_names =", label_names)



样本数量（整个 corpus）: 90338


Taxon,g__Stenotrophomonas,g__Bacteriovorax,g__Idiomarina,g__Eubacterium,g__Methylobacillus,g__Larkinella,g__Fonticella,g__Klebsiella,g__Merdibacter,g__Fibrobacter,...,g__Chujaibacter,g__Papillibacter,g__Tannerellaceae,g__Sporichthya,g__Sphingosinicella,g__Salinivibrio,g__Aquaspirillum,g__Methylibium,g__Austwickia,g__Oceanotoga
CRR768228,-0.048917,-0.066191,-0.014152,-0.014765,-0.03292,-0.01434,-0.018529,-0.044503,0.311284,-0.016866,...,-0.015618,-0.021123,-0.014295,-0.017187,-0.048861,-0.043702,-0.021391,-0.032084,-0.023196,-0.020422
CRR768229,11.297346,-0.066191,-0.014152,-0.014765,-0.03292,-0.01434,-0.018529,-0.02461,-0.031506,-0.016866,...,-0.015618,-0.021123,-0.014295,-0.017187,-0.048861,-0.043702,-0.021391,-0.032084,-0.023196,-0.020422
CRR768230,0.216301,-0.066191,2.658264,-0.014765,-0.03292,-0.01434,-0.018529,1.618438,0.32014,-0.016866,...,-0.015618,-0.021123,-0.014295,-0.017187,-0.048861,-0.043702,-0.021391,-0.032084,-0.023196,-0.020422
CRR768231,-0.045746,-0.066191,-0.014152,-0.014765,-0.03292,-0.01434,-0.018529,1.241796,17.916046,-0.016866,...,-0.015618,-0.021123,-0.014295,-0.017187,-0.048861,-0.043702,-0.021391,-0.032084,-0.023196,-0.020422
CRR768232,0.067157,-0.066191,-0.014152,-0.014765,-0.03292,-0.01434,-0.018529,-0.043309,0.174857,-0.016866,...,-0.015618,-0.021123,-0.014295,-0.017187,-0.048861,-0.043702,-0.021391,-0.032084,-0.023196,-0.020422


用于评估的样本数: 13901
前几个样本 ID: ['DRR452457', 'DRR452458', 'DRR452459', 'DRR452460', 'DRR452461']
类别数: 10
类别名称: ['BALF', 'Bronchus', 'Lung Tissue', 'Nasal', 'Nasopharynx', 'Oropharynx', 'Pharynx', 'Sputum', 'Throat', 'Trachea']
[y_true after map] counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 733,  206,  579,  475, 3774, 4580,  101, 3169,  169,  115]))
Dataset 大小: 13901
[raw] dtype = object
[raw] value_counts:
 Sample_Site
Oropharynx     4580
Nasopharynx    3774
Sputum         3169
BALF            733
Lung Tissue     579
Nasal           475
Bronchus        206
Throat          169
Trachea         115
Pharynx         101
Name: count, dtype: int64
[encoder] label_names = ['BALF', 'Bronchus', 'Lung Tissue', 'Nasal', 'Nasopharynx', 'Oropharynx', 'Pharynx', 'Sputum', 'Throat', 'Trachea']


In [12]:
# 这里 num_labels 已在上一个 cell 里得到

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 关键：不要用 GPT2ForSequenceClassification.from_pretrained 直接加载 v6/v9
# 统一用兼容加载器：既支持普通 GPT2，也支持 v6/v9 gated
model = load_model_compat(
    model_dir=args.model,
    num_labels=num_labels,
    g_min=args.g_min,
)

model.to(device)
model.eval()  # 进入 eval 模式

trainer = Trainer(model=model)
model


[predict] Detected gated model. gate_rank=1 (v6=1, v9=2)


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): GatedPriorEmbeddingCompat(
      (base): Embedding(1121, 256)
    )
    (wpe): Embedding(512, 256)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-7): 8 x GPT2Block(
        (ln_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  )
  (score): Linear(in_features=256, out_features=10, bias=False)
)

In [13]:
# 运行预测
predictions = trainer.predict(dataset)

# ====== 诊断：Trainer 返回的 label_ids 应该和 dataset 的 labels_tensor 一致 ======
if args.evaluate:
    y_true_from_trainer = predictions.label_ids
    y_true_from_dataset = labels_tensor.cpu().numpy()
    print("[check] same y_true?", np.array_equal(y_true_from_trainer, y_true_from_dataset))
    print("[check] trainer y_true counts:", np.unique(y_true_from_trainer, return_counts=True))


# 确保输出目录存在
os.makedirs(args.output, exist_ok=True)

# 预测得分矩阵（样本数 × 类别数）
y_score = predictions.predictions

# 保存为 csv，index 对齐筛选后的 sample_ids，列名来自 label encoder
score_path = os.path.join(args.output, "y_score.csv")
pd.DataFrame(
    y_score,
    index=sample_ids,     # 筛选后的样本 ID 子集
    columns=label_names,  # 类别名（兼容 LabelEncoder/OneHotEncoder）
).to_csv(score_path)

print("y_score 已保存到:", score_path)

# 如果需要 evaluation，就计算并保存
if args.evaluate:
    y_true = predictions.label_ids  # shape: [N_subset]
    eval_dir = os.path.join(args.output, "evaluation")

    eval_and_save(
        y_score,
        y_true,
        label_names,
        eval_dir,
    )
    print("evaluation 结果已保存到:", eval_dir)
else:
    print("只做预测，没有 evaluation。记得根据输出得分设置你自己的阈值或判定规则。")


  "input_ids": torch.tensor(self.seq[idx]),
  "attention_mask": torch.tensor(self.mask[idx]),
  "labels": torch.tensor(self.labels[idx])


[check] same y_true? True
[check] trainer y_true counts: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([ 733,  206,  579,  475, 3774, 4580,  101, 3169,  169,  115]))
y_score 已保存到: ../outputs/predict_finetuned_v5_pretrain_ResMicroDB_90338_GATED_v6_sampleSite/y_score.csv
Evaluating biome source: BALF
          TN     FP   FN   TP     Acc      Sn      Sp     TPR     FPR      Rc  \
t                                                                               
0.000      0  13168    0  733  0.0527  1.0000  0.0000  1.0000  1.0000  1.0000   
0.001   6785   6383   39  694  0.5380  0.9468  0.5153  0.9468  0.4847  0.9468   
0.002   7645   5523   64  669  0.5981  0.9127  0.5806  0.9127  0.4194  0.9127   
0.003   8099   5069   83  650  0.6294  0.8868  0.6151  0.8868  0.3849  0.8868   
0.004   8407   4761   94  639  0.6507  0.8718  0.6384  0.8718  0.3616  0.8718   
...      ...    ...  ...  ...     ...     ...     ...     ...     ...     ...   
0.997  13158     10  721   12  0.9474  0.0164  0.999