In [None]:
# -*- coding: utf-8 -*-
"""Untitled0.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1J0nRtRnRbNiPlDiVxOQcGqhBl0xtBTUX
"""

import torch
import json
import os
import matplotlib.pyplot as plt
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    # BitsAndBytesConfig, # <-- has been removed
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from datasets import load_dataset, concatenate_datasets
from transformers.trainer_callback import TrainerCallback
# 【fixed】load PeftModel
from peft import LoraConfig, get_peft_model, PeftModel
from tqdm import tqdm
import warnings
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Suppress warnings
warnings.filterwarnings("ignore")

# --- 1. Configuration ---
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
HOTPOT_DATASET_NAME = "hotpot_qa"
HOTPOT_DATASET_CONFIG = "distractor"
MATH_DATASET_NAME = "qwedsacf/competition_math"
RESULTS_DIR = "./drive/MyDrive/"


# 【checkpoint  1 - control group】
JOINT_ADAPTER_PATH = os.path.join(RESULTS_DIR, "joint_adapter_llama_fp32")

# 【checkpoint  2 - experiment group】
TASK_A_ADAPTER_PATH = os.path.join(RESULTS_DIR, "math_adapter_llama_fp32")


# --- VRAM-Saving Config ---
MAX_SEQ_LENGTH = 2048
#   Reduced batch size to fit larger models into limited VRAM
PER_DEVICE_BS = 64
GRAD_ACC_STEPS = 1 

# --- Experiment Config ---
N_TRAIN_EXAMPLES = 4000
N_VAL_EXAMPLES = 400
JOINT_EPOCHS = 2
TASK_A_EPOCHS = 2 # For Math
TASK_B_EPOCHS = 2 # For HotpotQA
MIX_RATIOS = [0.9,0.8, 0.5]  # youcan change this to other ratios like 0.7, 0.3 to do more experiments

# --- 2. Utility Functions (Data Formatting - Llama Chat Style) ---
def format_hotpot_qa(example):
    """Formats HotpotQA data into a Llama-chat-style prompt."""
    context = " ".join(["".join(s) for s in example["context"]["sentences"]])
    question = example["question"]
    answer = example["answer"]

    text = (
        f"<s>[INST] You are a helpful assistant. Use the following context to "
        f"answer the question. Context: {context}\n\nQuestion: {question} [/INST] "
        f"Answer: {answer}</s>"
    )
    return text

def format_math(example):
    """Formats MATH data into a Llama-chat-style prompt."""
    problem = example["problem"]
    solution = example["solution"]

    text = (
        f"<s>[INST] You are a math expert. Solve the following math problem. "
        f"Show your work.\nProblem: {problem} [/INST] "
        f"Solution: {solution}</s>"
    )
    return text

def filter_by_length(example, tokenizer, formatter):
    """
    only check the length. return True (keep) or False (discard).
    """
    text = formatter(example)
    tokenized = tokenizer(text, max_length=MAX_SEQ_LENGTH + 1, truncation=False, padding=False)
    return len(tokenized['input_ids']) <= MAX_SEQ_LENGTH

def preprocess(example, tokenizer, formatter):
    """
    【fixed】
    Format text, apply loss mask, and pad to max length.
    """
    text = formatter(example)
    tokenized = tokenizer(
        text,
        max_length=MAX_SEQ_LENGTH,
        truncation=True,
        padding="max_length", # fix ValueError
    )
    labels = tokenized["input_ids"].copy()
    inst_token_id = tokenizer.convert_tokens_to_ids("]")

    split_point = -1
    for i in range(len(tokenized["input_ids"]) - 1, -1, -1):
        if tokenized["input_ids"][i] == inst_token_id:
            split_point = i + 1
            break

    if split_point == -1:
        return {}

    for i in range(split_point):
        labels[i] = -100

    tokenized["labels"] = labels
    return tokenized

# --- 3. Model Loading (【refactored】) ---

def get_model_and_tokenizer_base():
    """
    Only load the FP16 TinyLlama base model and tokenizer.
    """
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16, # <-- use FP16 (save RAM), but training will be FP32
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Enable gradient checkpointing on the base model
    model.gradient_checkpointing_enable()

    return model, tokenizer

def get_lora_config():
    """
    Only define the LoRA configuration.
    """
    return LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
        ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
