In [2]:
%reload_ext autoreload
%autoreload 2

import torch
import torchvision
import numpy as np

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load Data 

In [4]:
batch_size = 64

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0, 0.5),
])


trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:01, 8764583.47it/s]                              


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 6830493.64it/s]          


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 4029193.11it/s]                            


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 12132675.98it/s]         

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






# Build Models

In [6]:
class Teacher(torch.nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        
        self.leakyrelu = torch.nn.LeakyReLU(negative_slope=0.2)
        self.max_pool = torch.nn.MaxPool2d(kernel_size=2) # padding 'same' do not exit...
        
        self.dense = torch.nn.Linear(4*4*512, 10)
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.leakyrelu(output)
        output = self.max_pool(output)
        output = self.conv2(output)
        output = self.dense(output.view(output.size(0), -1))
        
        return output
    
class Student(torch.nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        
        self.leakyrelu = torch.nn.LeakyReLU(negative_slope=0.2)
        self.max_pool = torch.nn.MaxPool2d(kernel_size=2) # padding 'same' do not exit...
        
        self.dense = torch.nn.Linear(4*4*32, 10)
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.leakyrelu(output)
        output = self.max_pool(output)
        output = self.conv2(output)
        output = self.dense(output.view(output.size(0), -1))
        
        return output

In [7]:
teacher = Teacher().to(device)
student = Student().to(device)
student_scratch = Student().to(device)

In [8]:
teacher_opt = torch.optim.Adam(teacher.parameters(), lr=0.001)
student_opt = torch.optim.Adam(student.parameters(), lr=0.001)
student_scratch_opt = torch.optim.Adam(student_scratch.parameters(), lr=0.001)

In [9]:
criterion = torch.nn.CrossEntropyLoss()
distill_criterion = torch.nn.KLDivLoss(reduction='batchmean')

# Training

Pytorch KL-Divergence 실험 (feat. log_softmax와 softmax 무엇이 맞는지)
- https://douglasrizzo.com.br/kl-div-pytorch/

<div align='center'>
    <img width="400" src="https://www.oreilly.com/library/view/generative-adversarial-networks/9781789136678/assets/3ed12abb-fb0d-4205-848a-928127ec92ca.png">
</div>

In [10]:
def train(
    model, dataloader, criterion, optimizer, teacher_model=None, 
    distill_criterion=None, alpha=0.1, temperature=10, device='cpu'
):
    if teacher_model!=None:
        teacher_model.eval()
        
    correct = 0 
    total = 0
    total_loss = 0
    for idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # predict
        outputs = model(inputs)
        
        if teacher_model != None:
            teacher_outputs = teacher_model(inputs)
            
        # loss and update
        optimizer.zero_grad()
                
        loss = criterion(outputs, targets)
        
        if teacher_model != None:
            student_outputs = torch.nn.functional.log_softmax(outputs / temperature, dim=1) # log softmax
            teacher_outputs = torch.nn.functional.softmax(teacher_outputs / temperature, dim=1) # softmax
            
            distill_loss = distill_criterion(student_outputs, teacher_outputs)
            
            loss = alpha * loss + (1 - alpha) * distill_loss 
            
        loss.backward()
        optimizer.step()
        
        # total loss and acc
        total_loss += loss.item()
        
        preds = outputs.argmax(dim=1) 
        correct += targets.eq(preds).sum().item()
        total += targets.size(0)
        
        if idx == (len(dataloader)-1):
            print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                  (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\n')
        else:
            print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                  (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\r')
        
        
def test(model, dataloader, criterion, device='cpu'):
    correct = 0
    total = 0
    total_loss = 0
    
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # predict
            outputs = model(inputs)
            
            # loss 
            loss = criterion(outputs, targets)
            
            # total loss and acc
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += targets.eq(preds).sum().item()
            total += targets.size(0)
            
            
            if idx == (len(dataloader)-1):
                print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                      (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\n')
            else:
                print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                      (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\r')
            

# Teacher Training

In [11]:
epochs = 5

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(teacher, trainloader, criterion, teacher_opt, device=device)
    test(teacher, testloader, criterion, device)

Epoch: 1/5
[938/938]: Loss: 0.181 | Acc: 94.542% [56725/60000]
[157/157]: Loss: 0.090 | Acc: 97.130% [9713/10000]
Epoch: 2/5
[938/938]: Loss: 0.097 | Acc: 97.147% [58288/60000]
[157/157]: Loss: 0.099 | Acc: 97.110% [9711/10000]
Epoch: 3/5
[938/938]: Loss: 0.079 | Acc: 97.600% [58560/60000]
[157/157]: Loss: 0.081 | Acc: 97.650% [9765/10000]
Epoch: 4/5
[938/938]: Loss: 0.072 | Acc: 97.795% [58677/60000]
[157/157]: Loss: 0.081 | Acc: 97.560% [9756/10000]
Epoch: 5/5
[938/938]: Loss: 0.066 | Acc: 98.015% [58809/60000]
[157/157]: Loss: 0.102 | Acc: 97.310% [9731/10000]


# Knowledge Distillation

In [12]:
epochs = 3

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(student, trainloader, criterion, student_opt,
          teacher, distill_criterion, alpha=0.1, temperature=10, device=device)
    test(student, testloader, criterion, device=device)

Epoch: 1/3
[938/938]: Loss: 0.178 | Acc: 87.552% [52531/60000]
[157/157]: Loss: 0.197 | Acc: 94.770% [9477/10000]
Epoch: 2/3
[938/938]: Loss: 0.062 | Acc: 95.352% [57211/60000]
[157/157]: Loss: 0.130 | Acc: 96.470% [9647/10000]
Epoch: 3/3
[938/938]: Loss: 0.046 | Acc: 96.475% [57885/60000]
[157/157]: Loss: 0.107 | Acc: 97.010% [9701/10000]


# Student Scratch Training

In [13]:
epochs = 3

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(student_scratch, trainloader, criterion, student_scratch_opt, device=device)
    test(student_scratch, testloader, criterion, device=device)

Epoch: 1/3
[938/938]: Loss: 0.349 | Acc: 89.823% [53894/60000]
[157/157]: Loss: 0.142 | Acc: 95.770% [9577/10000]
Epoch: 2/3
[938/938]: Loss: 0.142 | Acc: 95.652% [57391/60000]
[157/157]: Loss: 0.107 | Acc: 96.710% [9671/10000]
Epoch: 3/3
[938/938]: Loss: 0.115 | Acc: 96.517% [57910/60000]
[157/157]: Loss: 0.098 | Acc: 96.760% [9676/10000]
