# Knowledge Distillation for LLMs - Single GPU Lab

This notebook demonstrates **Knowledge Distillation** using NVIDIA TensorRT Model Optimizer on a single GPU.

## What is Knowledge Distillation?

Knowledge Distillation is a model compression technique where:
- A smaller **student model** learns to mimic a larger **teacher model**
- The student learns from both:
  1. Ground truth labels (standard supervised learning)
  2. Soft predictions from the teacher (knowledge transfer)

### In this lab:
- **Teacher**: Llama-3.2-3B-Instruct (3.2B parameters)
- **Student**: Llama-3.2-1B (1.2B parameters)
- **Goal**: Compress the teacher's knowledge into a 2.6√ó smaller model

---

## üìã Lab Setup

### Prerequisites
- Single GPU with ~40GB VRAM (e.g., A100, A6000)
- Python 3.10+
- CUDA 11.8+

### Learning Objectives
1. Understand knowledge distillation process
2. Implement dataset preprocessing for LLM training
3. Configure and run distributed training
4. Evaluate distilled model performance

---

## üîß Step 1: GPU Configuration

**IMPORTANT**: Configure which GPU to use BEFORE importing PyTorch!

Check available GPUs with `nvidia-smi` or `nvitop` and select a free GPU.

In [1]:
# For faster library installation
!pip install uv nvitop
# Install TensorRT Model Optimizer with HuggingFace support
!uv pip install -U nvidia-modelopt[hf]

!uv pip uninstall numpy transformers
# Install additional dependencies
!uv pip install pyarrow 'transformers<5.0' 'trl>=0.23.0' 'numpy<2.0' bitsandbytes accelerate

