In [None]:
# !jupyter nbconvert PMLM_distill.ipynb --to python

**Imports**
---

In [None]:
import os

# 💥 Set this BEFORE model/accelerator is created
os.environ["DEEPSPEED_USE_MPI"] = "false"
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1"
# os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

import random
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from accelerate import Accelerator

from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel,
)

from peft.tuners.lora import LoraModel, LoraLayer
from peft.utils import get_peft_model_state_dict
from deepspeed.accelerator import get_accelerator

**Installations**
---

In [None]:
# import sys
# !{sys.executable} -m pip install --no-cache-dir --upgrade bitsandbytes triton

In [None]:
# !pip install git+https://github.com/huggingface/transformers.git

**Config**
---

In [None]:
# Set random seed for reproducibility
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# Ensure HF_HOME is set explicitly before model download
os.environ["HF_HOME"] = "../huggingface_cache"
os.environ["HF_HUB_CACHE"] = "../huggingface_cache"

In [None]:
from huggingface_hub import login
login("INSERT_YOUR_OWN_TOKEN", add_to_git_credential=True)

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

**Utils**
---

In [None]:
def format_prompt(model_key: str, prompt: str) -> str:
    """Format prompt string based on model conventions for selected models only."""

    if "gpt2" in model_key:
        return prompt  # plain input, no special formatting

    elif "mistral" in model_key or "ministral" in model_key:
        return f"<s>[INST]{prompt}[/INST]"

    elif "llama3" in model_key:
        return f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n"

    elif "qwen" in model_key:
        return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant"

    elif "gemma" in model_key:
        return f"<start_of_turn>user\n{prompt}\n<end_of_turn>\n<start_of_turn>model\n"

    elif "internlm" in model_key:
        return f"<|User|>:{prompt}\n<|Bot|>:"

    elif "deepseek" in model_key:
        return f"### Instruction:\n{prompt}\n### Response:"

    elif "glm" in model_key:
        return f"[Round 1]\n\n问：{prompt}\n\n答："

    else:
        raise ValueError(f"Unknown model_key '{model_key}' in format_prompt.")

**Hparams**
---

In [None]:
batch_size = 4
max_length = 512
num_epochs = 1
learning_rate = 2e-5
gradient_accumulation_steps = 4

strategy_hint = (
        """
        Generate a [SYSTEM INSTRUCTION] based on the provided [USER REQUEST]. This [SYSTEM INSTRUCTION] will be combined 
        with the [USER REQUEST] and input into another language model to produce a watermarked output. 
        The [SYSTEM INSTRUCTION] should specify watermarking strategies that adapt dynamically to the content of the [USER REQUEST].
        Example [SYSTEM INSTRUCTION]: 'Use specific strategies to embed watermarks such as including special tokens or phrases that fit naturally with the content. The watermark should be later detectable by a classifier.'
        Example watermarking strategies:
        • Lexical Strategy: Incorporate specific rare or uncommon tokens as watermarks.
        • Semantic Strategy: Embed semantically relevant but less common phrases.
        • Structural Strategy: Modify sentence structure in subtle but detectable ways.
        • <You can add Strategies if necessary>
        Ensure watermarks are evenly distributed throughout the output.
        Your task is to output ONLY the [SYSTEM INSTRUCTION] that specifies the concrete watermarking strategy.
        """
    )

# Define the model names
MODEL_NAMES = {
    # Working PLM models
    "mistral_7b_v03_instruct": "mistralai/Mistral-7B-Instruct-v0.3",  #✅ Works
    
    # MLM models (teacher)
    "deepseek_llm_chat": "deepseek-ai/deepseek-llm-7b-chat",  # ✅ Works
    "qwen2.5_7b_instruct": "Qwen/Qwen2.5-7B-Instruct",  # ✅ Works
    "llama3_8b_instruct": "meta-llama/Meta-Llama-3-8B-Instruct",  # ✅ Works
    "gemma_7b_it": "google/gemma-7b-it",  # ✅ Works
    "ministral_8b_instruct": "mistralai/Ministral-8B-Instruct-2410",  #✅ Works
    "glm_4_9b_chat": "THUDM/glm-4-9b-chat",  # ✅ Works
    "internlm2.5_7b_chat": "internlm/internlm2-chat-7b",  # ✅ Works
    
    # Student models
    "mistral_7b_v02_instruct": "mistralai/Mistral-7B-Instruct-v0.2",  # ✅ Works
    "qwen1.5_1.8b_instruct": "Qwen/Qwen1.5-1.8B-Chat",  # ✅ Works
    "tinyllama_1.1b_chat": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # ✅ Works
    "gemma_1.1_2b_it": "google/gemma-1.1-2b-it",  # ✅ Works
}


# Tier	Model
# ⭐ Top-tier	LLaMA-3 8B Instruct
# ⭐ Top-tier	Qwen2.5-7B-Instruct
# ⭐ Top-tier	DeepSeek v2 7B-Chat
# ⭐ Mid-tier	Mistral-7B-Instruct-v0.3
# ⭐ Mid-tier	GEMMA-7B-IT
# ✅ Bonus	InternLM2.5-7B-Chat
# ✅ Bonus	GLM-4-9B-Chat

In [None]:
PLM_MODEL_KEY = "mistral_7b_v03_instruct"
TEACHER_MODEL_KEY = "mistral_7b_v03_instruct"
STUDENT_MODEL_KEY = "mistral_7b_v02_instruct"

