In [None]:
# ==============================================================================
# 0. SETUP AND INSTALLATIONS
# ==============================================================================
# Make sure these are run in your environment
!pip install -q --upgrade bitsandbytes
!pip install -q --upgrade transformers peft accelerate datasets trl einops

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
    AutoConfig,
)
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from tqdm import tqdm
import os
import random
import numpy as np
import gc
from dataclasses import dataclass
from typing import Dict, List, Any

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m62.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m35.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!nvidia-smi

Sun Jun 29 17:38:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   33C    P8             10W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip show torch

Name: torch
Version: 2.6.0+cu124
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /usr/local/lib/python3.11/dist-packages
Requires: filelock, fsspec, jinja2, networkx, nvidia-cublas-cu12, nvidia-cuda-cupti-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-runtime-cu12, nvidia-cudnn-cu12, nvidia-cufft-cu12, nvidia-curand-cu12, nvidia-cusolver-cu12, nvidia-cusparse-cu12, nvidia-cusparselt-cu12, nvidia-nccl-cu12, nvidia-nvjitlink-cu12, nvidia-nvtx-cu12, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, fastai, peft, sentence-transformers, timm, torchaudio, torchdata, torchvision


In [None]:
# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================
class Config:
    TEACHER_MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
    STUDENT_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    QA_DATASET_NAME = "squad"
    DATASET_SLICE = "train"
    NUM_EXAMPLES_FOR_DISTILLATION = 1000
    NUM_TRAIN_EXAMPLES = 800
    NUM_EVAL_EXAMPLES = 200

    DISTILL_DATA_DIR = "./distillation_squad_data_lazy"
    LOGITS_DIR = os.path.join(DISTILL_DATA_DIR, "teacher_logits")
    HIDDEN_STATES_DIR = os.path.join(DISTILL_DATA_DIR, "teacher_hidden_states")

    BATCH_SIZE = 2
    GRADIENT_ACCUMULATION_STEPS = 8
    LEARNING_RATE = 5e-5
    NUM_TRAIN_EPOCHS = 3
    WARMUP_STEPS = 50
    LOGGING_STEPS = 10
    ALPHA_HARD_LABEL = 0.7
    BETA_HIDDEN_STATE = 0.3
    LORA_R = 16
    LORA_ALPHA = 32
    LORA_DROPOUT = 0.05
    LORA_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "o_proj"]
    USE_4BIT_TEACHER = True
    USE_GRADIENT_CHECKPOINTING = True
    NORMALIZE_HIDDEN_STATES = True
    OUTPUT_DIR = "./distillation_output_phi3_tinyllama"
    SEED = 42
    # Define mapping from a key to (student_layer_index, teacher_layer_index)
    LAYER_MAPPING = {'layer_11': (11, 16), 'layer_21': (21, 31)}

os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
os.makedirs(Config.LOGITS_DIR, exist_ok=True)
os.makedirs(Config.HIDDEN_STATES_DIR, exist_ok=True)

