In [8]:
%reload_ext autoreload
%autoreload 2

import torch
import torchvision
import numpy as np

In [9]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Load Data 

In [10]:
batch_size = 64

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0, 0.5),
])


trainset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Build Models

In [11]:
class Teacher(torch.nn.Module):
    def __init__(self):
        super(Teacher, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        
        self.leakyrelu = torch.nn.LeakyReLU(negative_slope=0.2)
        self.max_pool = torch.nn.MaxPool2d(kernel_size=2) # padding 'same' do not exit...
        
        self.dense = torch.nn.Linear(4*4*512, 10)
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.leakyrelu(output)
        output = self.max_pool(output)
        output = self.conv2(output)
        output = self.dense(output.view(output.size(0), -1))
        
        return output
    
class Student(torch.nn.Module):
    def __init__(self):
        super(Student, self).__init__()
        
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        
        self.leakyrelu = torch.nn.LeakyReLU(negative_slope=0.2)
        self.max_pool = torch.nn.MaxPool2d(kernel_size=2) # padding 'same' do not exit...
        
        self.dense = torch.nn.Linear(4*4*32, 10)
        
    def forward(self, x):
        output = self.conv1(x)
        output = self.leakyrelu(output)
        output = self.max_pool(output)
        output = self.conv2(output)
        output = self.dense(output.view(output.size(0), -1))
        
        return output

In [33]:
teacher = Teacher().to(device)
student = Student().to(device)
student_scratch = Student().to(device)

In [34]:
teacher_opt = torch.optim.Adam(teacher.parameters(), lr=0.001)
student_opt = torch.optim.Adam(student.parameters(), lr=0.001)
student_scratch_opt = torch.optim.Adam(student_scratch.parameters(), lr=0.001)

In [28]:
criterion = torch.nn.CrossEntropyLoss()
distill_criterion = torch.nn.KLDivLoss(reduction='batchmean')

# Training

Pytorch KL-Divergence 실험 (feat. log_softmax와 softmax 무엇이 맞는지)
- https://douglasrizzo.com.br/kl-div-pytorch/

<div align='center'>
    <img width="400" src="https://www.oreilly.com/library/view/generative-adversarial-networks/9781789136678/assets/3ed12abb-fb0d-4205-848a-928127ec92ca.png">
</div>

In [37]:
def train(
    model, dataloader, criterion, optimizer, teacher_model=None, 
    distill_criterion=None, alpha=0.1, temperature=10, device='cpu'
):
    if teacher_model!=None:
        teacher_model.eval()
        
    correct = 0 
    total = 0
    total_loss = 0
    for idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # predict
        outputs = model(inputs)
        
        if teacher_model != None:
            teacher_outputs = teacher_model(inputs)
            
        # loss and update
        optimizer.zero_grad()
                
        loss = criterion(outputs, targets)
        
        if teacher_model != None:
            student_outputs = torch.nn.functional.log_softmax(outputs / temperature, dim=1) # log softmax
            teacher_outputs = torch.nn.functional.softmax(teacher_outputs / temperature, dim=1) # softmax
            
            distill_loss = distill_criterion(student_outputs, teacher_outputs)
            
            loss = alpha * loss + (1 - alpha) * distill_loss 
            
        loss.backward()
        optimizer.step()
        
        # total loss and acc
        total_loss += loss.item()
        
        preds = outputs.argmax(dim=1) 
        correct += targets.eq(preds).sum().item()
        total += targets.size(0)
        
        if idx == (len(dataloader)-1):
            print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                  (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\n')
        else:
            print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                  (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\r')
        
        
def test(model, dataloader, criterion, device='cpu'):
    correct = 0
    total = 0
    total_loss = 0
    
    with torch.no_grad():
        for idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # predict
            outputs = model(inputs)
            
            # loss 
            loss = criterion(outputs, targets)
            
            # total loss and acc
            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += targets.eq(preds).sum().item()
            total += targets.size(0)
            
            
            if idx == (len(dataloader)-1):
                print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                      (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\n')
            else:
                print('[%d/%d]: Loss: %.3f | Acc: %.3f%% [%d/%d]' % 
                      (idx+1, len(dataloader), total_loss/(idx+1), 100.*correct/total, correct, total),end='\r')
            

# Teacher Training

In [30]:
epochs = 5

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(teacher, trainloader, criterion, teacher_opt, device=device)
    test(teacher, testloader, criterion, device)

Epoch: 1/5
[938/938]: Loss: 0.175 | Acc: 94.737% [56842/60000]
[157/157]: Loss: 0.097 | Acc: 96.890% [9689/10000]
Epoch: 2/5
[938/938]: Loss: 0.098 | Acc: 97.027% [58216/60000]
[157/157]: Loss: 0.075 | Acc: 97.570% [9757/10000]
Epoch: 3/5
[938/938]: Loss: 0.080 | Acc: 97.602% [58561/60000]
[157/157]: Loss: 0.104 | Acc: 96.870% [9687/10000]
Epoch: 4/5
[938/938]: Loss: 0.071 | Acc: 97.805% [58683/60000]
[157/157]: Loss: 0.097 | Acc: 97.350% [9735/10000]
Epoch: 5/5
[938/938]: Loss: 0.067 | Acc: 98.065% [58839/60000]
[157/157]: Loss: 0.107 | Acc: 97.160% [9716/10000]


# Knowledge Distillation

In [35]:
epochs = 3

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(student, trainloader, criterion, student_opt,
          teacher, distill_criterion, alpha=0.1, temperature=10, device=device)
    test(student, testloader, criterion, device=device)

Epoch: 1/3
[938/938]: Loss: 0.168 | Acc: 87.383% [52430/60000]
[157/157]: Loss: 0.197 | Acc: 95.000% [9500/10000]
Epoch: 2/3
[938/938]: Loss: 0.057 | Acc: 95.397% [57238/60000]
[157/157]: Loss: 0.131 | Acc: 96.550% [9655/10000]
Epoch: 3/3
[938/938]: Loss: 0.042 | Acc: 96.530% [57918/60000]
[157/157]: Loss: 0.105 | Acc: 97.200% [9720/10000]


# Student Scratch Training

In [36]:
epochs = 3

for epoch in range(epochs):
    print(f'Epoch: {epoch+1}/{epochs}')
    train(student_scratch, trainloader, criterion, student_scratch_opt, device=device)
    test(student_scratch, testloader, criterion, device=device)

Epoch: 1/3
[938/938]: Loss: 0.387 | Acc: 88.713% [53228/60000]
[157/157]: Loss: 0.163 | Acc: 95.100% [9510/10000]
Epoch: 2/3
[938/938]: Loss: 0.143 | Acc: 95.658% [57395/60000]
[157/157]: Loss: 0.107 | Acc: 96.720% [9672/10000]
Epoch: 3/3
[938/938]: Loss: 0.111 | Acc: 96.578% [57947/60000]
[157/157]: Loss: 0.096 | Acc: 96.880% [9688/10000]
