In [1]:
import os
import math
import torch
import numpy as np
import pandas as pd
from pickle import load
from torch.utils.data import DataLoader, SequentialSampler, Subset
import torch.nn.functional as F
from transformers import GPT2Config, GPT2LMHeadModel, DataCollatorForLanguageModeling
from importlib.resources import files
from tqdm.auto import tqdm
from argparse import Namespace

from MiCoGPT.utils.pretrain import attach_gated_prior_to_gpt2


def load_gated_model(model_dir: str, tokenizer, npz_path, g_min=0.0, init_w=0.1, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) config
    config = GPT2Config.from_pretrained(model_dir)
    model = GPT2LMHeadModel(config)

    # 2) 先把结构替换成 gated（非常关键：否则 state_dict key 对不上）
    attach_gated_prior_to_gpt2(
        model=model,
        tokenizer=tokenizer,
        npz_path=npz_path,
        g_min=g_min,
        init_w=init_w,
    )

    # 3) 加载权重（兼容 .bin / .safetensors）
    bin_path = os.path.join(model_dir, "pytorch_model.bin")
    st_path = os.path.join(model_dir, "model.safetensors")

    if os.path.exists(st_path):
        from safetensors.torch import load_file
        state = load_file(st_path, device="cpu")
    else:
        state = torch.load(bin_path, map_location="cpu")

    model.load_state_dict(state, strict=True)
    model.to(device)
    model.eval()
    return model, device


@torch.no_grad()
def eval_loss_and_ppl_by_project(
    model,
    tokenizer,
    test_subset: Subset,
    project_col="Project_ID",
    batch_size=32,
    num_workers=0,
):
    # 取出 base corpus + indices，对齐 metadata
    base = test_subset.dataset
    base_indices = np.array(test_subset.indices)
    meta = base.metadata.iloc[base_indices]
    proj_ids_all = meta[project_col].astype(str).to_numpy()

    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    dl = DataLoader(
        test_subset,
        sampler=SequentialSampler(test_subset),  # 保证顺序与 indices 对齐
        batch_size=batch_size,
        num_workers=num_workers,
        collate_fn=collator,
        pin_memory=torch.cuda.is_available(),
    )

    total_loss_sum = 0.0
    total_tok_sum = 0

    # project 聚合
    loss_sum = {}
    tok_sum = {}
    n_samples = {}

    pos = 0
    device = next(model.parameters()).device

    for batch in tqdm(dl, desc="Eval on test", total=len(dl)):
        # batch: input_ids/attention_mask/labels (collator 会加 labels 并把 pad 置为 -100)
        batch = {k: v.to(device) for k, v in batch.items()}
        labels = batch["labels"]
        bsz = labels.size(0)

        # 对齐 project_id（依赖 SequentialSampler 保证顺序一致）
        proj_ids = proj_ids_all[pos:pos + bsz]
        pos += bsz

        out = model(input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask", None))
        logits = out.logits  # [B, T, V]

        # causal LM shift
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        V = shift_logits.size(-1)
        loss_flat = F.cross_entropy(
            shift_logits.view(-1, V),
            shift_labels.view(-1),
            ignore_index=-100,
            reduction="none",
        ).view(bsz, -1)

        mask = (shift_labels != -100)
        loss_per_sample = (loss_flat * mask).sum(dim=1)         # [B] 每条样本 loss 总和
        tok_per_sample = mask.sum(dim=1).to(torch.long)         # [B] 每条样本有效 token 数

        # 总体
        total_loss_sum += float(loss_per_sample.sum().item())
        total_tok_sum += int(tok_per_sample.sum().item())

        # 分 project
        for pid, ls, nt in zip(proj_ids, loss_per_sample, tok_per_sample):
            nt_i = int(nt.item())
            if nt_i == 0:
                continue
            loss_sum[pid] = loss_sum.get(pid, 0.0) + float(ls.item())
            tok_sum[pid] = tok_sum.get(pid, 0) + nt_i
            n_samples[pid] = n_samples.get(pid, 0) + 1

    overall_loss = total_loss_sum / max(total_tok_sum, 1)
    overall_ppl = math.exp(overall_loss)

    # 输出 project 级别 DataFrame
    rows = []
    for pid in sorted(loss_sum.keys()):
        ploss = loss_sum[pid] / max(tok_sum[pid], 1)
        rows.append(
            {
                "Project_ID": pid,
                "n_samples": n_samples.get(pid, 0),
                "n_tokens": tok_sum.get(pid, 0),
                "loss": ploss,
                "ppl": math.exp(ploss),
            }
        )
    df_by_project = pd.DataFrame(rows)

    # 总体汇总（也做成一行）
    df_overall = pd.DataFrame(
        [
            {
                "Project_ID": "ALL",
                "n_samples": len(test_subset),
                "n_tokens": total_tok_sum,
                "loss": overall_loss,
                "ppl": overall_ppl,
            }
        ]
    )

    return df_overall, df_by_project

def load_base_model(model_dir: str, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = GPT2LMHeadModel.from_pretrained(model_dir)
    model.to(device)
    model.eval()
    return model, device

In [7]:
args = Namespace(
    input="../data/try2_withCC/ResMicroDB_90338.pkl",
    models="../models/pretrain_ResMicroDB_90338_GATED_high_init_high_scale_random_group",
    test_dir="../outputs/test_eval/pretrain_ResMicroDB_90338_GATED_high_init_high_scale_random_group_base/",
)


# 载入 corpus（你已经在训练脚本里这么做过）
all_corpus = load(open(args.input, "rb"))
tokenizer = all_corpus.tokenizer

# test = Split_Group == "B"
test_set = all_corpus.subset_by_metadata(lambda df: df["Split_Group"] == "B")

# 载入训练好的 gated 模型
model_dir = args.models  # 你 trainer.save_model 保存的目录
npz_path = files("MiCoGPT")/"resources"/"genus_embeddings_256.npz"

G_MIN = 0.00
INIT_W = 0.50

model, device = load_base_model(model_dir)

# model, device = load_gated_model(
#     model_dir=model_dir,
#     tokenizer=tokenizer,
#     npz_path=npz_path,
#     g_min=G_MIN,
#     init_w=INIT_W,
# )

# 评估并输出 csv
df_all, df_proj = eval_loss_and_ppl_by_project(
    model=model,
    tokenizer=tokenizer,
    test_subset=test_set,
    project_col="Project_ID",
    batch_size=32,
)


os.makedirs(args.test_dir, exist_ok=True)
overall_path = os.path.join(args.test_dir, "test_overall_loss_ppl.csv")
project_path = os.path.join(args.test_dir, "test_project_loss_ppl.csv")

df_all.to_csv(overall_path, index=False)
df_proj.to_csv(project_path, index=False)

print(df_all)
print("Saved:", overall_path, project_path)



Some weights of the model checkpoint at ../models/pretrain_ResMicroDB_90338_GATED_high_init_high_scale_random_group were not used when initializing GPT2LMHeadModel: ['transformer.wte.gate_logits', 'transformer.wte.prior_matrix', 'transformer.wte.base.weight']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Eval on test:   0%|          | 0/435 [00:00<?, ?it/s]

  return {'input_ids': torch.tensor(tokens),


  Project_ID  n_samples  n_tokens      loss        ppl
0        ALL      13901    926639  4.096864  60.151336
Saved: ../outputs/test_eval/pretrain_ResMicroDB_90338_GATED_high_init_high_scale_random_group_base/test_overall_loss_ppl.csv ../outputs/test_eval/pretrain_ResMicroDB_90338_GATED_high_init_high_scale_random_group_base/test_project_loss_ppl.csv