def manual_evaluate(model, dataloader, device):
    """
    Manually run evaluation on the given dataloader.
    """
    model.eval()  # <--- Set to evaluation mode
    total_loss = 0
    total_steps = 0
    with torch.no_grad(): # <--- Disable gradient calculation
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            # Move batch to the device where the model is located
            batch = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask", "labels"]}

            outputs = model(**batch)
            loss = outputs.loss

            total_loss += loss.item()
            total_steps += 1

    model.train() # <--- 【important】Set the model back to training mode
    return total_loss / total_steps

# --- 4. Main Experiment Logic ---
def main(ratio):
    if not os.path.exists(RESULTS_DIR):
        os.makedirs(RESULTS_DIR)

    print(f"--- Loading Base Model & Tokenizer ---")
    base_model, tokenizer = get_model_and_tokenizer_base()

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # --- Load and Process Datasets ---
    print(f"\n--- Loading and Preprocessing Datasets (This may take a while) ---")

    # Task B: HotpotQA
    raw_hotpot = load_dataset(HOTPOT_DATASET_NAME, HOTPOT_DATASET_CONFIG)
    hotpot_train = raw_hotpot["train"].shuffle(seed=42).select(range(N_TRAIN_EXAMPLES))
    hotpot_val = raw_hotpot["validation"].shuffle(seed=42).select(range(N_VAL_EXAMPLES))

    print(f"Tokenizing and filtering HotpotQA...")
    hotpot_train_tokenized = hotpot_train.filter(
        lambda x: filter_by_length(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    hotpot_val_tokenized = hotpot_val.filter(
        lambda x: filter_by_length(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_hotpot_qa),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    print(f"HotpotQA: {len(hotpot_train_tokenized)} train, {len(hotpot_val_tokenized)} val (after filtering)")

    # Task A: MATH
    raw_math = load_dataset(MATH_DATASET_NAME)
    total_math_samples_needed = N_TRAIN_EXAMPLES + N_VAL_EXAMPLES
    math_subset = raw_math["train"].shuffle(seed=42).select(range(total_math_samples_needed))
    val_size_fraction = N_VAL_EXAMPLES / total_math_samples_needed
    math_splits = math_subset.train_test_split(test_size=val_size_fraction, seed=42)
    math_train = math_splits["train"]
    math_val = math_splits["test"]

    print(f"Tokenizing and filtering MATH...")
    math_train_tokenized = math_train.filter(
        lambda x: filter_by_length(x, tokenizer, format_math),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_math),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    math_val_tokenized = math_val.filter(
        lambda x: filter_by_length(x, tokenizer, format_math),
        batched=False,
    ).map(
        lambda x: preprocess(x, tokenizer, format_math),
        batched=False,
    ).filter(lambda example: len(example) > 0)

    print(f"MATH: {len(math_train_tokenized)} train, {len(math_val_tokenized)} val (after filtering)")

     # Build mixed train dataset: keep total training size = N_TRAIN_EXAMPLES
    math_count = int(round(N_TRAIN_EXAMPLES * ratio))
    hotpot_count = max(1, N_TRAIN_EXAMPLES - math_count)  # ensure at least 1 from hotpot

    # sample from tokenized datasets (they are datasets objects)
    sampled_hotpot1 = hotpot_train_tokenized.shuffle(seed=42).select(range(min(len(hotpot_train_tokenized), hotpot_count)))
    sampled_math1 = math_train_tokenized.shuffle(seed=42).select(range(min(len(math_train_tokenized), math_count)))
    sampled_hotpot2 = hotpot_train_tokenized.shuffle(seed=42).select(range(min(len(hotpot_train_tokenized), math_count)))
    sampled_math2 = math_train_tokenized.shuffle(seed=42).select(range(min(len(math_train_tokenized), hotpot_count)))
    # concatenate to one training set and shuffle
    math_train_tokenized = concatenate_datasets([sampled_hotpot1, sampled_math1]).shuffle(seed=42)

    # concatenate to one training set and shuffle
    hotpot_train_tokenized = concatenate_datasets([sampled_hotpot2, sampled_math2]).shuffle(seed=42)

    print(f"  Phase1 training: {len(math_train_tokenized)} samples (Hotpot: {len(sampled_hotpot1)}, MATH: {len(sampled_math1)})")
    print(f"  Phase2 training: {len(hotpot_train_tokenized)} samples (Hotpot: {len(sampled_hotpot2)}, MATH: {len(sampled_math2)})")



    # --- 【reversed】Experiment 2: Sequential Training (CF) [MATH -> HotpotQA] ---
    print(f"\n--- Starting Experiment 2: Sequential Training (CF) [MATH -> HotpotQA] ---")

    # --- Phase 1: Train on MATH (or load from checkpoint) ---
    '''if os.path.exists(os.path.join(TASK_A_ADAPTER_PATH, "adapter_model.safetensors")):
        print(f"--- Found existing Task A (MATH) adapter. Loading from {TASK_A_ADAPTER_PATH} ---")
        seq_model = PeftModel.from_pretrained(base_model, TASK_A_ADAPTER_PATH)
        print("Adapter loaded successfully.")

    else:'''
    print(f"--- No adapter found. Starting Phase 1: Training on Task A (MATH) ---")
    lora_config = get_lora_config()
    seq_model = get_peft_model(base_model, lora_config)
    seq_model.print_trainable_parameters()

    seq_args_a = TrainingArguments(
            output_dir=os.path.join(RESULTS_DIR, "seq_training_A"),
            per_device_train_batch_size=PER_DEVICE_BS,
            gradient_accumulation_steps=GRAD_ACC_STEPS,
            num_train_epochs=TASK_A_EPOCHS,
            learning_rate=2e-4,
            # fp16=True,    # <-- 【BUG fix】delete this code
            logging_steps=10,
            save_strategy="no",
            report_to="none",
            gradient_checkpointing=True,
    )

    seq_trainer_a = Trainer(
            model=seq_model,
            args=seq_args_a,
            train_dataset=math_train_tokenized, # <-- Train on MATH
            eval_dataset=math_val_tokenized,
            data_collator=data_collator,
    )

    seq_trainer_a.train()

    print(f"--- Phase 1 (MATH) training complete. Saving adapter to {TASK_A_ADAPTER_PATH} ---")
    seq_model.save_pretrained(TASK_A_ADAPTER_PATH)
    print("Adapter saved.")

    del seq_trainer_a
    torch.cuda.empty_cache()

     # --- Evaluate the "Task A Expert" model (whether trained or loaded) ---
    print("\n--- Evaluating Model after Phase 1 (Task A Expert) ---")
    eval_args = TrainingArguments(
        output_dir=os.path.join(RESULTS_DIR, "eval_temp"),
        per_device_eval_batch_size=PER_DEVICE_BS,
        # fp16=True, # <-- 【BUG fix】delete this code
        report_to="none",
        gradient_checkpointing=True,
    )


    eval_trainer = Trainer(
        model=seq_model,
        args=eval_args,
        data_collator=data_collator,
    )

    eval_hotpot_phase1 = eval_trainer.evaluate(eval_dataset=hotpot_val_tokenized)
    print(f"  > Task B Expert - HotpotQA Val Loss: {eval_hotpot_phase1['eval_loss']:.4f}")
    eval_math_phase1 = eval_trainer.evaluate(eval_dataset=math_val_tokenized)
    print(f"  > Task A Expert - MATH Val Loss: {eval_math_phase1['eval_loss']:.4f}")
    del eval_trainer, eval_args
    torch.cuda.empty_cache()
    # --- Phase 2: Train on HotpotQA (Forgetting MATH happens here) ---
    print(f"\n  --- Phase 2: Training on Task B (HotpotQA) ---")
    history = {"steps": [], "hotpot_loss": [], "math_loss": []}
    # Custom Trainer to log forgetting
    class ForgettingTrackerCallback(TrainerCallback):
      def __init__(self, hotpot_val, math_val, history_log, start_metrics):
          super().__init__()
          self.hotpot_eval_dataset = hotpot_val
          self.math_eval_dataset = math_val
          self.history = history_log
          self.trainer = None
          # --- 【fix】---
          # Add a "lock" to prevent infinite recursion
          self.is_evaluating = False
          # ----------------
          # Record initial state (Step 0)
          self.history["steps"].append(0)
          self.history["hotpot_loss"].append(start_metrics['hotpot_loss'])
          self.history["math_loss"].append(start_metrics['math_loss'])
          print("Initializing ForgettingTrackerCallback with starting metrics.")
      def set_trainer(self, trainer):
          """在 Trainer 例化后, 注入对它的引用。"""
          self.trainer = trainer
          print("Trainer reference set in callback.")

      def on_log(self, args, state, control, **kwargs):
          """Called when 'logging_steps' is triggered."""
          # --- 【fix 1】---
          # If we are already in this function (due to recursive calls), exit immediately.
          if self.is_evaluating:
              return
          # --- 【fix 2】---
          # "Acquire" the lock
          self.is_evaluating = True
          # Ensure the trainer reference has been set
          if not self.trainer:
              print("WARNING: Trainer reference not set in callback, skipping eval.")
              self.is_evaluating = False # <-- Don't forget to release the lock here
              return
          print(f"\n--- Custom Eval at Step {state.global_step} ---")
          print("Evaluating on Task B (HotpotQA)...")
          # Use the trainer's evaluate method
          hotpot_metrics = self.trainer.evaluate(eval_dataset=self.hotpot_eval_dataset)
          hotpot_loss = hotpot_metrics['eval_loss']
          print(f"  > Step {state.global_step} - HotpotQA Val Loss: {hotpot_loss:.4f} (LEARNING?)")
          print("Evaluating on Task A (MATH)...")
          math_metrics = self.trainer.evaluate(eval_dataset=self.math_eval_dataset)
          math_loss = math_metrics['eval_loss']
          print(f"  > Step {state.global_step} - MATH Val Loss: {math_loss:.4f} (FORGETTING?)")
          self.history["steps"].append(state.global_step)
          self.history["hotpot_loss"].append(hotpot_loss)
          self.history["math_loss"].append(math_loss)
          # --- 【fix 3】---
          # "Release" the lock, so that the next on_log can run
          self.is_evaluating = False
          self.trainer.model.train()


    seq_args_b = TrainingArguments(
        output_dir=os.path.join(RESULTS_DIR, "seq_training_B"),
        per_device_train_batch_size=PER_DEVICE_BS,
        gradient_accumulation_steps=GRAD_ACC_STEPS,
        num_train_epochs=TASK_B_EPOCHS,
        learning_rate=7e-5,
        logging_steps=10,
        save_strategy="no",
        report_to=[],         # <-- Keep this setting
        # disable_tqdm=True,  # <-- Keep this setting
        gradient_checkpointing=True,
    )
    seq_model.enable_input_require_grads()
    # 【fix 2】: Instantiate a *new* Callback
    tracker_callback = ForgettingTrackerCallback(
        hotpot_val=hotpot_val_tokenized,
        math_val=math_val_tokenized,
        history_log=history,
        start_metrics={
            'hotpot_loss': eval_hotpot_phase1['eval_loss'],
            'math_loss': eval_math_phase1['eval_loss'],
        }
    )
    # 【fix 3】: Instantiate a *standard* Trainer, and pass in the callback
    seq_trainer_b = Trainer(
        model=seq_model,
        args=seq_args_b,
        train_dataset=hotpot_train_tokenized,
        eval_dataset=hotpot_val_tokenized,
        data_collator=data_collator,
        callbacks=[tracker_callback]  # <-- Pass the callback here
    )
    # 【fix 4】: Link the trainer instance back to the callback
    # (The callback needs this reference to call self.trainer.evaluate())
    tracker_callback.set_trainer(seq_trainer_b)
    seq_trainer_b.train()

    # --- 5. Plot Results ---
    print("\n--- Saving History Data and Generating Plot ---")

    # --- Save history data to JSON ---
    history_filename = os.path.join(RESULTS_DIR, f"forgetting_curve_MATH_mix_{int(ratio*100)}.json")
    try:
        with open(history_filename, 'w') as f:
            json.dump(history, f, indent=4)
        print(f"History data saved to {history_filename}")
    except Exception as e:
        print(f"Error saving history to JSON: {e}")
    # --- [END] ---

    plt.figure(figsize=(12, 6))
    plt.plot(history["steps"], history["hotpot_loss"], 'o-', label="Task B (HotpotQA) Loss", color="blue")
    plt.plot(history["steps"], history["math_loss"], 'o-', label="Task A (MATH) Loss", color="red")

    plt.title(f"Forgetting Curve (mix {int(ratio*100)}% MATH in Phase2)")
    plt.xlabel(f"Training Steps on Task B (HotpotQA) (Total Epochs: {TASK_B_EPOCHS})")
    plt.ylabel("Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.yscale('log')
    plt.tight_layout()

    plot_filename = os.path.join(RESULTS_DIR, f"forgetting_curve_MATH_mix_{int(ratio*100)}.png")
    plt.savefig(plot_filename)
    print(f"Plot saved to {plot_filename}")
    
    # cleanup GPU memory between runs
    del seq_trainer_b
    torch.cuda.empty_cache()

    try:
        from google.colab import files
        plt.show()
    except ImportError:
        print("Not in Colab, plot saved to file.")

if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("ERROR: This experiment requires a GPU. Check Colab runtime type.")
    else:
        print(f"INFO: Running on GPU. VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        if torch.cuda.get_device_properties(0).total_memory / 1e9 < 11:
            print("WARNING: VRAM is less than 11GB. You may hit OOM errors. Try lowering MAX_SEQ_LENGTH.")
    for ratio in MIX_RATIOS:
        main(ratio)