# Knowledge Distillation on CIFAR-10

Presented by:
* Luca Zanchetta
  * Email: zanchettaluca.99@gmail.com
  * GitHub: https://github.com/luca-zanchetta

Official repository available at: [LINK]

## Description

In this notebook, I briefly present some funny experiments I have done during my free time at the beginning of my Master Thesis work. In particular, the aim of these experiments was to state whether:
1. The Knowledge Distillation technique actually performs better than a full training, starting from a given model;
2. The performance of the student model trained with the help of a pruned version of the teacher is greater than that of the 'classic' student model.



## Import statements

In [None]:
%%capture
!pip install pytorch_lightning

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.nn.utils.prune as prune

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from pytorch_lightning.callbacks import EarlyStopping

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Hyperparameters

In [None]:
BATCH_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 0.0001
PATIENCE = 2
PRUNE_RATE = 0.3

# KD
TEMPERATURE = 7.0
ALPHA = 0.3

## Load dataset

In [None]:
def get_cifar_ds():
    train_ds = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
    test_ds = CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

    val_ds, test_ds = train_test_split(
      test_ds, test_size=0.6, random_state=42
    )

    return train_ds, val_ds, test_ds

In [None]:
train_ds, val_ds, test_ds = get_cifar_ds()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 41765804.25it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Teacher model

### Class definition

In [None]:
class Teacher(pl.LightningModule):
  def __init__(self):
    super(Teacher, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
    self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.fc1 = nn.Linear(128*2*2, 512)  # 2x2 image size after 3 max pool operations
    self.fc2 = nn.Linear(512, 1024)
    self.fc3 = nn.Linear(1024, 10)   # 10 num_classes

    self.correct = 0
    self.total = 0
    self.loss = 0
    self.epoch = 0

  # Pruning function
  def prune_model(self):
    parameters_to_prune = [
        (self.conv1, 'weight'),
        (self.conv2, 'weight'),
        (self.conv3, 'weight'),
        (self.fc1, 'weight'),
        (self.fc2, 'weight'),
        (self.fc3, 'weight')
    ]

    for module, name in parameters_to_prune:
        prune.l1_unstructured(module, name=name, amount=PRUNE_RATE)

  def forward(self, x):
    x = self.max_pool(F.relu(self.conv1(x)))
    x = self.max_pool(F.relu(self.conv2(x)))
    x = self.max_pool(F.relu(self.conv3(x)))
    x = torch.flatten(x, 1)   # Flatten all dimensions except batch
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

  def training_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    loss = nn.CrossEntropyLoss()(outputs, targets)
    self.loss = loss
    return loss

  def on_train_epoch_end(self):
    print(f"[Training] Epoch {self.epoch}; Loss: {self.loss:.5f}")
    self.epoch += 1

  def validation_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_validation_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_validation_epoch_end(self):
    accuracy = (self.correct / self.total)*100
    self.log('accuracy', accuracy)
    print(f"[Validation] Accuracy: {accuracy}")

  def test_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_test_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_test_epoch_end(self):
    print(f"Accuracy of the Teacher: {(self.correct / self.total)*100}")

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)

### Training

