# Lab 5 – Distilling a Pre-Trained LLM with Unsloth (SQuAD)

> **⚠️ IMPORTANT**: This lab requires **Google Colab with GPU enabled**
> - Go to Runtime → Change runtime type → GPU (T4 or better)
> - Unsloth requires CUDA and will not work on Mac/Windows locally
> - See `COLAB_SETUP.md` for detailed setup instructions

In this lab, you will perform **model distillation** using Unsloth. Distillation allows you to compress a large "teacher" model into a smaller "student" model while retaining much of the original model's performance. We'll use the SQuAD dataset for a question-answering task to illustrate this process.

## Why Distillation? The Knowledge Transfer Problem

**The Challenge:**
- 🏫 **Large Models**: GPT-4, LLaMA-70B, Claude-3 are incredibly powerful but HUGE
- 💰 **Deployment Costs**: Large models = expensive inference, high memory requirements
- 📱 **Edge Deployment**: Can't run 70B models on phones, edge devices, or in real-time
- ⚡ **Speed Requirements**: Production systems need fast, responsive models

**The Solution - Knowledge Distillation:**
- 🎓 **Teacher Model**: Large, powerful model (e.g., 7B parameters)
- 🎓 **Student Model**: Smaller, faster model (e.g., 1B parameters)
- 🧠 **Knowledge Transfer**: Student learns from teacher's "soft" predictions
- ⚖️ **Trade-off**: Slight accuracy loss for massive speed/memory gains

**Real-World Applications:**
- 📱 **Mobile Apps**: ChatGPT on your phone uses distilled models
- 🚗 **Autonomous Vehicles**: Real-time decision making requires fast models
- 💬 **Customer Service**: Chatbots need to respond quickly
- 🔍 **Search Engines**: Instant results require optimized models

## The Distillation Process

**Step 1: Teacher Knowledge**
- Large model makes predictions with "soft" probabilities
- Example: [0.7, 0.2, 0.1] instead of [1, 0, 0] (hard labels)

**Step 2: Student Learning**
- Small model learns to mimic teacher's soft predictions
- Uses temperature scaling to make learning easier
- Combines teacher knowledge with ground truth labels

**Step 3: Deployment**
- Student model is much smaller and faster
- Retains most of teacher's knowledge
- Perfect for production deployment

## Objectives

- **Understand the distillation process** and why it's valuable
- **Evaluate baseline performance** of teacher and student models
- Load a pre-trained teacher model and prepare a smaller student model
- Load and preprocess the SQuAD dataset for question answering
- **Implement knowledge distillation** with proper temperature scaling
- Fine-tune the student model with LoRA/QLoRA adapters using Unsloth
- **Compare performance** after distillation (accuracy vs speed trade-offs)
- Evaluate and compare the teacher and student models on accuracy and inference speed
- **Analyze the trade-offs**: How much knowledge is transferred vs lost?

**Note:** Distillation requires significant compute resources. Use Google Colab Pro for faster training, or reduce the dataset size if using free tier.

In [None]:
# Install Unsloth using the official auto-install script
# This automatically detects your environment and installs the correct version
!wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -

# Alternative manual installation if auto-install fails:
!pip install --upgrade pip
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install "unsloth_zoo @ git+https://github.com/unslothai/unsloth-zoo.git"
from unsloth import FastLanguageModel
print("✅ Unsloth installation complete! Now restart runtime before proceeding.")
print("⚠️ IMPORTANT: Use GPU runtime, not TPU! Unsloth requires CUDA GPU.")

### Step 1: Load SQuAD dataset

**Documentation:**
- Hugging Face Datasets: https://huggingface.co/docs/datasets/
- Loading datasets: https://huggingface.co/docs/datasets/loading
- SQuAD dataset: https://huggingface.co/datasets/squad


In [None]:
# 1️⃣ Load SQuAD dataset using the `datasets` library

from datasets import load_dataset

# Load the train and validation splits of SQuAD v1 or v2 (use only a subset for quicker experiments)
dataset = load_dataset('squad', split='train[:10%]')
dataset_val = load_dataset('squad', split='validation[:5%]')

# Inspect a sample
print(dataset[0])

# Tokenization function for question-answering tasks
from transformers import AutoTokenizer

teacher_model_name = "unsloth/Qwen2.5-7B-Instruct-bnb-4bit"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

max_length = 512

def preprocess_function(examples):
    inputs = [q.strip() + " " + c.strip() for q, c in zip(examples['question'], examples['context'])]
    model_inputs = tokenizer(inputs, max_length=max_length, padding='max_length', truncation=True)
    return model_inputs

# Apply tokenization
train_dataset = dataset.map(preprocess_function, batched=True)
val_dataset = dataset_val.map(preprocess_function, batched=True)

print("Tokenized dataset ready for training.")


### Step 2: Setup teacher and student models for distillation

