# Network Compression (Knowledge Distillation)

In [None]:
# Download dataset
!gdown --id '1wCdNcClcd2p5UeDi6XxNdiBedNaVn3fP' --output food-11.zip
# Unzip the files
!unzip food-11.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: food-11/training/4_165.jpg  
  inflating: food-11/training/5_376.jpg  
  inflating: food-11/training/2_691.jpg  
  inflating: food-11/training/0_541.jpg  
  inflating: food-11/training/3_482.jpg  
  inflating: food-11/training/0_227.jpg  
  inflating: food-11/training/5_410.jpg  
  inflating: food-11/training/4_603.jpg  
  inflating: food-11/training/8_341.jpg  
  inflating: food-11/training/5_1154.jpg  
  inflating: food-11/training/9_37.jpg  
  inflating: food-11/training/9_152.jpg  
  inflating: food-11/training/5_438.jpg  
  inflating: food-11/training/9_1287.jpg  
  inflating: food-11/training/8_369.jpg  
  inflating: food-11/training/2_1455.jpg  
  inflating: food-11/training/10_247.jpg  
  inflating: food-11/training/7_32.jpg  
  inflating: food-11/training/10_521.jpg  
  inflating: food-11/training/2_1333.jpg  
  inflating: food-11/training/2_861.jpg  
  inflating: food-11/training/0_569.jpg  
  infla

# Readme


任務是模型壓縮 - Neural Network Compression。

Compression有很多種門派，在這裡我們會介紹上課出現過的其中四種，分別是:

* 知識蒸餾 Knowledge Distillation
* 網路剪枝 Network Pruning
* 用少量參數來做CNN Architecture Design
* 參數量化 Weight Quantization


In [None]:
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
# Load進我們的Model架構(在hw7_Architecture_Design.ipynb內)

!gdown --id '1T5c89HvLflLCeGz3UWrvjGm0ss9TC160' --output "hw7_Architecture_Design.ipynb" # 變更架構
%run "hw7_Architecture_Design.ipynb"

Downloading...
From: https://drive.google.com/uc?id=1T5c89HvLflLCeGz3UWrvjGm0ss9TC160
To: /content/hw7_Architecture_Design.ipynb
  0% 0.00/9.52k [00:00<?, ?B/s]100% 9.52k/9.52k [00:00<00:00, 17.5MB/s]


Knowledge Distillation
===

<img src="https://i.imgur.com/H2aF7Rv.png=100x" width="500px">

簡單上來說就是讓已經做得很好的大model們去告訴小model"如何"學習。
而我們如何做到這件事情呢? 就是利用大model預測的logits給小model當作標準就可以了。

## 為甚麼這會work?
* 例如當data不是很乾淨的時候，對一般的model來說他是個noise，只會干擾學習。透過去學習其他大model預測的logits會比較好。
* label和label之間可能有關連，這可以引導小model去學習。例如數字8可能就和6,9,0有關係。
* 弱化已經學習不錯的target(?)，避免讓其gradient干擾其他還沒學好的task。


## 要怎麼實作?
* $Loss = \alpha T^2 \times KL(\frac{\text{Teacher's Logits}}{T} || \frac{\text{Student's Logits}}{T}) + (1-\alpha)(\text{原本的Loss})$


* 以下code為甚麼要對student使用log_softmax: https://github.com/
=> input 需求，必須吃 log-probability distribution
peterliht/knowledge-distillation-pytorch/issues/2
* reference: [Distilling the Knowledge in a Neural Network](https://arxiv.org/abs/1503.02531)

In [None]:
def loss_fn_kd(outputs, labels, teacher_outputs, T=20, alpha=0.5): # alpha不是在這裡調整
    # 一般的Cross Entropy
    hard_loss = F.cross_entropy(outputs, labels) * (1. - alpha) # alpha大時，結果通常較好
    # 讓logits的log_softmax對目標機率(teacher的logits/T後softmax)做KL Divergence。
    soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T)
    # 對原本logits除T後做事情，他的原本的graidents會受到約T^(-2)的影響
    # KLDivLoss和cross_entropy之計算結果只差一個常數項，因此可更換為cross_entropy
    # reduction設為batchmean，可以把loss加總後除以batch size                      
    return hard_loss + soft_loss

# Data Processing

我們的Dataset使用的是跟Hw3 - CNN同樣的Dataset，因此這個區塊的Augmentation / Read Image大家參考或直接抄就好。

如果有不會的話可以回去看Hw3的colab。

需要注意的是如果要自己寫的話，Augment的方法最好使用我們的方法，避免輸入有差異導致Teacher Net預測不好。

