# Knowledge Distillation Approach Using PyTorch

The general work flow:
<p>
  <img alt=Knowledge Distillation Workflow" src="distillation_workflow.png" width="450" height="200"/>
</p>

[img source: knowledge distillation section](https://www.linkedin.com/learning/ai-model-compression-techniques-building-cheaper-faster-and-greener-ai)

## Load Modules

In [12]:
import copy
from datetime import datetime
import os
import random
import time
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

### Set Random Seeds for reproducibility

In [13]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

### Do an initial check if Compute Unified Device Architecture (CUDA) is available

Checking if CUDA is available is a crucial step in applications, particularly in deep learning and high-performance computing, for several reasons: 

* Enabling GPU Acceleration:

    CUDA (Compute Unified Device Architecture) is NVIDIA's parallel computing platform and API that allows software to leverage the power of NVIDIA GPUs for general-purpose computing. Checking for its availability determines whether your program can offload computationally intensive tasks to the GPU, leading to significant speedups compared to CPU-only execution.

* Conditional Code Execution:

    By checking for CUDA availability, you can write code that dynamically adapts to the hardware environment. If a CUDA-enabled GPU is present, your program can utilize GPU-specific operations and data structures. If not, it can gracefully fall back to CPU implementations or inform the user about the lack of GPU support. This prevents errors and ensures your application can run on various systems.

* Resource Management:

    Knowing if CUDA is available allows you to manage resources effectively. If a GPU is present, you can allocate memory on the device and perform computations there. If not, you avoid attempting to access non-existent GPU resources, which would lead to errors.

* Error Prevention and Debugging:

    Explicitly checking for CUDA availability helps in identifying and preventing issues related to missing or improperly configured CUDA installations or incompatible GPU drivers. If the check fails, it provides an immediate indication that GPU acceleration is not possible, guiding troubleshooting efforts.

* Optimized Performance:

    Many deep learning frameworks and libraries are designed to leverage CUDA for optimal performance. Verifying CUDA availability ensures that these frameworks can utilize the intended hardware acceleration, leading to faster training times and inference speeds for machine learning models.

In [14]:
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
print(f"Using device for training: {device}")

Using device for training: cpu


---

# Building the Teacher and Student Models: Model Definition

---

<p>
  <img alt="The Teacher - Student Architecture" src="teacher_student_architecture.png" width="450" height="300"/>
</p>

[img source: knowledge distillation section](https://www.linkedin.com/learning/ai-model-compression-techniques-building-cheaper-faster-and-greener-ai)

In [15]:
# Define teacher and student models
class TeacherModel(nn.Module):
    """A larger model to act as the teacher"""
    def __init__(self):
        super(TeacherModel, self).__init__()
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

        # Pooling and dropout
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 3 * 3, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        # Feature extraction
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))

        # Flatten
        x = x.view(-1, 128 * 3 * 3)

        # Classification
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

    def get_features(self, x):
        """Get intermediate features for additional distillation"""
        features = []

        # Extract features from each layer
        x = F.relu(self.conv1(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv3(x))
        features.append(x)
        x = self.pool(x)

        x = x.view(-1, 128 * 3 * 3)
        x = F.relu(self.fc1(x))
        features.append(x)

        return features

class StudentModel(nn.Module):
    """A smaller model to be trained via knowledge distillation"""
    def __init__(self):
        super(StudentModel, self).__init__()
        # Convolutional layers (fewer filters)
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        # Pooling
        self.pool = nn.MaxPool2d(2, 2)

        # Fully connected layers (smaller)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # Feature extraction
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))

        # Flatten
        x = x.view(-1, 32 * 7 * 7)

        # Classification
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x

    def get_features(self, x):
        """Get intermediate features for additional distillation"""
        features = []

        # Extract features from each layer
        x = F.relu(self.conv1(x))
        features.append(x)
        x = self.pool(x)

        x = F.relu(self.conv2(x))
        features.append(x)
        x = self.pool(x)

        x = x.view(-1, 32 * 7 * 7)
        x = F.relu(self.fc1(x))
        features.append(x)

        return features

---

# Knowledge Distillation Loss

---

<p>
  <img alt="Feature Distillation" src="feature_distillation.png" width="450" height="300"/>
</p>