# ==============================================================================
# 2. UTILITIES AND MODEL LOADING
# ==============================================================================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(Config.SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
def load_teacher_model():
    print("\n--- Loading Teacher Model ---")
    tokenizer = AutoTokenizer.from_pretrained(Config.TEACHER_MODEL_NAME, trust_remote_code=True)
    if Config.USE_4BIT_TEACHER:
        bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
        model = AutoModelForCausalLM.from_pretrained(Config.TEACHER_MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(Config.TEACHER_MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
    model.eval()
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

print("\n--- Loading Student Model & Teacher Config ---")
student_tokenizer = AutoTokenizer.from_pretrained(Config.STUDENT_MODEL_NAME, trust_remote_code=True)
student_model = AutoModelForCausalLM.from_pretrained(Config.STUDENT_MODEL_NAME, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token

teacher_config = AutoConfig.from_pretrained(Config.TEACHER_MODEL_NAME, trust_remote_code=True)


--- Loading Student Model & Teacher Config ---


In [None]:
# ==============================================================================
# 3. PREPARE STUDENT MODEL FOR LoRA
# ==============================================================================
print("\n--- Preparing Student Model for LoRA ---")
lora_config = LoraConfig(r=Config.LORA_R, lora_alpha=Config.LORA_ALPHA, lora_dropout=Config.LORA_DROPOUT, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=Config.LORA_TARGET_MODULES)
student_model = get_peft_model(student_model, lora_config)
student_model.print_trainable_parameters()

# ==============================================================================
# 4. GENERATE TEACHER OUTPUTS (OFFLINE DISTILLATION)
# ==============================================================================
def create_chat_prompt(tokenizer, question, context):
    messages = [{"role": "user", "content": f"Based on the following context, please answer the question.\n\nContext: {context}\n\nQuestion: {question}"}]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def generate_and_save_teacher_outputs_lazy(dataset_subset, teacher_model, teacher_tokenizer):
    manifest = []
    print("Generating teacher outputs and saving to individual files...")
    teacher_model.eval()
    for i, example in tqdm(enumerate(dataset_subset), total=len(dataset_subset)):
        question = example['question'].strip()
        context = example['context'].strip()
        ground_truth_answer = example['answers']['text'][0] if example['answers']['text'] else ""
        prompt = create_chat_prompt(teacher_tokenizer, question, context)

        # We only need the teacher's outputs for the answer part
        full_text = prompt + ground_truth_answer + teacher_tokenizer.eos_token
        inputs = teacher_tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True).to(teacher_model.device)
        prompt_len_tokens = teacher_tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True).input_ids.shape[1]

        with torch.no_grad():
            # Generate outputs for the full sequence
            teacher_outputs = teacher_model(**inputs, output_hidden_states=True)

        if teacher_outputs.hidden_states is None:
            raise ValueError("Teacher model did not return hidden states.")

        logits_path = os.path.join(Config.LOGITS_DIR, f"logits_{i}.pt")
        torch.save(teacher_outputs.logits[:, prompt_len_tokens-1:-1, :].cpu(), logits_path)

        hs_paths = {}
        for key, (_, teacher_layer_idx) in Config.LAYER_MAPPING.items():
            hs_path = os.path.join(Config.HIDDEN_STATES_DIR, f"hs_{key}_{i}.pt")
            # Hidden states corresponding to the tokens that produce the answer logits
            torch.save(teacher_outputs.hidden_states[teacher_layer_idx][:, prompt_len_tokens-1:-1, :].cpu(), hs_path)
            hs_paths[key] = hs_path

        manifest.append({
            "prompt": prompt,
            "ground_truth_answer": ground_truth_answer,
            "teacher_logits_path": logits_path,
            "teacher_hs_paths": hs_paths
        })

    return Dataset.from_list(manifest)


--- Preparing Student Model for LoRA ---
trainable params: 4,505,600 || all params: 1,104,553,984 || trainable%: 0.4079




In [None]:
# Check for the manifest dataset directory
if os.path.exists(Config.DISTILL_DATA_DIR) and len(os.listdir(Config.LOGITS_DIR)) >= Config.NUM_EXAMPLES_FOR_DISTILLATION:
    print(f"\nLoading cached distillation manifest from {Config.DISTILL_DATA_DIR}...")
    distillation_dataset = Dataset.load_from_disk(Config.DISTILL_DATA_DIR)
else:
    print("\n--- Preparing SQuAD Dataset for Distillation ---")
    teacher_model_instance, teacher_tokenizer_instance = load_teacher_model()
    squad_dataset = load_dataset(Config.QA_DATASET_NAME, split=Config.DATASET_SLICE)
    squad_dataset_subset = squad_dataset.shuffle(seed=Config.SEED).select(range(Config.NUM_EXAMPLES_FOR_DISTILLATION))
    distillation_dataset = generate_and_save_teacher_outputs_lazy(squad_dataset_subset, teacher_model_instance, teacher_tokenizer_instance)
    print(f"\nSaving distillation manifest to {Config.DISTILL_DATA_DIR}...")
    distillation_dataset.save_to_disk(Config.DISTILL_DATA_DIR)
    del teacher_model_instance, teacher_tokenizer_instance
    gc.collect()
    torch.cuda.empty_cache()

train_dataset = distillation_dataset.select(range(Config.NUM_TRAIN_EXAMPLES))
eval_dataset = distillation_dataset.select(range(Config.NUM_TRAIN_EXAMPLES, Config.NUM_TRAIN_EXAMPLES + Config.NUM_EVAL_EXAMPLES))


--- Preparing SQuAD Dataset for Distillation ---

--- Loading Teacher Model ---


modeling_phi3.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Generating teacher outputs and saving to individual files...


100%|██████████| 1000/1000 [16:10<00:00,  1.03it/s]


Saving distillation manifest to ./distillation_squad_data_lazy...





Saving the dataset (0/1 shards):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
# ==============================================================================
# 5. CUSTOM DATA COLLATOR
# ==============================================================================
@dataclass
class DistillationDataCollator:
    tokenizer: AutoTokenizer
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        prompts = [f["prompt"] for f in features]
        answers = [f["ground_truth_answer"] for f in features]

        # Tokenize student inputs
        full_texts = [p + a + self.tokenizer.eos_token for p, a in zip(prompts, answers)]
        student_inputs = self.tokenizer(
            full_texts,
            padding=True,
            truncation=True,
            max_length=1024,
            return_tensors="pt"
        )

        # Create labels, ignoring the prompt part for the loss calculation
        labels = student_inputs["input_ids"].clone()
        prompt_tokens_list = self.tokenizer(prompts, padding=False, truncation=True, max_length=1024)["input_ids"]
        for i, prompt_tokens in enumerate(prompt_tokens_list):
            labels[i, :len(prompt_tokens)] = -100 # -100 is the ignore_index for CrossEntropyLoss

        # Load teacher outputs from disk
        teacher_logits = [torch.load(f["teacher_logits_path"]) for f in features]
        teacher_hidden_states = [
            {key: torch.load(path) for key, path in f["teacher_hs_paths"].items()}
            for f in features
        ]

        return {
            "input_ids": student_inputs["input_ids"],
            "attention_mask": student_inputs["attention_mask"],
            "labels": labels,
            "teacher_logits": teacher_logits,
            "teacher_hidden_states": teacher_hidden_states
        }

In [None]:
# ==============================================================================
# 6. CUSTOM DISTILLATION TRAINER
# ==============================================================================
class DistillationTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        self.teacher_config = kwargs.pop("teacher_config")
        super().__init__(*args, **kwargs)

        # Regular setup
        self.alpha_hard = Config.ALPHA_HARD_LABEL
        self.beta_hidden = Config.BETA_HIDDEN_STATE
        self.normalize_hidden = Config.NORMALIZE_HIDDEN_STATES
        self.mse_loss_fn = nn.MSELoss(reduction='mean')

        # Initialize projection layers for matching hidden state dimensions
        self.projections = nn.ModuleDict()
        student_dim = self.model.config.hidden_size
        teacher_dim = self.teacher_config.hidden_size
        print(f"Student hidden dim: {student_dim}, Teacher hidden dim: {teacher_dim}")

        if student_dim != teacher_dim:
            for key in Config.LAYER_MAPPING.keys():
                print(f"Creating projection for {key} to map {student_dim} -> {teacher_dim}...")
                self.projections[key] = nn.Linear(student_dim, teacher_dim).to(
                    device=self.model.device,
                    dtype=torch.bfloat16  # The layer's weights are bfloat16
                )

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Pop the custom teacher data. They are not arguments for the student model.
        teacher_logits_list = inputs.pop("teacher_logits")
        teacher_hs_list_of_dicts = inputs.pop("teacher_hidden_states")

        # Get student outputs
        student_outputs = model(**inputs, output_hidden_states=True)

        # 1. Standard cross-entropy loss against ground truth labels
        hard_loss = student_outputs.loss

        # 2. NEW: Hidden State POOLING Loss
        batch_size = inputs["input_ids"].shape[0]
        labels = inputs.get("labels")
        student_hidden_states_all = student_outputs.hidden_states

        total_hidden_state_loss = 0.0
        num_samples_with_answers = 0

        # Loop through each sample in the batch
        for i in range(batch_size):
            # Find the tokens that correspond to the answer for this sample
            answer_mask = (labels[i] != -100)
            if answer_mask.sum() == 0:
                continue

            num_samples_with_answers += 1
            sample_layer_loss = 0.0

            # Loop over the layers we want to match
            for key, (student_layer_idx, _) in Config.LAYER_MAPPING.items():
                # a. Get student hidden states for the answer and POOL them
                student_hs_answer = student_hidden_states_all[student_layer_idx][i][answer_mask]
                student_pooled = student_hs_answer.mean(dim=0)  # Shape: [student_hidden_dim]

                # b. Get teacher hidden states for the answer and POOL them
                teacher_hs_answer = teacher_hs_list_of_dicts[i][key].squeeze(0).to(student_pooled.device)
                teacher_pooled = teacher_hs_answer.mean(dim=0)  # Shape: [teacher_hidden_dim]

                # c. Project the student's pooled representation if dimensions differ
                if key in self.projections:
                    student_pooled_projected = self.projections[key](student_pooled.to(torch.bfloat16))
                else:
                    student_pooled_projected = student_pooled

                # d. Cast to float32 for stable loss calculation
                student_for_loss = student_pooled_projected.to(torch.float32)
                teacher_for_loss = teacher_pooled.to(torch.float32)

                # e. Normalize if required (acts on the single pooled vector)
                if self.normalize_hidden:
                    student_for_loss = F.normalize(student_for_loss, p=2, dim=0)
                    teacher_for_loss = F.normalize(teacher_for_loss, p=2, dim=0)

                # f. Compute MSE loss for this layer
                sample_layer_loss += self.mse_loss_fn(student_for_loss, teacher_for_loss)

            # Average the loss over the matched layers for this sample
            if len(Config.LAYER_MAPPING) > 0:
                total_hidden_state_loss += (sample_layer_loss / len(Config.LAYER_MAPPING))

        # Average the hidden state loss over the samples in the batch
        final_hidden_state_loss = total_hidden_state_loss / num_samples_with_answers if num_samples_with_answers > 0 else torch.tensor(0.0, device=hard_loss.device)

        # 3. Final Combined Loss
        loss = (self.alpha_hard * hard_loss) + (self.beta_hidden * final_hidden_state_loss)

        return (loss, student_outputs) if return_outputs else loss

In [None]:
# ==============================================================================
# 7. SETUP AND RUN TRAINING
# ==============================================================================
training_args = TrainingArguments(
    output_dir=Config.OUTPUT_DIR,
    per_device_train_batch_size=Config.BATCH_SIZE,
    per_device_eval_batch_size=Config.BATCH_SIZE,
    gradient_accumulation_steps=Config.GRADIENT_ACCUMULATION_STEPS,
    learning_rate=Config.LEARNING_RATE,
    num_train_epochs=Config.NUM_TRAIN_EPOCHS,
    gradient_checkpointing=Config.USE_GRADIENT_CHECKPOINTING,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
    fp16=False,
    logging_steps=Config.LOGGING_STEPS,
    save_strategy="epoch",
    eval_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="tensorboard",
    remove_unused_columns=False,
)

data_collator = DistillationDataCollator(tokenizer=student_tokenizer)

trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=student_tokenizer,
    data_collator=data_collator,
    teacher_config=teacher_config,
)

  super().__init__(*args, **kwargs)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Student hidden dim: 2048, Teacher hidden dim: 3072