In [None]:
import re # 正則表達式
import torch
from glob import glob
from PIL import Image
import torchvision.transforms as transforms

class MyDataset(torch.utils.data.Dataset):

    def __init__(self, folderName, transform=None):
        self.transform = transform
        self.data = []
        self.label = []

        for img_path in sorted(glob(folderName + '/*.jpg')):
            try:
                # Get classIdx by parsing image path
                class_idx = int(re.findall(re.compile(r'\d+'), img_path)[1]) # 透過re找數字
                # './food-11/training/6_163.jpg' 中11是第一個數字，第二個數字6為分類  
            except:
                # if inference mode (there's no answer), class_idx default 0
                class_idx = 0

            image = Image.open(img_path)
            # Get File Descriptor
            image_fp = image.fp
            image.load()
            # Close File Descriptor (or it'll reach OPEN_MAX)
            image_fp.close()

            self.data.append(image)
            self.label.append(class_idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.data[idx]
        if self.transform:
            image = self.transform(image)
        return image, self.label[idx]


trainTransform = transforms.Compose([
    transforms.RandomCrop(256, pad_if_needed=True, padding_mode='symmetric'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])
testTransform = transforms.Compose([
    transforms.CenterCrop(256),
    transforms.ToTensor(),
])

def get_dataloader(mode='training', batch_size=32):

    assert mode in ['training', 'testing', 'validation']

    dataset = MyDataset(
        f'./food-11/{mode}',
        transform=trainTransform if mode == 'training' else testTransform)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(mode == 'training'))

    return dataloader

# Pre-processing


In [None]:
# get dataloader
train_dataloader = get_dataloader('training', batch_size=16) # batch_size大較容易overfitting
valid_dataloader = get_dataloader('validation', batch_size=16)

In [None]:
!gdown --id '1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN' --output teacher_resnet18.bin
# accuracy 88.41
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
student_net = StudentNet(base=16).cuda()

teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)

Downloading...
From: https://drive.google.com/uc?id=1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN
To: /content/teacher_resnet18.bin
44.8MB [00:00, 206MB/s]


# Start Training


## 小提醒

* torch.no_grad是指接下來的運算或該tensor不需要算gradient。
* model.eval()與model.train()差在於Batchnorm要不要紀錄，以及要不要做Dropout。



In [None]:
def run_epoch(dataloader, update=True, alpha=0.5):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, hard_labels = batch_data
        inputs = inputs.cuda()
        hard_labels = torch.LongTensor(hard_labels).cuda()
        # 因為Teacher沒有要backprop，所以我們使用torch.no_grad
        # 告訴torch不要暫存中間值(去做backprop)以浪費記憶體空間。
        with torch.no_grad():
            soft_labels = teacher_net(inputs)

        if update:
            logits = student_net(inputs)
            # 使用我們之前所寫的融合soft label&hard label的loss。
            # T=20是原始論文的參數設定。
            loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            loss.backward()
            optimizer.step()    
        else:
            # 只是算validation acc的話，就開no_grad節省空間。
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            
        total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()
        total_num += len(inputs)

        total_loss += loss.item() * len(inputs)
    return total_loss / total_num, total_hit / total_num

from google.colab import drive
drive.mount('/content/gdrive')
model_save_name = 'student_model.bin' 
path = F"/content/gdrive/My Drive/{model_save_name}" 

# TeacherNet永遠都是Eval mode.
teacher_net.eval()
now_best_acc = 0
learningrate = 1e-3

for epoch in range(200):
    if epoch % 50 == 0:
      learningrate /= 10
      optimizer = optim.AdamW(student_net.parameters(), lr=learningrate)
    student_net.train()
    train_loss, train_acc = run_epoch(train_dataloader, update=True)
    student_net.eval()
    valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

    # 存下最好的model。
    if valid_acc > now_best_acc:
        now_best_acc = valid_acc
        torch.save(student_net.state_dict(), path)
    print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive
epoch   0: train loss: 17.3037, acc 0.2028 valid loss: 18.7363, acc 0.2329
epoch   1: train loss: 16.3081, acc 0.2512 valid loss: 17.7105, acc 0.2671
epoch   2: train loss: 15.6450, acc 0.2936 valid loss: 16.9654, acc 0.3149
epoch   3: train loss: 15.0290, acc 0.3342 valid loss: 16.1987, acc 0.3397
epoch   4: train loss: 14.5445, acc 0.3511 valid loss: 15.6840, acc 0.3665
epoch   5: train loss: 14.1608, acc 0.3717 valid loss: 15.4330, acc 0.3889
epoch   6: tr

KeyboardInterrupt: ignored