[img source: knowledge distillation section](https://www.linkedin.com/learning/ai-model-compression-techniques-building-cheaper-faster-and-greener-ai)

In [16]:
# Knowledge distillation loss
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super(DistillationLoss, self).__init__()
        self.alpha = alpha  # Weight for distillation loss vs standard loss
        self.temperature = temperature  # Temperature for softening probability distributions

    def forward(self, student_logits, teacher_logits, labels):
        # Standard cross-entropy loss
        hard_loss = F.cross_entropy(student_logits, labels)

        # Distillation loss: KL-divergence between soft targets from teacher and student
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_prob = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (self.temperature ** 2)

        # Combine the two losses
        loss = (1 - self.alpha) * hard_loss + self.alpha * soft_loss

        return loss

# Feature distillation loss - optional enhancement
class FeatureDistillationLoss(nn.Module):
    def __init__(self, beta=0.1):
        super(FeatureDistillationLoss, self).__init__()
        self.beta = beta  # Weight for feature distillation

    def forward(self, student_features, teacher_features):
        # We'll implement a simple L2 distance for feature matching
        # For simplicity, we only use the last feature map from each
        loss = 0

        # Adapt student feature dimensions to match teacher's
        student_last_feature = student_features[-1]
        teacher_last_feature = teacher_features[-1]

        # Compute the mean squared error loss
        feat_loss = F.mse_loss(student_last_feature, teacher_last_feature)

        return self.beta * feat_loss

### Loading the Modified National Institute of Standards and Technology (MNIST) Dataset

The MNIST (Modified National Institute of Standards and Technology) dataset is a widely used dataset in the field of machine learning, particularly for image recognition and classification tasks.

#### Key characteristics of the MNIST dataset:

* **Handwritten Digits:**
    - It consists of a large collection of grayscale images of handwritten digits (0-9).

* **Image Dimensions:**
    - Each image is a 28x28 pixel grayscale image.

* **Dataset Size:**
    - It comprises a training set of 60,000 examples and a test set of 10,000 examples. 

In [17]:
# Load MNIST dataset
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)
    return trainloader, testloader

# Train Teacher and Student Models (standard training)



In [18]:
# Train teacher model (standard training)
def train_teacher(model, trainloader, epochs=3):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("Training teacher model...")
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished training teacher model')
    return model

# Train student model with knowledge distillation
def train_student_with_distillation(student_model, teacher_model, trainloader,
                                   epochs=3, alpha=0.5, temperature=2.0, beta=0.0):
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval()  # Teacher model is fixed

    distill_criterion = DistillationLoss(alpha=alpha, temperature=temperature)
    feature_criterion = FeatureDistillationLoss(beta=beta) if beta > 0 else None
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)

    print("Training student model with distillation...")
    student_model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()

            # Get outputs from both models
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
                if beta > 0:
                    teacher_features = teacher_model.get_features(inputs)

            student_outputs = student_model(inputs)
            if beta > 0:
                student_features = student_model.get_features(inputs)

            # Compute distillation loss
            loss = distill_criterion(student_outputs, teacher_outputs, labels)

            # Add feature matching loss if requested
            if beta > 0:
                feature_loss = feature_criterion(student_features, teacher_features)
                loss += feature_loss

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished training student model')
    return student_model

---

## Fine-Tuning Student Model

---

In [19]:
# Fine-tune student model on task-specific data
def fine_tune_student(model, trainloader, epochs=2):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005)  # Lower learning rate for fine-tuning

    print("Fine-tuning student model...")
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data[0].to(device), data[1].to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Fine-tuning Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
                running_loss = 0.0

    print('Finished fine-tuning student model')
    return model

## Utils for Evaluation 

In [20]:
# Evaluate model accuracy
def evaluate_model(model, testloader):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Measure inference time
def measure_inference_time(model, testloader, num_batches=10):
    model.to(device)
    model.eval()

    # Warm-up
    for i, (images, _) in enumerate(testloader):
        if i > 5:
            break
        images = images.to(device)
        with torch.no_grad():
            _ = model(images)

    # Measure time
    start_time = time.time()
    batch_count = 0

    with torch.no_grad():
        for i, (images, _) in enumerate(testloader):
            if i >= num_batches:
                break
            images = images.to(device)
            _ = model(images)
            batch_count += 1

    end_time = time.time()
    avg_time = (end_time - start_time) / batch_count

    return avg_time

# Get model size
def get_model_size(model):
    torch.save(model.state_dict(), "temp_model.pt")
    size_mb = os.path.getsize("temp_model.pt") / (1024 * 1024)
    os.remove("temp_model.pt")
    return size_mb

# Save model predictions for further analysis
def save_predictions(model, testloader, filename):
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    np.savez(filename, predictions=np.array(all_preds), labels=np.array(all_labels))

# Visualize predictions
def plot_confusion_matrix(model_name, predictions_file):
    data = np.load(predictions_file)
    preds = data['predictions']
    labels = data['labels']

    from sklearn.metrics import confusion_matrix
    import seaborn as sns

    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f'Confusion Matrix - {model_name}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(f'{model_name}_confusion_matrix.png')
    plt.close()