Creating projection for layer_11 to map 2048 -> 3072...
Creating projection for layer_21 to map 2048 -> 3072...


In [None]:
print("\nStarting student model distillation training...")
trainer.train()
print("Training complete!")


Starting student model distillation training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss
1,0.7465,0.057105
2,0.4755,0.048409
3,0.4086,0.047822


Training complete!


In [None]:
# ==============================================================================
# 8. SAVE FINAL MODEL AND EVALUATE
# ==============================================================================

best_adapter_path = '/content/distillation_output_phi3_tinyllama/checkpoint-150'

# Load the base student model again to merge the adapter
base_model = AutoModelForCausalLM.from_pretrained(Config.STUDENT_MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
distilled_model = PeftModel.from_pretrained(base_model, best_adapter_path)
distilled_model = distilled_model.merge_and_unload()
distilled_model.eval()
print("LoRA adapter merged successfully.")

def run_inference(prompt, model, tokenizer):
    # For TinyLlama, the chat template might differ, this is a generic way
    messages = [{"role": "user", "content": prompt}]
    inference_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(inference_prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=50, num_beams=4, early_stopping=True, pad_token_id=tokenizer.eos_token_id)
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # A simple way to extract just the new generation
    return full_response.split(prompt)[-1].strip()

test_prompts = [
    "Based on the following context, please answer the question.\n\nContext: The skin is the largest organ of the body, with a total area of about 20 square feet. The skin protects us from microbes and the elements, helps regulate body temperature, and permits the sensations of touch, heat, and cold.\n\nQuestion: what is the largest organ in the human body?",
    "Based on the following context, please answer the question.\n\nContext: Tim Berners-Lee, a British scientist, invented the World Wide Web (WWW) in 1989 while working at CERN. His vision of a global, interconnected information system revolutionized how we access and share information.\n\nQuestion: who is credited with inventing the world wide web?",
]

print("\n--- Running Inference for Comparison ---")
for prompt in test_prompts:
    print(f"\nPROMPT:\n{prompt}")
    print("-" * 30)
    distilled_response = run_inference(prompt, distilled_model, student_tokenizer)
    print(f"RESPONSE (Distilled Student):\n{distilled_response}")
    print("=" * 50)

LoRA adapter merged successfully.

--- Running Inference for Comparison ---

PROMPT:
Based on the following context, please answer the question.

Context: The skin is the largest organ of the body, with a total area of about 20 square feet. The skin protects us from microbes and the elements, helps regulate body temperature, and permits the sensations of touch, heat, and cold.

Question: what is the largest organ in the human body?
------------------------------
RESPONSE (Distilled Student):
<|assistant|>
skin

PROMPT:
Based on the following context, please answer the question.

Context: Tim Berners-Lee, a British scientist, invented the World Wide Web (WWW) in 1989 while working at CERN. His vision of a global, interconnected information system revolutionized how we access and share information.

Question: who is credited with inventing the world wide web?
------------------------------
RESPONSE (Distilled Student):
<|assistant|>
Tim Berners-Lee


In [None]:
!pip install -q evaluate

In [None]:
# ==============================================================================
# 9. COMPREHENSIVE EVALUATION
# ==============================================================================
import pandas as pd
import time
import evaluate
from datasets import load_dataset

# Ensure previous models are cleared from memory
# del base_model
gc.collect()
torch.cuda.empty_cache()

print("\n--- Starting Comprehensive Evaluation ---")

# --- Configuration for Evaluation ---
EVAL_DATASET_NAME = "squad"
EVAL_SPLIT = "validation[:100]" # Using a 100-example slice of the validation set for speed
EVAL_BATCH_SIZE = 4 # Batch size for performance measurement

# --- Helper Functions ---

def get_model_size_info(model, model_name):
    """Calculates model parameters and size on disk."""
    param_count = sum(p.numel() for p in model.parameters()) / 1e6  # in millions
    temp_dir = f"./temp_{model_name}"
    model.save_pretrained(temp_dir)
    disk_size = sum(f.stat().st_size for f in os.scandir(temp_dir) if f.is_file()) / (1024 ** 2) # in MB
    # Clean up
    import shutil
    shutil.rmtree(temp_dir)
    return param_count, disk_size

def measure_inference_metrics(model, tokenizer, dataset):
    """Measures latency, throughput, and peak GPU memory."""
    latencies = []
    model.eval()

    # Warm-up run
    prompt = create_chat_prompt(tokenizer, dataset[0]['question'], dataset[0]['context'])
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # Add use_cache=False for models that might have issues with dynamic caching during evaluation
    _ = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id, use_cache=False)

    # Measure memory
    torch.cuda.reset_peak_memory_stats(model.device)
    peak_memory_before = torch.cuda.max_memory_allocated(model.device)

    start_time = time.time()
    for example in tqdm(dataset, desc="Measuring speed"):
        prompt = create_chat_prompt(tokenizer, example['question'], example['context'])
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        iter_start = time.time()
        # Add use_cache=False for models that might have issues with dynamic caching during evaluation
        _ = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id, use_cache=False)
        iter_end = time.time()
        latencies.append(iter_end - iter_start)

    end_time = time.time()

    peak_memory_after = torch.cuda.max_memory_allocated(model.device)
    peak_memory_used = (peak_memory_after - peak_memory_before) / (1024 ** 2) # in MB

    avg_latency = np.mean(latencies) * 1000  # in ms
    throughput = len(dataset) / (end_time - start_time)  # examples/sec

    return avg_latency, throughput, peak_memory_used