In [None]:
early_stop_callback = EarlyStopping(
    monitor="accuracy",
    min_delta=0.00,
    patience=PATIENCE,
    verbose=False,
    mode="max"
)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
teacher = Teacher()
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, accelerator="auto", callbacks=[early_stop_callback])
trainer.fit(teacher, train_dataloader, val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type      | Params
---------------------------------------
0 | conv1    | Conv2d    | 896   
1 | conv2    | Conv2d    | 18.5 K
2 | conv3    | Conv2d    | 73.9 K
3 | max_pool | MaxPool2d | 0     
4 | fc1      | Linear    | 262 K 
5 | fc2      | Linear    | 525 K 
6 | fc3      | Linear    | 10.2 K
---------------------------------------
891 K     Trainable params
0         Non-trainable params
891 K     Total params
3.566     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 14.0625


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 40.550000000000004
[Training] Epoch 0; Loss: 1.84011


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 48.975
[Training] Epoch 1; Loss: 1.63656


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.25
[Training] Epoch 2; Loss: 1.41573


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.825
[Training] Epoch 3; Loss: 0.88559


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 59.050000000000004
[Training] Epoch 4; Loss: 1.32926


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 60.475
[Training] Epoch 5; Loss: 1.32124


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 61.6
[Training] Epoch 6; Loss: 0.76470


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 63.675000000000004
[Training] Epoch 7; Loss: 1.13752


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 63.625
[Training] Epoch 8; Loss: 1.42374


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 63.9
[Training] Epoch 9; Loss: 1.02403


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 65.25
[Training] Epoch 10; Loss: 0.76448


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 66.925
[Training] Epoch 11; Loss: 0.47997


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 67.7
[Training] Epoch 12; Loss: 0.90444


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 67.30000000000001
[Training] Epoch 13; Loss: 0.67812


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 69.025
[Training] Epoch 14; Loss: 0.96085


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 67.72500000000001
[Training] Epoch 15; Loss: 0.48125


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 68.525
[Training] Epoch 16; Loss: 0.66664


### Evaluation

In [None]:
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
tester = pl.Trainer(max_epochs=1, accelerator="auto")
tester.test(teacher, test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Accuracy of the Teacher: 68.85


[{}]

## Smaller model

###Class definition

In [None]:
class Smaller(pl.LightningModule):
  def __init__(self):
    super(Smaller, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.fc1 = nn.Linear(32*6*6, 10)  # 6x6 output after 2 max pooling operations with stride 2

    self.correct = 0
    self.total = 0
    self.loss = 0
    self.epoch = 0

  def forward(self, x):
    x = self.max_pool(F.relu(self.conv1(x)))
    x = self.max_pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)   # Flatten all dimensions except batch
    x = self.fc1(x)
    return x

  def training_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    loss = nn.CrossEntropyLoss()(outputs, targets)
    self.loss = loss
    return loss

  def on_train_epoch_end(self):
    print(f"[Training] Epoch {self.epoch}; Loss: {self.loss:.5f}")
    self.epoch += 1

  def validation_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_validation_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_validation_epoch_end(self):
    accuracy = (self.correct / self.total)*100
    self.log('accuracy', accuracy)
    print(f"[Validation] Accuracy: {accuracy}")

  def test_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_test_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_test_epoch_end(self):
    print(f"Accuracy of the Smaller: {(self.correct / self.total)*100}")

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)

### Training

In [None]:
early_stop_callback = EarlyStopping(
    monitor="accuracy",
    min_delta=0.00,
    patience=PATIENCE,
    verbose=False,
    mode="max"
)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
smaller = Smaller()
trainer = pl.Trainer(max_epochs=NUM_EPOCHS, accelerator="auto", callbacks=[early_stop_callback])
trainer.fit(smaller, train_dataloader, val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type      | Params
---------------------------------------
0 | conv1    | Conv2d    | 448   
1 | conv2    | Conv2d    | 4.6 K 
2 | max_pool | MaxPool2d | 0     
3 | fc1      | Linear    | 11.5 K
---------------------------------------
16.6 K    Trainable params
0         Non-trainable params
16.6 K    Total params
0.066     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 7.8125


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 37.55
[Training] Epoch 0; Loss: 1.80738


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 42.325
[Training] Epoch 1; Loss: 1.85470


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 45.4
[Training] Epoch 2; Loss: 1.70343


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 46.85
[Training] Epoch 3; Loss: 1.42019


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 49.175000000000004
[Training] Epoch 4; Loss: 1.42353


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 51.1
[Training] Epoch 5; Loss: 1.72377


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 51.175000000000004
[Training] Epoch 6; Loss: 1.52037


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.325
[Training] Epoch 7; Loss: 1.50035


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.275000000000006
[Training] Epoch 8; Loss: 1.27770


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.625
[Training] Epoch 9; Loss: 1.40056


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.35
[Training] Epoch 10; Loss: 1.69006


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.825
[Training] Epoch 11; Loss: 1.60522


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.75
[Training] Epoch 12; Loss: 1.45609


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.574999999999996
[Training] Epoch 13; Loss: 1.26430


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.525000000000006
[Training] Epoch 14; Loss: 1.74920


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.25
[Training] Epoch 15; Loss: 1.23316


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.95
[Training] Epoch 16; Loss: 1.19677


### Evaluation

In [None]:
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
tester = pl.Trainer(max_epochs=1, accelerator="auto")
tester.test(smaller, test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Accuracy of the Smaller: 55.766666666666666


[{}]

## Student model

###Class definition

In [None]:
class Student(pl.LightningModule):
  def __init__(self, teacher):
    super(Student, self).__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3)
    self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3)
    self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
    self.fc1 = nn.Linear(32*6*6, 10)  # 6x6 output after 2 max pooling operations with stride 2

    self.teacher = teacher

    self.correct = 0
    self.total = 0
    self.loss = 0
    self.epoch = 0

  def forward(self, x):
    x = self.max_pool(F.relu(self.conv1(x)))
    x = self.max_pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)   # Flatten all dimensions except batch
    x = self.fc1(x)
    return x

  def training_step(self, batch, batch_idx):
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)

    # Output logits
    student_logits = self(inputs)
    teacher_logits = self.teacher(inputs)

    # Convert logits to probabilities
    student_probs = torch.softmax(student_logits / TEMPERATURE, dim=1)
    teacher_probs = torch.log_softmax(teacher_logits / TEMPERATURE, dim=1)

    # Compute distillation loss
    distillation_loss = nn.KLDivLoss(reduction='batchmean')(teacher_probs, student_probs)
    hard_target_loss = nn.CrossEntropyLoss()(student_logits, targets)
    loss = ALPHA * hard_target_loss + (1.0 - ALPHA) * (TEMPERATURE**2) * distillation_loss

    self.loss = loss
    return loss

  def on_train_epoch_end(self):
    print(f"[Training] Epoch {self.epoch}; Loss: {self.loss:.5f}")
    self.epoch += 1

  def validation_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_validation_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_validation_epoch_end(self):
    accuracy = (self.correct / self.total)*100
    self.log('accuracy', accuracy)
    print(f"[Validation] Accuracy: {accuracy}")

  def test_step(self, batch, batch_idx):
    inputs, targets = batch
    outputs = self(inputs)
    _, predicted = torch.max(outputs.data, 1)
    self.total += targets.size(0)
    self.correct += (predicted == targets).sum().item()

  def on_test_epoch_start(self):
    self.total = 0
    self.correct = 0

  def on_test_epoch_end(self):
    print(f"Accuracy of the Student: {(self.correct / self.total)*100}")

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=LEARNING_RATE)

