# Knowledge Distillation
- The concept of **knowledge distillation** is to utilize class probabilities of a higher-capacity model (teacher) as soft targets of a smaller model (student)
- The implement processes can be divided into several stages:
  1. Finish the `ResNet()` classes
  2. Train the teacher model (ResNet50) and the student model (ResNet18) from scratch, i.e. **without KD**
  3. Define the `Distiller()` class and `loss_re()`, `loss_fe()` functions
  4. Train the student model **with KD** from the teacher model in two different ways, response-based and feature based distillation
  5. Comparison of student models w/ & w/o KD

## Setup

In [1]:
!pip install torchinfo



In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset, random_split
from torchinfo import summary
from tqdm import tqdm
import sys
import numpy as np
import math
import matplotlib.pyplot as plt
import os
from PIL import Image

In [16]:
torch.manual_seed(0)

# 設置設備：支援多 GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    # 檢查可用的 GPU 數量
    gpu_count = torch.cuda.device_count()
    print(f'使用 GPU 訓練，可用 GPU 數量: {gpu_count}')
    for i in range(gpu_count):
        print(f'  GPU {i}: {torch.cuda.get_device_name(i)}')
    use_multi_gpu = gpu_count > 1
else:
    device = torch.device('cpu')
    use_multi_gpu = False
    print('使用 CPU 訓練')

torch.backends.cudnn.benchmark = True

使用 GPU 訓練，可用 GPU 數量: 2
  GPU 0: Tesla V100-SXM2-32GB
  GPU 1: Tesla V100-SXM2-32GB


## Download dataset

In [4]:
validation_split = 0.2
batch_size = 128

# data augmentation and normalization
transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# download dataset
train_and_val_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=True,
    transform=transform_train,
    download=True
)

test_dataset = torchvision.datasets.CIFAR10(
    root='dataset/',
    train=False,
    transform=transform_test,
    download=True
)

# split train and validation dataset
train_size = int((1 - validation_split) * len(train_and_val_dataset))
val_size = len(train_and_val_dataset) - train_size
train_dataset, val_dataset = random_split(train_and_val_dataset, [train_size, val_size])

# create dataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

test_num = len(test_dataset)
test_steps = len(test_loader)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to dataset/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:12<00:00, 13.9MB/s] 


Extracting dataset/cifar-10-python.tar.gz to dataset/
Files already downloaded and verified


## Create teacher and student models
### Define BottleNeck for ResNet50

In [5]:
class BottleNeck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out

### Define Resifual Block

In [6]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

### Define ResNet Model

In [7]:
class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 1. Finish the forward pass and return the output layer as well as hidden features.
        # 2. The output layer and hidden features will be used later for distilling.
        # 3. You can refer to the ResNet structure illustration to finish it.

        # 初始卷積層
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        # 通過四個 ResNet 層，並保存每層的輸出作為特徵
        feature1 = self.layer1(x)
        feature2 = self.layer2(feature1)
        feature3 = self.layer3(feature2)
        feature4 = self.layer4(feature3)
        
        # 全局平均池化
        x = self.avgpool(feature4)
        # 展平為一維向量
        x = torch.flatten(x, 1)
        # 全連接層得到最終輸出
        x = self.fc(x)

        return x, [feature1, feature2, feature3, feature4]

### Define ResNet50 and Resnet18

In [8]:
def resnet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

def resnet50(num_classes=10):
    return ResNet(BottleNeck, [3, 4, 6, 3], num_classes=num_classes)

## Teacher Model (ResNet50)

In [18]:
Teacher = resnet50(num_classes=10)  # commment out this line if loading trained teacher model
# Teacher = torch.load('Teacher.pt', weights_only=False)  # loading trained teacher model
Teacher = Teacher.to(device)

# 如果有多個 GPU，使用 DataParallel 進行並行訓練
if use_multi_gpu:
    Teacher = nn.DataParallel(Teacher)
    print(f'Teacher 模型使用 {torch.cuda.device_count()} 個 GPU 進行訓練')

Teacher 模型使用 2 個 GPU 進行訓練


In [19]:
summary(Teacher)

Layer (type:depth-idx)                        Param #
DataParallel                                  --
├─ResNet: 1-1                                 --
│    └─Conv2d: 2-1                            1,728
│    └─BatchNorm2d: 2-2                       128
│    └─ReLU: 2-3                              --
│    └─MaxPool2d: 2-4                         --
│    └─Sequential: 2-5                        --
│    │    └─BottleNeck: 3-1                   75,008
│    │    └─BottleNeck: 3-2                   70,400
│    │    └─BottleNeck: 3-3                   70,400
│    └─Sequential: 2-6                        --
│    │    └─BottleNeck: 3-4                   379,392
│    │    └─BottleNeck: 3-5                   280,064
│    │    └─BottleNeck: 3-6                   280,064
│    │    └─BottleNeck: 3-7                   280,064
│    └─Sequential: 2-7                        --
│    │    └─BottleNeck: 3-8                   1,512,448
│    │    └─BottleNeck: 3-9                   1,117,184
│    │    └─Bo