def evaluate_qa_performance(model, tokenizer, dataset):
    """Computes Exact Match and F1 Score for Question Answering (Robust Version)."""
    squad_metric = evaluate.load("squad")
    predictions = []
    references = []

    for example in tqdm(dataset, desc="Evaluating QA Performance"):
        question = example['question']
        context = example['context']
        prompt = create_chat_prompt(tokenizer, question, context)

        inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(model.device)
        input_token_len = inputs.input_ids.shape[1]

        # Add use_cache=False for models that might have issues with dynamic caching during evaluation
        outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id, use_cache=False)

        # More robust way to get the generated text, avoiding string splitting
        generated_tokens = outputs[0, input_token_len:]
        prediction_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        # Handle cases where the model might not generate anything
        if not prediction_text:
            prediction_text = " " # SQuAD metric expects non-empty predictions

        predictions.append({'id': example['id'], 'prediction_text': prediction_text})
        references.append({'id': example['id'], 'answers': example['answers']})

    results = squad_metric.compute(predictions=predictions, references=references)
    return results['exact_match'], results['f1']


--- Starting Comprehensive Evaluation ---


In [None]:
# 1. Load evaluation dataset
eval_dataset = load_dataset(EVAL_DATASET_NAME, split=EVAL_SPLIT)

# 2. Evaluate the Distilled Student Model
print("\n--- Evaluating Distilled Student Model ---")
student_model = distilled_model # Use the merged model from the previous step
student_tokenizer = AutoTokenizer.from_pretrained(Config.STUDENT_MODEL_NAME)
if student_tokenizer.pad_token is None: student_tokenizer.pad_token = student_tokenizer.eos_token

