## 0. 导入依赖 (Import Dependencies)

**[Ablation Study: Value + Condition]**
这是消融实验的进阶版。
在此配置中，我们将：
1. **保留 Value Embeddings** (丰度信息)
2. **保留 Condition Embeddings** (环境元数据)
3. **关闭 Cross-Attention** (不使用先验知识库)
4. 验证在没有 Cross-Attention 的情况下，仅靠环境上下文能带来多少提升。

In [None]:
import os
import torch
import warnings
from pickle import load
from argparse import Namespace
from configparser import ConfigParser
from importlib.resources import files

import pandas as pd
from torch.utils.data import Subset

from transformers import (
    Trainer,
    TrainingArguments,
)
from transformers.trainer_callback import EarlyStoppingCallback

from MiCoGPT.utils_vCross.model_vCross import MiCoGPTConfig, MiCoGPTForCausalLM
from MiCoGPT.utils_vCross.collator_vCross import MiCoGPTDataCollator
from MiCoGPT.utils.tools import split_train_val_by_project_stratified

warnings.filterwarnings("ignore")

## 1. 基本参数设置

In [None]:
args = Namespace(
    input="../data/vCross/ResMicroDB_90338_vCross.pkl",
    # [Ablation] 输出路径: Value + Condition
    output="../models/pretrain_vCross_value_condition",
    log="../logs/pretrain_vCross_value_condition",
    prior_npz=None
)
VAL_RATIO = 0.10

## 2. 载入语料库

In [None]:
print(f"Loading corpus from {args.input} ...")
all_corpus = load(open(args.input, "rb"))

if all_corpus.metadata is not None and "Split_Group" in all_corpus.metadata.columns:
    print("Subsetting corpus by Split_Group == 'A'...")
    corpus = all_corpus.subset_by_metadata(lambda df: df["Split_Group"] == "A")
else:
    print("Using full corpus (no Split_Group found or metadata missing).")
    corpus = all_corpus
    
tokenizer = all_corpus.tokenizer

## 3. 提取环境元数据信息

这里我们需要正常提取 `condition_vocab_sizes`，因为我们要使用 Condition Embeddings。

In [None]:
if isinstance(corpus, Subset):
    base_corpus = corpus.dataset
else:
    base_corpus = corpus

condition_cols = list(base_corpus.meta_encoders.keys())
condition_vocab_sizes = [len(le.classes_) + 1 for le in base_corpus.meta_encoders.values()]

print("Condition Columns:", condition_cols)
print("Condition Vocab Sizes:", condition_vocab_sizes)

## 4. 构建模型 (Value + Condition)

配置：
1. `num_bins=52` (保留 Value)
2. `condition_vocab_sizes=[...]` (保留 Condition)
3. `prior_matrix_path=None` & `add_cross_attention=False` (关闭 Cross-Attention)

In [None]:
cfg = ConfigParser()
cfg.read(files("MiCoGPT")/"resources/config.ini")

gpt2_config_dict = {
    "vocab_size":   tokenizer.vocab_size,
    "n_positions":  cfg.getint("GPT2", "n_positions"),
    "n_embd":       cfg.getint("GPT2", "n_embd"),
    "n_layer":      cfg.getint("GPT2", "n_layer"),
    "n_head":       cfg.getint("GPT2", "n_head"),
    "bos_token_id": tokenizer.bos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}

config = MiCoGPTConfig(
    num_bins=52,                        # 保留 Value
    condition_vocab_sizes=condition_vocab_sizes, # 保留 Condition
    prior_matrix_path=None,             # 不用 Prior
    add_cross_attention=False,          # 关闭 Cross-Attention
    **gpt2_config_dict
)

model = MiCoGPTForCausalLM(config)
print("Model Config:", config)
print("Model Architecture:", model)
print("\n[Check] Cross Attention is:", "ENABLED" if config.add_cross_attention else "DISABLED")
print("[Check] Condition Embeddings:", len(model.condition_embeddings))

## 5. 初始化数据整理器 (正常模式)

使用标准的 `MiCoGPTDataCollator` 即可，它会正常返回 `input_ids`, `value_ids`, `condition_ids`。

In [None]:
collator = MiCoGPTDataCollator(
    tokenizer=tokenizer,
    max_length=config.n_positions
)

In [None]:
if isinstance(corpus, Subset):
    metadata = corpus.dataset.metadata
else:
    metadata = corpus.metadata

if metadata is not None and "Project_ID" in metadata.columns:
    print("Using stratified split by Project_ID...")
    train_dataset, val_dataset = split_train_val_by_project_stratified(
        corpus,
        val_ratio=VAL_RATIO,
        project_col="Project_ID"
    )
else:
    train_dataset, val_dataset = torch.utils.data.random_split(
        corpus, 
        [len(corpus)-int(len(corpus)*VAL_RATIO), int(len(corpus)*VAL_RATIO)], 
        generator=torch.Generator().manual_seed(42)
    )

In [None]:
training_args = TrainingArguments(
    output_dir=f"{args.output}/checkpoints",
    overwrite_output_dir=True,
    num_train_epochs=50,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    learning_rate=1e-3,
    weight_decay=0.01,
    logging_dir=args.log,
    logging_steps=100,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=True,
    fp16=torch.cuda.is_available(),
    no_cuda=not torch.cuda.is_available(),
    report_to=["tensorboard"],
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)
trainer.train()

In [None]:
import json
import matplotlib.pyplot as plt

trainer.save_model(args.output)
tokenizer.save_pretrained(args.output)

log_history = trainer.state.log_history
with open(f"{args.output}/training_logs.json", "w") as f:
    json.dump(log_history, f, indent=2)

train_steps = [x["step"] for x in log_history if "loss" in x]
train_loss = [x["loss"] for x in log_history if "loss" in x]
eval_steps = [x["step"] for x in log_history if "eval_loss" in x]
eval_loss = [x["eval_loss"] for x in log_history if "eval_loss" in x]

plt.figure(figsize=(10, 6))
if train_steps: plt.plot(train_steps, train_loss, label="Training Loss", alpha=0.7)
if eval_steps: plt.plot(eval_steps, eval_loss, label="Validation Loss", marker="o", linestyle="--")
plt.title("Ablation: Value + Condition (No Cross-Attn)")
plt.legend()
plt.savefig(f"{args.output}/loss_curve.png")
plt.show()