<b>Knowledge Distillation</b>

In [1]:
import numpy as np
import torch

seed = 100
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [2]:
# Dataset

from torchvision import datasets, transforms
import torch.utils as utils

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])

dataset_train = datasets.CIFAR10(
    './data', 
    train=True, 
    download=True, 
    transform=transform)
dataset_test  = datasets.CIFAR10(
    './data', 
    train=False, 
    download=True, 
    transform=transform)

print(len(dataset_train))
print(len(dataset_test))

batch_size = 100

dataloader_train = utils.data.DataLoader(dataset_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=4)
dataloader_test = utils.data.DataLoader(dataset_test,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=4)

Files already downloaded and verified
Files already downloaded and verified
50000
10000


In [3]:
# Network : Teacher

import torch.nn as nn

class TeacherNet(nn.Module):
  def __init__(self):
    super(TeacherNet, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(3, 4, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(4, 8, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(8, 16, kernel_size=5, padding=2),
        nn.ReLU(),
        nn.MaxPool2d(2),
    )
    self.fc = nn.Sequential(
        nn.Linear(16 * 4 * 4, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )

  def forward(self, x1):
    x2 = self.conv(x1)
    x3 = x2.view(x2.size()[0], -1)
    x4 = self.fc(x3)
    return x4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_t = TeacherNet().to(device)
print(model_t)

TeacherNet(
  (conv): Sequential(
    (0): Conv2d(3, 4, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(4, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(8, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
  )
)


In [4]:
# Training : Teacher Model

from torch import optim

optimizer = optim.SGD(model_t.parameters(), lr=0.01, weight_decay=0.00001)
#optimizer = optim.Adam(model_t.parameters(), lr=0.01, weight_decay=0.001)
criterion_t = nn.CrossEntropyLoss()

nepoch=100

for i in range(nepoch):
  print(f"EPOCH: {i+1}")

  ### Train ###
  model_t.train()
  for x, t in dataloader_train:
    x = x.to(device)
    t = t.to(device)
    model_t.zero_grad()
    y = model_t(x)
    loss = criterion_t(y, t)
    loss.backward()
    optimizer.step()

  model_t.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_train:
    x = x.to(device)
    t = t.to(device)
    y = model_t(x)
    loss = criterion_t(y, t)
    _, predicted = y.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  train loss: {sum_loss/sum_iter}")
  print(f"  train acc : {sum_correct/(sum_iter*batch_size)}")

  ### Test ###
  model_t.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_test:
    x = x.to(device)
    t = t.to(device)
    y = model_t(x)
    loss = criterion_t(y, t)
    _, predicted = y.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  test  loss: {sum_loss/sum_iter}")
  print(f"  test  acc : {sum_correct/(sum_iter*batch_size)}")

EPOCH: 1
  train loss: 2.2931676821708677
  train acc : 0.11062
  test  loss: 2.293211588859558
  test  acc : 0.1096
EPOCH: 2
  train loss: 2.10754123711586
  train acc : 0.23526
  test  loss: 2.1074262726306916
  test  acc : 0.2402
EPOCH: 3
  train loss: 1.925419793844223
  train acc : 0.3194
  test  loss: 1.9195134139060974
  test  acc : 0.3231
EPOCH: 4
  train loss: 1.7611073186397552
  train acc : 0.3714
  test  loss: 1.7557154047489165
  test  acc : 0.3706
EPOCH: 5
  train loss: 1.685886886358261
  train acc : 0.39968
  test  loss: 1.680221792459488
  test  acc : 0.3964
EPOCH: 6
  train loss: 1.5804206295013428
  train acc : 0.43392
  test  loss: 1.5767684173583985
  test  acc : 0.4283
EPOCH: 7
  train loss: 1.604973829984665
  train acc : 0.42878
  test  loss: 1.6015375423431397
  test  acc : 0.4257
EPOCH: 8
  train loss: 1.5612058963775635
  train acc : 0.44518
  test  loss: 1.5616987907886506
  test  acc : 0.4465
EPOCH: 9
  train loss: 1.4764125971794129
  train acc : 0.47216
 

In [5]:
# Freeze Teacher Model Parameter

for param in model_t.parameters():
  param.requires_grad = False

In [6]:
# Network : Student

import torch.nn as nn

class StudentNet(nn.Module):
  def __init__(self):
    super(StudentNet, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(3, 4, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Conv2d(4, 8, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2),
    )
    self.fc = nn.Sequential(
        nn.Linear(8 * 8 * 8, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )

  def forward(self, x1):
    x2 = self.conv(x1)
    x3 = x2.view(x2.size()[0], -1)
    x4 = self.fc(x3)
    return x4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = StudentNet().to(device)
print(model)

StudentNet(
  (conv): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=512, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
  )
)


In [7]:
# Training : Student Model by the standard CrossEntropyLoss

from torch import optim

optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.00001)
#optimizer = optim.Adam(model_t.parameters(), lr=0.01, weight_decay=0.001)
criterion = nn.CrossEntropyLoss()

nepoch=100

for i in range(nepoch):
  print(f"EPOCH: {i+1}")

  ### Train ###
  model_t.train()
  for x, t in dataloader_train:
    x = x.to(device)
    t = t.to(device)
    model.zero_grad()
    y = model(x)
    loss = criterion(y, t)
    loss.backward()
    optimizer.step()

  model.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_train:
    x = x.to(device)
    t = t.to(device)
    y = model(x)
    loss = criterion(y, t)
    _, predicted = y.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  train loss: {sum_loss/sum_iter}")
  print(f"  train acc : {sum_correct/(sum_iter*batch_size)}")

  ### Test ###
  model.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_test:
    x = x.to(device)
    t = t.to(device)
    y = model(x)
    loss = criterion(y, t)
    _, predicted = y.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  test  loss: {sum_loss/sum_iter}")
  print(f"  test  acc : {sum_correct/(sum_iter*batch_size)}")

EPOCH: 1
  train loss: 2.1246299471855163
  train acc : 0.24438
  test  loss: 2.1187277603149415
  test  acc : 0.2507
EPOCH: 2
  train loss: 1.9269846847057344
  train acc : 0.31846
  test  loss: 1.9150080978870392
  test  acc : 0.3256
EPOCH: 3
  train loss: 1.7830498971939086
  train acc : 0.36746
  test  loss: 1.775259416103363
  test  acc : 0.372
EPOCH: 4
  train loss: 1.6747831597328187
  train acc : 0.40774
  test  loss: 1.6701966845989227
  test  acc : 0.4117
EPOCH: 5
  train loss: 1.60275314950943
  train acc : 0.43146
  test  loss: 1.6032016015052795
  test  acc : 0.4368
EPOCH: 6
  train loss: 1.5313728458881377
  train acc : 0.45858
  test  loss: 1.5387340426445006
  test  acc : 0.4577
EPOCH: 7
  train loss: 1.4773691787719727
  train acc : 0.47872
  test  loss: 1.4908929932117463
  test  acc : 0.4767
EPOCH: 8
  train loss: 1.432268519639969
  train acc : 0.49384
  test  loss: 1.4552681756019592
  test  acc : 0.4841
EPOCH: 9
  train loss: 1.4079386060237884
  train acc : 0.498

In [8]:
# Network : Student for Distilation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_s = StudentNet().to(device)
print(model_s)

StudentNet(
  (conv): Sequential(
    (0): Conv2d(3, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=512, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=10, bias=True)
  )
)


In [9]:
# Training : Student Model

from torch import optim
import torch.nn.functional as F

optimizer = optim.SGD(model_s.parameters(), lr=0.01, weight_decay=0.00001)
#optimizer = optim.Adam(model_s.parameters(), lr=0.01, weight_decay=0.001)
criterion_s = nn.KLDivLoss(reduction="batchmean")

nepoch = 100

model_t.eval()

for i in range(nepoch):
  print(f"EPOCH: {i+1}")

  ### Train ###
  model_s.train()
  for x, _ in dataloader_train:
    x = x.to(device)
    model_s.zero_grad()
    y_s = model_s(x)
    y_t = model_t(x)
    loss = criterion_s(F.log_softmax(y_s, dim=1), F.softmax(y_t, dim=1))
    loss.backward()
    optimizer.step()

  model_s.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_train:
    x = x.to(device)
    t = t.to(device)
    y_s = model_s(x)
    y_t = model_t(x)
    loss = criterion_s(F.log_softmax(y_s, dim=1), F.softmax(y_t, dim=1))
    _, predicted = y_s.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  train loss: {sum_loss/sum_iter}")
  print(f"  train acc : {sum_correct/(sum_iter*batch_size)}")

  ### Test ###
  model_s.eval()
  sum_loss = 0.0
  sum_correct = 0
  sum_iter = 0
  for x, t in dataloader_test:
    x = x.to(device)
    t = t.to(device)
    y_s = model_s(x)
    y_t = model_t(x)
    loss = criterion_s(F.log_softmax(y_s, dim=1), F.softmax(y_t, dim=1))
    _, predicted = y_s.max(1)
    sum_loss += loss.cpu().detach().numpy()
    sum_correct += (predicted == t).sum().item()
    sum_iter += 1
  print(f"  test  loss: {sum_loss/sum_iter}")
  print(f"  test  acc : {sum_correct/(sum_iter*batch_size)}")

EPOCH: 1
  train loss: 1.635397979259491
  train acc : 0.14438
  test  loss: 1.6134192073345184
  test  acc : 0.1457
EPOCH: 2
  train loss: 1.3672885711193086
  train acc : 0.29066
  test  loss: 1.342571724653244
  test  acc : 0.2938
EPOCH: 3
  train loss: 1.2431135723590852
  train acc : 0.33892
  test  loss: 1.2191564106941224
  test  acc : 0.3463
EPOCH: 4
  train loss: 1.1555605547428132
  train acc : 0.37314
  test  loss: 1.1344581580162048
  test  acc : 0.377
EPOCH: 5
  train loss: 1.0533302186727524
  train acc : 0.40612
  test  loss: 1.037122082710266
  test  acc : 0.4104
EPOCH: 6
  train loss: 0.9701670808792114
  train acc : 0.4291
  test  loss: 0.9553786396980286
  test  acc : 0.4349
EPOCH: 7
  train loss: 0.9479205652475357
  train acc : 0.44286
  test  loss: 0.9323837625980377
  test  acc : 0.4479
EPOCH: 8
  train loss: 0.8690963395833969
  train acc : 0.47078
  test  loss: 0.8580413711071014
  test  acc : 0.4691
EPOCH: 9
  train loss: 0.825521358370781
  train acc : 0.4863

参考<br>
https://github.com/peterliht/knowledge-distillation-pytorch<br>
http://codecrafthouse.jp/p/2018/01/knowledge-distillation/