student_params, student_disk_size = get_model_size_info(student_model, "student")


--- Evaluating Distilled Student Model ---


NameError: name 'create_chat_prompt' is not defined

In [None]:
student_latency, student_throughput, student_mem = measure_inference_metrics(student_model, student_tokenizer, eval_dataset)
student_em, student_f1 = evaluate_qa_performance(student_model, student_tokenizer, eval_dataset)

# 3. Clear memory and load the original Teacher model
del student_model
gc.collect()
torch.cuda.empty_cache()

Measuring speed: 100%|██████████| 100/100 [01:39<00:00,  1.01it/s]
Evaluating QA Performance: 100%|██████████| 100/100 [01:37<00:00,  1.02it/s]


In [None]:
print("\n--- Evaluating Original Teacher Model ---")
teacher_model, teacher_tokenizer = load_teacher_model() # This function loads the 4-bit quantized model

teacher_params, teacher_disk_size = get_model_size_info(teacher_model, "teacher")
teacher_latency, teacher_throughput, teacher_mem = measure_inference_metrics(teacher_model, teacher_tokenizer, eval_dataset)
teacher_em, teacher_f1 = evaluate_qa_performance(teacher_model, teacher_tokenizer, eval_dataset)


--- Evaluating Original Teacher Model ---

--- Loading Teacher Model ---




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Measuring speed: 100%|██████████| 100/100 [24:54<00:00, 14.94s/it]
Evaluating QA Performance: 100%|██████████| 100/100 [24:53<00:00, 14.94s/it]