# Saving Model

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [None]:
model_save_name = 'student_model.bin' 
path = F"/content/gdrive/My Drive/{model_save_name}" 
torch.save(student_net.state_dict(), path)

# Loading Model

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


Trained Model

In [None]:
model_save_name = 'student_model.bin' 
path = F"/content/gdrive/My Drive/{model_save_name}"
student_net.load_state_dict(torch.load(path)) # 舊model

<All keys matched successfully>

In [None]:
!gdown --id '12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL' --output student_custom_small.bin
# 把 model train 到更好
student_net = StudentNet().cuda() 
student_net.load_state_dict(torch.load('student_custom_small.bin'))

Downloading...
From: https://drive.google.com/uc?id=12wtIa0WVRcpboQzhgRUJOpcXe23tgWUL
To: /content/student_custom_small.bin
  0% 0.00/1.05M [00:00<?, ?B/s]100% 1.05M/1.05M [00:00<00:00, 69.4MB/s]


<All keys matched successfully>

# Keep Training

In [None]:
def run_epoch(dataloader, update=True, alpha=0.5): # 調整alpha
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, hard_labels = batch_data
        inputs = inputs.cuda()
        hard_labels = torch.LongTensor(hard_labels).cuda()
        # 因為Teacher沒有要backprop，所以我們使用torch.no_grad
        # 告訴torch不要暫存中間值(去做backprop)以浪費記憶體空間。
        with torch.no_grad():
            soft_labels = teacher_net(inputs)

        if update:
            logits = student_net(inputs)
            # 此處的logits和統計中所定義的logits不同，可是其為將[0, 1]間的值放大到[-inf, inf]間的產物(也就是通過softmax之前的ouput)
            # 使用我們之前所寫的融合soft label&hard label的loss。
            # T=20是原始論文的參數設定。
            loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            loss.backward()
            optimizer.step()    
        else:
            # 只是算validation acc的話，就開no_grad節省空間。
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            
        total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()
        total_num += len(inputs)

        total_loss += loss.item() * len(inputs)
    return total_loss / total_num, total_hit / total_num

model_save_name = 'student_model.bin' # 儲存名字
path = F"/content/gdrive/My Drive/{model_save_name}" # 儲存位置
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)
teacher_net.eval()


for epoch in range(50):
    student_net.train()
    train_loss, train_acc = run_epoch(train_dataloader, update=True)
    student_net.eval()
    valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

    # 存下最好的model。
    if valid_acc > now_best_acc:
        now_best_acc = valid_acc
        torch.save(student_net.state_dict(), path)
    print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))

epoch   0: train loss: 4.0862, acc 0.8347 valid loss: 5.3863, acc 0.7781
epoch   1: train loss: 4.0514, acc 0.8339 valid loss: 4.5861, acc 0.8015
epoch   2: train loss: 4.0860, acc 0.8333 valid loss: 4.8169, acc 0.7977
epoch   3: train loss: 4.0110, acc 0.8370 valid loss: 5.1090, acc 0.7793
epoch   4: train loss: 3.9681, acc 0.8391 valid loss: 4.8267, acc 0.7802
epoch   5: train loss: 3.9815, acc 0.8357 valid loss: 4.9262, acc 0.7918
epoch   6: train loss: 3.9472, acc 0.8409 valid loss: 4.4520, acc 0.7927
epoch   7: train loss: 3.9380, acc 0.8418 valid loss: 5.2788, acc 0.7784
epoch   8: train loss: 3.8865, acc 0.8443 valid loss: 5.9261, acc 0.7682
epoch   9: train loss: 3.8733, acc 0.8463 valid loss: 5.0433, acc 0.7927
epoch  10: train loss: 3.8771, acc 0.8430 valid loss: 4.8997, acc 0.7959
epoch  11: train loss: 3.8766, acc 0.8510 valid loss: 4.8258, acc 0.7948
epoch  12: train loss: 3.7820, acc 0.8522 valid loss: 4.8018, acc 0.7901
epoch  13: train loss: 3.9058, acc 0.8432 valid los

In [None]:
optimizer = optim.SGD(student_net.parameters(), lr=1e-4, momentum=0.9) # SGD的收斂結果較好
for epoch in range(50):
    student_net.train()
    train_loss, train_acc = run_epoch(train_dataloader, update=True)
    student_net.eval()
    valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

    # 存下最好的model。
    if valid_acc > now_best_acc:
        now_best_acc = valid_acc
        torch.save(student_net.state_dict(), path)
    print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))

