# Knowledge Distillation
- The concept of **knowledge distillation** is to utilize class probabilities of a higher-capacity model (teacher) as soft targets of a smaller model (student)
- The implement processes can be divided into several stages:
  1. Finish the `ResNet()` classes
  2. Train the teacher model (ResNet50) and the student model (ResNet18) from scratch, i.e. **without KD**
  3. Define the `Distiller()` class and `loss_re()`, `loss_fe()` functions
  4. Train the student model **with KD** from the teacher model in two different ways, response-based and feature based distillation
  5. Comparison of student models w/ & w/o KD

## Setup

In [17]:
# ! pip install torchinfo

In [18]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
from torchinfo import summary
from tqdm import tqdm
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
import os
from PIL import Image

In [19]:
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

## Download dataset

In [20]:
validation_split = 0.2
batch_size = 128

# data augmentation and normalization
transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# download dataset
train_and_val_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=False,
    transform=transform_test,
    download=True
)

# split train and validation dataset
train_size = int((1 - validation_split) * len(train_and_val_dataset))
val_size = len(train_and_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_and_val_dataset, [train_size, val_size])

# create dataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

test_num = len(test_dataset)
test_steps = len(test_loader)

## Create teacher and student models
### Define BottleNeck for ResNet50

In [21]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

### Define Resifual Block

In [22]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

### Define ResNet Model

In [23]:
class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Finish the forward pass and return the output layer as well as hidden features.
        # 2. The output layer and hidden features will be used later for distilling.
        # 3. You can refer to the ResNet structure illustration to finish it.
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        #layer output
        feature1 = self.layer1(x)
        feature2 = self.layer2(feature1)
        feature3 = self.layer3(feature2)
        feature4 = self.layer4(feature3)

        out = self.avgpool(feature4)
        out = torch.flatten(out, 1)
        x = self.fc(out)
        return x, [feature1, feature2, feature3, feature4]

### Define ResNet50 and Resnet18

In [24]:
def resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet50(num_classes=10):
    return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes)

## Teacher Model (ResNet50)

In [99]:
# Teacher = resnet50(num_classes=10)  # commment out this line if loading trained teacher model
Teacher = torch.load('Teacher1.pt', weights_only=False)  # loading trained teacher model
Teacher = Teacher.to(device)

In [41]:
summary(Teacher)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BottleNeck: 2-1                   --
│    │    └─Conv2d: 3-1                  4,096
│    │    └─BatchNorm2d: 3-2             128
│    │    └─Conv2d: 3-3                  36,864
│    │    └─BatchNorm2d: 3-4             128
│    │    └─Conv2d: 3-5                  16,384
│    │    └─BatchNorm2d: 3-6             512
│    │    └─ReLU: 3-7                    --
│    │    └─Sequential: 3-8              16,896
│    └─BottleNeck: 2-2                   --
│    │    └─Conv2d: 3-9                  16,384
│    │    └─BatchNorm2d: 3-10            128
│    │    └─Conv2d: 3-11                 36,864
│    │    └─BatchNorm2d: 3-12            128
│    │    └─Conv2d: 3-13               

## Student Model (ResNet18)

In [100]:
# Student = resnet18(num_classes=10)  # commment out this line if loading trained student model
Student = torch.load('Student1.pt', weights_only=False)  # loading trained student model
Student = Student.to(device)

In [44]:
summary(Student)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 73,728

## Define training function

In [29]:
def train_from_scratch(model, train_loader, val_loader, epochs, learning_rate, device, model_name):
    criterion = nn.CrossEntropyLoss()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.

        model.train()
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, hidden = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            pred = logits.data.max(1, keepdim=True)[1]
            correct += np.sum(np.squeeze(pred.eq(labels.data.view_as(pred))).cpu().numpy())
            total += images.size(0)
            train_acc =  correct/total
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        model.eval()
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs, hidden_outputs = model(val_images)
                loss = criterion(outputs, val_labels)
                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.data.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    torch.save(model, f'{model_name}.pt')
    print(f'{model_name}.pt is saved')

    print('Finished Training')