In [None]:
# Initial list:
# ---
# GPT3.5-turbo-0125
# QWEN-plus
# LLAMA3-8B
# QWEN2.5-7B
# QWEN2-1.5B
# vicuna_7b_v1_3
# vicuna_7b_v1_5
# open_llama_3b
# open_llama_7b
# mistral_7b_v03
# mistral_7b_v03_instruct
# baize_v2_7b
# GLM-4-plus
# GLM-3-Turbo
# LLAMA3-8B
# GEMMA-7B
# GPT4o-mini
# GPT-4o
# DEEPSEEK v2
# CLAUDE-3.5-sonnet
# INTERNLM2.5-7B

**Deepspeed accelarate**
---

In [None]:
from accelerate.utils import DeepSpeedPlugin

deepspeed_plugin = DeepSpeedPlugin(
    zero_stage=3,
    offload_optimizer_device="cpu",
    offload_param_device="cpu",
    gradient_clipping=1.0
)

accelerator = Accelerator(
    mixed_precision="bf16",
    gradient_accumulation_steps=gradient_accumulation_steps,
    deepspeed_plugin=deepspeed_plugin
)

In [None]:
# Enable DeepSpeed zero3 if needed (for large models)
if accelerator.state.deepspeed_plugin is not None:
    print("DeepSpeed is enabled. Adjust configurations accordingly.")
else:
    print("DeepSpeed is not being used.") 

# Empty CUDA cache to free unused memory
torch.cuda.empty_cache()

# Reduce fragmentation issues
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
print(f"Using {accelerator.num_processes} GPUs.")

**Dataset**
---

In [None]:
class DistillationDataset(Dataset):
    def __init__(self, csv_dir, tokenizer):
        self.data = self.load_all_csvs(csv_dir)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def load_all_csvs(self, csv_dir):
        all_dfs = []
        for file in os.listdir(csv_dir):
            if file.endswith(".csv"):
                df = pd.read_csv(os.path.join(csv_dir, file))
                all_dfs.append(df)
        if not all_dfs:
            raise ValueError(f"No CSVs found in: {csv_dir}")
        return pd.concat(all_dfs, ignore_index=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        prompt = self.data.iloc[idx]["USER REQUEST"]
        response = self.data.iloc[idx]["NON-WATERMARKED RESPONSE"]

        full_text = prompt.strip() + "\n" + response.strip()
        tokenized = self.tokenizer(
            full_text, padding="max_length", truncation=True,
            max_length=self.max_length, return_tensors="pt"
        )

        input_ids = tokenized["input_ids"].squeeze(0)
        attention_mask = tokenized["attention_mask"].squeeze(0)

        # Mask out prompt tokens in labels
        with self.tokenizer.as_target_tokenizer():
            prompt_ids = self.tokenizer(prompt, truncation=True, max_length=self.max_length)["input_ids"]

        labels = input_ids.clone()
        labels[:len(prompt_ids)] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

def prepare_distillation_dataloader(csv_dir, tokenizer):
    dataset = DistillationDataset(csv_dir, tokenizer)
    # dataset = dataset.select(range(200)) # Testing
    
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

**Model Distillation**
---

In [None]:
def distill_with_lora(student_model, dataloader, accelerator):
    student_model = prepare_model_for_kbit_training(student_model)
    student_model.gradient_checkpointing_enable()
    student_model.enable_input_require_grads()
    student_model.config.use_cache = False

    if hasattr(student_model.config, "use_flash_attention_2"):
        student_model.config.use_flash_attention_2 = True

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]

    )

    student_model = get_peft_model(student_model, lora_config)
    print(student_model.print_trainable_parameters())
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)

    student_model, optimizer, dataloader = accelerator.prepare(student_model, optimizer, dataloader)
    student_model = torch.compile(student_model)
    student_model.train()
    
    best_loss = float("inf")
    patience = 2
    wait = 0

    for epoch in range(num_epochs):
        total_loss = 0
        optimizer.zero_grad()

        for step, batch in enumerate(dataloader):
            batch = {key: value.to(accelerator.device, non_blocking=True) for key, value in batch.items()}
            student_model.train(False)

            outputs = student_model(**batch)
            loss = outputs.loss / gradient_accumulation_steps
            total_loss += loss.detach().item()

            accelerator.backward(loss)

            if (step + 1) % 32 == 0 or step == len(dataloader) - 1:
                optimizer.step()
                get_accelerator().empty_cache()
                optimizer.zero_grad()
                torch.cuda.empty_cache()
                
            if accelerator.is_main_process and (step + 1) % 10 == 0:
                print(f"Epoch {epoch+1} | Step {step + 1}/{len(dataloader)} | Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(dataloader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch+1}. Best loss: {best_loss:.4f}")
                break
                
    return student_model

**Training: Distillation**
---

In [None]:
# BREAKPOINT_0

In [None]:
distilled_path = f"tuned_models/distilled_{TEACHER_MODEL_KEY}_to_{STUDENT_MODEL_KEY}"
os.makedirs(distilled_path, exist_ok=True)

In [None]:
# Load tokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAMES[TEACHER_MODEL_KEY], trust_remote_code=True)
student_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAMES[STUDENT_MODEL_KEY], trust_remote_code=True)

# Ensure a padding token exists
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token

# Load model
student_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAMES[STUDENT_MODEL_KEY],
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map=None,
    low_cpu_mem_usage=False
)

In [None]:
csv_dir = f"Datasets/10k/PLM_{PLM_MODEL_KEY}/{TEACHER_MODEL_KEY}_data/"
distillation_dataloader = prepare_distillation_dataloader(csv_dir, student_tokenizer)

In [None]:
distilled_model = distill_with_lora(student_model, distillation_dataloader, accelerator)

**Saving**
---

In [None]:
# Save just LoRA adapter weights
if accelerator.is_main_process:
    accelerator.unwrap_model(distilled_model).save_pretrained(distilled_path)
    student_tokenizer.save_pretrained(distilled_path)
    print(f"Distilled model saved to: {distilled_path}")