In [1]:
import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torchvision import models
import numpy as np
from tqdm.notebook import tqdm
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
from os import path

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

BATCH_SIZE = 512
NUM_EPOCHS = 20
PRINT_EVERY = NUM_EPOCHS // 100 if NUM_EPOCHS > 100 else 1
TEACHER_PATH = "./teacher.pth"
LR = 0.01
NUM_WORKERS = 1

cuda


In [3]:
def get_acc(net, loader):
    with torch.no_grad():
        total = 0
        correct = 0
        for data in loader:
            images, labels = data
            
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [4]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=BATCH_SIZE, 
                                          shuffle=True, 
                                          num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=BATCH_SIZE, 
                                         shuffle=False, 
                                         num_workers=NUM_WORKERS)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(len(classes))

Files already downloaded and verified
Files already downloaded and verified
10


In [5]:
teacher = models.vgg16(pretrained=True)
teacher.classifier[6] = nn.Linear(4096,10)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=LR, momentum=0.9)
# optimizer = optim.Adam(net.parameters(), lr=LR)

In [6]:

if not path.exists(TEACHER_PATH):
    t = tqdm(range(NUM_EPOCHS))
    teacher.to(DEVICE)
    for epoch in t:
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            optimizer.zero_grad()

            outputs = teacher(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        if (epoch + 1) % PRINT_EVERY == 0:
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader)}')
            running_loss = 0.0
            
            correct = 0
            total = 0

            acc = get_acc(teacher, testloader)
            t.set_description(f'Acc: {acc} %')

            # with torch.no_grad():
            #     for data in testloader:
            #         images, labels = data
                    
            #         images = images.to(DEVICE)
            #         labels = labels.to(DEVICE)

            #         outputs = teacher(images)

            #         _, predicted = torch.max(outputs.data, 1)
            #         total += labels.size(0)
            #         correct += (predicted == labels).sum().item()
            # t.set_description(f'Acc: {100 * correct / total} %')

    print("Finished Training")
    torch.save(teacher.state_dict(), TEACHER_PATH)
else:
    print("Loaded saved teacher model")
    teacher.load_state_dict(torch.load(TEACHER_PATH))
    teacher.to(DEVICE)

Loaded saved teacher model


In [7]:
TEACHER_NUM_PARAMS = sum(p.numel() for p in teacher.parameters())
print(TEACHER_NUM_PARAMS)

134301514


In [8]:
TEACHER_ACC = get_acc(teacher, testloader)
print(f"Accuracy: {TEACHER_ACC} %")

Accuracy: 88.21 %


In [9]:
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in testloader:
        images, labels = data    
        outputs = teacher(images.to(DEVICE))    
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions.cpu()):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1
  
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class {classname} is: {accuracy}")

Accuracy for class plane is: 87.5
Accuracy for class car is: 95.6
Accuracy for class bird is: 83.3
Accuracy for class cat is: 74.2
Accuracy for class deer is: 86.4
Accuracy for class dog is: 86.8
Accuracy for class frog is: 92.6
Accuracy for class horse is: 87.9
Accuracy for class ship is: 96.2
Accuracy for class truck is: 92.2


In [10]:
class dVGG(nn.Module):
    def __init__(self, a=0, kind=1):
        super().__init__()
        self.one = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        self.two = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 10 * 10, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
        
        self.criterion_mse = torch.nn.MSELoss()
        self.criterion_ce = torch.nn.CrossEntropyLoss()
        self.a = a
        self.kind = kind

    def forward(self, x):
        if self.kind == 1:
            out = self.one(x)
        elif self.kind == 2:
            out = self.two(x)
        else:
            raise ValueError("Unexpected `kind`")
        return out
    
    def loss(self, output, teacher_prob, real_label):
        return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, teacher_prob)


In [11]:
def distil(a=0, kind=1, opt="sgd", coef=1):
    print(f"=== {a} | {kind} | {opt} ===")
    t = tqdm(range(NUM_EPOCHS * coef))
    net = dVGG(a, kind).to(DEVICE)
    if opt.lower() == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if opt.lower() == "adam":
        optimizer = optim.Adam(net.parameters(), lr=LR)

    for epoch in t:
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()

            outputs_teacher = teacher(inputs)
            outputs = net(inputs)
            loss = net.loss(outputs, outputs_teacher, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        if (epoch + 1) % PRINT_EVERY == 0:
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader)}')
            running_loss = 0.0
            
            acc = get_acc(net, testloader)
            t.set_description(f'Acc: {acc} %')

    learner_acc = get_acc(net, testloader)
    torch.save(net.state_dict(), f"./distilled_{a}_{kind}_{opt}.pth")

    learner_num_params = sum(p.numel() for p in net.parameters())
    print(f"=== Finished: {a} | {kind} | {opt} | {learner_acc} ===")
    print("\tTotal number of teacher params:", TEACHER_NUM_PARAMS)
    print("\tTotal number of learner params:", learner_num_params)
    print("\tTotal reduction:", TEACHER_NUM_PARAMS / learner_num_params)
    print("\tTeacher accuracy on test:", TEACHER_ACC)
    print("\tLearner accuracy on test:", learner_acc)
    print("\tDiff:", TEACHER_ACC - learner_acc)
    print()


In [11]:
for a in (0, 0.1, 0.5, 0.7, 0.9):
    for kind in (1, 2):
        for opt in ("sgd", "adam"):
            distil(a, kind, opt)

=== 0 | 1 | sgd ===


  0%|          | 0/20 [00:00<?, ?it/s]

[1] loss: 37.32420335497175
[2] loss: 27.914160806305553
[3] loss: 24.34347592567911
[4] loss: 21.804948787299953
[5] loss: 20.056565868611237
[6] loss: 18.261655924271565
[7] loss: 17.34165014539446
[8] loss: 16.665676846796153
[9] loss: 15.721143732265551
[10] loss: 15.170251982552665
[11] loss: 14.602479311884666
[12] loss: 13.97068081096727
[13] loss: 13.727932647782929
[14] loss: 13.429940554560448
[15] loss: 12.95250855659952
[16] loss: 12.631445943092814
[17] loss: 12.263682404342962
[18] loss: 11.739936955121099
[19] loss: 11.826859104390048
[20] loss: 11.636767990735112
=== Finished: 0 | 1 | sgd | 64.86 ===
	Total number of teacher params: 134301514
	Total number of learner params: 279812
	Total reduction: 0.20834612482477302
	Teacher accuracy on test: 88.04
	Learner accuracy on test: 64.86
	Diff: 23.180000000000007
=== 0 | 1 | adam ===


  0%|          | 0/20 [00:00<?, ?it/s]

[1] loss: 34.961723969907176
[2] loss: 27.226333793328735
