### Load Libraries

In [1]:
import torch
from torch import nn, optim
from torch.nn import functional
import os

from teacher import CNN_Teacher
from student import CNN_Student
from common import MODEL_DIRECTORY, accuracy, evaluate, get_loaders

In [2]:
def distillation_loss(logits_stu, logits_base, t):
    logits_stu = logits_stu / t
    logits_base = logits_base / t
    pred_stu = functional.log_softmax(logits_stu,  dim=1)
    prop_base = nn.Softmax(dim=1)(logits_base)
    pred_base = torch.argmax(prop_base, dim=1)
    return functional.nll_loss(pred_stu, pred_base, reduction='sum')
    
def train(model, iterator, optimizer, teacher, t, alpha):    
    epoch_loss = 0
    epoch_acc = 0

    model = model.train()
    teacher = teacher.eval()

    for (x, y) in iterator:    
        optimizer.zero_grad()

        if teacher is not None:
            y_pred, logits_pred = model(x)        
            _, logits_teacher = teacher(x)

            dist_loss = distillation_loss(logits_pred, logits_teacher, t)            
            stu_loss = functional.nll_loss(y_pred, y, reduction='sum')     
            loss = alpha * dist_loss + (1 - alpha) * stu_loss
        else:
            y_pred, _ = model(x)        
            loss = functional.nll_loss(y_pred, y, reduction='sum')   

        acc = accuracy(y_pred, y)        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [3]:
teacher = CNN_Teacher()
teacher.load_state_dict(torch.load(os.path.join(MODEL_DIRECTORY, 'teacher.pt')))
teacher = teacher.eval()

In [4]:
t = 6.0
alpha = 0.95

dist_model = CNN_Student()
best_test_acc  = 0
train_loader, test_loader = get_loaders()
optimizer = optim.SGD(dist_model.parameters(), lr=0.01)

for epoch in range(50):

    train_loss, train_acc = train(dist_model, train_loader, optimizer, teacher, t, alpha)
    test_loss, test_acc = evaluate(dist_model, test_loader)

    if test_acc > best_test_acc:
        file_name = os.path.join(MODEL_DIRECTORY, f'distilled-{t}.pt')
        print(f'Saving file {file_name}, test_accuracy ({test_acc}) > best_test_accuracy({best_test_acc})')
        torch.save(dist_model.state_dict(), file_name)

        best_test_acc = test_acc

    if epoch % 10 == 9:
        print(
            f'{epoch:02}: train: loss {train_loss:.3f}, acc {train_acc * 100:.2f}%, ' +
                f'test: loss {test_loss:.3f}, acc {test_acc * 100:.2f}%, best test acc: {best_test_acc * 100:.2f}%'
        )
    else:
        print(f'{epoch:02}, ', end='')

Loaded data: 50000 train and 10000 test examples
Saving file models/distilled-6.0.pt, test_accuracy (0.5200039808917197) > best_test_accuracy(0)
00, Saving file models/distilled-6.0.pt, test_accuracy (0.6472929936305732) > best_test_accuracy(0.5200039808917197)
01, 02, Saving file models/distilled-6.0.pt, test_accuracy (0.6899880573248408) > best_test_accuracy(0.6472929936305732)
03, Saving file models/distilled-6.0.pt, test_accuracy (0.7127786624203821) > best_test_accuracy(0.6899880573248408)
04, Saving file models/distilled-6.0.pt, test_accuracy (0.7390525477707006) > best_test_accuracy(0.7127786624203821)
05, Saving file models/distilled-6.0.pt, test_accuracy (0.7418391719745223) > best_test_accuracy(0.7390525477707006)
06, 07, Saving file models/distilled-6.0.pt, test_accuracy (0.7625398089171974) > best_test_accuracy(0.7418391719745223)
08, 09: train: loss 51.390, acc 73.69%, test: loss 148.142, acc 76.07%, best test acc: 76.25%
10, 11, Saving file models/distilled-6.0.pt, test_a