---

## Main Function 

---

Run the experiments.

In [21]:
# Main function
def main():
    # Load data
    trainloader, testloader = load_data()

    # Create and train teacher model
    teacher_model = TeacherModel()
    teacher_params = count_parameters(teacher_model)
    print(f"Teacher model has {teacher_params:,} parameters")

    # Check if a pretrained model exists to save time
    if os.path.exists('teacher_model.pt'):
        print("Loading pre-trained teacher model...")
        teacher_model.load_state_dict(torch.load('teacher_model.pt'))
    else:
        teacher_model = train_teacher(teacher_model, trainloader, epochs=3)
        torch.save(teacher_model.state_dict(), 'teacher_model.pt')

    # Evaluate teacher model
    teacher_accuracy = evaluate_model(teacher_model, testloader)
    teacher_inference_time = measure_inference_time(teacher_model, testloader)
    teacher_size = get_model_size(teacher_model)

    print("\n--- Teacher Model Metrics ---")
    print(f"Accuracy: {teacher_accuracy:.2f}%")
    print(f"Parameters: {teacher_params:,}")
    print(f"Inference Time: {teacher_inference_time*1000:.2f} ms per batch")
    print(f"Model Size: {teacher_size:.2f} MB")

    # Create student model
    student_model = StudentModel()
    student_params = count_parameters(student_model)
    print(f"\nStudent model has {student_params:,} parameters")
    print(f"Parameter reduction: {(1 - student_params/teacher_params)*100:.1f}%")

    # Train student model without distillation (for comparison)
    standard_student = copy.deepcopy(student_model)
    if os.path.exists('standard_student.pt'):
        print("Loading pre-trained standard student model...")
        standard_student.load_state_dict(torch.load('standard_student.pt'))
    else:
        standard_student = train_teacher(standard_student, trainloader, epochs=3)
        torch.save(standard_student.state_dict(), 'standard_student.pt')

    # Evaluate standard student
    standard_student_accuracy = evaluate_model(standard_student, testloader)
    standard_student_time = measure_inference_time(standard_student, testloader)
    standard_student_size = get_model_size(standard_student)

    print("\n--- Standard Student Model Metrics ---")
    print(f"Accuracy: {standard_student_accuracy:.2f}%")
    print(f"Parameters: {student_params:,}")
    print(f"Inference Time: {standard_student_time*1000:.2f} ms per batch")
    print(f"Model Size: {standard_student_size:.2f} MB")

    # Train student with knowledge distillation
    distilled_student = copy.deepcopy(student_model)

    if os.path.exists('distilled_student.pt'):
        print("Loading pre-trained distilled student model...")
        distilled_student.load_state_dict(torch.load('distilled_student.pt'))
    else:
        distilled_student = train_student_with_distillation(
            distilled_student, teacher_model, trainloader,
            epochs=3, alpha=0.5, temperature=4.0)
        torch.save(distilled_student.state_dict(), 'distilled_student.pt')

    # Evaluate distilled student
    distilled_accuracy = evaluate_model(distilled_student, testloader)
    distilled_time = measure_inference_time(distilled_student, testloader)
    distilled_size = get_model_size(distilled_student)

    print("\n--- Distilled Student Model Metrics ---")
    print(f"Accuracy: {distilled_accuracy:.2f}%")
    print(f"Parameters: {student_params:,}")
    print(f"Inference Time: {distilled_time*1000:.2f} ms per batch")
    print(f"Model Size: {distilled_size:.2f} MB")

    # Fine-tune the distilled student
    fine_tuned_student = copy.deepcopy(distilled_student)

    if os.path.exists('fine_tuned_student.pt'):
        print("Loading pre-trained fine-tuned student model...")
        fine_tuned_student.load_state_dict(torch.load('fine_tuned_student.pt'))
    else:
        fine_tuned_student = fine_tune_student(fine_tuned_student, trainloader, epochs=2)
        torch.save(fine_tuned_student.state_dict(), 'fine_tuned_student.pt')

    # Evaluate fine-tuned student
    fine_tuned_accuracy = evaluate_model(fine_tuned_student, testloader)
    fine_tuned_time = measure_inference_time(fine_tuned_student, testloader)

    print("\n--- Fine-tuned Student Model Metrics ---")
    print(f"Accuracy: {fine_tuned_accuracy:.2f}%")
    print(f"Accuracy Improvement from Distillation: {distilled_accuracy - standard_student_accuracy:.2f}%")
    print(f"Accuracy Improvement from Fine-tuning: {fine_tuned_accuracy - distilled_accuracy:.2f}%")
    print(f"Inference Time: {fine_tuned_time*1000:.2f} ms per batch")

    # Save predictions for analysis
    save_predictions(teacher_model, testloader, 'teacher_preds.npz')
    save_predictions(standard_student, testloader, 'standard_student_preds.npz')
    save_predictions(distilled_student, testloader, 'distilled_student_preds.npz')
    save_predictions(fine_tuned_student, testloader, 'fine_tuned_student_preds.npz')

    # Comparison summary
    print("\n" + "="*50)
    print("KNOWLEDGE DISTILLATION SUMMARY")
    print("="*50)
    print(f"{'Model':<25} {'Accuracy':<10} {'Size (MB)':<12} {'Inference (ms)':<15} {'Parameters':<12}")
    print("-" * 75)
    print(f"{'Teacher':<25} {teacher_accuracy:<10.2f} {teacher_size:<12.2f} {teacher_inference_time*1000:<15.2f} {teacher_params:,}")
    print(f"{'Student (Standard)':<25} {standard_student_accuracy:<10.2f} {standard_student_size:<12.2f} {standard_student_time*1000:<15.2f} {student_params:,}")
    print(f"{'Student (Distilled)':<25} {distilled_accuracy:<10.2f} {distilled_size:<12.2f} {distilled_time*1000:<15.2f} {student_params:,}")
    print(f"{'Student (Fine-tuned)':<25} {fine_tuned_accuracy:<10.2f} {distilled_size:<12.2f} {fine_tuned_time*1000:<15.2f} {student_params:,}")

    # Visualization
    models = ['Teacher', 'Student\nStandard', 'Student\nDistilled', 'Student\nFine-tuned']
    accuracies = [teacher_accuracy, standard_student_accuracy, distilled_accuracy, fine_tuned_accuracy]
    params = [teacher_params, student_params, student_params, student_params]
    inference_times = [teacher_inference_time*1000, standard_student_time*1000,
                       distilled_time*1000, fine_tuned_time*1000]

    # Create bar charts
    plt.figure(figsize=(15, 10))

    # Accuracy comparison
    plt.subplot(2, 2, 1)
    plt.bar(models, accuracies, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Accuracy (%)')
    plt.ylabel('Accuracy')

    # Parameter comparison
    plt.subplot(2, 2, 2)
    plt.bar(models, params, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Parameters')
    plt.ylabel('Parameters')

    # Inference time comparison
    plt.subplot(2, 2, 3)
    plt.bar(models, inference_times, color=['blue', 'orange', 'green', 'red'])
    plt.title('Inference Time (ms)')
    plt.ylabel('Time (ms)')

    # Size comparison
    sizes = [teacher_size, standard_student_size, distilled_size, distilled_size]
    plt.subplot(2, 2, 4)
    plt.bar(models, sizes, color=['blue', 'orange', 'green', 'red'])
    plt.title('Model Size (MB)')
    plt.ylabel('Size (MB)')

    plt.tight_layout()
    plt.savefig('knowledge_distillation_comparison.png')
    plt.close()

    # Plot confusion matrices
    plot_confusion_matrix('Teacher', 'teacher_preds.npz')
    plot_confusion_matrix('Standard_Student', 'standard_student_preds.npz')
    plot_confusion_matrix('Distilled_Student', 'distilled_student_preds.npz')
    plot_confusion_matrix('Fine_Tuned_Student', 'fine_tuned_student_preds.npz')

    print("\nVisualization saved as 'knowledge_distillation_comparison.png'")
    print("Confusion matrices saved for each model.")


In [22]:
if __name__ == "__main__":
    main()

Teacher model has 688,138 parameters
Loading pre-trained teacher model...

--- Teacher Model Metrics ---
Accuracy: 98.97%
Parameters: 688,138
Inference Time: 118.97 ms per batch
Model Size: 2.63 MB

Student model has 206,922 parameters
Parameter reduction: 69.9%
Loading pre-trained standard student model...

--- Standard Student Model Metrics ---
Accuracy: 98.77%
Parameters: 206,922
Inference Time: 112.56 ms per batch
Model Size: 0.79 MB
Loading pre-trained distilled student model...

--- Distilled Student Model Metrics ---
Accuracy: 98.65%
Parameters: 206,922
Inference Time: 111.10 ms per batch
Model Size: 0.79 MB
Loading pre-trained fine-tuned student model...

--- Fine-tuned Student Model Metrics ---
Accuracy: 99.12%
Accuracy Improvement from Distillation: -0.12%
Accuracy Improvement from Fine-tuning: 0.47%
Inference Time: 112.09 ms per batch

KNOWLEDGE DISTILLATION SUMMARY
Model                     Accuracy   Size (MB)    Inference (ms)  Parameters  
-------------------------------