epoch   0: train loss: 2.8171, acc 0.9070 valid loss: 3.7391, acc 0.8271
epoch   1: train loss: 2.6889, acc 0.9062 valid loss: 3.7458, acc 0.8297
epoch   2: train loss: 2.6315, acc 0.9114 valid loss: 3.6469, acc 0.8324
epoch   3: train loss: 2.6440, acc 0.9119 valid loss: 3.7081, acc 0.8286
epoch   4: train loss: 2.6182, acc 0.9107 valid loss: 3.6513, acc 0.8356
epoch   5: train loss: 2.6172, acc 0.9105 valid loss: 3.6008, acc 0.8353
epoch   6: train loss: 2.6285, acc 0.9101 valid loss: 3.7299, acc 0.8251
epoch   7: train loss: 2.6091, acc 0.9114 valid loss: 3.5906, acc 0.8306
epoch   8: train loss: 2.5528, acc 0.9150 valid loss: 3.5717, acc 0.8312
epoch   9: train loss: 2.6032, acc 0.9137 valid loss: 3.7122, acc 0.8335
epoch  10: train loss: 2.5672, acc 0.9128 valid loss: 3.5096, acc 0.8312
epoch  11: train loss: 2.5699, acc 0.9159 valid loss: 3.6366, acc 0.8341
epoch  12: train loss: 2.5506, acc 0.9134 valid loss: 3.6266, acc 0.8312
epoch  13: train loss: 2.5548, acc 0.9175 valid los

# Try a worse teacher

In [None]:
!gdown --id '1m-phKfGYn3yKK-_p1BOijfSIQz8tcFtc' --output teacher_resnet18.bin
# accuracy 80.09
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
student_net = StudentNet(base=16).cuda()

teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)

Downloading...
From: https://drive.google.com/uc?id=1m-phKfGYn3yKK-_p1BOijfSIQz8tcFtc
To: /content/teacher_resnet18.bin
44.8MB [00:00, 210MB/s]


In [None]:
def run_epoch(dataloader, update=True, alpha=0.5):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, hard_labels = batch_data
        inputs = inputs.cuda()
        hard_labels = torch.LongTensor(hard_labels).cuda()
        # 因為Teacher沒有要backprop，所以我們使用torch.no_grad
        # 告訴torch不要暫存中間值(去做backprop)以浪費記憶體空間。
        with torch.no_grad():
            soft_labels = teacher_net(inputs)

        if update:
            logits = student_net(inputs)
            # 使用我們之前所寫的融合soft label&hard label的loss。
            # T=20是原始論文的參數設定。
            loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            loss.backward()
            optimizer.step()    
        else:
            # 只是算validation acc的話，就開no_grad節省空間。
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            
        total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()
        total_num += len(inputs)

        total_loss += loss.item() * len(inputs)
    return total_loss / total_num, total_hit / total_num

from google.colab import drive
drive.mount('/content/gdrive')
model_save_name = 'worse_teacher.bin' 
path = F"/content/gdrive/My Drive/{model_save_name}" 

# TeacherNet永遠都是Eval mode.
teacher_net.eval()
now_best_acc = 0
learningrate = 1e-3

for epoch in range(200):
    if epoch % 50 == 0:
      learningrate /= 10
      optimizer = optim.AdamW(student_net.parameters(), lr=learningrate)
    student_net.train()
    train_loss, train_acc = run_epoch(train_dataloader, update=True)
    student_net.eval()
    valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

    # 存下最好的model。
    if valid_acc > now_best_acc:
        now_best_acc = valid_acc
        torch.save(student_net.state_dict(), path)
    print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))

Mounted at /content/gdrive
epoch   0: train loss: 9.5415, acc 0.2354 valid loss: 9.0052, acc 0.3397
epoch   1: train loss: 8.5107, acc 0.3297 valid loss: 8.3105, acc 0.3627
epoch   2: train loss: 8.0756, acc 0.3640 valid loss: 7.8423, acc 0.3971
epoch   3: train loss: 7.7952, acc 0.3903 valid loss: 7.3945, acc 0.4481
epoch   4: train loss: 7.4865, acc 0.4168 valid loss: 7.1739, acc 0.4659
epoch   5: train loss: 7.2349, acc 0.4263 valid loss: 6.9024, acc 0.4781
epoch   6: train loss: 7.0978, acc 0.4523 valid loss: 6.6157, acc 0.5044
epoch   7: train loss: 6.9108, acc 0.4604 valid loss: 6.5461, acc 0.5073
epoch   8: train loss: 6.7645, acc 0.4778 valid loss: 6.2696, acc 0.5219
epoch   9: train loss: 6.6018, acc 0.4891 valid loss: 6.2957, acc 0.5429
epoch  10: train loss: 6.4585, acc 0.4867 valid loss: 5.9931, acc 0.5452
epoch  11: train loss: 6.3122, acc 0.5029 valid loss: 5.8897, acc 0.5752
epoch  12: train loss: 6.2657, acc 0.5111 valid loss: 5.7585, acc 0.5706
epoch  13: train loss: 6

