In [1]:
import torchvision
import torch
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

### 1 下载数据

In [2]:
transform = transforms.Compose(
  [transforms.ToTensor(),
 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
  download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
  shuffle=True, num_workers=2)

Files already downloaded and verified


In [3]:
testset = torchvision.datasets.CIFAR10(root='./data',
  download=True, transform=transform)
testloader = torch.utils.data.DataLoader(trainset, batch_size=16,
  shuffle=False, num_workers=2)

Files already downloaded and verified


In [4]:
len(trainloader), len(testloader)

(3125, 3125)

In [5]:
x, y = next(iter(trainloader))

In [6]:
x.shape, y.shape

(torch.Size([16, 3, 32, 32]), torch.Size([16]))

### 2 创建学生和老师模型

In [7]:
class Teacher(nn.Module):
  def __init__(self, in_channel, out_channel, drop_rate):
    super(Teacher, self).__init__()
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.conv1 = nn.Conv2d(in_channel, out_channel*4, kernel_size=2, stride=1)
    self.bn1 = nn.BatchNorm2d(out_channel*4)
    self.conv2 = nn.Conv2d(out_channel*4, out_channel*2, kernel_size=2, stride=1)
    self.bn2 = nn.BatchNorm2d(out_channel*2)
    self.conv3 = nn.Conv2d(out_channel*2, out_channel, kernel_size=2, stride=1)
    self.bn3 = nn.BatchNorm2d(out_channel)
    self.fc1 = nn.Linear(out_channel*29*29, 16*16)
    self.fc2 = nn.Linear(16*16,10)
    self.drop_rate = drop_rate
  
  def forward(self, x):
    out = self.bn3(self.conv3(self.bn2(self.conv2(self.bn1(self.conv1(x))))))
    out = out.view(-1, self.out_channel*29*29)
    return F.dropout(self.fc2(self.fc1(out)),p=self.drop_rate, training=self.training)

In [8]:
class Student(nn.Module):
  def __init__(self, in_channel, out_channel, drop_rate):
    super(Student, self).__init__()
    self.in_channel = in_channel
    self.out_channel = out_channel
    self.conv1 = nn.Conv2d(in_channel, out_channel*2, kernel_size=2, stride=1)
    self.bn1 = nn.BatchNorm2d(out_channel*2)
    self.conv2 = nn.Conv2d(out_channel*2, out_channel, kernel_size=2, stride=1)
    self.bn2 = nn.BatchNorm2d(out_channel)
    self.fc1 = nn.Linear(out_channel*30*30, 10)
    self.drop_rate = drop_rate
  
  def forward(self, x):
    out = self.bn2(self.conv2(self.bn1(self.conv1(x))))
    out = out.view(-1, self.out_channel*30*30)
    return F.dropout(self.fc1(out), p=self.drop_rate, training=self.training)

### 3 训练模型

In [9]:
def train_teacher(teacher, loader, epochs=20):
  loss_fn = nn.CrossEntropyLoss()
  optimizer = optim.Adam(teacher.parameters(), lr=1e-4)
  for epoch in range(epochs):
    loss_list = []
    for i, (x, y) in enumerate(loader):
      teacher_out = teacher(x.to(device))
      loss = loss_fn(teacher_out,y.to(device))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_list.append(loss.item())
    print('epoch: ', epoch, ' loss: ', np.mean(loss_list))  

In [10]:
def evaluate(model, loader):
  output, target = [], []
  with torch.no_grad():
    for i, (x, y) in enumerate(loader):
      teacher_out = model(x.to(device))
      output.append(teacher_out.cpu().numpy())
      target.append(y.numpy())
    print(accuracy(np.array(output).reshape(-1,10), np.array(target).reshape(-1,1)))

In [11]:
def accuracy(outputs, labels):
  """
  outputs: (np.ndarray) output of the model
  labels: (np.ndarray) [0, 1, ..., num_classes-1]
  Returns: (float) accuracy in [0,1]
  """
  outputs = np.argmax(outputs, axis=1)
  return np.sum(outputs==labels.flatten())/len(labels.flatten())
  

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
teacher = Teacher(3,8,0.01).to(device)
train_teacher(teacher, trainloader, epochs=20)


epoch:  0  loss:  1.8812922985076905
epoch:  1  loss:  1.7353566838645935
epoch:  2  loss:  1.7084689106750488
epoch:  3  loss:  1.6895167400550841
epoch:  4  loss:  1.6714395731925964
epoch:  5  loss:  1.6568386020278931
epoch:  6  loss:  1.6460781346511841
epoch:  7  loss:  1.631511506462097
epoch:  8  loss:  1.6256449743843078
epoch:  9  loss:  1.6205993981361388
epoch:  10  loss:  1.6109432560729982
epoch:  11  loss:  1.6055876219558716
epoch:  12  loss:  1.6007826876449585
epoch:  13  loss:  1.594331014251709
epoch:  14  loss:  1.5883402904510497
epoch:  15  loss:  1.5821655068397522
epoch:  16  loss:  1.5767782825469971
epoch:  17  loss:  1.5731523121833801
epoch:  18  loss:  1.5698520596885681
epoch:  19  loss:  1.5617119542884828


In [13]:
evaluate(teacher, testloader)

0.48522


In [14]:
def loss_fn(outputs, teacher_outputs, labels, alpha=0.5, T=2):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
            F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
            F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

In [15]:
def train_student(teacher, student, loader, epochs=20):  
  teacher.eval()
  student.train()
  optimizer = optim.Adam(student.parameters(), lr=1e-4)
  for epoch in range(epochs):
    loss_list = []
    for i, (x, y) in enumerate(loader):
      student_out = student(x.to(device))
      teacher_out = teacher(x.to(device))
      loss = loss_fn(student_out, teacher_out, y.to(device))
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      loss_list.append(loss.item())
    print('epoch: ', epoch, ' loss: ', np.mean(loss_list))  

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student = Student(3,8,0.01).to(device)
train_student(teacher,student,trainloader, epochs=20)
evaluate(student,testloader)

  "reduction: 'mean' divides the total loss by both the batch size and the support size."


epoch:  0  loss:  0.9423875627994537
epoch:  1  loss:  0.8951312295341491
epoch:  2  loss:  0.8772634046649933
epoch:  3  loss:  0.8665837880897522
epoch:  4  loss:  0.8573688481521606
epoch:  5  loss:  0.8484362033462525
epoch:  6  loss:  0.8420833777904511
epoch:  7  loss:  0.8351898818588257
epoch:  8  loss:  0.8292628209209442
epoch:  9  loss:  0.8236157040977478
epoch:  10  loss:  0.8187981503486633
epoch:  11  loss:  0.8155362488937378
epoch:  12  loss:  0.810830752248764
epoch:  13  loss:  0.8076176631832123
epoch:  14  loss:  0.8039064839172363
epoch:  15  loss:  0.8005043474006652
epoch:  16  loss:  0.796022368850708
epoch:  17  loss:  0.7930228970432281
epoch:  18  loss:  0.7911533143043518
epoch:  19  loss:  0.7881668622207642
0.48956