## Student Model (ResNet18)

In [20]:
Student = resnet18(num_classes=10)  # commment out this line if loading trained student model
# Student = torch.load('Student.pt', weights_only=False)  # loading trained student model
Student = Student.to(device)

# 如果有多個 GPU，使用 DataParallel 進行並行訓練
if use_multi_gpu:
    Student = nn.DataParallel(Student)
    print(f'Student 模型使用 {torch.cuda.device_count()} 個 GPU 進行訓練')

Student 模型使用 2 個 GPU 進行訓練


In [12]:
summary(Student)

Layer (type:depth-idx)                   Param #
ResNet                                   --
├─Conv2d: 1-1                            1,728
├─BatchNorm2d: 1-2                       128
├─ReLU: 1-3                              --
├─MaxPool2d: 1-4                         --
├─Sequential: 1-5                        --
│    └─BasicBlock: 2-1                   --
│    │    └─Conv2d: 3-1                  36,864
│    │    └─BatchNorm2d: 3-2             128
│    │    └─ReLU: 3-3                    --
│    │    └─Conv2d: 3-4                  36,864
│    │    └─BatchNorm2d: 3-5             128
│    └─BasicBlock: 2-2                   --
│    │    └─Conv2d: 3-6                  36,864
│    │    └─BatchNorm2d: 3-7             128
│    │    └─ReLU: 3-8                    --
│    │    └─Conv2d: 3-9                  36,864
│    │    └─BatchNorm2d: 3-10            128
├─Sequential: 1-6                        --
│    └─BasicBlock: 2-3                   --
│    │    └─Conv2d: 3-11                 73,728

## Define training function

In [13]:
def train_from_scratch(model, train_loader, val_loader, epochs, learning_rate, device, model_name):
    criterion = nn.CrossEntropyLoss()
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.

        model.train()
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            logits, hidden = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * images.size(0)
            pred = logits.data.max(1, keepdim=True)[1]
            correct += np.sum(np.squeeze(pred.eq(labels.data.view_as(pred))).cpu().numpy())
            total += images.size(0)
            train_acc =  correct/total
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        model.eval()
        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs, hidden_outputs = model(val_images)
                loss = criterion(outputs, val_labels)
                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.data.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    torch.save(model, f'{model_name}.pt')
    print(f'{model_name}.pt is saved')

    print('Finished Training')

## Define testing function

In [14]:
def test(model, test_loader ,device, type=None):
    criterion = nn.CrossEntropyLoss()
    acc = 0.0
    test_loss = 0.0

    if type == None:
        model.eval()
    elif type == 'distiller':
        model.eval()
        model.teacher.eval()
        model.student.eval()
    else:
       raise ValueError(f'Error: only support response-based and feature-based distillation')

    with torch.no_grad():
        test_bar = tqdm(test_loader, file=sys.stdout)
        for test_data in test_bar:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            if type == None:
                outputs, features = model(test_images)
                loss = criterion(outputs, test_labels)
            elif type == 'distiller':
                outputs, loss = model(test_images, test_labels)
            else:
                raise ValueError(f'Error: only support response-based and feature-based distillation')

            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, test_labels.to(device)).sum().item()
            test_loss += loss.item()
            test_bar.desc = "test"

    test_accurate = acc / test_num
    print('test_loss: %.3f  test_accuracy: %.3f' %(test_loss / test_steps, test_accurate * 100))
    return test_loss / test_steps, test_accurate * 100.

## Train Teacher and Student model from scratch

In [22]:
# Decide the epochs and learning rate
train_from_scratch(Teacher, train_loader, val_loader, epochs=50, learning_rate=0.001, device=device, model_name="Teacher")