### Training

In [None]:
early_stop_callback = EarlyStopping(
    monitor="accuracy",
    min_delta=0.00,
    patience=PATIENCE,
    verbose=False,
    mode="max"
)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

teacher = teacher.to(device)
student = Student(teacher=teacher).to(device)

trainer = pl.Trainer(max_epochs=NUM_EPOCHS, accelerator="auto", callbacks=[early_stop_callback])
trainer.fit(student, train_dataloader, val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type      | Params
---------------------------------------
0 | conv1    | Conv2d    | 448   
1 | conv2    | Conv2d    | 4.6 K 
2 | max_pool | MaxPool2d | 0     
3 | fc1      | Linear    | 11.5 K
4 | teacher  | Teacher   | 891 K 
---------------------------------------
908 K     Trainable params
0         Non-trainable params
908 K     Total params
3.632     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 9.375


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 39.925
[Training] Epoch 0; Loss: 0.55292


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 42.875
[Training] Epoch 1; Loss: 0.48789


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 45.1
[Training] Epoch 2; Loss: 0.51295


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 47.225
[Training] Epoch 3; Loss: 0.36877


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 48.325
[Training] Epoch 4; Loss: 0.44236


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 49.85
[Training] Epoch 5; Loss: 0.35475


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 50.6
[Training] Epoch 6; Loss: 0.45082


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 51.225
[Training] Epoch 7; Loss: 0.32454


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.025
[Training] Epoch 8; Loss: 0.32183


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.275000000000006
[Training] Epoch 9; Loss: 0.41732


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.075
[Training] Epoch 10; Loss: 0.57040


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.025
[Training] Epoch 11; Loss: 0.51565


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.949999999999996
[Training] Epoch 12; Loss: 0.39206


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.6
[Training] Epoch 13; Loss: 0.36904


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.275
[Training] Epoch 14; Loss: 0.40552


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.15
[Training] Epoch 15; Loss: 0.59355


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.875
[Training] Epoch 16; Loss: 0.38063


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.025000000000006
[Training] Epoch 17; Loss: 0.32454


### Evaluation

In [None]:
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
tester = pl.Trainer(max_epochs=1, accelerator="auto")
tester.test(student, test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Accuracy of the Student: 55.93333333333334


[{}]

## Student model with Pruned Teacher

### Prune Teacher Model

In [None]:
pruned_teacher = teacher
pruned_teacher.prune_model()

### Training

In [None]:
early_stop_callback = EarlyStopping(
    monitor="accuracy",
    min_delta=0.00,
    patience=PATIENCE,
    verbose=False,
    mode="max"
)

In [None]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

pruned_teacher = pruned_teacher.to(device)
student_from_pruned = Student(teacher=pruned_teacher).to(device)

trainer = pl.Trainer(max_epochs=NUM_EPOCHS, accelerator="auto", callbacks=[early_stop_callback])
trainer.fit(student_from_pruned, train_dataloader, val_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type      | Params
---------------------------------------
0 | conv1    | Conv2d    | 448   
1 | conv2    | Conv2d    | 4.6 K 
2 | max_pool | MaxPool2d | 0     
3 | fc1      | Linear    | 11.5 K
4 | teacher  | Teacher   | 891 K 
---------------------------------------
908 K     Trainable params
0         Non-trainable params
908 K     Total params
3.632     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 14.0625


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 40.675
[Training] Epoch 0; Loss: 0.51096


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 43.75
[Training] Epoch 1; Loss: 0.45765


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 46.550000000000004
[Training] Epoch 2; Loss: 0.49822


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 47.775
[Training] Epoch 3; Loss: 0.54996


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 49.6
[Training] Epoch 4; Loss: 0.41987


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 50.975
[Training] Epoch 5; Loss: 0.65538


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 51.975
[Training] Epoch 6; Loss: 0.43112


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.625
[Training] Epoch 7; Loss: 0.38454


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 52.55
[Training] Epoch 8; Loss: 0.34205


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 53.949999999999996
[Training] Epoch 9; Loss: 0.28224


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.574999999999996
[Training] Epoch 10; Loss: 0.37946


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.725
[Training] Epoch 11; Loss: 0.41869


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 54.65
[Training] Epoch 12; Loss: 0.39437


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.05
[Training] Epoch 13; Loss: 0.39401


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.300000000000004
[Training] Epoch 14; Loss: 0.60454


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 55.175
[Training] Epoch 15; Loss: 0.35902


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.225
[Training] Epoch 16; Loss: 0.43600


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 56.525000000000006
[Training] Epoch 17; Loss: 0.27491


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.225
[Training] Epoch 18; Loss: 0.33315


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.550000000000004
[Training] Epoch 19; Loss: 0.51958


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.675
[Training] Epoch 20; Loss: 0.43274


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.75
[Training] Epoch 21; Loss: 0.42854


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.775
[Training] Epoch 22; Loss: 0.34441


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 58.050000000000004
[Training] Epoch 23; Loss: 0.47838


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 57.85
[Training] Epoch 24; Loss: 0.33920


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 58.9
[Training] Epoch 25; Loss: 0.29173


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 59.25
[Training] Epoch 26; Loss: 0.41513


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 58.550000000000004
[Training] Epoch 27; Loss: 0.36572


Validation: |          | 0/? [00:00<?, ?it/s]

[Validation] Accuracy: 59.475
[Training] Epoch 28; Loss: 0.41102


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


[Validation] Accuracy: 59.724999999999994
[Training] Epoch 29; Loss: 0.42707


### Evaluation

In [None]:
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
tester = pl.Trainer(max_epochs=1, accelerator="auto")
tester.test(student_from_pruned, test_dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Accuracy of the Student: 59.63333333333334


[{}]

## Conclusion

Results show that:
1. In general, a student model performs better than an identical model which has not been trained by applying knowledge distillation;
2. A student model trained on a pruned version of the teacher can outperform the student model trained by applying the classical knowledge distillation technique. However, in this case, the performance heavily depends on the adopted pruning strategy.