In [1]:
import torch
from torch import optim
from torch import nn
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torchvision import models
import numpy as np
from tqdm.notebook import tqdm
import matplotlib
from matplotlib import pyplot as plt
%matplotlib inline
from os import path

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

BATCH_SIZE = 512
NUM_EPOCHS = 100
PRINT_EVERY = NUM_EPOCHS // 100 if NUM_EPOCHS > 100 else 1
TEACHER_PATH = "./teacher.pth"
LR = 0.01
NUM_WORKERS = 1

cuda


In [3]:
def get_acc(net, loader):
    net.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        for data in loader:
            images, labels = data
            
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = net(images)

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    net.train()
    return 100 * correct / total

In [4]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=BATCH_SIZE, 
                                          shuffle=True, 
                                          num_workers=NUM_WORKERS)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=BATCH_SIZE, 
                                         shuffle=False, 
                                         num_workers=NUM_WORKERS)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(len(classes))

Files already downloaded and verified
Files already downloaded and verified
10


In [5]:
teacher = models.vgg16_bn(pretrained=True)

# for param in teacher.features.parameters():
#     param.requires_grad = False

teacher.classifier[6] = nn.Linear(4096,10)

# teacher.classifier[6] = nn.Linear(4096,1024)
# teacher.classifier.add_module("head", nn.Linear(1024, 10))


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=LR, momentum=0.9)
# optimizer = optim.Adam(teacher.parameters(), lr=LR)

In [6]:
if not path.exists(TEACHER_PATH):
    t = tqdm(range(NUM_EPOCHS))
    teacher.to(DEVICE)
    for epoch in t:
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            optimizer.zero_grad()

            outputs = teacher(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        if (epoch + 1) % PRINT_EVERY == 0:
            acc = get_acc(teacher, testloader)
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.9f} | accuracy: {acc:0.2f}%')
            running_loss = 0.0

    print("Finished Training")
    torch.save(teacher.state_dict(), TEACHER_PATH)
else:
    print("Loaded saved teacher model")
    teacher.load_state_dict(torch.load(TEACHER_PATH))
    teacher.to(DEVICE)

  0%|          | 0/100 [00:00<?, ?it/s]

[1] loss: 7.151553208 | accuracy: 28.00%
[2] loss: 2.321216430 | accuracy: 29.84%


In [None]:
TEACHER_NUM_PARAMS = sum(p.numel() for p in teacher.parameters())
print(TEACHER_NUM_PARAMS)

In [None]:
TEACHER_ACC = get_acc(teacher, testloader)
print(f"Accuracy: {TEACHER_ACC} %")

In [None]:
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in testloader:
        images, labels = data    
        outputs = teacher(images.to(DEVICE))    
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions.cpu()):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1
  
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f"Accuracy for class {classname} is: {accuracy}")

In [None]:
class dVGG(nn.Module):
    def __init__(self, a=0, kind=1):
        super().__init__()
        self.one = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            
            nn.MaxPool2d(2, 2),
            
            nn.Flatten(),
            
            nn.Linear(16 * 5 * 5, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        self.two = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),

            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            
            nn.Flatten(),
            
            nn.Linear(16 * 10 * 10, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

        self.three = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.ReLU(),
            nn.Dropout(),

            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(),
            nn.BatchNorm2d(128),

            nn.Conv2d(128, 256, 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(),
            nn.BatchNorm2d(256),

            nn.Conv2d(256, 512, 3),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Conv2d(512, 512, 3),
            nn.ReLU(),
            nn.Dropout(),
            nn.BatchNorm2d(512),

            nn.Flatten(),

            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 10),
            nn.ReLU()
        )
        
        self.criterion_mse = torch.nn.MSELoss()
        self.criterion_ce = torch.nn.CrossEntropyLoss()
        self.a = a
        self.kind = kind

    def forward(self, x):
        if self.kind == 1:
            out = self.one(x)
        elif self.kind == 2:
            out = self.two(x)
        elif self.kind == 3:
            out = self.three(x)
        else:
            raise ValueError("Unexpected `kind`")
        
        return out
    
    def loss(self, output, teacher_prob, real_label):
        return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, teacher_prob)


In [None]:
def train_baseline(kind=3, opt="sgd", coef=1):
    print(f"=== {kind} | {opt} ===")
    criterion = nn.CrossEntropyLoss()
    net = dVGG(1, kind).to(DEVICE)
    
    if opt.lower() == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if opt.lower() == "adam":
        optimizer = optim.Adam(net.parameters(), lr=LR)

    for epoch in tqdm(range(int(NUM_EPOCHS * coef))):
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        if (epoch + 1) % PRINT_EVERY == 0:
            acc = get_acc(net, testloader)
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.5f} | accuracy: {acc:0.2f}%')
            running_loss = 0.0

    baseline_acc = get_acc(net, testloader)
    torch.save(net.state_dict(), f"./baseline_{a}_{kind}_{opt}.pth")

    print(f"=== Finished: {kind} | {opt} ===")
    print("\Baseline accuracy on test:", baseline_acc, "%")
    print()



In [None]:
def distil(a=0, kind=3, opt="sgd", coef=1):
    print(f"=== {a} | {kind} | {opt} ===")
    net = dVGG(a, kind).to(DEVICE)
    if opt.lower() == "sgd":
        optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    if opt.lower() == "adam":
        optimizer = optim.Adam(net.parameters(), lr=LR)

    for epoch in tqdm(range(int(NUM_EPOCHS * coef))):
        running_loss = 0.0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()

            outputs_teacher = teacher(inputs)
            outputs = net(inputs)

            loss = net.loss(outputs, outputs_teacher, labels)
            
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        if (epoch + 1) % PRINT_EVERY == 0:
            acc = get_acc(net, testloader)
            print(f'[{epoch + 1}] loss: {running_loss / len(trainloader):0.5f} | accuracy: {acc:0.2f}%')
            running_loss = 0.0

    learner_acc = get_acc(net, testloader)
    torch.save(net.state_dict(), f"./distilled_{a}_{kind}_{opt}.pth")

    learner_num_params = sum(p.numel() for p in net.parameters())
    print(f"=== Finished: {a} | {kind} | {opt} ===")
    print("\tTotal number of teacher params:", TEACHER_NUM_PARAMS)
    print("\tTotal number of learner params:", learner_num_params)
    print("\tTotal reduction:", (TEACHER_NUM_PARAMS - learner_num_params) / TEACHER_NUM_PARAMS, "%")
    print("\tTeacher accuracy on test:", TEACHER_ACC, "%")
    print("\tLearner accuracy on test:", learner_acc, "%")
    print("\tDiff:", TEACHER_ACC - learner_acc)
    print()

In [None]:
train_baseline(coef=1)

In [None]:
for a in (0, 0.1, 0.5, 0.7, 0.9):
    for opt in ("sgd", "adam"):
        distil(a=a, kind=3, opt=opt, coef=1)