Defaulting to user installation because normal site-packages is not writeable
[2K[2mResolved [1m90 packages[0m [2min 904ms[0m[0m                                        [0m
[2K[37m‚†ô[0m [2mPreparing packages...[0m (0/5)                                                   
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m     0 B/858.14 MiB          [1A
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m 14.88 KiB/858.14 MiB        [1A
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m 30.88 KiB/858.14 MiB        [1A
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m 46.88 KiB/858.14 MiB        [1A
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m 62.88 KiB/858.14 MiB        [1A
[2K[1A[37m‚†ô[0m [2mPreparing packages...[0m (0/5)-------------------[0m[0m 78.88 KiB/858.14 MiB        [1A
[2K[1A[37m‚

In [2]:
import os
from dataclasses import dataclass

GPU_ID = 4  # Change this to your available GPU (0-7)

os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_ID)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

print(f"üéØ GPU Configuration:")
print(f"   Selected GPU: {GPU_ID}")
print(f"   CUDA_VISIBLE_DEVICES = {os.environ['CUDA_VISIBLE_DEVICES']}")
print(f"\n‚ö†Ô∏è  After this setting, GPU {GPU_ID} will appear as 'cuda:0' in PyTorch")
print(f"   This is normal and expected!\n")

üéØ GPU Configuration:
   Selected GPU: 4
   CUDA_VISIBLE_DEVICES = 4

‚ö†Ô∏è  After this setting, GPU 4 will appear as 'cuda:0' in PyTorch
   This is normal and expected!



## üì¶ Step 2: Import Libraries

Now that GPU is configured, we can import PyTorch and other libraries.

In [3]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets

# TensorRT Model Optimizer imports
import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss

print("‚úÖ Libraries imported successfully!")
print(f"==> PyTorch version: {torch.__version__}")
print(f"==> Transformers version: {transformers.__version__}")
print(f"==> CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"==> Visible GPU count: {torch.cuda.device_count()}")
    print(f"==> Device 0 name: {torch.cuda.get_device_name(0)}")
    print(f"==> Total memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"\n‚úÖ SUCCESS: Using GPU {GPU_ID} as 'cuda:0'")
else:
    print("‚ùå ERROR: CUDA not available!")

  from .autonotebook import tqdm as notebook_tqdm
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


‚úÖ Libraries imported successfully!
==> PyTorch version: 2.9.1+cu128
==> Transformers version: 4.57.1
==> CUDA available: True
==> Visible GPU count: 1
==> Device 0 name: NVIDIA A100-SXM4-40GB
==> Total memory: 39.39 GB

‚úÖ SUCCESS: Using GPU 4 as 'cuda:0'




## ‚öôÔ∏è Step 3: Training Configuration

Configure the model paths and training hyperparameters.

In [3]:
@dataclass
class ModelArguments:
    """Model Configuration"""
    # Teacher: Larger model we distill FROM
    teacher_name_or_path: str = "meta-llama/Llama-3.2-3B-Instruct"
    
    # Student: Smaller model we distill TO
    student_name_or_path: str = "meta-llama/Llama-3.2-1B"


# Create model configuration
model_args = ModelArguments()

# Configure training arguments
training_args = transformers.TrainingArguments(
    output_dir="./llama3.2-1b-distilled-1gpu",
    do_train=True,
    do_eval=True,
    
    # Training duration
    max_steps=3200,  # ~2 epochs with our dataset
    
    # Checkpointing
    save_strategy="steps",
    save_steps=500,
    
    # Batch size - adjust based on GPU memory
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=16,  # Effective batch = 2 * 16 = 32
    
    # Optimizer settings
    optim="adamw_torch",
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    
    # Mixed precision (faster training, less memory)
    bf16=True,
    tf32=False,
    
    # Logging and evaluation
    logging_steps=50,
    eval_steps=400,
    report_to="none",
    
    # Data processing
    dataloader_drop_last=True,
)

print("‚úÖ Configuration:")
print("="*80)
print(f"  Teacher: {model_args.teacher_name_or_path}")
print(f"  Student: {model_args.student_name_or_path}")
print(f"\nüìä Training Parameters:")
print(f"  Batch size per device:      {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation:      {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size:       {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate:              {training_args.learning_rate}")
print(f"  Max steps:                  {training_args.max_steps}")
print(f"\nüíæ Output: {training_args.output_dir}")
print("="*80)

‚úÖ Configuration:
  Teacher: meta-llama/Llama-3.2-3B-Instruct
  Student: meta-llama/Llama-3.2-1B

üìä Training Parameters:
  Batch size per device:      2
  Gradient accumulation:      16
  Effective batch size:       32
  Learning rate:              2e-05
  Max steps:                  3200

üíæ Output: ./llama3.2-1b-distilled-1gpu


  self.setter(val)


## üìä Step 4: Load and Prepare Dataset

We'll use the **smol-smoltalk-Interaction-SFT** dataset, which contains conversational query-answer pairs.

### Dataset Preprocessing
The raw dataset has columns: `query`, `answer`, `source`

We need to:
1. Format into chat template (user/assistant messages)
2. Tokenize the text
3. Create `input_ids`, `attention_mask`, and `labels`

In [4]:
print("Loading dataset...")

# Load the dataset from HuggingFace
dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train")

# Split into training and evaluation sets
dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420)
dset_train, dset_eval = dset_splits["train"], dset_splits["test"]

print(f"‚úÖ Dataset loaded!")
print(f"  Training samples: {len(dset_train):,}")
print(f"  Evaluation samples: {len(dset_eval):,}")
print(f"\nüìù Sample data:")
print(f"  Query: {dset_train[0]['query'][:100]}...")
print(f"  Answer: {dset_train[0]['answer'][:100]}...")

Loading dataset...
‚úÖ Dataset loaded!
  Training samples: 12,800
  Evaluation samples: 1,280

üìù Sample data:
  Query: What are Data visualization types....
  Answer: Data visualization types are diverse and can be categorized based on their purpose, structure, and f...


## üî§ Step 5: Load Tokenizer and Preprocess Dataset

The tokenizer converts text into tokens that the model can understand.

In [5]:
print("Loading tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(model_args.teacher_name_or_path, use_fast=True)

# Configure padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"‚úÖ Tokenizer loaded")
print(f"  Vocab size: {len(tokenizer):,}")
print(f"  Pad token: '{tokenizer.pad_token}'")
print(f"  EOS token: '{tokenizer.eos_token}'")

Loading tokenizer...
‚úÖ Tokenizer loaded
  Vocab size: 128,256
  Pad token: '<|eot_id|>'
  EOS token: '<|eot_id|>'


### Define Preprocessing Function

This function:
1. Converts query/answer into chat format
2. Applies the chat template
3. Tokenizes with truncation
4. Creates labels for language modeling

In [6]:
def format_sample(sample):
    """
    Format and tokenize a dataset sample.
    
    Args:
        sample: Dict with 'query' and 'answer' keys
        
    Returns:
        Dict with 'input_ids', 'attention_mask', and 'labels'
    """
    # Create chat messages
    messages = [
        {"role": "user", "content": sample["query"]},
        {"role": "assistant", "content": sample["answer"]},
    ]
    
    # Apply chat template
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    
    # Tokenize with truncation
    tokenized = tokenizer(
        text, 
        truncation=True, 
        max_length=512,  # Limit sequence length
        padding=False    # Dynamic padding in data collator
    )
    
    # Create labels (copy of input_ids for language modeling)
    tokenized["labels"] = list(tokenized["input_ids"])
    
    return tokenized

print("‚úÖ Preprocessing function defined")

# Test the function on one sample
test_sample = format_sample(dset_train[0])
print(f"\nüìä Test output:")
print(f"  Input IDs length: {len(test_sample['input_ids'])}")
print(f"  Attention mask length: {len(test_sample['attention_mask'])}")
print(f"  Labels length: {len(test_sample['labels'])}")

‚úÖ Preprocessing function defined

üìä Test output:
  Input IDs length: 486
  Attention mask length: 486
  Labels length: 486


### Apply Preprocessing to Dataset

This will tokenize all samples in the dataset.

In [7]:
print("Tokenizing datasets...")
print("This may take a few minutes...\n")

# Apply to training set
dset_train = dset_train.map(
    format_sample, 
    remove_columns=dset_train.column_names,  # Remove original columns
    num_proc=4,  # Parallel processing
    desc="Tokenizing train set"
)

# Apply to evaluation set
dset_eval = dset_eval.map(
    format_sample,
    remove_columns=dset_eval.column_names,
    num_proc=4,
    desc="Tokenizing eval set"
)

print(f"\n‚úÖ Tokenization complete!")
print(f"  Train set columns: {dset_train.column_names}")
print(f"  Train set features: {dset_train.features}")

Tokenizing datasets...
This may take a few minutes...



Tokenizing train set (num_proc=4): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12800/12800 [00:04<00:00, 3149.60 examples/s]
Tokenizing eval set (num_proc=4): 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1280/1280 [00:00<00:00, 1480.15 examples/s]


‚úÖ Tokenization complete!
  Train set columns: ['input_ids', 'attention_mask', 'labels']
  Train set features: {'input_ids': List(Value('int32')), 'attention_mask': List(Value('int8')), 'labels': List(Value('int64'))}





## ü§ñ Step 6: Load Student Model

Load the smaller student model that will learn from the teacher.

In [8]:
print(f"Loading student model: {model_args.student_name_or_path}")
print("This may take a few minutes...\n")

student_model = AutoModelForCausalLM.from_pretrained(
    model_args.student_name_or_path,
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
    device_map={"":0}  # Load on first GPU
)

student_params = sum(p.numel() for p in student_model.parameters())

print(f"‚úÖ Student model loaded!")
print(f"  Parameters: {student_params:,} ({student_params/1e9:.2f}B)")
print(f"  Device: {next(student_model.parameters()).device}")
print(f"  Memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

Loading student model: meta-llama/Llama-3.2-1B
This may take a few minutes...



`torch_dtype` is deprecated! Use `dtype` instead!


‚úÖ Student model loaded!
  Parameters: 1,235,814,400 (1.24B)
  Device: cuda:0
  Memory: 2.30 GB


## üë®‚Äçüè´ Step 7: Load Teacher Model & Configure Distillation

Load the larger teacher model and set up knowledge distillation.

In [9]:
print(f"Loading teacher model: {model_args.teacher_name_or_path}")
print("This may take a few minutes...\n")

teacher_model = AutoModelForCausalLM.from_pretrained(
    model_args.teacher_name_or_path,
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
    device_map={"":0}  # Load on first GPU
)

teacher_params = sum(p.numel() for p in teacher_model.parameters())

print(f"‚úÖ Teacher model loaded!")
print(f"  Parameters: {teacher_params:,} ({teacher_params/1e9:.2f}B)")
print(f"  Device: {next(teacher_model.parameters()).device}")
print(f"  Compression ratio: {teacher_params/student_params:.2f}x")
print(f"  Total GPU memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")

Loading teacher model: meta-llama/Llama-3.2-3B-Instruct
This may take a few minutes...



Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:01<00:00,  1.20it/s]


‚úÖ Teacher model loaded!
  Parameters: 3,212,749,824 (3.21B)
  Device: cuda:0
  Compression ratio: 2.60x
  Total GPU memory: 8.29 GB


### Configure Knowledge Distillation

Set up the distillation loss function and convert the student model.

In [10]:
print("Configuring knowledge distillation...\n")

# Configure KD loss
kd_config = {
    "teacher_model": teacher_model,
    "criterion": LMLogitsLoss(),  # KL-divergence on logits
}

# Enable ModelOpt checkpointing for HuggingFace
mto.enable_huggingface_checkpointing()

# Convert student to distillation model
model = mtd.convert(student_model, mode=[("kd_loss", kd_config)])

# Fix generation config warnings
model.generation_config.temperature = None
model.generation_config.top_p = None

print("‚úÖ Distillation configured!")
print(f"  Loss function: LMLogitsLoss (KL-divergence)")
print(f"  Student will learn from:")
print(f"    1. Ground truth labels (standard loss)")
print(f"    2. Teacher's soft predictions (distillation loss)")

Configuring knowledge distillation...

ModelOpt save/restore enabled for `transformers` library.
ModelOpt save/restore enabled for `diffusers` library.
ModelOpt save/restore enabled for `peft` library.
‚úÖ Distillation configured!
  Loss function: LMLogitsLoss (KL-divergence)
  Student will learn from:
    1. Ground truth labels (standard loss)
    2. Teacher's soft predictions (distillation loss)


## üöÄ Step 8: Setup Trainer and Start Training

Create the trainer with the appropriate data collator.

In [11]:
print("Setting up trainer...\n")

# Use DataCollatorForSeq2Seq which handles labels properly
data_collator = transformers.DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True,
    return_tensors="pt"
)

# Create KD Trainer
trainer = KDTrainer(
    model,
    training_args,
    train_dataset=dset_train,
    eval_dataset=dset_eval,
    data_collator=data_collator,
    processing_class=tokenizer,
)

print("‚úÖ Trainer ready!")
print(f"\nüìä Training schedule:")
print(f"  Total samples: {len(dset_train):,}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Total steps: {training_args.max_steps:,}")
print(f"  Estimated time: ~{training_args.max_steps * 1.5 / 60:.0f} minutes")

Setting up trainer...

ModelOpt save/restore enabled for `transformers` library.
ModelOpt save/restore enabled for `diffusers` library.
ModelOpt save/restore enabled for `peft` library.


The model is already on multiple devices. Skipping the move to device specified in `args`.


‚úÖ Trainer ready!

üìä Training schedule:
  Total samples: 12,800
  Batch size: 2
  Gradient accumulation: 16
  Effective batch size: 32
  Total steps: 3,200
  Estimated time: ~80 minutes


### Start Training

This will train for 3,200 steps (~2 epochs).

**Note**: Training will take approximately 50-60 minutes on a single A100 GPU.

In [12]:
print("="*80)
print("STARTING TRAINING")
print("="*80)
print("\n‚è±Ô∏è  This will take approximately 50-60 minutes...\n")

# Start training
trainer.train()

print("\n" + "="*80)
print("‚úÖ TRAINING COMPLETE!")
print("="*80)

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


STARTING TRAINING

‚è±Ô∏è  This will take approximately 50-60 minutes...



Step,Training Loss
50,32.5987
100,15.7592
150,7.1449
200,6.2123
250,5.8827
300,5.6601
350,5.5585
400,5.3847
450,4.3933
500,4.4146


Memory usage at training step 1, device=0: memory (MB) | allocated:  1.32e+04 | max_allocated:  1.79e+04 | reserved:  1.81e+04 | max_reserved:  1.81e+04
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-500/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-1000/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-1500/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-2000/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-2500/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-3000/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled-1gpu/checkpoint-3200/modelopt_state.pth

‚úÖ TRAINING COMPLETE!


## üíæ Step 9: Save the Distilled Model

In [21]:
print("Saving model...\n")

# Save training state
trainer.save_state()

# Save the distilled student model
# Note: export_student parameter may not work with all ModelOpt versions
# If it fails, the model is already saved in checkpoints
try:
    # Try to export student model
    if hasattr(model, 'export'):
        exported_model = model.export()
        trainer.model = exported_model
        trainer.save_model(training_args.output_dir)
    else:
        # Fallback: save without export
        # The model checkpoints already contain the distilled weights
        print("‚ö†Ô∏è  Model doesn't have export() method")
        print("‚úÖ Using checkpoint-3200 as final model (already saved during training)")
        
        # Copy the latest checkpoint to output directory
        import shutil
        checkpoint_dir = f"{training_args.output_dir}/checkpoint-3200"
        if os.path.exists(checkpoint_dir):
            # The checkpoint is already there, just note it
            print(f"‚úÖ Final model available at: {checkpoint_dir}")
except Exception as e:
    print(f"‚ö†Ô∏è  Export failed: {e}")
    print("‚úÖ Model weights saved in checkpoints during training")

print(f"\n‚úÖ Model saved successfully!")
print(f"\nüìÇ Available checkpoints:")
for item in sorted(os.listdir(training_args.output_dir)):
    item_path = os.path.join(training_args.output_dir, item)
    if os.path.isdir(item_path) and item.startswith('checkpoint'):
        print(f"  - {item}")
        
print(f"\nüí° To use the model, load from the latest checkpoint:")
print(f"   AutoModelForCausalLM.from_pretrained('{training_args.output_dir}/checkpoint-3200')")

Saving model...

‚ö†Ô∏è  Model doesn't have export() method
‚úÖ Using checkpoint-3200 as final model (already saved during training)
‚úÖ Final model available at: ./llama3.2-1b-distilled-1gpu/checkpoint-3200

‚úÖ Model saved successfully!

üìÇ Available checkpoints:
  - checkpoint-1000
  - checkpoint-1500
  - checkpoint-2000
  - checkpoint-2500
  - checkpoint-3000
  - checkpoint-3200
  - checkpoint-500

üí° To use the model, load from the latest checkpoint:
   AutoModelForCausalLM.from_pretrained('./llama3.2-1b-distilled-1gpu/checkpoint-3200')


## üìà Step 10: Comprehensive Evaluation - Before & After Distillation

We'll evaluate three models to demonstrate the effectiveness of knowledge distillation:
1. **Teacher Model** (3B params) - Our target/reference
2. **Baseline Student** (1B params) - Before distillation (untrained)
3. **Distilled Student** (1B params) - After distillation (trained)

This comparison will show:
- How much the baseline student lags behind the teacher
- How much knowledge distillation helps close the gap
- The effectiveness of our training process

In [22]:
import math
import pandas as pd

print("="*80)
print("COMPREHENSIVE MODEL EVALUATION")
print("="*80)
print()

# Helper function to evaluate a model
def evaluate_model(model, model_name):
    """Evaluate a model and return metrics"""
    print(f"Evaluating {model_name}...")
    
    eval_trainer = transformers.Trainer(
        model,
        transformers.TrainingArguments(
            output_dir=training_args.output_dir,
            per_device_eval_batch_size=2,
            bf16=True,
        ),
        eval_dataset=dset_eval,
        data_collator=data_collator,
        processing_class=tokenizer,
    )
    
    results = eval_trainer.evaluate()
    perplexity = math.exp(results['eval_loss'])
    
    return {
        'loss': results['eval_loss'],
        'perplexity': perplexity,
        'runtime': results['eval_runtime'],
        'samples_per_sec': results['eval_samples_per_second']
    }

# Store results
results_dict = {}

# 1. Evaluate Teacher Model
print("\n" + "="*80)
print("1Ô∏è‚É£  TEACHER MODEL (3B - Target Performance)")
print("="*80)
results_dict['Teacher (3B)'] = evaluate_model(teacher_model, "Teacher")
print(f"‚úÖ Teacher Loss: {results_dict['Teacher (3B)']['loss']:.4f}")
print(f"‚úÖ Teacher Perplexity: {results_dict['Teacher (3B)']['perplexity']:.2f}")

# 2. Evaluate Baseline Student (before distillation)
print("\n" + "="*80)
print("2Ô∏è‚É£  BASELINE STUDENT (1B - Before Distillation)")
print("="*80)
print("Loading baseline student model (untrained)...\n")

baseline_student = AutoModelForCausalLM.from_pretrained(
    model_args.student_name_or_path,
    torch_dtype=torch.bfloat16,
    device_map={"":0}
)

results_dict['Baseline Student (1B)'] = evaluate_model(baseline_student, "Baseline Student")
print(f"‚úÖ Baseline Loss: {results_dict['Baseline Student (1B)']['loss']:.4f}")
print(f"‚úÖ Baseline Perplexity: {results_dict['Baseline Student (1B)']['perplexity']:.2f}")

# Free memory
del baseline_student
torch.cuda.empty_cache()

# 3. Evaluate Distilled Student (after training)
print("\n" + "="*80)
print("3Ô∏è‚É£  DISTILLED STUDENT (1B - After Distillation)")
print("="*80)
print("Loading distilled model...\n")

# Load from the latest checkpoint
distilled_model_path = f"{training_args.output_dir}/checkpoint-{training_args.max_steps}"
if not os.path.exists(distilled_model_path):
    # Fallback to the output directory itself
    distilled_model_path = training_args.output_dir
    
print(f"Loading from: {distilled_model_path}")

distilled_model = AutoModelForCausalLM.from_pretrained(
    distilled_model_path,
    torch_dtype=torch.bfloat16,
    device_map={"":0}
)

results_dict['Distilled Student (1B)'] = evaluate_model(distilled_model, "Distilled Student")
print(f"‚úÖ Distilled Loss: {results_dict['Distilled Student (1B)']['loss']:.4f}")
print(f"‚úÖ Distilled Perplexity: {results_dict['Distilled Student (1B)']['perplexity']:.2f}")

# Calculate improvements
print("\n" + "="*80)
print("üìä COMPARISON SUMMARY")
print("="*80)

# Create comparison table
df = pd.DataFrame(results_dict).T
df.index.name = 'Model'
df = df[['loss', 'perplexity', 'samples_per_sec']]
df.columns = ['Eval Loss', 'Perplexity', 'Samples/sec']

print("\n" + df.to_string())

# Calculate improvements
baseline_ppl = results_dict['Baseline Student (1B)']['perplexity']
distilled_ppl = results_dict['Distilled Student (1B)']['perplexity']
teacher_ppl = results_dict['Teacher (3B)']['perplexity']

improvement = ((baseline_ppl - distilled_ppl) / baseline_ppl) * 100
gap_closed = ((baseline_ppl - teacher_ppl) - (distilled_ppl - teacher_ppl)) / (baseline_ppl - teacher_ppl) * 100

print("\n" + "="*80)
print("üéØ KEY INSIGHTS")
print("="*80)
print(f"\n1. Performance Improvement:")
print(f"   Distillation improved perplexity by {improvement:.1f}%")
print(f"   (from {baseline_ppl:.2f} to {distilled_ppl:.2f})")

print(f"\n2. Gap to Teacher:")
print(f"   Baseline gap: {baseline_ppl - teacher_ppl:.2f} perplexity points")
print(f"   Distilled gap: {distilled_ppl - teacher_ppl:.2f} perplexity points")
print(f"   Knowledge distillation closed {gap_closed:.1f}% of the gap!")

print(f"\n3. Efficiency:")
print(f"   Distilled model has {student_params/1e9:.2f}B parameters")
print(f"   Teacher has {teacher_params/1e9:.2f}B parameters")
print(f"   Achieved {(1 - (distilled_ppl - teacher_ppl)/(baseline_ppl - teacher_ppl)) * 100:.1f}% of teacher quality")
print(f"   with only {(student_params/teacher_params)*100:.1f}% of parameters!")

print("\n" + "="*80)
print("‚úÖ Evaluation Complete!")
print("="*80)

The model is already on multiple devices. Skipping the move to device specified in `args`.


COMPREHENSIVE MODEL EVALUATION


1Ô∏è‚É£  TEACHER MODEL (3B - Target Performance)
Evaluating Teacher...




‚úÖ Teacher Loss: 1.6820
‚úÖ Teacher Perplexity: 5.38

2Ô∏è‚É£  BASELINE STUDENT (1B - Before Distillation)
Loading baseline student model (untrained)...

Evaluating Baseline Student...


The model is already on multiple devices. Skipping the move to device specified in `args`.


‚úÖ Baseline Loss: 2.3376
‚úÖ Baseline Perplexity: 10.36

3Ô∏è‚É£  DISTILLED STUDENT (1B - After Distillation)
Loading distilled model...

Loading from: ./llama3.2-1b-distilled-1gpu/checkpoint-3200
Restored ModelOpt state from ./llama3.2-1b-distilled-1gpu/checkpoint-3200/modelopt_state.pth


The model is already on multiple devices. Skipping the move to device specified in `args`.


Evaluating Distilled Student...


‚úÖ Distilled Loss: 1.8450
‚úÖ Distilled Perplexity: 6.33

üìä COMPARISON SUMMARY

                        Eval Loss  Perplexity  Samples/sec
Model                                                     
Teacher (3B)             1.682030    5.376461       21.560
Baseline Student (1B)    2.337583   10.356179       42.929
Distilled Student (1B)   1.844975    6.327943       43.541

üéØ KEY INSIGHTS

1. Performance Improvement:
   Distillation improved perplexity by 38.9%
   (from 10.36 to 6.33)

2. Gap to Teacher:
   Baseline gap: 4.98 perplexity points
   Distilled gap: 0.95 perplexity points
   Knowledge distillation closed 80.9% of the gap!

3. Efficiency:
   Distilled model has 1.24B parameters
   Teacher has 3.21B parameters
   Achieved 80.9% of teacher quality
   with only 38.5% of parameters!

‚úÖ Evaluation Complete!


## üß™ Step 11: Test Inference

Test the distilled model with sample prompts.

In [23]:
print("Testing distilled model inference...\n")

# Test prompt
test_messages = [
    {"role": "user", "content": "What is knowledge distillation and why is it useful?"}
]

test_prompt = tokenizer.apply_chat_template(
    test_messages, 
    tokenize=False, 
    add_generation_prompt=True
)

print("üìù Prompt:")
print(test_prompt)
print("\n" + "="*80)

# Tokenize input
inputs = tokenizer(test_prompt, return_tensors="pt").to(distilled_model.device)

# Generate response
print("Generating response...\n")
with torch.no_grad():
    outputs = distilled_model.generate(
        **inputs,
        max_new_tokens=200,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("ü§ñ Response:")
print("="*80)
print(response)
print("="*80)
print("\n‚úÖ Inference complete!")

Testing distilled model inference...

üìù Prompt:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 13 Nov 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

What is knowledge distillation and why is it useful?<|eot_id|><|start_header_id|>assistant<|end_header_id|>



Generating response...

ü§ñ Response:
system

Cutting Knowledge Date: December 2023
Today Date: 13 Nov 2025

user

What is knowledge distillation and why is it useful?assistant

Knowledge distillation is the process of selectively extracting the most valuable and relevant information from a vast amount of knowledge sources, such as books, articles, research papers, and expert opinions, and distilling it into a concise and actionable form. The goal of knowledge distillation is to identify the most important insights, patterns, and principles that can be applied to a specific problem or issue.

The benefits of knowledge distillation include:

1. **Reducin

## üéâ Lab Complete!

### Summary

You have successfully:
1. ‚úÖ Configured a single GPU environment
2. ‚úÖ Loaded and preprocessed a conversational dataset
3. ‚úÖ Set up knowledge distillation between teacher and student models
4. ‚úÖ Trained a distilled 1B parameter model from a 3B teacher
5. ‚úÖ Evaluated the distilled model's performance
6. ‚úÖ Tested inference with the distilled model

### Key Takeaways

- **Model Compression**: Reduced from 3.2B to 1.2B parameters (2.6√ó smaller)
- **Knowledge Transfer**: Student learned from both ground truth and teacher predictions
- **Practical Skills**: Dataset preprocessing, distributed training setup, model evaluation

### Next Steps

1. **Compare Performance**: Benchmark the distilled model against the teacher
2. **Optimize Further**: Try different hyperparameters or longer training
3. **Deploy**: Use the smaller model for efficient inference
4. **Experiment**: Try different teacher-student pairs or datasets

---

**Happy distilling! üöÄ**