In [None]:
# -*- coding: utf-8 -*-
"""Untitled0.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1J0nRtRnRbNiPlDiVxOQcGqhBl0xtBTUX
"""

import torch
import json
import os
import matplotlib.pyplot as plt
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    # BitsAndBytesConfig, # <-- 已移除
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset, concatenate_datasets
from transformers.trainer_callback import TrainerCallback
# 【已修复】导入 PeftModel
from peft import LoraConfig, get_peft_model, PeftModel
from tqdm import tqdm
import warnings
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Suppress warnings
warnings.filterwarnings("ignore")

# --- 1. Configuration ---
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
HOTPOT_DATASET_NAME = "hotpot_qa"
HOTPOT_DATASET_CONFIG = "distractor"
MATH_DATASET_NAME = "qwedsacf/competition_math"
RESULTS_DIR = "./drive/MyDrive/"

# 【检查点路径 1 - 对照组】
JOINT_ADAPTER_PATH = os.path.join(RESULTS_DIR, "joint_adapter_llama_fp32")

# 【检查点路径 2 - 实验组 Phase 1 (MATH)】
TASK_B_ADAPTER_PATH = os.path.join(RESULTS_DIR, "math_adapter_llama_fp32")


# --- VRAM-Saving Config ---
MAX_SEQ_LENGTH = 2048
# 【BUG 修复】降低 BS 以适应 FP32
PER_DEVICE_BS = 64
GRAD_ACC_STEPS = 1 # (有效批量大小仍然是 8 * 4 = 32)

# --- Experiment Config ---
N_TRAIN_EXAMPLES = 4000
N_VAL_EXAMPLES = 400
JOINT_EPOCHS = 2
TASK_A_EPOCHS = 2 # 用于 Phase 2 (HotpotQA)
TASK_B_EPOCHS = 1 # 用于 Phase 1 (MATH)

# --- 2. Utility Functions (Data Formatting - Llama Chat Style) ---
def format_hotpot_qa(example):
    """Formats HotpotQA data into a Llama-chat-style prompt."""
    context = " ".join(["".join(s) for s in example["context"]["sentences"]])
    question = example["question"]
    answer = example["answer"]

    text = (
        f"<s>[INST] You are a helpful assistant. Use the following context to "
        f"answer the question. Context: {context}\n\nQuestion: {question} [/INST] "
        f"Answer: {answer}</s>"
    )
    return text

def format_math(example):
    """Formats MATH data into a Llama-chat-style prompt."""
    problem = example["problem"]
    solution = example["solution"]

    text = (
        f"<s>[INST] You are a math expert. Solve the following math problem. "
        f"Show your work.\nProblem: {problem} [/INST] "
        f"Solution: {solution}</s>"
    )
    return text

def filter_by_length(example, tokenizer, formatter):
    """
    只检查长度。返回 True (保留) 或 False (丢弃)。
    """
    text = formatter(example)
    tokenized = tokenizer(text, max_length=MAX_SEQ_LENGTH + 1, truncation=False, padding=False)
    return len(tokenized['input_ids']) <= MAX_SEQ_LENGTH

# 【BUG 修复】这是修复了 1.7 Loss 问题 和 ValueError 的 Preprocess 函数
def preprocess(example, tokenizer, formatter):
    """
    【已修正】
    格式化文本，应用损失掩码，并填充到最大长度。
    """
    text = formatter(example)
    tokenized = tokenizer(
        text,
        max_length=MAX_SEQ_LENGTH,
        truncation=True,
        padding="max_length", # 修复 ValueError
    )
    labels = tokenized["input_ids"].copy()
    inst_token_id = tokenizer.convert_tokens_to_ids("]")

    split_point = -1
    for i in range(len(tokenized["input_ids"]) - 1, -1, -1):
        if tokenized["input_ids"][i] == inst_token_id:
            split_point = i + 1
            break

    if split_point == -1:
        return {}

    for i in range(split_point):
        labels[i] = -100

    tokenized["labels"] = labels
    return tokenized