## Define testing function

In [30]:
def test(model, test_loader ,device, type=None):
    criterion = nn.CrossEntropyLoss()
    acc = 0.0
    test_loss = 0.0

    if type == None:
        model.eval()
    elif type == 'distiller':
        model.eval()
        model.teacher.eval()
        model.student.eval()
    else:
       raise ValueError(f'Error: only support response-based and feature-based distillation')

    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            if type == None:
                outputs, features = model(test_images)
                loss = criterion(outputs, test_labels)
            elif type == 'distiller':
                outputs, loss = model(test_images, test_labels)
            else:
                raise ValueError(f'Error: only support response-based and feature-based distillation')

            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()
            test_loss += loss.item()
            test_bar.desc = "test"

    test_accurate = acc / test_num
    print('test_loss: %.3f  test_accuracy: %.3f' %(test_loss / test_steps, test_accurate * 100))
    return test_loss / test_steps, test_accurate * 100.

## Train Teacher and Student model from scratch

In [35]:
# Decide the epochs and learning rate
train_from_scratch(Teacher, train_loader, val_loader, epochs=50 , learning_rate= 0.01, device=device, model_name="Teacher")

train epoch[1/50]: 100%|██████████| 313/313 [00:27<00:00, 11.22it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:04<00:00, 17.34it/s]
	Training Loss: 0.360963 	Validation Loss: 0.449705
	Train Accuracy: 87.595d% (35038/40000)	Valdation Accuracy: 85.120d% (8512/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:29<00:00, 10.78it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:04<00:00, 17.72it/s]
	Training Loss: 0.404853 	Validation Loss: 0.837649
	Train Accuracy: 86.470d% (34588/40000)	Valdation Accuracy: 73.000d% (7300/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:28<00:00, 11.17it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:04<00:00, 17.80it/s]
	Training Loss: 0.457830 	Validation Loss: 0.486907
	Train Accuracy: 84.567d% (33827/40000)	Valdation Accuracy: 84.020d% (8402/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:27<00:00, 11.28it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:04<00:00, 17.67it/s]
	Training Loss: 0.314713 	Validation Loss: 0.451118

In [36]:
T_loss, T_accuracy = test(Teacher, test_loader, device=device)

test: 100%|██████████| 79/79 [00:03<00:00, 25.64it/s]
test_loss: 0.486  test_accuracy: 87.980


In [37]:
# Decide the epochs and learning rate
train_from_scratch(Student, train_loader, val_loader, epochs=50 , learning_rate= 0.01, device=device, model_name="Student")

train epoch[1/50]: 100%|██████████| 313/313 [00:20<00:00, 15.21it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:03<00:00, 20.14it/s]
	Training Loss: 1.795088 	Validation Loss: 1.594286
	Train Accuracy: 34.275d% (13710/40000)	Valdation Accuracy: 41.180d% (4118/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:20<00:00, 15.45it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:03<00:00, 20.55it/s]
	Training Loss: 1.307280 	Validation Loss: 1.252743
	Train Accuracy: 52.710d% (21084/40000)	Valdation Accuracy: 55.250d% (5525/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:20<00:00, 15.47it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:03<00:00, 20.48it/s]
	Training Loss: 1.086648 	Validation Loss: 1.104006
	Train Accuracy: 61.085d% (24434/40000)	Valdation Accuracy: 60.710d% (6071/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:20<00:00, 15.47it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:03<00:00, 20.64it/s]
	Training Loss: 0.956561 	Validation Loss: 0.966722

In [38]:
S_loss, S_accuracy = test(Student, test_loader, device=device)

test: 100%|██████████| 79/79 [00:02<00:00, 32.26it/s]
test_loss: 0.471  test_accuracy: 88.210


## Define distillation

### Define the loss functions

