# Knowledge Distillation for HuggingFace Models - Single GPU

This notebook demonstrates how to perform **Knowledge Distillation** using TensorRT Model Optimizer on a single GPU (perfect for Google Colab).

## What is Knowledge Distillation?

Knowledge Distillation is a technique where a smaller "student" model learns to mimic a larger "teacher" model. The student learns from both:
1. **Ground truth labels** (standard training)
2. **Soft predictions from the teacher** (distillation)

### Example in this notebook:
- **Teacher**: Llama-3.2-3B-Instruct (3 billion parameters)
- **Student**: Llama-3.2-1B (1 billion parameters)
- **Result**: A smaller, faster model with teacher's knowledge!

---

## üì¶ Step 1: Install Dependencies

First, we need to install TensorRT Model Optimizer and other required packages.

In [None]:
# For faster library installation
!pip install uv nvitop
# Install TensorRT Model Optimizer with HuggingFace support
!uv pip install -U nvidia-modelopt[hf]

!uv pip uninstall numpy transformers
# Install additional dependencies
!uv pip install pyarrow 'transformers<5.0' 'trl>=0.23.0' 'numpy<2.0' bitsandbytes accelerate

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m91 packages[0m [2min 607ms[0m[0m
[2K[2mPrepared [1m2 packages[0m [2min 0.53ms[0m[0m
[2mUninstalled [1m2 packages[0m [2min 343ms[0m[0m
[2K[2mInstalled [1m2 packages[0m [2min 253ms[0m[0m
 [31m-[39m [1mnumpy[0m[2m==1.26.4[0m
 [32m+[39m [1mnumpy[0m[2m==2.3.4[0m
 [31m-[39m [1mtransformers[0m[2m==4.57.1[0m
 [32m+[39m [1mtransformers[0m[2m==4.56.2[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2mUninstalled [1m2 packages[0m [2min 263ms[0m[0m
 [31m-[39m [1mnumpy[0m[2m==2.3.4[0m
 [31m-[39m [1mtransformers[0m[2m==4.56.2[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m68 packages[0m [2min 71ms[0m[0m
[2K[2mInstalled [1m2 packages[0m [2min 105ms[0m[0m
 [32m+[39m [1mnumpy[0m[2m==1.26.4[0m
 [32m+[39m [1mtransformers[0m[2m==4.57.1[0m


## üîß Step 2: Check GPU Availability

Let's verify that we have a GPU available for training.

In [None]:
!nvidia-smi

Mon Oct 27 12:00:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   51C    P8             12W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## üìö Step 3: Import Libraries

Import all necessary libraries for knowledge distillation.

In [None]:
import os
from dataclasses import dataclass

# Optimize CUDA memory allocation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import datasets
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer

# TensorRT Model Optimizer imports
import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss

print("‚úì All libraries imported successfully!")
print(f"==> PyTorch version: {torch.__version__}")
print(f"==> Transformers version: {transformers.__version__}")
print(f"==> CUDA available: {torch.cuda.is_available()}")



‚úì All libraries imported successfully!
==> PyTorch version: 2.9.0+cu128
==> Transformers version: 4.57.1
==> CUDA available: True




## ‚öôÔ∏è Step 4: Configuration

Set up the configuration for models and training hyperparameters.

### üìù You can modify these settings:
- **Models**: Change teacher/student models
- **Batch size**: Adjust based on your GPU memory
- **Training steps**: Increase for better results (will take longer)
- **Learning rate**: Fine-tune the learning process

In [None]:
@dataclass
class ModelArguments:
    """Model Configuration"""
    # Teacher: Larger model we distill FROM
    teacher_name_or_path: str = "meta-llama/Llama-3.2-3B-Instruct"

    # Student: Smaller model we distill TO
    student_name_or_path: str = "meta-llama/Llama-3.2-1B"


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    """Training Configuration"""
    output_dir: str = "./llama3.2-1b-distilled"
    do_train: bool = True
    do_eval: bool = True
    save_strategy: str = "steps"
    save_steps: int = 100
    max_length: int = 512

    # Optimizer settings
    optim: str = "adamw_torch"
    learning_rate: float = 1e-5
    lr_scheduler_type: str = "cosine"

    # Data processing
    dataloader_drop_last: bool = True
    dataset_num_proc: int = 4

    # Mixed precision (faster training, less memory)
    bf16: bool = True
    tf32: bool = False

    # Batch size - ADJUST based on your GPU memory!
    per_device_train_batch_size: int = 1
    per_device_eval_batch_size: int = 1
    gradient_accumulation_steps: int = 1  # Effective batch size = 1 * 4 = 4

    # Training duration
    max_steps: int = 200  # Increase for better results (e.g., 500, 1000)

    # Logging
    logging_steps: int = 5
    eval_steps: int = 50
    warmup_steps: int = 10
    report_to: str = "none" # Disable wandb reporting


# Create configuration instances
model_args = ModelArguments()
training_args = TrainingArguments(output_dir="./llama3.2-1b-distilled")

print("Configuration:")
print(f"  Teacher: {model_args.teacher_name_or_path}")
print(f"  Student: {model_args.student_name_or_path}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Max steps: {training_args.max_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

Configuration:
  Teacher: meta-llama/Llama-3.2-3B-Instruct
  Student: meta-llama/Llama-3.2-1B
  Batch size: 1
  Gradient accumulation: 1
  Effective batch size: 1
  Max steps: 200
  Learning rate: 1e-05


## üìä Step 5: Load Dataset

We'll use the **smol-smoltalk-Interaction-SFT** dataset, which contains conversational query-answer pairs.

In [None]:
print("Loading dataset...")

# Load the dataset from HuggingFace
dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train")

# Split into training and evaluation sets
dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420)
dset_train, dset_eval = dset_splits["train"], dset_splits["test"]

print(f"‚úì Dataset loaded!")
print(f"  Training samples: {len(dset_train):,}")
print(f"  Evaluation samples: {len(dset_eval):,}")
print(f"\nSample data:")
print(dset_train[0])

Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


‚úì Dataset loaded!
  Training samples: 12,800
  Evaluation samples: 1,280

Sample data:
{'query': 'What are Data visualization types.', 'answer': 'Data visualization types are diverse and can be categorized based on their purpose, structure, and functionality. Here are some common data visualization types:\n\n**Basic Visualization Types:**\n\n1. Bar charts: Used to compare categorical data across different groups.\n2. Line charts: Used to show trends and patterns over time.\n3. Pie charts: Used to represent proportional data.\n4. Histograms: Used to display the distribution of continuous data.\n5. Scatter plots: Used to visualize relationships between two variables.\n\n**Advanced Visualization Types:**\n\n1. Heat maps: Used to display complex relationships between two variables.\n2. Tree maps: Used to display hierarchical data.\n3. Network diagrams: Used to show relationships between entities.\n4. Sankey diagrams: Used to display flows and relationships between variables.\n5. Gauge ch

## üî§ Step 6: Load Tokenizer

Load the tokenizer to convert text into tokens that the model can understand.

In [None]:
print("Loading tokenizer...")

model_path = model_args.teacher_name_or_path

# Use the huggingface_hub library to log in with the token from Colab secrets
from huggingface_hub import login
from google.colab import userdata

try:
    hf_token = ""
    login(token=hf_token)
    print("‚úì Successfully logged in to Hugging Face Hub!")
except Exception as e:
    print(f"Error logging in to Hugging Face Hub: {e}")
    print("Please make sure you have added your HF_TOKEN to Colab secrets.")


tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)

# Configure padding
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"‚úì Tokenizer loaded from {model_path}")
print(f"  Vocab size: {len(tokenizer):,}")
print(f"  Pad token: '{tokenizer.pad_token}'")

Loading tokenizer...
‚úì Successfully logged in to Hugging Face Hub!
‚úì Tokenizer loaded from meta-llama/Llama-3.2-3B-Instruct
  Vocab size: 128,256
  Pad token: '<|eot_id|>'


## üéØ Step 7: Define Data Formatting Function

This function formats our dataset samples into the chat template format.

In [None]:
def _format_smoltalk_chat_template(sample, tokenizer):
    """
    Convert dataset sample into chat format.

    Args:
        sample: Dataset sample with 'query' and 'answer' fields
        tokenizer: Tokenizer with chat template

    Returns:
        Formatted conversation string
    """
    messages = [
        {"role": "user", "content": sample["query"]},
        {"role": "assistant", "content": sample["answer"]},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False)

print("‚úì Data formatting function defined")

‚úì Data formatting function defined


## ü§ñ Step 8: Load Student Model

Load the smaller student model that will learn from the teacher.

In [None]:
print(f"Loading student model: {model_args.student_name_or_path}")
print("This may take a few minutes...")

student_model = AutoModelForCausalLM.from_pretrained(
    model_args.student_name_or_path,
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
    device_map="auto"
)

student_params = sum(p.numel() for p in student_model.parameters())
print(f"\n‚úì Student model loaded!")
print(f"  Parameters: {student_params:,} ({student_params/1e9:.2f}B)")
print(f"  Device: {next(student_model.parameters()).device}")

Loading student model: meta-llama/Llama-3.2-1B
This may take a few minutes...


`torch_dtype` is deprecated! Use `dtype` instead!



‚úì Student model loaded!
  Parameters: 1,235,814,400 (1.24B)
  Device: cuda:0


## üë®‚Äçüè´ Step 9: Load Teacher Model & Configure Distillation

Load the larger teacher model and set up knowledge distillation.

In [None]:
print(f"Loading teacher model: {model_args.teacher_name_or_path}")
print("This may take a few minutes...")

teacher_model = AutoModelForCausalLM.from_pretrained(
    model_args.teacher_name_or_path,
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
    # load_in_4bit=True,  # Add this line for 8-bit quantization
    device_map="auto"
)

teacher_params = sum(p.numel() for p in teacher_model.parameters())
print(f"\n‚úì Teacher model loaded!")
print(f"  Parameters: {teacher_params:,} ({teacher_params/1e9:.2f}B)")
print(f"  Device: {next(teacher_model.parameters()).device}")
print(f"  Compression ratio: {teacher_params/student_params:.2f}x")

# Configure Knowledge Distillation
print("\nConfiguring Knowledge Distillation...")
kd_config = {
    "teacher_model": teacher_model,
    "criterion": LMLogitsLoss(),  # KL-divergence on logits
}

# Enable ModelOpt checkpointing
mto.enable_huggingface_checkpointing()

# Convert student to distillation model
model = mtd.convert(student_model, mode=[("kd_loss", kd_config)])

# Fix generation config warnings
model.generation_config.temperature = None
model.generation_config.top_p = None

print("‚úì Distillation configured!")
print("  Loss function: LMLogitsLoss (KL-divergence)")
print("  Student will learn from:")
print("    1. Ground truth labels")
print("    2. Teacher's predictions")

# Check memory usage
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1024**3
    reserved = torch.cuda.memory_reserved(0) / 1024**3
    print(f"\nGPU Memory:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")

Loading teacher model: meta-llama/Llama-3.2-3B-Instruct
This may take a few minutes...


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


‚úì Teacher model loaded!
  Parameters: 1,803,463,680 (1.80B)
  Device: cuda:0
  Compression ratio: 1.46x

Configuring Knowledge Distillation...
ModelOpt save/restore enabled for `transformers` library.
ModelOpt save/restore enabled for `diffusers` library.
ModelOpt save/restore enabled for `peft` library.
‚úì Distillation configured!
  Loss function: LMLogitsLoss (KL-divergence)
  Student will learn from:
    1. Ground truth labels
    2. Teacher's predictions

GPU Memory:
  Allocated: 4.51 GB
  Reserved: 5.34 GB


## üèãÔ∏è Step 10: Create Custom Trainer

Define a custom trainer that combines supervised fine-tuning with knowledge distillation.

In [None]:
class KDSFTTrainer(SFTTrainer, KDTrainer):
    """
    Combined Knowledge Distillation + Supervised Fine-Tuning Trainer.

    Inherits from:
    - SFTTrainer: Supervised fine-tuning logic
    - KDTrainer: Knowledge distillation logic
    """
    pass

## üéì Step 11: Initialize Trainer

Set up the trainer with our models, datasets, and configuration.

In [None]:
print("Initializing trainer...")

trainer = KDSFTTrainer(
    model,
    training_args,
    train_dataset=dset_train,
    eval_dataset=dset_eval,
    formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer),
    processing_class=tokenizer,
)

print("‚úì Trainer initialized!")
print(f"  Training steps: {training_args.max_steps}")
print(f"  Checkpoints: {training_args.output_dir}")

Initializing trainer...


Truncating train dataset (num_proc=4):   0%|          | 0/12800 [00:00<?, ? examples/s]

Truncating eval dataset (num_proc=4):   0%|          | 0/1280 [00:00<?, ? examples/s]

ModelOpt save/restore enabled for `transformers` library.
ModelOpt save/restore enabled for `diffusers` library.
ModelOpt save/restore enabled for `peft` library.


The model is already on multiple devices. Skipping the move to device specified in `args`.


‚úì Trainer initialized!
  Training steps: 200
  Checkpoints: ./llama3.2-1b-distilled


## üöÄ Step 12: Start Training!

Now we train the student model with knowledge distillation.

**This will take some time!** Monitor the loss values:
- **loss**: Combined loss (should decrease)
- Lower loss = better learning

In [None]:
print("="*80)
print("STARTING TRAINING")
print("="*80)
print("The student is now learning from the teacher...\n")

# Train!
trainer.train()

print("\n‚úì Training completed!")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


STARTING TRAINING
The student is now learning from the teacher...



Step,Training Loss
5,1.3413
10,1.2372
15,1.2168
20,1.0857
25,0.8748
30,1.1062
35,0.9413
40,0.8246
45,0.7382
50,0.834


Memory usage at training step 1, device=0: memory (MB) | allocated:  9.35e+03 | max_allocated:  1.41e+04 | reserved:  1.42e+04 | max_reserved:  1.42e+04
Saved ModelOpt state to ./llama3.2-1b-distilled/checkpoint-100/modelopt_state.pth
Saved ModelOpt state to ./llama3.2-1b-distilled/checkpoint-200/modelopt_state.pth

‚úì Training completed!


## üìà Step 13: Evaluate the Model

Evaluate the trained student model on the test set.

In [None]:
print("="*80)
print("RUNNING EVALUATION")
print("="*80)

eval_results = trainer.evaluate()

print("\nEvaluation Results:")
for key, value in eval_results.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

print("\n‚úì Evaluation complete!")

RUNNING EVALUATION





Evaluation Results:
  eval_loss: 2.0978
  eval_runtime: 2116.9910
  eval_samples_per_second: 0.6050
  eval_steps_per_second: 0.6050
  eval_entropy: 1.8538
  eval_num_tokens: 70365.0000
  eval_mean_token_accuracy: 0.5848
  epoch: 0.0156

‚úì Evaluation complete!


## üíæ Step 14: Save the Distilled Model

Save the trained student model for later use.

In [None]:
print("="*80)
print("SAVING MODEL")
print("="*80)

# Save training state
trainer.save_state()

# Save the student model (removes distillation wrapper)
trainer.save_model(training_args.output_dir, export_student=True)

print(f"‚úì Model saved to: {training_args.output_dir}")
print("\nYou can now load the model with:")
print(f"  model = AutoModelForCausalLM.from_pretrained('{training_args.output_dir}')")

SAVING MODEL
Saved ModelOpt state to ./llama3.2-1b-distilled/modelopt_state.pth
‚úì Model saved to: ./llama3.2-1b-distilled

You can now load the model with:
  model = AutoModelForCausalLM.from_pretrained('./llama3.2-1b-distilled')


## üéâ Step 15: Test the Distilled Model (Optional)

Let's try generating some text with our newly trained model!

In [None]:
print("Testing the distilled model...\n")

# Export the student model for inference
# inference_model = mtd.export(model) # This line caused the error

# Load the saved student model directly from the output directory
inference_model = AutoModelForCausalLM.from_pretrained(training_args.output_dir)


# Prepare a test prompt
test_messages = [
    {"role": "user", "content": "What is knowledge distillation?"}
]
test_prompt = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)

# Tokenize
inputs = tokenizer(test_prompt, return_tensors="pt").to(inference_model.device)

# Generate
print("Generating response...\n")
with torch.no_grad():
    outputs = inference_model.generate(
        **inputs,
        max_new_tokens=150,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )

# Decode and print
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Model Response:")
print("="*80)
print(response)
print("="*80)
print("\n‚úì Inference test complete!")

Testing the distilled model...

Restored ModelOpt state from ./llama3.2-1b-distilled/modelopt_state.pth
Generating response...

Model Response:
system

Cutting Knowledge Date: December 2023
Today Date: 27 Oct 2025

user

What is knowledge distillation?assistant

Knowledge distillation is a process of taking raw data and transforming it into a distilled, distilled, distilled version of the original information. It's a method of creating a concise, focused, and accurate summary of a larger body of knowledge. The goal of knowledge distillation is to extract the most valuable lessons from the original data, which can be used to improve decision-making, problem-solving, and overall understanding.

Here's a step-by-step guide on how knowledge distillation works:

1. **Data collection:** This is where the raw, original data is collected. This could be from a survey, a report, or a dataset.
2. **Analysis:** The data is analyzed to identify the most important insights, patterns, and relationshi

## üéä Congratulations!

You've successfully completed knowledge distillation! üéâ

### What you've accomplished:
‚úÖ Loaded a large teacher model (3B parameters)  
‚úÖ Loaded a small student model (1B parameters)  
‚úÖ Configured knowledge distillation with TensorRT Model Optimizer  
‚úÖ Trained the student to learn from the teacher  
‚úÖ Saved a smaller, faster model with the teacher's knowledge  

### Next Steps:
- **Fine-tune further**: Increase `max_steps` for better results
- **Try different models**: Change teacher/student in the configuration
- **Use your own data**: Replace the dataset with your own
- **Deploy the model**: Use the saved model for inference

### Resources:
- [TensorRT Model Optimizer Docs](https://nvidia.github.io/TensorRT-Model-Optimizer/)
- [HuggingFace Transformers](https://huggingface.co/docs/transformers)
- [TRL Library](https://huggingface.co/docs/trl)

---
**Happy distilling! üöÄ**