# Model Distillation for Vision Models

**Model distillation** is a knowledge transfer technique where a small "student" model learns from a larger, pre-trained "teacher" model. The goal is to train a compact and fast student model that mimics the performance of the much larger teacher.

We will use:
- **Teacher Model**: A pre-trained `ResNet-18` from `torchvision`.
- **Student Model**: A simple, custom-built Convolutional Neural Network (CNN).
- **Dataset**: `CIFAR-10`.
- **Monitoring**: `TensorBoard` for logging training progress.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

# Set device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

2025-10-13 04:13:43.379525: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760328823.392143   20482 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760328823.396011   20482 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-13 04:13:43.408323: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Using device: cuda


## Define the Student Model

This is our small, lightweight model that we want to train. It's a simple CNN, much less complex than our ResNet-18 teacher.

In [2]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10) # 10 classes for CIFAR-10

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

## 3. Load Data and Models

We'll load the CIFAR-10 dataset and apply standard transformations. Then, we'll load our pre-trained ResNet-18 teacher model and instantiate our StudentNet.

In [3]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Load the pre-trained teacher model (ResNet-18)
teacher_model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
# Modify the final layer for CIFAR-10 (10 classes)
num_ftrs = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs, 10)
teacher_model = teacher_model.to(device)
# For this example, we'll assume the teacher is already trained on CIFAR-10.
# In a real scenario, you would fine-tune this teacher model on CIFAR-10 first.
teacher_model.eval() # Set teacher to evaluation mode

# Instantiate the student model
student_model = StudentNet().to(device)

## Define the Distillation Loss Function

The loss is a weighted average of two components:

1.  **Student Loss**: The standard cross-entropy loss between the student's predictions and the true labels. This teaches the student to predict the correct class.
2.  **Distillation Loss**: The Kullback-Leibler (KL) Divergence loss between the softened outputs of the teacher and the student. This encourages the student to mimic the teacher's probability distribution over classes, learning not just *what* the correct class is, but also *how* the teacher relates different classes to each other.

The `temperature` parameter is used to soften the probability distributions, making them less peaked and providing more information for the student to learn from.

In [4]:
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    """
    Calculates the distillation loss.
    
    Args:
        student_logits: Logits from the student model.
        teacher_logits: Logits from the teacher model.
        true_labels: The ground truth labels.
        temperature (float): Softening parameter for the probability distributions.
        alpha (float): Weight for the distillation loss component.
    """
    # Distillation loss (KL divergence)
    soft_teacher_probs = F.softmax(teacher_logits / temperature, dim=1)
    soft_student_log_probs = F.log_softmax(student_logits / temperature, dim=1)
    distill_loss = F.kl_div(soft_student_log_probs, soft_teacher_probs, reduction='batchmean') * (temperature ** 2)

    # Standard cross-entropy loss with true labels
    student_loss = F.cross_entropy(student_logits, true_labels)

    # Combine the two losses
    total_loss = alpha * distill_loss + (1. - alpha) * student_loss
    return total_loss

## Training the Student Model

Now we'll write the training loop. In each step, we get predictions from both the teacher and the student. We then calculate our custom `distillation_loss` and update the student's weights. The teacher's weights remain frozen throughout.

We will also evaluate the model on the test set at the end of each epoch and log the results to TensorBoard.

In [5]:
# Hyperparameters
epochs = 15
learning_rate = 0.001
temperature = 5.0
alpha = 0.7 # Weight for distillation loss

optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
writer = SummaryWriter('./logs/distillation') # TensorBoard writer

print("Starting distillation training...")
global_step = 0
for epoch in range(epochs):
    running_loss = 0.0
    
    student_model.train() # Set student to training mode
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass for the student
        student_logits = student_model(inputs)

        # Forward pass for the teacher (no gradient calculation needed)
        with torch.no_grad():
            teacher_logits = teacher_model(inputs)

        # Calculate the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix({'loss': running_loss / (progress_bar.n + 1)})
        
        # Log training loss to TensorBoard
        writer.add_scalar('Loss/train_step', loss.item(), global_step)
        global_step += 1
    
    # --- Evaluation at the end of each epoch ---
    student_model.eval() # Set student to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = student_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'\nEpoch {epoch+1} Test Accuracy: {accuracy:.2f} %\n')
    
    # Log accuracy and epoch loss to TensorBoard
    writer.add_scalar('Accuracy/test', accuracy, epoch)
    writer.add_scalar('Loss/train_epoch', running_loss / len(trainloader), epoch)


writer.close()
print("Finished Training")

Starting distillation training...


Epoch 1/15: 100%|██████████| 391/391 [00:03<00:00, 111.18it/s, loss=0.65] 



Epoch 1 Test Accuracy: 53.03 %



Epoch 2/15: 100%|██████████| 391/391 [00:03<00:00, 118.96it/s, loss=0.611]



Epoch 2 Test Accuracy: 58.94 %



Epoch 3/15: 100%|██████████| 391/391 [00:03<00:00, 118.58it/s, loss=0.589]



Epoch 3 Test Accuracy: 62.85 %



Epoch 4/15: 100%|██████████| 391/391 [00:03<00:00, 120.05it/s, loss=0.579]



Epoch 4 Test Accuracy: 64.61 %



Epoch 5/15: 100%|██████████| 391/391 [00:03<00:00, 116.00it/s, loss=0.561]



Epoch 5 Test Accuracy: 66.03 %



Epoch 6/15: 100%|██████████| 391/391 [00:03<00:00, 117.47it/s, loss=0.546]



Epoch 6 Test Accuracy: 65.93 %



Epoch 7/15: 100%|██████████| 391/391 [00:03<00:00, 118.77it/s, loss=0.54] 



Epoch 7 Test Accuracy: 66.95 %



Epoch 8/15: 100%|██████████| 391/391 [00:03<00:00, 119.14it/s, loss=0.547]



Epoch 8 Test Accuracy: 66.79 %



Epoch 9/15: 100%|██████████| 391/391 [00:03<00:00, 117.78it/s, loss=0.528]



Epoch 9 Test Accuracy: 67.99 %



Epoch 10/15: 100%|██████████| 391/391 [00:03<00:00, 117.72it/s, loss=0.536]



Epoch 10 Test Accuracy: 66.63 %



Epoch 11/15: 100%|██████████| 391/391 [00:03<00:00, 120.18it/s, loss=0.531]



Epoch 11 Test Accuracy: 68.56 %



Epoch 12/15: 100%|██████████| 391/391 [00:03<00:00, 118.65it/s, loss=0.532]



Epoch 12 Test Accuracy: 68.83 %



Epoch 13/15: 100%|██████████| 391/391 [00:03<00:00, 117.09it/s, loss=0.518]



Epoch 13 Test Accuracy: 68.24 %



Epoch 14/15: 100%|██████████| 391/391 [00:03<00:00, 118.12it/s, loss=0.524]



Epoch 14 Test Accuracy: 69.71 %



Epoch 15/15: 100%|██████████| 391/391 [00:03<00:00, 117.65it/s, loss=0.508]



Epoch 15 Test Accuracy: 68.72 %

Finished Training


In [6]:
%load_ext tensorboard
%tensorboard --logdir logs --port 6006 --bind_all