In [None]:
# 4. Compile and Display Results
results_data = {
    "Metric": [
        "Parameters (Millions)", "Disk Size (MB)", "Avg. Latency (ms/ex)",
        "Throughput (ex/sec)", "Peak VRAM (MB)", "Exact Match (%)", "F1 Score (%)"
    ],
    "Teacher (Phi-3-mini)": [
        f"{teacher_params:.2f}", f"{teacher_disk_size:.2f}", f"{teacher_latency:.2f}",
        f"{teacher_throughput:.2f}", f"{teacher_mem:.2f}", f"{teacher_em:.2f}", f"{teacher_f1:.2f}"
    ],
    "Student (Distilled TinyLlama)": [
        f"{student_params:.2f}", f"{student_disk_size:.2f}", f"{student_latency:.2f}",
        f"{student_throughput:.2f}", f"{student_mem:.2f}", f"{student_em:.2f}", f"{student_f1:.2f}"
    ]
}

results_df = pd.DataFrame(results_data)
print("\n\n" + "="*50)
print("          VIGOROUS EVALUATION RESULTS")
print("="*50)
print(results_df.to_string(index=False))

# Calculate and print optimization achieved
param_reduction = 100 * (1 - student_params / teacher_params)
speed_increase = 100 * (student_throughput / teacher_throughput - 1)
f1_retention = 100 * (student_f1 / teacher_f1)

print("\n--- Summary of Optimization ---")
print(f"Parameter Reduction: {param_reduction:.2f}%")
print(f"Inference Speed-up: {speed_increase:.2f}%")
print(f"F1 Score Retention: {f1_retention:.2f}% of teacher's performance")
print("="*50)



          VIGOROUS EVALUATION RESULTS
               Metric Teacher (Phi-3-mini) Student (Distilled TinyLlama)
Parameters (Millions)              2009.14                       1100.05
       Disk Size (MB)              2320.30                       2098.20
 Avg. Latency (ms/ex)             14940.72                        990.10
  Throughput (ex/sec)                 0.07                          1.01
       Peak VRAM (MB)               221.01                         33.32
      Exact Match (%)                 1.00                         81.00
         F1 Score (%)                34.04                         83.40

--- Summary of Optimization ---
Parameter Reduction: 45.25%
Inference Speed-up: 1403.05%
F1 Score Retention: 244.97% of teacher's performance
