<a href="https://colab.research.google.com/github/dorzv/ComputerVision/blob/master/knowledge_distillation/knowledge_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Knowledge Distillation
A notebook for the blog post [Knowledge Distillation — How Networks Can Teach](https://dzdata.medium.com/knowledge-distillation-how-networks-can-teach-a3e287d28eea)

In [None]:
from copy import deepcopy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.datasets as datasets

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

<torch._C.Generator at 0x7aa6ba054170>

Create CIFAR-10 datasets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 41.8MB/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


Create DataLoaders

In [None]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

## Define Models

Teacher model

In [None]:
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features_extractor = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features_extractor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Student model

In [None]:
class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features_extractor = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features_extractor(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

## Training

In [None]:
def train(model, dataloader, epochs, learning_rate, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.to(device)
    model.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Mean Loss: {running_loss / len(train_loader)}")

In [None]:
def test(model, data_loader, device):
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            predicted_class = torch.argmax(outputs, dim=1)

            total += labels.size(0)
            correct += (predicted_class == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy

We now train the teacher and student models in a regular way (without knowledge distillation)

In [None]:
print('Teacher')
teacher = TeacherModel(num_classes=10)
train(teacher, train_loader, epochs=30, learning_rate=0.001, device=device)
test_accuracy_teacher = test(teacher, test_loader, device)

print('\nStudent')
student1 = StudentModel(num_classes=10)
student2 = deepcopy(student1)
train(student1, train_loader, epochs=30, learning_rate=0.001, device=device)
test_accuracy_student1 = test(student1, test_loader, device)

Teacher
Epoch 1/30, Mean Loss: 1.227013551975455
Epoch 2/30, Mean Loss: 0.8006297277520075
Epoch 3/30, Mean Loss: 0.6526896884984068
Epoch 4/30, Mean Loss: 0.5564498979112377
Epoch 5/30, Mean Loss: 0.46989234474004077
Epoch 6/30, Mean Loss: 0.4065409217725324
Epoch 7/30, Mean Loss: 0.33846967425340274
Epoch 8/30, Mean Loss: 0.28724067671524595
Epoch 9/30, Mean Loss: 0.23811305247609268
Epoch 10/30, Mean Loss: 0.19084547246661027
Epoch 11/30, Mean Loss: 0.16286258813936996
Epoch 12/30, Mean Loss: 0.13428122205350101
Epoch 13/30, Mean Loss: 0.12302607766655095
Epoch 14/30, Mean Loss: 0.10258829252094107
Epoch 15/30, Mean Loss: 0.08792296682229585
Epoch 16/30, Mean Loss: 0.08880708702480244
Epoch 17/30, Mean Loss: 0.07689892817669741
Epoch 18/30, Mean Loss: 0.07379904136900096
Epoch 19/30, Mean Loss: 0.0673288441503711
Epoch 20/30, Mean Loss: 0.06766464117952549
Epoch 21/30, Mean Loss: 0.061599307644712116
Epoch 22/30, Mean Loss: 0.05611306975078781
Epoch 23/30, Mean Loss: 0.0592158445242

## Knowledge Distillation

In [None]:
def train_knowledge_distillation(teacher, student, data_loader, epochs, learning_rate, temperature, alpha, device):
    criterion = nn.CrossEntropyLoss()
    criterion_distill = nn.KLDivLoss(reduction='batchmean')
    optimizer = optim.Adam(student.parameters(), lr=learning_rate)
    teacher.to(device)
    student.to(device)
    teacher.eval()
    student.train()

    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass with the teacher model
            with torch.no_grad():
                teacher_logits = teacher(images)

            # Forward pass with the student model
            student_logits = student(images)

            # Calculate the soft label loss
            soft_targets = torch.softmax(teacher_logits / temperature, dim=1)
            soft_probs = torch.log_softmax(student_logits / temperature, dim=1)

            distillation_loss = criterion_distill(soft_probs, soft_targets)

            # Calculate the true label loss
            label_loss = criterion(student_logits, labels)

            # Weighted sum of the two losses
            loss = alpha * label_loss + (1.0 - alpha) * temperature**2 * distillation_loss
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}")

In [None]:
train_knowledge_distillation(teacher, student2, train_loader, epochs=30, learning_rate=0.001, temperature=2, alpha=0.8, device=device)
test_accuracy_student_with_distillation = test(student2, test_loader, device)

Epoch 1/30, Loss: 2.1964448251382773
Epoch 2/30, Loss: 1.6716872932356033
Epoch 3/30, Loss: 1.4725313601286516
Epoch 4/30, Loss: 1.3435371722406744
Epoch 5/30, Loss: 1.2465182867501399
Epoch 6/30, Loss: 1.1615051953384028
Epoch 7/30, Loss: 1.0832165539112237
Epoch 8/30, Loss: 1.0170706502921747
Epoch 9/30, Loss: 0.948769456590228
Epoch 10/30, Loss: 0.8996146485933563
Epoch 11/30, Loss: 0.8429149076761797
Epoch 12/30, Loss: 0.7944906306693621
Epoch 13/30, Loss: 0.7432063121320037
Epoch 14/30, Loss: 0.708498257657756
Epoch 15/30, Loss: 0.6670365183402205
Epoch 16/30, Loss: 0.6346628576745768
Epoch 17/30, Loss: 0.5962860049189204
Epoch 18/30, Loss: 0.5657400330314246
Epoch 19/30, Loss: 0.5407315236528206
Epoch 20/30, Loss: 0.5148001837608455
Epoch 21/30, Loss: 0.4852442342759398
Epoch 22/30, Loss: 0.46830802027831603
Epoch 23/30, Loss: 0.43950596619445037
Epoch 24/30, Loss: 0.4295593956318658
Epoch 25/30, Loss: 0.4140379125504847
Epoch 26/30, Loss: 0.4031776143690509
Epoch 27/30, Loss: 0.