# --- 3. Model Loading (【重构】) ---

def get_model_and_tokenizer_base():
    """
    只加载 FP16 TinyLlama 基础模型和 Tokenizer。
    """
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16, # <-- 加载时仍然用 FP16 (节省 RAM)，但训练会是 FP32
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # 在基础模型上启用梯度检查点
    model.gradient_checkpointing_enable()

    return model, tokenizer

def get_lora_config():
    """
    只定义 LoRA 配置。
    """
    return LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
def manual_evaluate(model, dataloader, device):
    """
    在给定 dataloader 上手动运行评估。
    """
    model.eval()  # <--- 设置为评估模式
    total_loss = 0
    total_steps = 0
    with torch.no_grad(): # <--- 禁用梯度计算
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            # 将批次移动到模型所在的设备
            batch = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask", "labels"]}

            outputs = model(**batch)
            loss = outputs.loss

            total_loss += loss.item()
            total_steps += 1

    model.train() # <--- 【重要】将模型设置回训练模式
    return total_loss / total_steps

# --- 4. Main Experiment Logic ---
def main():
    if not os.path.exists(RESULTS_DIR):
        os.makedirs(RESULTS_DIR)

    print(f"--- Loading Base Model & Tokenizer ---")
    base_model, tokenizer = get_model_and_tokenizer_base()

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # --- Load and Process Datasets ---
    print(f"\n--- Loading and Preprocessing Datasets (This may take a while) ---")

    # Task A: HotpotQA
    raw_hotpot = load_dataset(HOTPOT_DATASET_NAME, HOTPOT_DATASET_CONFIG)
    hotpot_train = raw_hotpot["train"].shuffle(seed=42).select(range(N_TRAIN_EXAMPLES))
    hotpot_val = raw_hotpot["validation"].shuffle(seed=42).select(range(N_VAL_EXAMPLES))

    print(f"Tokenizing and filtering HotpotQA...")
    hotpot_train_tokenized = hotpot_train.filter(
        lambda x: filter_by_length(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    hotpot_val_tokenized = hotpot_val.filter(
        lambda x: filter_by_length(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    print(f"HotpotQA: {len(hotpot_train_tokenized)} train, {len(hotpot_val_tokenized)} val (after filtering)")

    # Task B: MATH
    raw_math = load_dataset(MATH_DATASET_NAME)
    total_math_samples_needed = N_TRAIN_EXAMPLES + N_VAL_EXAMPLES
    math_subset = raw_math["train"].shuffle(seed=42).select(range(total_math_samples_needed))
    val_size_fraction = N_VAL_EXAMPLES / total_math_samples_needed
    math_splits = math_subset.train_test_split(test_size=val_size_fraction, seed=42)
    math_train = math_splits["train"]
    math_val = math_splits["test"]

    print(f"Tokenizing and filtering MATH...")
    math_train_tokenized = math_train.filter(
        lambda x: filter_by_length(x, tokenizer, format_math),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_math),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    math_val_tokenized = math_val.filter(
        lambda x: filter_by_length(x, tokenizer, format_math),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_math),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    print(f"MATH: {len(math_train_tokenized)} train, {len(math_val_tokenized)} val (after filtering)")

    # # --- Experiment 1: Joint Training (Control Group) ---
    # print(f"\n--- Starting Experiment 1: Joint Training ---")

    # if os.path.exists(os.path.join(JOINT_ADAPTER_PATH, "adapter_model.safetensors")):
    #     print(f"--- Found existing Joint adapter. Loading from {JOINT_ADAPTER_PATH} ---")
    #     joint_model = PeftModel.from_pretrained(base_model, JOINT_ADAPTER_PATH)
    #     print("Adapter loaded successfully.")

    # else:
    #     print(f"--- No Joint adapter found. Starting Joint Training ---")
    #     lora_config = get_lora_config()
    #     joint_model = get_peft_model(base_model, lora_config)
    #     joint_model.print_trainable_parameters()

    #     joint_train_dataset = concatenate_datasets([hotpot_train_tokenized, math_train_tokenized]).shuffle(seed=42)

    #     joint_training_args = TrainingArguments(
    #         output_dir=os.path.join(RESULTS_DIR, "joint_training_temp"),
    #         per_device_train_batch_size=PER_DEVICE_BS,
    #         gradient_accumulation_steps=GRAD_ACC_STEPS,
    #         num_train_epochs=JOINT_EPOCHS,
    #         learning_rate=2e-4,
    #         # fp16=True, # <-- 【BUG 修复】删除此行
    #         logging_steps=50,
    #         save_strategy="no",
    #         report_to="none",
    #         gradient_checkpointing=True,
    #     )

    #     joint_trainer = Trainer(
    #         model=joint_model,
    #         args=joint_training_args,
    #         train_dataset=joint_train_dataset,
    #         data_collator=data_collator,
    #     )

    #     joint_trainer.train()

    #     print(f"--- Joint training complete. Saving adapter to {JOINT_ADAPTER_PATH} ---")
    #     joint_model.save_pretrained(JOINT_ADAPTER_PATH)
    #     print("Adapter saved.")

    #     del joint_trainer
    #     torch.cuda.empty_cache()

    # # --- Evaluate the "Joint" model (whether trained or loaded) ---
    # print("\n--- Evaluating Joint Model ---")

    # eval_args_joint = TrainingArguments(
    #     output_dir=os.path.join(RESULTS_DIR, "eval_temp_joint"),
    #     per_device_eval_batch_size=PER_DEVICE_BS,
    #     # fp16=True, # <-- 【BUG 修复】删除此行
    #     report_to="none",
    #     gradient_checkpointing=True,
    # )

    # eval_trainer_joint = Trainer(
    #     model=joint_model,
    #     args=eval_args_joint,
    #     data_collator=data_collator,
    # )

    # eval_hotpot_joint = eval_trainer_joint.evaluate(eval_dataset=hotpot_val_tokenized)
    # print(f"  > Joint Model - HotpotQA Val Loss: {eval_hotpot_joint['eval_loss']:.4f}")

    # eval_math_joint = eval_trainer_joint.evaluate(eval_dataset=math_val_tokenized)
    # print(f"  > Joint Model - MATH Val Loss: {eval_math_joint['eval_loss']:.4f}")

    # del joint_model, eval_trainer_joint, eval_args_joint
    # torch.cuda.empty_cache()


    # --- 【已反转】Experiment 2: Sequential Training (CF) [MATH -> HotpotQA] ---
    print(f"\n--- Starting Experiment 2: Sequential Training (CF) [MATH -> HotpotQA] ---")

    # --- Phase 1: Train on MATH (or load from checkpoint) ---
    if os.path.exists(os.path.join(TASK_B_ADAPTER_PATH, "adapter_model.safetensors")):
        print(f"--- Found existing Task B (MATH) adapter. Loading from {TASK_B_ADAPTER_PATH} ---")
        seq_model = PeftModel.from_pretrained(base_model, TASK_B_ADAPTER_PATH)
        print("Adapter loaded successfully.")

    else:
        print(f"--- No adapter found. Starting Phase 1: Training on Task B (MATH) ---")
        lora_config = get_lora_config()
        seq_model = get_peft_model(base_model, lora_config)
        seq_model.print_trainable_parameters()

        seq_args_b = TrainingArguments(
            output_dir=os.path.join(RESULTS_DIR, "seq_training_B_temp"),
            per_device_train_batch_size=PER_DEVICE_BS,
            gradient_accumulation_steps=GRAD_ACC_STEPS,
            num_train_epochs=TASK_B_EPOCHS,
            learning_rate=2e-4,
            # fp16=True, # <-- 【BUG 修复】删除此行
            logging_steps=10,
            save_strategy="no",
            report_to="none",
            gradient_checkpointing=True,
        )

        seq_trainer_b = Trainer(
            model=seq_model,
            args=seq_args_b,
            train_dataset=math_train_tokenized, # <-- 训练 MATH
            eval_dataset=math_val_tokenized,
            data_collator=data_collator,
        )

        seq_trainer_b.train()

        print(f"--- Phase 1 (MATH) training complete. Saving adapter to {TASK_B_ADAPTER_PATH} ---")
        seq_model.save_pretrained(TASK_B_ADAPTER_PATH)
        print("Adapter saved.")

        del seq_trainer_b
        torch.cuda.empty_cache()

     # --- Evaluate the "Task B Expert" model (whether trained or loaded) ---
    print("\n--- Evaluating Model after Phase 1 (Task B Expert) ---")
    eval_args = TrainingArguments(
        output_dir=os.path.join(RESULTS_DIR, "eval_temp"),
        per_device_eval_batch_size=PER_DEVICE_BS,
        # fp16=True, # <-- 【BUG 修复】删除此行
        report_to="none",
        gradient_checkpointing=True,
    )


    eval_trainer = Trainer(
        model=seq_model,
        args=eval_args,
        data_collator=data_collator,
    )

    eval_hotpot_phase1 = eval_trainer.evaluate(eval_dataset=hotpot_val_tokenized)
    print(f"  > Task B Expert - HotpotQA Val Loss: {eval_hotpot_phase1['eval_loss']:.4f}")
    eval_math_phase1 = eval_trainer.evaluate(eval_dataset=math_val_tokenized)
    print(f"  > Task B Expert - MATH Val Loss: {eval_math_phase1['eval_loss']:.4f}")
    del eval_trainer, eval_args
    torch.cuda.empty_cache()
    # --- Phase 2: Train on HotpotQA (Forgetting MATH happens here) ---
    print(f"\n  --- Phase 2: Training on Task A (HotpotQA) ---")
    history = {"steps": [], "hotpot_loss": [], "math_loss": []}
    # Custom Trainer to log forgetting
    class ForgettingTrackerCallback(TrainerCallback):
      def __init__(self, hotpot_val, math_val, history_log, start_metrics):
          super().__init__()
          self.hotpot_eval_dataset = hotpot_val
          self.math_eval_dataset = math_val
          self.history = history_log
          self.trainer = None
          # --- 【修复】---
          # 添加一个 "锁" 来防止无限递归
          self.is_evaluating = False
          # ----------------
          # 记录初始状态 (Step 0)
          self.history["steps"].append(0)
          self.history["hotpot_loss"].append(start_metrics['hotpot_loss'])
          self.history["math_loss"].append(start_metrics['math_loss'])
          print("Initializing ForgettingTrackerCallback with starting metrics.")
      def set_trainer(self, trainer):
          """在 Trainer 例化后, 注入对它的引用。"""
          self.trainer = trainer
          print("Trainer reference set in callback.")

      def on_log(self, args, state, control, **kwargs):
          """在 'logging_steps' 触发时被调用。"""
          # --- 【修复 1】---
          # 如果我们已经在这个函数中 (因为递归调用), 立即退出。
          if self.is_evaluating:
              return
          # --- 【修复 2】---
          # "获取" 锁
          self.is_evaluating = True
          # 确保 trainer 引用已被设置
          if not self.trainer:
              print("WARNING: Trainer reference not set in callback, skipping eval.")
              self.is_evaluating = False # <-- 别忘了在这里释放锁
              return
          print(f"\n--- Custom Eval at Step {state.global_step} ---")
          print("Evaluating on Task A (HotpotQA)...")
          # 使用 trainer 的 evaluate 方法
          hotpot_metrics = self.trainer.evaluate(eval_dataset=self.hotpot_eval_dataset)
          hotpot_loss = hotpot_metrics['eval_loss']
          print(f"  > Step {state.global_step} - HotpotQA Val Loss: {hotpot_loss:.4f} (LEARNING?)")
          print("Evaluating on Task B (MATH)...")
          math_metrics = self.trainer.evaluate(eval_dataset=self.math_eval_dataset)
          math_loss = math_metrics['eval_loss']
          print(f"  > Step {state.global_step} - MATH Val Loss: {math_loss:.4f} (FORGETTING?)")
          self.history["steps"].append(state.global_step)
          self.history["hotpot_loss"].append(hotpot_loss)
          self.history["math_loss"].append(math_loss)
          # --- 【修复 3】---
          # "释放" 锁, 以便下一次 on_log 可以运行
          self.is_evaluating = False
          self.trainer.model.train()


    seq_args_a = TrainingArguments(
        output_dir=os.path.join(RESULTS_DIR, "seq_training_A"),
        per_device_train_batch_size=PER_DEVICE_BS,
        gradient_accumulation_steps=GRAD_ACC_STEPS,
        num_train_epochs=TASK_A_EPOCHS,
        learning_rate=7e-5,
        logging_steps=10,
        save_strategy="no",
        report_to=[],         # <-- 保持这个设置
        # disable_tqdm=True,  # <-- 保持这个设置
        gradient_checkpointing=True,
    )
    seq_model.enable_input_require_grads()
    # 【修复 2】: 实例化 *新* 的 Callback
    tracker_callback = ForgettingTrackerCallback(
        hotpot_val=hotpot_val_tokenized,
        math_val=math_val_tokenized,
        history_log=history,
        start_metrics={
            'hotpot_loss': eval_hotpot_phase1['eval_loss'],
            'math_loss': eval_math_phase1['eval_loss'],
        }
    )
    # 【修复 3】: 实例化一个 *标准* Trainer, 并传入回调
    seq_trainer_a = Trainer(
        model=seq_model,
        args=seq_args_a,
        train_dataset=hotpot_train_tokenized,
        eval_dataset=hotpot_val_tokenized,
        data_collator=data_collator,
        callbacks=[tracker_callback]  # <-- 在这里传入回调
    )
    # 【修复 4】: 将 trainer 实例链接回回调
    # (回调需要这个引用来调用 self.trainer.evaluate())
    tracker_callback.set_trainer(seq_trainer_a)
    seq_trainer_a.train()

    # --- 5. Plot Results ---
    print("\n--- Saving History Data and Generating Plot ---")

    # --- 保存 history data 到 JSON ---
    history_filename = os.path.join(RESULTS_DIR, "forgetting_history_MATH_to_HotpotQA_fp32.json")
    try:
        with open(history_filename, 'w') as f:
            json.dump(history, f, indent=4)
        print(f"History data saved to {history_filename}")
    except Exception as e:
        print(f"Error saving history to JSON: {e}")
    # --- [END] ---

    plt.figure(figsize=(12, 6))
    plt.plot(history["steps"], history["hotpot_loss"], 'o-', label="Task A (HotpotQA) Loss", color="blue")
    plt.plot(history["steps"], history["math_loss"], 'o-', label="Task B (MATH) Loss", color="red")

    plt.title(f"Catastrophic Forgetting: MATH -> HotpotQA (Model: {MODEL_NAME} FP32 LoRA)")
    plt.xlabel(f"Training Steps on Task A (HotpotQA) (Total Epochs: {TASK_A_EPOCHS})")
    plt.ylabel("Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.yscale('log')
    plt.tight_layout()

    plot_filename = os.path.join(RESULTS_DIR, "sequential_forgetting_curve_MATH_to_HotpotQA_fp32.png")
    plt.savefig(plot_filename)
    print(f"Plot saved to {plot_filename}")

    try:
        from google.colab import files
        plt.show()
    except ImportError:
        print("Not in Colab, plot saved to file.")

if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("ERROR: This experiment requires a GPU. Check Colab runtime type.")
    else:
        print(f"INFO: Running on GPU. VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        if torch.cuda.get_device_properties(0).total_memory / 1e9 < 11:
            print("WARNING: VRAM is less than 11GB. You may hit OOM errors. Try lowering MAX_SEQ_LENGTH.")
    main()