In [None]:
# Finish the loss function for response-based distillation.
def loss_re(student_logits, teacher_logits, labels):
    T = 5# Set temperature parameter
    alpha = 0.5 # Set weighting parameter

    # Implement loss calculation
    Student_prob = F.log_softmax(student_logits / T, dim=1)
    teacher_prob = F.softmax(teacher_logits / T, dim=1)
    kd_loss = F.kl_div(Student_prob, teacher_prob, reduction='batchmean') * (T * T)

    #caculate 標準分類損失（直接用原始分數算 cross-entropy，和 ground truth label 做對比。）
    ce_loss = F.cross_entropy(student_logits, labels)
    loss = alpha * kd_loss + (1 - alpha) * ce_loss #(1 - alpha) * ce_loss：偏向 ground-truth（教材）的部份。alpha * kd_loss：偏向模仿老師的部份。
    return loss

In [103]:
# Finish the loss function for feature-based distillation.
def loss_fe(student_features, teacher_features, adapters, student_logits, labels, alpha=0.5):
    feature_loss = 0
    n = len(student_features)
    for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
        s_feat_adapted = adapters[i](s_feat)
        feature_loss += F.mse_loss(s_feat_adapted, t_feat)
    feature_loss = feature_loss / n
    ce_loss = F.cross_entropy(student_logits, labels)
    loss = alpha * feature_loss + (1 - alpha) * ce_loss
    return loss


### Define Distillation Framework

In [92]:
class Distiller(nn.Module):
    def __init__(self, teacher, student, type):
        super(Distiller, self).__init__()

        # 1. Finish the __init__ method.
        self.teacher = teacher
        self.student = student
        self.type = type

        if type == 'feature':
            device = next(student.parameters()).device
            with torch.no_grad():
                dummy = torch.randn(1, 3, 32, 32, device=device)
                _, s_features = self.student(dummy)
                _, t_features = self.teacher(dummy)
                s_channels = [f.shape[1] for f in s_features]
                t_channels = [f.shape[1] for f in t_features]
            self.adapters = nn.ModuleList([
                nn.Conv2d(s_c, t_c, kernel_size=1).to(device) if s_c != t_c else nn.Identity().to(device)
                for s_c, t_c in zip(s_channels, t_channels)
            ])
        else:
            self.adapters = None

    def forward(self, x, target):
        # 2. Finish the forward pass.
        student_logits, student_features = self.student(x)
        with torch.no_grad():
            teacher_logits, teacher_features = self.teacher(x)
        if self.type == 'response':
            loss_distill =  loss_re(student_logits,teacher_logits,target)# call the loss_re()
        elif self.type == 'feature':
            loss_distill = loss_fe(student_features,teacher_features,self.adapters,student_logits,target,alpha=0.5)# call the loss_re()
        else:
            raise ValueError(f'Error: only support response-based and feature-based distillation')

        return student_logits, loss_distill

### Training function

In [86]:
def train_distillation(distiller, student, train_loader, val_loader, epochs, learning_rate, device):
    ce_loss = nn.CrossEntropyLoss()
    # define the parameter the optimizer used
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        distiller.train()
        distiller.teacher.train()
        distiller.student.train()

        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc  = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, loss = distiller(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            pred = outputs.data.max(1, keepdim=True)[1]
            result = pred.eq(labels.data.view_as(pred))
            result = np.squeeze(result.cpu().numpy())
            correct += np.sum(result)
            total += images.size(0)
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        distiller.eval()
        distiller.teacher.eval()
        distiller.student.eval()

        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:

                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                outputs, loss = distiller(val_images, val_labels)

                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    print('Finished Distilling')

## Response-based distillation

In [None]:
# Decide the epochs and learning rate
Student_re = resnet18(num_classes=10)
Student_re = Student_re.to(device)
distiller_re = Distiller(Teacher, Student_re, type='response')
train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs= 50, learning_rate= 0.01, device=device)

train epoch[1/20]: 100%|██████████| 313/313 [00:22<00:00, 13.71it/s]
valid epoch[1/20]: 100%|██████████| 79/79 [00:04<00:00, 16.67it/s]
	Training Loss: 8.980219 	Validation Loss: 9.373297
	Train Accuracy: 39.990d% (15996/40000)	Valdation Accuracy: 41.790d% (4179/10000) 