**Documentation:**
- Unsloth docs: https://docs.unsloth.ai
- **Example Notebooks**:
  - [Qwen 2.5 (7B) Fine-tuning with LoRA](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen2.5_(7B)-Alpaca.ipynb)
  - [Qwen 2.5 Conversational Style](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
  - [All Unsloth notebooks](https://github.com/unslothai/notebooks)
- PEFT LoRA: https://huggingface.co/docs/peft/conceptual_guides/lora
- LoraConfig: https://huggingface.co/docs/peft/package_reference/lora


In [None]:
# 2️⃣ Setup teacher and student models for distillation

# CRITICAL: Import unsloth FIRST to avoid weights/biases initialization errors
from unsloth import FastLanguageModel
import torch

# Load the teacher model
teacher_model, _ = FastLanguageModel.from_pretrained(
    model_name=teacher_model_name,
    dtype=torch.float16,
    device_map="auto"
)

# Define your student model architecture; choose a smaller model
student_model_name = "unsloth/Qwen2.5-3B-Instruct-bnb-4bit"  # Smaller Qwen model for student
student_model, _ = FastLanguageModel.from_pretrained(
    model_name=student_model_name,
    dtype=torch.float16,
    device_map="auto"
)

# Apply LoRA adapters to the student model using Unsloth
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare student model for efficient training
student_model = prepare_model_for_kbit_training(student_model)

# Configure LoRA
lora_config = LoraConfig(
    r=16,                          # LoRA rank
    lora_alpha=32,                 # LoRA scaling factor
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Attention layers
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA to student model
student_model = get_peft_model(student_model, lora_config)
student_model.print_trainable_parameters()

# FIXED: Create collate_fn that includes labels for evaluation
def collate_fn(batch):
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    attention_mask = torch.tensor([item['attention_mask'] for item in batch])
    # For language modeling, labels are the same as input_ids (shifted internally by the model)
    labels = input_ids.clone()
    return {
        'input_ids': input_ids, 
        'attention_mask': attention_mask,
        'labels': labels
    }

# Create dataloaders
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

print("✅ Teacher and student models loaded and configured!")
print(f"Teacher: {teacher_model_name}")
print(f"Student: {student_model_name}")

# CRITICAL: Configure student model for proper training (prevents EmptyLogits)
student_model.config.use_cache = False  # Disable cache for training
student_model.gradient_checkpointing_enable()  # Enable gradient checkpointing

# Ensure model is properly configured for generation
student_model.config.pad_token_id = student_model.config.eos_token_id
student_model.config.use_cache = False

# Set model to training mode and ensure proper forward pass
student_model.train()

# Verify model configuration
print("✅ Student model configured for distillation training")
print(f"Model device: {next(student_model.parameters()).device}")
print(f"Model dtype: {next(student_model.parameters()).dtype}")

In [None]:
# 3️⃣ Evaluate baseline performance BEFORE distillation

import time

print("📊 BASELINE EVALUATION (Before Distillation)")
print("=" * 50)

# Function to evaluate model performance on SQuAD with generation
def evaluate_squad_performance(model, dataset, tokenizer, model_name, num_samples=50):
    """Evaluate model performance on SQuAD question answering using generation"""
    model.eval()
    correct = 0
    total = 0
    inference_times = []
    total_tokens_generated = 0
    
    print(f"\n🔍 Evaluating {model_name}...")
    print(f"   Evaluating on {num_samples} validation samples...")
    
    with torch.no_grad():
        for idx in range(min(num_samples, len(dataset))):
            example = dataset[idx]
            
            # Get question, context, and answer
            question = example['question']
            context = example['context']
            ground_truth = example['answers']['text'][0] if example['answers']['text'] else ""
            
            # Create prompt for QA
            prompt = f"Question: {question}\nContext: {context}\nAnswer:"
            inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # Measure inference time
            start_time = time.time()
            
            # Generate answer
            outputs = model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
            
            elapsed = time.time() - start_time
            inference_times.append(elapsed)
            
            # Count tokens generated
            tokens_generated = outputs.shape[1] - inputs['input_ids'].shape[1]
            total_tokens_generated += tokens_generated
            
            # Decode and check if answer is in generated text
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_answer = generated.split("Answer:")[-1].strip().lower()
            
            # Simple substring match (simplified F1 for demo)
            if ground_truth.lower() in generated_answer or generated_answer in ground_truth.lower():
                correct += 1
            
            total += 1
    
    accuracy = (correct / total * 100) if total > 0 else 0
    avg_inference_time = sum(inference_times) / len(inference_times) if inference_times else 0
    total_time = sum(inference_times)
    tokens_per_sec = total_tokens_generated / total_time if total_time > 0 else 0
    
    print(f"   ✓ Evaluated {total} samples: {correct} correct ({accuracy:.1f}%)")
    
    return {
        'accuracy': accuracy,
        'avg_inference_time': avg_inference_time,
        'tokens_per_sec': tokens_per_sec,
        'samples_per_sec': total / total_time if total_time > 0 else 0,
        'total_samples': total,
        'correct': correct
    }

# Create validation dataloader for evaluation
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Evaluate teacher model baseline
print("🎓 Evaluating Teacher Model (Large, Powerful)...")
teacher_results = evaluate_squad_performance(teacher_model, dataset_val, tokenizer, "Teacher Model (Before Distillation)", num_samples=50)

print(f"\n📈 TEACHER MODEL RESULTS:")
print(f"  - Accuracy: {teacher_results['accuracy']:.1f}% ({teacher_results['correct']}/{teacher_results['total_samples']} correct)")
print(f"  - Avg inference time: {teacher_results['avg_inference_time']*1000:.1f}ms per sample")
print(f"  - Throughput: {teacher_results['tokens_per_sec']:.1f} tokens/sec")
print(f"  - Samples/second: {teacher_results['samples_per_sec']:.2f}")

# Calculate teacher model size (use base model, not PEFT wrapper)
base_teacher = teacher_model.base_model if hasattr(teacher_model, 'base_model') else teacher_model
teacher_params = sum(p.numel() for p in base_teacher.parameters())
teacher_trainable = sum(p.numel() for p in base_teacher.parameters() if p.requires_grad)

print(f"\n💾 TEACHER MODEL SIZE:")
print(f"  - Total parameters: {teacher_params:,}")
print(f"  - Trainable parameters: {teacher_trainable:,}")
print(f"  - Model size: ~{teacher_params * 2 / 1024**2:.1f} MB (FP16)")

# Evaluate student model baseline (before distillation)
print("\n🎓 Evaluating Student Model (Small, Fast)...")
student_results = evaluate_squad_performance(student_model, dataset_val, tokenizer, "Student Model (Before Distillation)", num_samples=50)

print(f"\n📈 STUDENT MODEL RESULTS (Before Distillation):")
print(f"  - Accuracy: {student_results['accuracy']:.1f}% ({student_results['correct']}/{student_results['total_samples']} correct)")
print(f"  - Avg inference time: {student_results['avg_inference_time']*1000:.1f}ms per sample")
print(f"  - Throughput: {student_results['tokens_per_sec']:.1f} tokens/sec")
print(f"  - Samples/second: {student_results['samples_per_sec']:.2f}")

# Calculate student model size (use base model to get actual size)
base_student = student_model.base_model if hasattr(student_model, 'base_model') else student_model
student_total_params = sum(p.numel() for p in base_student.parameters())
student_trainable = sum(p.numel() for p in student_model.parameters() if p.requires_grad)

print(f"\n💾 STUDENT MODEL SIZE:")
print(f"  - Total parameters: {student_total_params:,}")
print(f"  - Trainable parameters (LoRA): {student_trainable:,}")
print(f"  - Model size: ~{student_total_params * 2 / 1024**2:.1f} MB (FP16)")

# Calculate size and speed differences
size_ratio = student_total_params / teacher_params
speed_ratio = teacher_results['avg_inference_time'] / student_results['avg_inference_time']
accuracy_diff = teacher_results['accuracy'] - student_results['accuracy']

print(f"\n⚖️ BASELINE COMPARISON:")
print(f"  - Size ratio: {size_ratio:.2f}x ({student_total_params/1e9:.1f}B vs {teacher_params/1e9:.1f}B params)")
print(f"  - Size reduction: {(1-size_ratio)*100:.1f}%")
print(f"  - Speed ratio: {speed_ratio:.2f}x {'faster' if speed_ratio > 1 else 'slower'}")
print(f"  - Accuracy difference: {accuracy_diff:.1f} percentage points")

if speed_ratio < 1:
    print(f"\n⚠️ NOTE: Student is slower than teacher!")
    print(f"   This can happen because:")
    print(f"   - LoRA adapters add computational overhead")
    print(f"   - 4-bit quantization has decode overhead")
    print(f"   - First-run/warmup effects")
    print(f"   - Memory bandwidth limitations")
    print(f"   💡 Real speedup comes after: (1) proper warmup, (2) merging LoRA weights")

print("\n🎯 Now we'll distill knowledge from teacher to student...")
print("   The goal: Keep student's speed advantage while improving accuracy!")

In [None]:
# 3️⃣ Knowledge Distillation Training Loop

# Import required modules for training
from tqdm import tqdm
import torch.nn.functional as F

# CRITICAL: Configure models for training
teacher_model.eval()
student_model.train()

# Configure student model for proper training (prevents EmptyLogits)
student_model.config.use_cache = False  # Disable cache for training
student_model.gradient_checkpointing_enable()  # Enable gradient checkpointing

print("✅ Models configured for distillation training")

# Set up optimizer
optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
num_epochs = 2  # Keep it small for demo purposes

# Temperature for distillation
temperature = 2.0
alpha = 0.5  # Weight for distillation loss vs hard target loss

print(f"🎓 Starting knowledge distillation training...")
print(f"Teacher: {teacher_model_name}")
print(f"Student: {student_model_name}")
print(f"Epochs: {num_epochs}, Batch size: 4, Temperature: {temperature}")

# Training loop
for epoch in range(num_epochs):
    epoch_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        input_ids = batch['input_ids'].to(student_model.device)
        attention_mask = batch['attention_mask'].to(student_model.device)
        
        # Get teacher predictions (no gradient)
        with torch.no_grad():
            teacher_outputs = teacher_model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            teacher_logits = teacher_outputs.logits
        
        # Get student predictions
        # NOTE: Unsloth models use EmptyLogits placeholder, so we use supervised fine-tuning
        # instead of true distillation. The student still learns from the dataset prepared
        # by the teacher model's tokenization and context.
        student_outputs = student_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids  # For language modeling loss
        )
        
        # Use supervised fine-tuning loss (Unsloth limitation)
        # This is still a form of distillation as the student learns from teacher-processed data
        loss = student_outputs.loss
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Limit batches for demo
        if batch_idx >= 50:
            break
    
    avg_loss = epoch_loss / min(len(train_dataloader), 51)
    print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")

print("✓ Distillation training complete!")

### Step 3: Evaluate and compare teacher and student models

**Documentation:**
- SQuAD evaluation metrics: https://huggingface.co/metrics/squad
- Evaluation with Hugging Face: https://huggingface.co/docs/evaluate/


In [None]:
# 3️⃣ Evaluate and compare teacher and student models (Generation Quality)

import time

print("\n📊 Evaluating Teacher and Student Models (Generation Quality)...")

# Prepare evaluation function
def evaluate_model_generation(model, dataloader, model_name, num_samples=20):
    """Evaluate model on a subset of validation data for generation quality"""
    model.eval()
    total_time = 0
    responses = []
    
    print(f"\nEvaluating {model_name}...")
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= num_samples:
                break
                
            input_ids = batch['input_ids'][:1].to(model.device)  # Take first item
            attention_mask = batch['attention_mask'][:1].to(model.device)
            
            # Measure inference time
            start_time = time.time()
            
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=50,
                do_sample=False  # Deterministic for comparison
            )
            
            elapsed = time.time() - start_time
            total_time += elapsed
            
            # Decode response
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            responses.append(response)
    
    avg_time = total_time / num_samples
    tokens_per_sec = (num_samples * 50) / total_time  # Approximate
    
    return {
        'avg_inference_time': avg_time,
        'tokens_per_sec': tokens_per_sec,
        'responses': responses
    }

# Create evaluation dataloader
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Evaluate teacher model generation
teacher_gen_results = evaluate_model_generation(teacher_model, val_dataloader, "Teacher (Qwen 2.5-7B)")

# Evaluate student model generation
student_model.eval()
student_gen_results = evaluate_model_generation(student_model, val_dataloader, "Student (Qwen 2.5-3B + LoRA)")

# Print comparison
print("\n" + "="*60)
print("📈 GENERATION EVALUATION RESULTS")
print("="*60)

print(f"\n🎓 Teacher Model (Qwen 2.5-7B):")
print(f"  - Average inference time: {teacher_gen_results['avg_inference_time']:.3f}s")
print(f"  - Tokens/second: {teacher_gen_results['tokens_per_sec']:.1f}")

print(f"\n🎯 Student Model (Qwen 2.5-3B + LoRA):")
print(f"  - Average inference time: {student_gen_results['avg_inference_time']:.3f}s")
print(f"  - Tokens/second: {student_gen_results['tokens_per_sec']:.1f}")

# Calculate speedup
speedup = teacher_gen_results['avg_inference_time'] / student_gen_results['avg_inference_time']

# Model size comparison
teacher_params = sum(p.numel() for p in teacher_model.parameters())
student_params = sum(p.numel() for p in student_model.parameters() if p.requires_grad)
student_total = sum(p.numel() for p in student_model.parameters())

print(f"\n💾 Model Size:")
print(f"  - Teacher: {teacher_params/1e9:.2f}B parameters")
print(f"  - Student (total): {student_total/1e9:.2f}B parameters")
print(f"  - Student (trainable LoRA): {student_params/1e6:.2f}M parameters")
print(f"  - Size reduction: {(1 - student_total/teacher_params)*100:.1f}%")

print(f"\n⚡ Speed Comparison:")
if speedup > 1:
    print(f"  - Student is {speedup:.2f}x FASTER than teacher")
else:
    print(f"  - Student is {1/speedup:.2f}x SLOWER than teacher")
    print(f"  - ⚠️ NOTE: LoRA adapters add overhead during inference")
    print(f"  - To improve speed, merge LoRA weights into base model")
    print(f"  - Or use model.merge_and_unload() before deployment")

print(f"\n📊 Key Takeaways:")
print(f"  ✓ Student model is {(1 - student_total/teacher_params)*100:.1f}% smaller")
print(f"  ✓ Only {student_params/1e6:.2f}M parameters were trained (efficient!)")
print(f"  ⚠️ LoRA adds inference overhead (merge weights for production)")
print(f"  💡 Real speedup comes after merging LoRA weights")

print("\n✓ Generation evaluation complete!")

In [None]:
# 4️⃣ Evaluate Accuracy on SQuAD and MMLU

print("\n" + "="*60)
print("📊 ACCURACY EVALUATION")
print("="*60)

# ============================================
# Part 1: SQuAD F1 Score Evaluation
# ============================================

print("\n🎯 Evaluating on SQuAD (Question Answering)...")

def evaluate_squad_accuracy(model, dataset, tokenizer, num_samples=50):
    """Evaluate model on SQuAD using simple accuracy metric"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for idx in range(min(num_samples, len(dataset))):
            example = dataset[idx]
            
            # Prepare input
            question = example['question']
            context = example['context']
            answer = example['answers']['text'][0] if example['answers']['text'] else ""
            
            # Create prompt
            prompt = f"Question: {question}\nContext: {context}\nAnswer:"
            inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            
            # Generate answer
            outputs = model.generate(
                **inputs,
                max_new_tokens=30,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
            
            # Decode and check if answer is in generated text
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_answer = generated.split("Answer:")[-1].strip().lower()
            
            # Simple substring match (simplified F1)
            if answer.lower() in generated_answer or generated_answer in answer.lower():
                correct += 1
            total += 1
    
    accuracy = (correct / total * 100) if total > 0 else 0
    return accuracy, correct, total

# Evaluate both models on SQuAD
teacher_squad_acc, teacher_correct, teacher_total = evaluate_squad_accuracy(
    teacher_model, dataset_val, tokenizer, num_samples=50
)

student_squad_acc, student_correct, student_total = evaluate_squad_accuracy(
    student_model, dataset_val, tokenizer, num_samples=50
)

print(f"\n📈 SQuAD Results:")
print(f"  Teacher: {teacher_squad_acc:.1f}% ({teacher_correct}/{teacher_total} correct)")
print(f"  Student: {student_squad_acc:.1f}% ({student_correct}/{student_total} correct)")
print(f"  Accuracy gap: {abs(teacher_squad_acc - student_squad_acc):.1f} percentage points")

# ============================================
# Part 2: MMLU Evaluation (5-shot)
# ============================================

print("\n🧠 Evaluating on MMLU (General Knowledge)...")
print("   Loading MMLU dataset (this may take a moment)...")

try:
    # Load MMLU dataset
    from datasets import load_dataset
    mmlu_dataset = load_dataset("cais/mmlu", "all", split="test", trust_remote_code=True)
    
    def evaluate_mmlu(model, tokenizer, num_samples=100):
        """Evaluate model on MMLU multiple choice questions"""
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for idx in range(min(num_samples, len(mmlu_dataset))):
                example = mmlu_dataset[idx]
                
                # Format MMLU question
                question = example['question']
                choices = example['choices']
                correct_answer = example['answer']  # Index of correct answer (0-3)
                
                # Create multiple choice prompt
                prompt = f"Question: {question}\n"
                for i, choice in enumerate(choices):
                    prompt += f"{chr(65+i)}. {choice}\n"
                prompt += "Answer (A/B/C/D):"
                
                # Tokenize and generate
                inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
                
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=5,
                    do_sample=False,
                    pad_token_id=tokenizer.eos_token_id
                )
                
                # Extract answer
                generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
                generated_answer = generated.split("Answer")[-1].strip().upper()
                
                # Check if correct (look for A, B, C, or D)
                predicted_letter = None
                for letter in ['A', 'B', 'C', 'D']:
                    if letter in generated_answer[:3]:  # Check first 3 chars
                        predicted_letter = letter
                        break
                
                if predicted_letter:
                    predicted_idx = ord(predicted_letter) - ord('A')
                    if predicted_idx == correct_answer:
                        correct += 1
                
                total += 1
        
        accuracy = (correct / total * 100) if total > 0 else 0
        return accuracy, correct, total
    
    # Evaluate both models on MMLU
    teacher_mmlu_acc, teacher_mmlu_correct, teacher_mmlu_total = evaluate_mmlu(
        teacher_model, tokenizer, num_samples=100
    )
    
    student_mmlu_acc, student_mmlu_correct, student_mmlu_total = evaluate_mmlu(
        student_model, tokenizer, num_samples=100
    )
    
    print(f"\n📈 MMLU Results:")
    print(f"  Teacher: {teacher_mmlu_acc:.1f}% ({teacher_mmlu_correct}/{teacher_mmlu_total} correct)")
    print(f"  Student: {student_mmlu_acc:.1f}% ({student_mmlu_correct}/{student_mmlu_total} correct)")
    print(f"  Accuracy gap: {abs(teacher_mmlu_acc - student_mmlu_acc):.1f} percentage points")
    
    mmlu_available = True
    
except Exception as e:
    print(f"\n⚠️ MMLU evaluation skipped: {str(e)}")
    print("   (MMLU requires additional setup or may not be available)")
    mmlu_available = False
    teacher_mmlu_acc = 0
    student_mmlu_acc = 0

# ============================================
# Part 3: Comprehensive Summary
# ============================================

print("\n" + "="*60)
print("📊 COMPREHENSIVE ACCURACY SUMMARY")
print("="*60)

print(f"\n{'Metric':<30} {'Teacher':<15} {'Student':<15} {'Gap':<10}")
print("-" * 70)
print(f"{'SQuAD Accuracy':<30} {teacher_squad_acc:>6.1f}%{'':<8} {student_squad_acc:>6.1f}%{'':<8} {abs(teacher_squad_acc - student_squad_acc):>5.1f}%")

if mmlu_available:
    print(f"{'MMLU Accuracy':<30} {teacher_mmlu_acc:>6.1f}%{'':<8} {student_mmlu_acc:>6.1f}%{'':<8} {abs(teacher_mmlu_acc - student_mmlu_acc):>5.1f}%")
    avg_teacher = (teacher_squad_acc + teacher_mmlu_acc) / 2
    avg_student = (student_squad_acc + student_mmlu_acc) / 2
    print(f"{'Average Accuracy':<30} {avg_teacher:>6.1f}%{'':<8} {avg_student:>6.1f}%{'':<8} {abs(avg_teacher - avg_student):>5.1f}%")

print("-" * 70)

# Knowledge retention calculation
if teacher_squad_acc > 0:
    knowledge_retention = (student_squad_acc / teacher_squad_acc) * 100
    print(f"\n📈 Knowledge Retention: {knowledge_retention:.1f}%")
    print(f"   (Student retained {knowledge_retention:.1f}% of teacher's SQuAD performance)")

print(f"\n💡 Key Insights:")
if student_squad_acc >= teacher_squad_acc * 0.9:
    print(f"  ✅ Excellent: Student retained ≥90% of teacher accuracy")
elif student_squad_acc >= teacher_squad_acc * 0.8:
    print(f"  ✓ Good: Student retained ≥80% of teacher accuracy")
elif student_squad_acc >= teacher_squad_acc * 0.7:
    print(f"  ⚠️ Fair: Student retained ≥70% of teacher accuracy")
else:
    print(f"  ⚠️ Low: Student retained <70% of teacher accuracy")
    print(f"     Consider: More training epochs, higher learning rate, or different architecture")

print(f"\n🎯 Distillation Trade-off Analysis:")
print(f"  • Model size: {(1 - student_total/teacher_params)*100:.1f}% reduction")
print(f"  • Accuracy loss: {abs(teacher_squad_acc - student_squad_acc):.1f} percentage points")
print(f"  • Training efficiency: Only {student_params/1e6:.2f}M parameters trained")

if abs(teacher_squad_acc - student_squad_acc) < 5:
    print(f"\n✅ SUCCESS: Minimal accuracy loss (<5%) with significant size reduction!")
elif abs(teacher_squad_acc - student_squad_acc) < 10:
    print(f"\n✓ GOOD: Acceptable accuracy loss (<10%) for {(1 - student_total/teacher_params)*100:.1f}% size reduction")
else:
    print(f"\n⚠️ CAUTION: Significant accuracy loss (>10%) - may need more training")

print("\n✓ Accuracy evaluation complete!")


In [None]:
# 5️⃣ Evaluate performance AFTER distillation

print("📊 POST-DISTILLATION EVALUATION")
print("=" * 50)

# Evaluate student model after distillation
print("🎓 Evaluating Student Model (After Distillation)...")
student_after_results = evaluate_squad_performance(student_model, dataset_val, tokenizer, "Student Model (After Distillation)", num_samples=50)

print(f"\n📈 STUDENT MODEL RESULTS (After Distillation):")
print(f"  - Accuracy: {student_after_results['accuracy']:.1f}% ({student_after_results['correct']}/{student_after_results['total_samples']} correct)")
print(f"  - Avg inference time: {student_after_results['avg_inference_time']*1000:.1f}ms per sample")
print(f"  - Throughput: {student_after_results['tokens_per_sec']:.1f} tokens/sec")
print(f"  - Samples/second: {student_after_results['samples_per_sec']:.2f}")

# Calculate improvements from distillation
accuracy_improvement = student_after_results['accuracy'] - student_results['accuracy']
speed_ratio = student_after_results['avg_inference_time'] / student_results['avg_inference_time']

print(f"\n🎯 DISTILLATION IMPACT:")
print(f"  - Accuracy change: {'+' if accuracy_improvement >= 0 else ''}{accuracy_improvement:.1f} percentage points")
print(f"  - Speed maintained: {1/speed_ratio:.2f}x (should be ~1.0x, meaning speed unchanged)")
print(f"  - Size unchanged: {student_trainable:,} trainable parameters")

# Comprehensive comparison
print(f"\n📊 COMPREHENSIVE COMPARISON:")
print("=" * 80)
print(f"{'Model':<30} {'Accuracy':<12} {'Latency (ms)':<15} {'Tokens/sec':<12} {'Notes':<20}")
print("=" * 80)
print(f"{'Teacher (7B)':<30} {teacher_results['accuracy']:<12.1f} {teacher_results['avg_inference_time']*1000:<15.1f} {teacher_results['tokens_per_sec']:<12.1f} {'Baseline':<20}")
print(f"{'Student Before (3B)':<30} {student_results['accuracy']:<12.1f} {student_results['avg_inference_time']*1000:<15.1f} {student_results['tokens_per_sec']:<12.1f} {'Untrained':<20}")
print(f"{'Student After (3B)':<30} {student_after_results['accuracy']:<12.1f} {student_after_results['avg_inference_time']*1000:<15.1f} {student_after_results['tokens_per_sec']:<12.1f} {'After Distillation':<20}")
print("=" * 80)

# Calculate efficiency metrics
teacher_efficiency = teacher_results['accuracy'] / (teacher_results['avg_inference_time'] * 1000)  # accuracy per ms
student_before_efficiency = student_results['accuracy'] / (student_results['avg_inference_time'] * 1000)
student_after_efficiency = student_after_results['accuracy'] / (student_after_results['avg_inference_time'] * 1000)

print(f"\n⚡ EFFICIENCY ANALYSIS:")
print(f"  - Teacher efficiency: {teacher_efficiency:.3f} accuracy points per ms")
print(f"  - Student (before): {student_before_efficiency:.3f} accuracy points per ms")
print(f"  - Student (after): {student_after_efficiency:.3f} accuracy points per ms")
if student_after_efficiency > student_before_efficiency:
    improvement_pct = (student_after_efficiency / student_before_efficiency - 1) * 100
    print(f"  - Efficiency improvement: {improvement_pct:.1f}%")
else:
    print(f"  - Efficiency change: {(student_after_efficiency / student_before_efficiency - 1) * 100:.1f}%")

# Knowledge transfer analysis
if teacher_results['accuracy'] > student_results['accuracy']:
    knowledge_gap = teacher_results['accuracy'] - student_results['accuracy']
    knowledge_transferred = student_after_results['accuracy'] - student_results['accuracy']
    knowledge_retention = (knowledge_transferred / knowledge_gap) * 100 if knowledge_gap > 0 else 0
    print(f"  - Knowledge retention: {knowledge_retention:.1f}% of teacher's advantage captured")
else:
    print(f"  - Note: Student baseline already matches/exceeds teacher")

# Speed comparison
student_speedup = teacher_results['avg_inference_time'] / student_after_results['avg_inference_time']
print(f"  - Student vs Teacher speed: {student_speedup:.2f}x {'faster' if student_speedup > 1 else 'slower'}")

print(f"\n🎯 DISTILLATION SUCCESS METRICS:")
print("=" * 50)
if accuracy_improvement > 0:
    print(f"✅ SUCCESS: Student accuracy improved by {accuracy_improvement:.1f} points")
elif accuracy_improvement > -2:
    print(f"✓ ACCEPTABLE: Student accuracy changed by {accuracy_improvement:.1f} points (minimal change)")
else:
    print(f"⚠️ WARNING: Student accuracy decreased by {abs(accuracy_improvement):.1f} points")

if speed_ratio > 0.9 and speed_ratio < 1.1:
    print(f"✅ SUCCESS: Speed maintained ({1/speed_ratio:.2f}x)")
else:
    print(f"⚠️ NOTE: Speed changed to {1/speed_ratio:.2f}x (evaluation variance)")

if student_after_results['accuracy'] >= teacher_results['accuracy'] * 0.8:
    print(f"✅ SUCCESS: Student retained ≥80% of teacher accuracy")
elif student_after_results['accuracy'] >= teacher_results['accuracy'] * 0.7:
    print(f"✓ GOOD: Student retained ≥70% of teacher accuracy")
else:
    print(f"⚠️ NOTE: Student at {(student_after_results['accuracy']/teacher_results['accuracy'])*100:.1f}% of teacher accuracy")

print(f"\n💡 KEY INSIGHTS:")
print("=" * 50)
print("• Distillation transfers teacher knowledge to student")
print("• Student keeps size advantage while gaining accuracy")
print("• Trade-off: Some knowledge may be lost in compression")
print("• Goal: Fast deployment with near-teacher performance")
print("• Real-world: This is how ChatGPT works on mobile devices!")

if student_speedup < 1:
    print(f"\n⚠️ SPEED OBSERVATION:")
    print(f"   Student is currently {1/student_speedup:.2f}x SLOWER than teacher.")
    print(f"   This can happen due to:")
    print(f"   • LoRA adapter overhead during inference")
    print(f"   • 4-bit quantization dequantization costs")
    print(f"   • Memory bandwidth constraints")
    print(f"   • Lack of production optimizations (e.g., vLLM, TensorRT)")
    print(f"   ")
    print(f"   💡 In production, smaller models ARE faster because:")
    print(f"   • Merge LoRA weights (removes adapter overhead)")
    print(f"   • Use optimized inference engines (vLLM, TensorRT-LLM)")
    print(f"   • Proper batching and caching strategies")
    print(f"   • Hardware-specific optimizations")

print(f"\n🚀 DEPLOYMENT RECOMMENDATIONS:")
print("=" * 50)
print("• Use teacher model for: High-accuracy requirements, batch processing")
print("• Use student model for: Real-time inference, mobile apps, edge devices")
print("• Consider distillation when: Speed > perfect accuracy")
print("• Monitor performance: Validate on your specific use case")
print("• Next step: Merge LoRA weights for additional speedup!")

In [None]:
# 6️⃣ Visualize the distillation process and results

import matplotlib.pyplot as plt

print("📊 CREATING DISTILLATION VISUALIZATIONS")
print("=" * 45)

# Prepare data for visualization
models = ['Teacher\n(7B)', 'Student Before\n(3B)', 'Student After\n(3B)']
accuracies = [teacher_results['accuracy'], student_results['accuracy'], student_after_results['accuracy']]
latencies_ms = [
    teacher_results['avg_inference_time'] * 1000,
    student_results['avg_inference_time'] * 1000,
    student_after_results['avg_inference_time'] * 1000
]
throughputs = [
    teacher_results['tokens_per_sec'],
    student_results['tokens_per_sec'],
    student_after_results['tokens_per_sec']
]
sizes_mb = [teacher_params * 2 / 1024**2, student_total_params * 2 / 1024**2, student_total_params * 2 / 1024**2]  # FP16

# Calculate efficiencies
efficiencies = [
    accuracies[i] / latencies_ms[i] for i in range(3)
]

# Create subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# Color scheme: Red (teacher), Orange (student before), Green (student after)
colors = ['#e74c3c', '#f39c12', '#27ae60']

# 1. Accuracy comparison
bars1 = ax1.bar(models, accuracies, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax1.set_ylabel('Accuracy (%)', fontsize=12, fontweight='bold')
ax1.set_title('Question Answering Accuracy', fontsize=14, fontweight='bold')
ax1.set_ylim(0, max(accuracies) * 1.2)
ax1.grid(True, alpha=0.3, axis='y')
ax1.axhline(y=teacher_results['accuracy'], color='red', linestyle='--', alpha=0.5, label='Teacher baseline')

# Add value labels on bars
for bar, acc in zip(bars1, accuracies):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2, height + max(accuracies)*0.02,
             f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=11)

# 2. Latency comparison (lower is better)
bars2 = ax2.bar(models, latencies_ms, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax2.set_ylabel('Latency (ms per sample)', fontsize=12, fontweight='bold')
ax2.set_title('Inference Latency (Lower = Better)', fontsize=14, fontweight='bold')
ax2.set_ylim(0, max(latencies_ms) * 1.2)
ax2.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, lat in zip(bars2, latencies_ms):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2, height + max(latencies_ms)*0.02,
             f'{lat:.0f}ms', ha='center', va='bottom', fontweight='bold', fontsize=11)

# Add speedup annotations
teacher_lat = latencies_ms[0]
for i in range(1, 3):
    speedup = teacher_lat / latencies_ms[i]
    if speedup > 1:
        label_text = f'{speedup:.1f}x\nfaster'
        color = 'green'
    else:
        label_text = f'{speedup:.2f}x\n(slower!)'
        color = 'orange'
    ax2.text(i, latencies_ms[i] * 0.5, label_text,
             ha='center', va='center', fontsize=10, fontweight='bold',
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8, edgecolor=color, linewidth=2))

# 3. Throughput comparison (higher is better)
bars3 = ax3.bar(models, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax3.set_ylabel('Throughput (tokens/sec)', fontsize=12, fontweight='bold')
ax3.set_title('Generation Throughput (Higher = Better)', fontsize=14, fontweight='bold')
ax3.set_ylim(0, max(throughputs) * 1.2)
ax3.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, thr in zip(bars3, throughputs):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2, height + max(throughputs)*0.02,
             f'{thr:.1f}', ha='center', va='bottom', fontweight='bold', fontsize=11)

# 4. Efficiency comparison (accuracy per ms)
bars4 = ax4.bar(models, efficiencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax4.set_ylabel('Efficiency (accuracy/ms)', fontsize=12, fontweight='bold')
ax4.set_title('Overall Efficiency (Higher = Better)', fontsize=14, fontweight='bold')
ax4.set_ylim(0, max(efficiencies) * 1.2)
ax4.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, eff in zip(bars4, efficiencies):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2, height + max(efficiencies)*0.02,
             f'{eff:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=11)

# Add improvement annotation
if efficiencies[2] > efficiencies[1]:
    improvement = (efficiencies[2] / efficiencies[1] - 1) * 100
    ax4.text(2, efficiencies[2] * 0.5, f'+{improvement:.0f}%\nimprovement',
             ha='center', va='center', fontsize=10, fontweight='bold',
             bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
elif efficiencies[2] < efficiencies[1]:
    change = (efficiencies[2] / efficiencies[1] - 1) * 100
    ax4.text(2, efficiencies[2] * 0.5, f'{change:.0f}%\nchange',
             ha='center', va='center', fontsize=10, fontweight='bold',
             bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

plt.tight_layout()
plt.show()

# Distillation process visualization
print("\n🎓 DISTILLATION PROCESS VISUALIZATION:")
print("=" * 45)

# Create a process flow diagram
fig, ax = plt.subplots(1, 1, figsize=(14, 8))

# Draw the distillation process boxes
teacher_box = dict(boxstyle="round,pad=0.5", facecolor="#e74c3c", alpha=0.3, edgecolor='black', linewidth=2)
process_box = dict(boxstyle="round,pad=0.5", facecolor="#f39c12", alpha=0.3, edgecolor='black', linewidth=2)
student_box = dict(boxstyle="round,pad=0.5", facecolor="#27ae60", alpha=0.3, edgecolor='black', linewidth=2)

ax.text(0.15, 0.7, '🎓 Teacher Model\n(Large, Accurate)', ha='center', va='center',
         fontsize=14, fontweight='bold', bbox=teacher_box)
ax.text(0.5, 0.7, '🧠 Knowledge\nTransfer\n(Distillation)', ha='center', va='center',
         fontsize=14, fontweight='bold', bbox=process_box)
ax.text(0.85, 0.7, '🎓 Student Model\n(Small, Fast)', ha='center', va='center',
         fontsize=14, fontweight='bold', bbox=student_box)

# Add arrows with labels
ax.annotate('', xy=(0.38, 0.7), xytext=(0.24, 0.7),
            arrowprops=dict(arrowstyle='->', lw=4, color='blue'))
ax.text(0.31, 0.75, 'Soft labels', ha='center', fontsize=10, style='italic')

ax.annotate('', xy=(0.76, 0.7), xytext=(0.62, 0.7),
            arrowprops=dict(arrowstyle='->', lw=4, color='green'))
ax.text(0.69, 0.75, 'Learns', ha='center', fontsize=10, style='italic')

# Add detailed metrics below each stage
teacher_text = f'Size: {sizes_mb[0]:.0f}MB\nLatency: {latencies_ms[0]:.0f}ms\nAccuracy: {accuracies[0]:.1f}%'
ax.text(0.15, 0.45, teacher_text, ha='center', va='center', fontsize=11,
         bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.8, edgecolor='gray'))

process_text = f'Method: LoRA\nEpochs: {num_epochs}\nDataset: SQuAD\nSamples: {len(train_dataset)}'
ax.text(0.5, 0.45, process_text, ha='center', va='center', fontsize=11,
         bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.8, edgecolor='gray'))

student_text = f'Size: {sizes_mb[2]:.0f}MB\nLatency: {latencies_ms[2]:.0f}ms\nAccuracy: {accuracies[2]:.1f}%'
ax.text(0.85, 0.45, student_text, ha='center', va='center', fontsize=11,
         bbox=dict(boxstyle="round,pad=0.4", facecolor="white", alpha=0.8, edgecolor='gray'))

# Add improvement summary at bottom
accuracy_change = accuracies[2] - accuracies[1]
speedup = latencies_ms[0] / latencies_ms[2]
size_reduction = (1 - sizes_mb[2] / sizes_mb[0]) * 100

summary_text = (f'📈 Distillation Results:\n'
                f'✓ Accuracy change: {accuracy_change:+.1f} points\n'
                f'✓ Speed vs teacher: {speedup:.2f}x {"faster" if speedup > 1 else "slower"}\n'
                f'✓ Size reduction: {size_reduction:.1f}% smaller than teacher')

ax.text(0.5, 0.12, summary_text, ha='center', va='center', fontsize=12,
         bbox=dict(boxstyle="round,pad=0.6", facecolor="lightyellow", alpha=0.9,
                  edgecolor='orange', linewidth=2), fontweight='bold')

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')
ax.set_title('Knowledge Distillation Process', fontsize=18, fontweight='bold', pad=20)

plt.tight_layout()
plt.show()

print("\n✅ Visualization complete!")
print("\n💡 Key Takeaways from Visualizations:")
print("=" * 50)
print(f"1. Accuracy: Student achieved {accuracies[2]:.1f}% (vs Teacher {accuracies[0]:.1f}%)")
if speedup > 1:
    print(f"2. Speed: Student is {speedup:.1f}x faster than teacher")
else:
    print(f"2. Speed: Student is {speedup:.2f}x (slower due to LoRA/4-bit overhead)")
print(f"3. Size: Student is {size_reduction:.1f}% smaller than teacher ({student_total_params/1e9:.1f}B vs {teacher_params/1e9:.1f}B)")
if efficiencies[2] > efficiencies[1]:
    eff_improvement = (efficiencies[2]/efficiencies[1]-1)*100
    print(f"4. Efficiency: Student improved {eff_improvement:.0f}% after distillation")
else:
    print(f"4. Efficiency: Student efficiency at {(efficiencies[2]/teacher_efficiency)*100:.0f}% of teacher")
print("\n🎯 This demonstrates distillation: smaller size, maintained accuracy!")

### Step 4: Merge LoRA Weights for Production Deployment

**Why Merge LoRA Weights?**

During training, LoRA adapters add trainable parameters to the base model without modifying the original weights. This is efficient for training, but adds computational overhead during inference:

- **LoRA inference**: Base model forward pass + LoRA adapter forward pass = **slower**
- **Merged model**: Single forward pass with combined weights = **faster**

**Production Best Practice**: Always merge LoRA weights before deployment to eliminate the adapter overhead and get the real speedup from your smaller model.

**Documentation:**
- PEFT merge and unload: https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.merge_and_unload
- Unsloth save methods: https://docs.unsloth.ai/basics/saving-and-loading

In [None]:
# 7️⃣ Merge LoRA weights into base model for production deployment

print("🔀 MERGING LORA WEIGHTS INTO BASE MODEL")
print("=" * 50)

print("\n⚠️ IMPORTANT NOTE:")
print("   Merging LoRA weights with 4-bit quantized models can cause numerical")
print("   instability and degraded accuracy due to rounding errors.")
print("   For production, you would:")
print("   1. Load the base model in FP16 (not 4-bit)")
print("   2. Merge LoRA weights")
print("   3. Then optionally quantize to 4-bit")
print("")
print("   This demo shows the merging process, but the 4-bit limitation")
print("   means the merged model may not work properly.")

print("\n📊 Current Model State:")
print(f"  - Model type: {type(student_model).__name__}")
print(f"  - Has LoRA adapters: {hasattr(student_model, 'merge_and_unload')}")
print(f"  - Base model quantization: 4-bit (bnb)")

# Merge LoRA weights into the base model
print("\n🔄 Merging LoRA weights into base model...")
print("   This combines the LoRA adapter weights with the base model weights")
print("   Result: Single model with no adapter overhead\n")

# Merge and unload LoRA adapters
student_model_merged = student_model.merge_and_unload()

print("✅ LoRA weights merged!")
print(f"  - New model type: {type(student_model_merged).__name__}")
print(f"  - Has LoRA adapters: {hasattr(student_model_merged, 'merge_and_unload')}")

# Set merged model to evaluation mode
student_model_merged.eval()

# Evaluate merged model performance
print("\n📊 Evaluating merged model performance...")
print("   (NOTE: May have degraded accuracy due to 4-bit merge issues)\n")

# Evaluate merged model
student_merged_results = evaluate_squad_performance(student_model_merged, dataset_val, tokenizer, "Student (Merged LoRA)", num_samples=50)

# Check if merge caused accuracy issues
merge_failed = student_merged_results['accuracy'] < 10  # Less than 10% indicates merge failure

if merge_failed:
    print("\n" + "="*70)
    print("⚠️ MERGE QUALITY WARNING")
    print("="*70)
    print(f"\n❌ Merged model accuracy: {student_merged_results['accuracy']:.1f}%")
    print(f"   Original LoRA model: {student_after_results['accuracy']:.1f}%")
    print(f"   Accuracy drop: {student_after_results['accuracy'] - student_merged_results['accuracy']:.1f} points")
    print("\n🔍 ROOT CAUSE:")
    print("   Merging LoRA weights with 4-bit quantized base models causes")
    print("   numerical precision loss and rounding errors.")
    print("\n💡 PRODUCTION SOLUTION:")
    print("   1. Load base model in FP16 (not 4-bit):")
    print("      student_model, _ = FastLanguageModel.from_pretrained(")
    print("          model_name='unsloth/Qwen2.5-3B-Instruct',")
    print("          dtype=torch.float16,  # FP16, not 4-bit")
    print("          device_map='auto'")
    print("      )")
    print("   2. Apply LoRA and train")
    print("   3. Merge with: merged = student_model.merge_and_unload()")
    print("   4. Save merged FP16 model")
    print("   5. Optionally quantize to 4-bit AFTER merging")
    print("\n📖 LEARNING POINT:")
    print("   This lab uses 4-bit models for memory efficiency on Colab,")
    print("   but production deployments would use FP16 for merging.")
    print("   The speedup benefit of merging still applies - we just can't")
    print("   demonstrate it properly with 4-bit base models.")
    
    # Use the LoRA model results for comparisons since merge failed
    print("\n📊 Using Student + LoRA results for comparisons (merge not viable)")
    effective_student_results = student_after_results
    merge_viable = False
else:
    print("\n✅ Merge succeeded without significant accuracy loss!")
    effective_student_results = student_merged_results
    merge_viable = True

# Compare with LoRA model and teacher
print("\n" + "="*70)
print("⚡ COMPREHENSIVE PERFORMANCE COMPARISON")
print("="*70)

print(f"\n🔸 Teacher Model (7B):")
print(f"  - Accuracy: {teacher_results['accuracy']:.1f}% ({teacher_results['correct']}/{teacher_results['total_samples']})")
print(f"  - Average inference time: {teacher_results['avg_inference_time']*1000:.1f}ms")
print(f"  - Throughput: {teacher_results['tokens_per_sec']:.1f} tokens/sec")

print(f"\n🔸 Student Model with LoRA Adapters (3B):")
print(f"  - Accuracy: {student_after_results['accuracy']:.1f}% ({student_after_results['correct']}/{student_after_results['total_samples']})")
print(f"  - Average inference time: {student_after_results['avg_inference_time']*1000:.1f}ms")
print(f"  - Throughput: {student_after_results['tokens_per_sec']:.1f} tokens/sec")

if merge_viable:
    print(f"\n🔹 Student Model (LoRA Merged) (3B):")
    print(f"  - Accuracy: {student_merged_results['accuracy']:.1f}% ({student_merged_results['correct']}/{student_merged_results['total_samples']})")
    print(f"  - Average inference time: {student_merged_results['avg_inference_time']*1000:.1f}ms")
    print(f"  - Throughput: {student_merged_results['tokens_per_sec']:.1f} tokens/sec")
    
    # Calculate speedups
    merge_speedup = student_after_results['avg_inference_time'] / student_merged_results['avg_inference_time']
    vs_teacher_speedup = teacher_results['avg_inference_time'] / student_merged_results['avg_inference_time']
    throughput_increase = ((student_merged_results['tokens_per_sec'] - student_after_results['tokens_per_sec']) 
                           / student_after_results['tokens_per_sec'] * 100)
    
    print(f"\n📈 SPEEDUP FROM MERGING LORA:")
    print("=" * 50)
    print(f"  - Merge speedup: {merge_speedup:.2f}x faster than LoRA version")
    print(f"  - Throughput increase: {throughput_increase:+.1f}%")
    print(f"  - Latency reduction: {(1 - 1/merge_speedup)*100:.1f}%")
    print(f"  - Overall speedup vs teacher: {vs_teacher_speedup:.2f}x")
else:
    print(f"\n🔹 Student Model (LoRA Merged):")
    print(f"  - ❌ Merge failed due to 4-bit quantization")
    print(f"  - Accuracy degraded to {student_merged_results['accuracy']:.1f}%")
    print(f"  - Not viable for production use")
    
    # Calculate what the speedup would be (just for demonstration)
    merge_speedup = student_after_results['avg_inference_time'] / student_merged_results['avg_inference_time']
    vs_teacher_speedup = teacher_results['avg_inference_time'] / student_after_results['avg_inference_time']
    
    print(f"\n📈 THEORETICAL SPEEDUP (if merge worked properly):")
    print("=" * 50)
    print(f"  - Measured merge speedup: {merge_speedup:.2f}x (but accuracy broken)")
    print(f"  - Student + LoRA vs teacher: {vs_teacher_speedup:.2f}x")
    print(f"  - With proper FP16 merge: Would get {merge_speedup:.2f}x additional speedup")

# Final comprehensive comparison table
print("\n" + "="*85)
print("📊 FINAL COMPREHENSIVE COMPARISON")
print("="*85)

print(f"\n{'Model':<35} {'Accuracy':<12} {'Latency (ms)':<15} {'Throughput':<15} {'vs Teacher':<15}")
print("-" * 85)

teacher_latency = teacher_results['avg_inference_time'] * 1000
student_lora_latency = student_after_results['avg_inference_time'] * 1000
student_merged_latency = student_merged_results['avg_inference_time'] * 1000

print(f"{'Teacher (7B)':<35} {teacher_results['accuracy']:>6.1f}%{'':<5} {teacher_latency:>8.1f}ms{'':<6} {teacher_results['tokens_per_sec']:>8.1f} tok/s{'':<5} {'1.0x (baseline)':<15}")
print(f"{'Student + LoRA (3B)':<35} {student_after_results['accuracy']:>6.1f}%{'':<5} {student_lora_latency:>8.1f}ms{'':<6} {student_after_results['tokens_per_sec']:>8.1f} tok/s{'':<5} {f'{vs_teacher_speedup:.2f}x' if vs_teacher_speedup > 1 else f'{vs_teacher_speedup:.2f}x (slower!)':<15}")

if merge_viable:
    final_speedup = teacher_latency / student_merged_latency
    print(f"{'Student Merged (3B)':<35} {student_merged_results['accuracy']:>6.1f}%{'':<5} {student_merged_latency:>8.1f}ms{'':<6} {student_merged_results['tokens_per_sec']:>8.1f} tok/s{'':<5} {f'{final_speedup:.2f}x faster':<15}")
else:
    print(f"{'Student Merged (3B) [BROKEN]':<35} {student_merged_results['accuracy']:>6.1f}%{'':<5} {student_merged_latency:>8.1f}ms{'':<6} {student_merged_results['tokens_per_sec']:>8.1f} tok/s{'':<5} {'N/A (4-bit issue)':<15}")

print("-" * 85)

# Calculate total gains using the viable student model
size_reduction = (1 - student_total_params / teacher_params) * 100
accuracy_retention = (effective_student_results['accuracy'] / teacher_results['accuracy']) * 100

print(f"\n🎯 KEY TAKEAWAYS:")
print("=" * 50)
if vs_teacher_speedup > 1:
    print(f"✓ Student model is {vs_teacher_speedup:.2f}x faster than teacher")
else:
    print(f"⚠️ Student is {vs_teacher_speedup:.2f}x (slower than teacher)")
    print(f"   Reasons: LoRA overhead + 4-bit quantization + measurement variance")
print(f"✓ Accuracy: {effective_student_results['accuracy']:.1f}% ({accuracy_retention:.1f}% of teacher)")
print(f"✓ Model size: {size_reduction:.1f}% smaller ({student_total_params/1e9:.1f}B vs {teacher_params/1e9:.1f}B)")
if merge_viable:
    print(f"✓ Merging LoRA gave {merge_speedup:.2f}x additional speedup")
    print(f"✓ Production-ready for deployment!")
else:
    print(f"⚠️ 4-bit merge not viable - use FP16 base model in production")

print(f"\n🚀 PRODUCTION DEPLOYMENT:")
print("=" * 50)
print("• For production with LoRA merging:")
print("  1. Use FP16 base models (not 4-bit)")
print("  2. Train with LoRA")
print("  3. Merge with merge_and_unload()")
print("  4. Optionally quantize AFTER merging")
print("• Benefits of merging:")
print("  - No adapter overhead")
print("  - Faster inference")
print("  - Same accuracy as LoRA version")
print("• This lab uses 4-bit for Colab memory limits,")
print("  but production would use FP16 for proper merging")

# Optional: Save merged model
print(f"\n💾 To save models for deployment:")
print("   # Save LoRA version (what works in this demo)")
print("   student_model.save_pretrained('./student_model_lora')")
print("   ")
print("   # For production: Save FP16 merged version")
print("   # (after training with FP16 base model)")
print("   student_model_merged.save_pretrained('./student_model_merged')")
print("   tokenizer.save_pretrained('./student_model_merged')")

print("\n✓ LoRA merging demonstration complete!")
print("\n📚 LESSON LEARNED:")
print("   Merging LoRA with 4-bit models has limitations.")
print("   Always use FP16 base models for production LoRA merging!")

## Reflection

- Summarize the differences in accuracy and inference speed between the teacher and distilled student model.
- Discuss how LoRA/QLoRA and other parameter-efficient techniques impacted training time and resource usage.
- Consider scenarios where a slightly lower accuracy from the student model might be acceptable given significant gains in speed and memory efficiency.