KeyboardInterrupt: ignored

# Try no teacher

In [None]:
!gdown --id '1m-phKfGYn3yKK-_p1BOijfSIQz8tcFtc' --output teacher_resnet18.bin
# accuracy 80.09
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
student_net = StudentNet(base=16).cuda()

teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
optimizer = optim.AdamW(student_net.parameters(), lr=1e-3)

Downloading...
From: https://drive.google.com/uc?id=1m-phKfGYn3yKK-_p1BOijfSIQz8tcFtc
To: /content/teacher_resnet18.bin
44.8MB [00:00, 108MB/s] 


In [None]:
def run_epoch(dataloader, update=True, alpha=0):
    total_num, total_hit, total_loss = 0, 0, 0
    for now_step, batch_data in enumerate(dataloader):
        # 清空 optimizer
        optimizer.zero_grad()
        # 處理 input
        inputs, hard_labels = batch_data
        inputs = inputs.cuda()
        hard_labels = torch.LongTensor(hard_labels).cuda()
        # 因為Teacher沒有要backprop，所以我們使用torch.no_grad
        # 告訴torch不要暫存中間值(去做backprop)以浪費記憶體空間。
        with torch.no_grad():
            soft_labels = teacher_net(inputs)

        if update:
            logits = student_net(inputs)
            # 使用我們之前所寫的融合soft label&hard label的loss。
            # T=20是原始論文的參數設定。
            loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            loss.backward()
            optimizer.step()    
        else:
            # 只是算validation acc的話，就開no_grad節省空間。
            with torch.no_grad():
                logits = student_net(inputs)
                loss = loss_fn_kd(logits, hard_labels, soft_labels, 20, alpha)
            
        total_hit += torch.sum(torch.argmax(logits, dim=1) == hard_labels).item()
        total_num += len(inputs)

        total_loss += loss.item() * len(inputs)
    return total_loss / total_num, total_hit / total_num

from google.colab import drive
drive.mount('/content/gdrive')
model_save_name = 'worse_teacher.bin' 
path = F"/content/gdrive/My Drive/{model_save_name}" 

# TeacherNet永遠都是Eval mode.
teacher_net.eval()
now_best_acc = 0
learningrate = 1e-3

for epoch in range(200):
    if epoch % 50 == 0:
      learningrate /= 10
      optimizer = optim.AdamW(student_net.parameters(), lr=learningrate)
    student_net.train()
    train_loss, train_acc = run_epoch(train_dataloader, update=True)
    student_net.eval()
    valid_loss, valid_acc = run_epoch(valid_dataloader, update=False)

    # 存下最好的model。
    if valid_acc > now_best_acc:
        now_best_acc = valid_acc
        torch.save(student_net.state_dict(), path)
    print('epoch {:>3d}: train loss: {:6.4f}, acc {:6.4f} valid loss: {:6.4f}, acc {:6.4f}'.format(
        epoch, train_loss, train_acc, valid_loss, valid_acc))

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
epoch   0: train loss: 2.0893, acc 0.2588 valid loss: 1.8665, acc 0.3315
epoch   1: train loss: 1.8859, acc 0.3316 valid loss: 1.7378, acc 0.4087
epoch   2: train loss: 1.7972, acc 0.3743 valid loss: 1.6558, acc 0.4283
epoch   3: train loss: 1.7421, acc 0.3917 valid loss: 1.6244, acc 0.4367
epoch   4: train loss: 1.7004, acc 0.4091 valid loss: 1.5507, acc 0.4813
epoch   5: train loss: 1.6483, acc 0.4298 valid loss: 1.4989, acc 0.4921
epoch   6: train loss: 1.6100, acc 0.4462 valid loss: 1.5402, acc 0.4880
epoch   7: train loss: 1.5723, acc 0.4579 valid loss: 1.4434, acc 0.5152
epoch   8: train loss: 1.5370, acc 0.4718 valid loss: 1.3863, acc 0.5469
epoch   9: train loss: 1.5176, acc 0.4841 valid loss: 1.3851, acc 0.5353
epoch  10: train loss: 1.4795, acc 0.4969 valid loss: 1.3450, acc 0.5472
epoch  11: train loss: 1.4637, acc 0.5017 valid loss: 1.3439, acc 