train epoch[2/20]: 100%|██████████| 313/313 [00:22<00:00, 13.69it/s]
valid epoch[2/20]: 100%|██████████| 79/79 [00:04<00:00, 16.65it/s]
	Training Loss: 5.792336 	Validation Loss: 5.118157
	Train Accuracy: 58.335d% (23334/40000)	Valdation Accuracy: 60.880d% (6088/10000) 
train epoch[3/20]: 100%|██████████| 313/313 [00:22<00:00, 13.61it/s]
valid epoch[3/20]: 100%|██████████| 79/79 [00:04<00:00, 16.75it/s]
	Training Loss: 4.449553 	Validation Loss: 4.231401
	Train Accuracy: 66.270d% (26508/40000)	Valdation Accuracy: 65.740d% (6574/10000) 
train epoch[4/20]: 100%|██████████| 313/313 [00:22<00:00, 13.62it/s]
valid epoch[4/20]: 100%|██████████| 79/79 [00:04<00:00, 16.76it/s]
	Training Loss: 3.553749 	Validation Loss: 3.316957

In [55]:
reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:03<00:00, 23.90it/s]
test_loss: 1.481  test_accuracy: 85.250


## Feature-based distillation

In [None]:
# Decide the epochs and learning rate
Student_fe = resnet18(num_classes=10)
Student_fe = Student_fe.to(device)
distiller_fe = Distiller(Teacher, Student_fe, type='feature')
train_distillation(distiller_fe, Student_fe, train_loader, val_loader, epochs= 50, learning_rate=0.01 , device=device)

train epoch[1/30]: 100%|██████████| 313/313 [00:23<00:00, 13.05it/s]
valid epoch[1/30]: 100%|██████████| 79/79 [00:04<00:00, 16.17it/s]
	Training Loss: 4.078429 	Validation Loss: 3.893579
	Train Accuracy: 35.270d% (14108/40000)	Valdation Accuracy: 43.550d% (4355/10000) 
train epoch[2/30]: 100%|██████████| 313/313 [00:23<00:00, 13.09it/s]
valid epoch[2/30]: 100%|██████████| 79/79 [00:04<00:00, 16.19it/s]
	Training Loss: 3.739371 	Validation Loss: 3.935662
	Train Accuracy: 53.930d% (21572/40000)	Valdation Accuracy: 52.250d% (5225/10000) 
train epoch[3/30]: 100%|██████████| 313/313 [00:23<00:00, 13.11it/s]
valid epoch[3/30]: 100%|██████████| 79/79 [00:04<00:00, 16.65it/s]
	Training Loss: 3.619781 	Validation Loss: 3.719208
	Train Accuracy: 62.435d% (24974/40000)	Valdation Accuracy: 56.740d% (5674/10000) 
train epoch[4/30]: 100%|██████████| 313/313 [00:23<00:00, 13.17it/s]
valid epoch[4/30]: 100%|██████████| 79/79 [00:04<00:00, 16.50it/s]
	Training Loss: 3.542092 	Validation Loss: 3.599402

In [97]:
ftS_loss, ftS_accuracy = test(distiller_fe, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:03<00:00, 23.06it/s]
test_loss: 3.447  test_accuracy: 85.850


## Result and Comparison

In [98]:
print(f'Teacher from scratch: loss = {T_loss:.2f}, accuracy = {T_accuracy:.2f}')
print(f'Student from scratch: loss = {S_loss:.2f}, accuracy = {S_accuracy:.2f}')
print(f'Response-based student: loss = {reS_loss:.2f}, accuracy = {reS_accuracy:.2f}')
print(f'Featured-based student: loss = {ftS_loss:.2f}, accuracy = {ftS_accuracy:.2f}')

Teacher from scratch: loss = 0.49, accuracy = 87.98
Student from scratch: loss = 0.47, accuracy = 88.21
Response-based student: loss = 1.48, accuracy = 85.25
Featured-based student: loss = 3.45, accuracy = 85.85