train epoch[1/50]: 100%|██████████| 313/313 [02:11<00:00,  2.38it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]
	Training Loss: 0.744452 	Validation Loss: 0.881200
	Train Accuracy: 74.222d% (29689/40000)	Valdation Accuracy: 70.990d% (7099/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [02:11<00:00,  2.39it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]
	Training Loss: 0.615066 	Validation Loss: 0.769146
	Train Accuracy: 78.755d% (31502/40000)	Valdation Accuracy: 75.020d% (7502/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [02:11<00:00,  2.39it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:28<00:00,  2.73it/s]
	Training Loss: 0.542589 	Validation Loss: 0.582035
	Train Accuracy: 81.448d% (32579/40000)	Valdation Accuracy: 80.380d% (8038/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [02:10<00:00,  2.39it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:28<00:00,  2.73it/s]
	Training Loss: 0.487883 	Validation Loss: 0.570289

In [23]:
T_loss, T_accuracy = test(Teacher, test_loader, device=device)

test: 100%|██████████| 79/79 [00:29<00:00,  2.72it/s]
test_loss: 0.433  test_accuracy: 90.900


In [25]:
# Decide the epochs and learning rate
train_from_scratch(Student, train_loader, val_loader, epochs= 50, learning_rate= 0.001, device=device, model_name="Student")

train epoch[1/50]: 100%|██████████| 313/313 [00:54<00:00,  5.72it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:12<00:00,  6.52it/s]
	Training Loss: 1.294300 	Validation Loss: 1.203104
	Train Accuracy: 53.032d% (21213/40000)	Valdation Accuracy: 56.790d% (5679/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [00:54<00:00,  5.75it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:12<00:00,  6.50it/s]
	Training Loss: 0.935933 	Validation Loss: 0.889185
	Train Accuracy: 66.418d% (26567/40000)	Valdation Accuracy: 68.800d% (6880/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [00:54<00:00,  5.74it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:12<00:00,  6.50it/s]
	Training Loss: 0.748574 	Validation Loss: 0.727477
	Train Accuracy: 73.770d% (29508/40000)	Valdation Accuracy: 74.540d% (7454/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [00:54<00:00,  5.78it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:12<00:00,  6.49it/s]
	Training Loss: 0.630760 	Validation Loss: 0.694219

In [26]:
S_loss, S_accuracy = test(Student, test_loader, device=device)

test: 100%|██████████| 79/79 [00:10<00:00,  7.24it/s]
test_loss: 0.489  test_accuracy: 90.390


## Define distillation

### Define the loss functions

In [27]:
# Finish the loss function for response-based distillation.
def loss_re(student_logits, teacher_logits, target):
    """
    Response-Based Knowledge Distillation 損失函數
    
    Args:
        student_logits: 學生模型的輸出 logits
        teacher_logits: 教師模型的輸出 logits  
        target: 真實標籤（硬標籤）
    
    Returns:
        loss: 總損失（蒸餾損失 + 硬標籤損失）
    """
    T = 4.0  # Set temperature parameter (溫度參數，用於軟化 softmax 分佈)
    alpha = 0.7  # Set weighting parameter (權重參數，平衡蒸餾損失和硬標籤損失)

    # Implement loss calculation
    
    # 1. 蒸餾損失（Distillation Loss）- 使用 KL 散度衡量學生和教師的軟標籤差異
    # 教師的軟標籤：使用溫度 T 軟化的 softmax
    soft_teacher = F.softmax(teacher_logits / T, dim=1)
    # 學生的軟標籤：使用溫度 T 軟化的 log_softmax（KL散度要求）
    soft_student = F.log_softmax(student_logits / T, dim=1)
    # KL 散度損失，乘以 T^2 來補償溫度縮放對梯度的影響
    distillation_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)
    
    # 2. 硬標籤損失（Hard Label Loss）- 使用交叉熵確保學生能正確分類
    student_loss = F.cross_entropy(student_logits, target)
    
    # 3. 總損失：加權組合蒸餾損失和硬標籤損失
    loss = alpha * distillation_loss + (1 - alpha) * student_loss

    return loss

In [28]:
# Finish the loss function for feature-based distillation.
def loss_fe(student_features, teacher_features, student_logits, target):
    """
    Feature-Based Knowledge Distillation 損失函數
    使用中間層特徵進行知識蒸餾
    
    Args:
        student_features: 學生模型的中間層特徵列表 [feature1, feature2, feature3, feature4]
        teacher_features: 教師模型的中間層特徵列表 [feature1, feature2, feature3, feature4]
        student_logits: 學生模型的輸出 logits
        target: 真實標籤
    
    Returns:
        loss: 總損失（特徵蒸餾損失 + 分類損失）
    """
    # Implement loss calculation whatever you prefer
    
    # 1. 分類損失：確保學生模型能正確分類
    ce_loss = F.cross_entropy(student_logits, target)
    
    # 2. 特徵蒸餾損失：使用 MSE 衡量學生和教師特徵的差異
    feature_loss = 0.0
    
    # 對每一層的特徵計算損失
    for student_feat, teacher_feat in zip(student_features, teacher_features):
        # 因為 ResNet50 和 ResNet18 的特徵通道數不同，我們使用 Spatial Attention
        # 計算每個位置的特徵重要性（沿通道維度求和）
        
        # Student spatial attention: [B, C, H, W] -> [B, 1, H, W]
        student_attention = torch.mean(student_feat, dim=1, keepdim=True)
        student_attention = F.normalize(student_attention.view(student_attention.size(0), -1))
        
        # Teacher spatial attention: [B, C, H, W] -> [B, 1, H, W]  
        teacher_attention = torch.mean(teacher_feat, dim=1, keepdim=True)
        teacher_attention = F.normalize(teacher_attention.view(teacher_attention.size(0), -1))
        
        # 計算注意力圖的 MSE 損失
        feature_loss += F.mse_loss(student_attention, teacher_attention)
    
    # 3. 總損失：分類損失 + 特徵損失
    # 使用權重平衡兩種損失
    beta = 1000.0  # 特徵損失的權重（因為特徵損失通常比較小）
    loss = ce_loss + beta * feature_loss

    return loss

### Define Distillation Framework

In [29]:
class Distiller(nn.Module):
    def __init__(self, teacher, student, type):
        super(Distiller, self).__init__()

        # 1. Finish the __init__ method.
        self.teacher = teacher
        self.student = student
        self.type = type

    def forward(self, x, target):
        # 2. Finish the forward pass.
        
        # 教師模型前向傳播（不計算梯度，只獲取輸出）
        with torch.no_grad():
            teacher_logits, teacher_features = self.teacher(x)
        
        # 學生模型前向傳播（需要計算梯度）
        student_logits, student_features = self.student(x)
        
        # 根據蒸餾類型計算損失
        if self.type == 'response':
            # Response-based: 使用輸出 logits 進行蒸餾
            loss_distill = loss_re(student_logits, teacher_logits, target)
        elif self.type == 'feature':
            # Feature-based: 使用中間層特徵進行蒸餾
            loss_distill = loss_fe(student_features, teacher_features, student_logits, target)
        else:
            raise ValueError(f'Error: only support response-based and feature-based distillation')

        return student_logits, loss_distill

### Training function

In [30]:
def train_distillation(distiller, student, train_loader, val_loader, epochs, learning_rate, device):
    ce_loss = nn.CrossEntropyLoss()
    # define the parameter the optimizer used
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss = []
    train_error=[]
    val_error = []
    valdation_error = []
    train_loss = []
    valdation_loss = []
    train_accuraacy = []
    valdation_accuracy= []

    for epoch in range(epochs):
        distiller.train()
        distiller.teacher.train()
        distiller.student.train()

        train_loss = 0.0
        valid_loss = 0.0
        train_acc = 0.0
        valid_acc  = 0.0
        correct = 0.
        total = 0.
        V_correct = 0.
        V_total = 0.
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs, loss = distiller(images, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            pred = outputs.data.max(1, keepdim=True)[1]
            result = pred.eq(labels.data.view_as(pred))
            result = np.squeeze(result.cpu().numpy())
            correct += np.sum(result)
            total += images.size(0)
            train_bar.desc = "train epoch[{}/{}]".format(epoch + 1, epochs)

        distiller.eval()
        distiller.teacher.eval()
        distiller.student.eval()

        with torch.no_grad():
            val_bar = tqdm(val_loader, file=sys.stdout)
            for val_data in val_bar:

                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)

                outputs, loss = distiller(val_images, val_labels)

                valid_loss += loss.item() * val_images.size(0)
                pred = outputs.max(1, keepdim=True)[1]
                V_correct += np.sum(np.squeeze(pred.eq(val_labels.data.view_as(pred))).cpu().numpy())
                V_total += val_images.size(0)
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        train_loss = train_loss / len(train_loader.dataset)
        train_error.append(train_loss)
        valid_loss = valid_loss / len(val_loader.dataset)
        val_error.append(valid_loss)
        train_accuraacy.append( correct / total)
        valdation_accuracy.append(V_correct / V_total)

        print('\tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(train_loss, valid_loss))
        print('\tTrain Accuracy: %.3fd%% (%2d/%2d)\tValdation Accuracy: %.3fd%% (%2d/%2d) '% (100. * correct / total, correct, total, 100. * V_correct / V_total, V_correct, V_total))

    print('Finished Distilling')

## Response-based distillation

In [31]:
# Decide the epochs and learning rate
Student_re = resnet18(num_classes=10)
Student_re = Student_re.to(device)

# 如果有多個 GPU，使用 DataParallel
if use_multi_gpu:
    Student_re = nn.DataParallel(Student_re)
    print(f'Response-based Student 模型使用 {torch.cuda.device_count()} 個 GPU 進行訓練')

distiller_re = Distiller(Teacher, Student_re, type='response')
train_distillation(distiller_re, Student_re, train_loader, val_loader, epochs=50, learning_rate=0.001, device=device)

Response-based Student 模型使用 2 個 GPU 進行訓練
train epoch[1/50]: 100%|██████████| 313/313 [02:29<00:00,  2.09it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:37<00:00,  2.13it/s]
	Training Loss: 11.569585 	Validation Loss: 8.798748
	Train Accuracy: 46.445d% (18578/40000)	Valdation Accuracy: 56.750d% (5675/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [02:29<00:00,  2.09it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:37<00:00,  2.12it/s]
	Training Loss: 7.018527 	Validation Loss: 6.187397
	Train Accuracy: 65.438d% (26175/40000)	Valdation Accuracy: 67.800d% (6780/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [02:29<00:00,  2.09it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:37<00:00,  2.12it/s]
	Training Loss: 5.113157 	Validation Loss: 4.591733
	Train Accuracy: 73.780d% (29512/40000)	Valdation Accuracy: 74.360d% (7436/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [02:29<00:00,  2.09it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:37<00:00,  2.13it/s]
	Training

In [32]:
reS_loss, reS_accuracy = test(distiller_re, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:37<00:00,  2.13it/s]
test_loss: 0.841  test_accuracy: 91.350


## Feature-based distillation

In [33]:
# Decide the epochs and learning rate
Student_fe = resnet18(num_classes=10)
Student_fe = Student_fe.to(device)

# 如果有多個 GPU，使用 DataParallel
if use_multi_gpu:
    Student_fe = nn.DataParallel(Student_fe)
    print(f'Feature-based Student 模型使用 {torch.cuda.device_count()} 個 GPU 進行訓練')

distiller_fe = Distiller(Teacher, Student_fe, type='feature')
train_distillation(distiller_fe, Student_fe, train_loader, val_loader, epochs=50, learning_rate=0.001, device=device)

Feature-based Student 模型使用 2 個 GPU 進行訓練
train epoch[1/50]: 100%|██████████| 313/313 [02:31<00:00,  2.07it/s]
valid epoch[1/50]: 100%|██████████| 79/79 [00:36<00:00,  2.14it/s]
	Training Loss: 10.391021 	Validation Loss: 7.842277
	Train Accuracy: 41.355d% (16542/40000)	Valdation Accuracy: 52.250d% (5225/10000) 
train epoch[2/50]: 100%|██████████| 313/313 [02:31<00:00,  2.07it/s]
valid epoch[2/50]: 100%|██████████| 79/79 [00:36<00:00,  2.16it/s]
	Training Loss: 6.756028 	Validation Loss: 5.852428
	Train Accuracy: 62.983d% (25193/40000)	Valdation Accuracy: 68.350d% (6835/10000) 
train epoch[3/50]: 100%|██████████| 313/313 [02:31<00:00,  2.07it/s]
valid epoch[3/50]: 100%|██████████| 79/79 [00:36<00:00,  2.14it/s]
	Training Loss: 5.414731 	Validation Loss: 5.143411
	Train Accuracy: 72.513d% (29005/40000)	Valdation Accuracy: 71.160d% (7116/10000) 
train epoch[4/50]: 100%|██████████| 313/313 [02:31<00:00,  2.07it/s]
valid epoch[4/50]: 100%|██████████| 79/79 [00:36<00:00,  2.15it/s]
	Training 

In [34]:
ftS_loss, ftS_accuracy = test(distiller_fe, test_loader, type='distiller', device=device)

test: 100%|██████████| 79/79 [00:36<00:00,  2.15it/s]
test_loss: 1.820  test_accuracy: 91.510


## Result and Comparison

In [35]:
print(f'Teacher from scratch: loss = {T_loss:.2f}, accuracy = {T_accuracy:.2f}')
print(f'Student from scratch: loss = {S_loss:.2f}, accuracy = {S_accuracy:.2f}')
print(f'Response-based student: loss = {reS_loss:.2f}, accuracy = {reS_accuracy:.2f}')
print(f'Featured-based student: loss = {ftS_loss:.2f}, accuracy = {ftS_accuracy:.2f}')

Teacher from scratch: loss = 0.43, accuracy = 90.90
Student from scratch: loss = 0.49, accuracy = 90.39
Response-based student: loss = 0.84, accuracy = 91.35
Featured-based student: loss = 1.82, accuracy = 91.51
