### Results
- VGG-19: 85.48%
- VGG-16: 85.71%
- VGG-13: 85.52%
- VGG-11: 82.70%
- DistilVGG: 79.68%

Distill into DistillVGG T=5, alpha=0.8
- VGG-13 -> 81.09%
- VGG-16 -> 80.50%
- VGG-16 -> VGG-11 -> 80.25%
- VGG-19 -> VGG-16 -> VGG-13 -> VGG-11 -> 79.47%

T6, alpha = .8
- VGG-19 -> VGG-16 -> VGG-13 -> VGG-11 -> DistillVGG
- 85.48 (Original), 85.38, 86.01, 83.41, 79.49

T10, alpha = .8
- VGG-16 -> VGG-13 -> VGG-11 -> DistillVGG
- 85.71, 85.23, 83.11, 80.14

T15, alpha = .8
- VGG-16 -> VGG-13 -> VGG-11 -> DistillVGG
- 85.71, 85.73, 83.54, 80.27

T15, alpha = .8
- VGG-16 -> VGG-13 -> VGG-11 -> DistillVGG
- 85.71 -> 85.30, 83.10, 80.46

VGG-13 -> DistilVGG
- T=5: 81.09
- T=10: 81.03
- T=15: 81.41
- T=20: 80.94

In [9]:
pytorch_total_params = lambda model: sum(p.numel() for p in model.parameters())

In [12]:
VGG('VGG13')

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3)

In [13]:
VGG('DistilVGG')

VGG(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_runn

In [10]:
pytorch_total_params(VGG('VGG13'))

9416010

In [11]:
pytorch_total_params(VGG('DistilVGG'))

1575690

In [1]:
import os
if os.getcwd().split('/')[-1] == "notebooks":
    os.chdir('..')

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import torchvision
import torchvision.transforms as transforms
from src.train import train
from src.kd import extract_logits, kd_ce_loss
from src.vgg import VGG
torch.manual_seed(0)

data = torch.load("./data/cifar10_training_data.pt")
labels = torch.load("./data/cifar10_training_labels.pt")

batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


### Compress VGG13 into DistilVGG

In [3]:
torch.manual_seed(0)
teacher = VGG('VGG13')
teacher.load_state_dict(torch.load("./models/vgg13.pt"))
student = VGG('DistilVGG')
if torch.cuda.is_available():
    teacher.cuda()
    student.cuda()

logits = extract_logits(teacher, trainloader).cpu()
kdtrain = torch.utils.data.TensorDataset(data, labels, logits)
kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)

criterion = kd_ce_loss
optimizer = optim.Adam(student.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 70, 105], gamma=0.1)
train(student, kdloader, testloader, optimizer, criterion, 120, writer=None, scheduler=scheduler)
torch.save(student.state_dict(), "./models/distilvgg_vgg13d.pt")
del teacher, kdtrain, kdloader

HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))

Epoch 1 accuracy = 60.92%
Epoch 2 accuracy = 69.21%
Epoch 3 accuracy = 71.25%
Epoch 4 accuracy = 70.79%
Epoch 5 accuracy = 71.79%
Epoch 6 accuracy = 70.03%
Epoch 7 accuracy = 73.29%
Epoch 8 accuracy = 74.34%
Epoch 9 accuracy = 74.77%
Epoch 10 accuracy = 76.44%
Epoch 11 accuracy = 76.32%
Epoch 12 accuracy = 76.17%
Epoch 13 accuracy = 76.20%
Epoch 14 accuracy = 77.83%
Epoch 15 accuracy = 77.35%
Epoch 16 accuracy = 78.51%
Epoch 17 accuracy = 77.79%
Epoch 18 accuracy = 77.59%
Epoch 19 accuracy = 76.71%
Epoch 20 accuracy = 76.93%
Epoch 21 accuracy = 77.15%
Epoch 22 accuracy = 78.06%
Epoch 23 accuracy = 78.30%
Epoch 24 accuracy = 79.12%
Epoch 25 accuracy = 80.26%
Epoch 26 accuracy = 80.98%
Epoch 27 accuracy = 81.07%
Epoch 28 accuracy = 80.94%
Epoch 29 accuracy = 81.01%
Epoch 30 accuracy = 81.02%
Epoch 31 accuracy = 81.13%
Epoch 32 accuracy = 81.21%
Epoch 33 accuracy = 81.27%
Epoch 34 accuracy = 81.36%
Epoch 35 accuracy = 81.39%
Epoch 36 accuracy = 81.29%
Epoch 37 accuracy = 81.31%
Epoch 38 a

In [30]:
torch.manual_seed(0)
T = 13
teacher = VGG('VGG13')
teacher.load_state_dict(torch.load("./models/vgg13.pt"))
student = VGG('DistilVGG')
if torch.cuda.is_available():
    teacher.cuda()
    student.cuda()

logits = extract_logits(teacher, trainloader).cpu()
kdtrain = torch.utils.data.TensorDataset(data, labels, logits)
kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)

criterion = kd_ce_loss(temperature = T, alpha = 0.8)
optimizer = optim.Adam(student.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 70], gamma=0.1)
train(student, kdloader, testloader, optimizer, criterion, 75, writer=None, scheduler=scheduler)
torch.save(student.state_dict(), "./models/distilvgg_vgg13d_t{}.pt".format(T))
del teacher, kdtrain, kdloader

HBox(children=(FloatProgress(value=0.0, max=75.0), HTML(value='')))

Epoch 1 accuracy = 56.07%
Epoch 2 accuracy = 67.58%
Epoch 3 accuracy = 66.16%
Epoch 4 accuracy = 72.54%
Epoch 5 accuracy = 73.43%
Epoch 6 accuracy = 74.20%
Epoch 7 accuracy = 76.43%
Epoch 8 accuracy = 76.08%
Epoch 9 accuracy = 76.01%
Epoch 10 accuracy = 75.80%
Epoch 11 accuracy = 77.30%
Epoch 12 accuracy = 77.84%
Epoch 13 accuracy = 77.96%
Epoch 14 accuracy = 78.74%
Epoch 15 accuracy = 79.32%
Epoch 16 accuracy = 79.60%
Epoch 17 accuracy = 80.29%
Epoch 18 accuracy = 80.48%
Epoch 19 accuracy = 80.29%
Epoch 20 accuracy = 80.55%
Epoch 21 accuracy = 80.21%
Epoch 22 accuracy = 80.24%
Epoch 23 accuracy = 80.18%
Epoch 24 accuracy = 80.12%
Epoch 25 accuracy = 79.78%
Epoch 26 accuracy = 81.33%
Epoch 27 accuracy = 81.30%
Epoch 28 accuracy = 81.38%
Epoch 29 accuracy = 81.40%
Epoch 30 accuracy = 81.46%
Epoch 31 accuracy = 81.42%
Epoch 32 accuracy = 81.41%
Epoch 33 accuracy = 81.37%
Epoch 34 accuracy = 81.39%
Epoch 35 accuracy = 81.37%
Epoch 36 accuracy = 81.43%
Epoch 37 accuracy = 81.40%
Epoch 38 a

### Compress VGG13 into DistilVGG T=17

In [31]:
torch.manual_seed(0)
T = 17
teacher = VGG('VGG13')
teacher.load_state_dict(torch.load("./models/vgg13.pt"))
student = VGG('DistilVGG')
if torch.cuda.is_available():
    teacher.cuda()
    student.cuda()

logits = extract_logits(teacher, trainloader).cpu()
kdtrain = torch.utils.data.TensorDataset(data, labels, logits)
kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)

criterion = kd_ce_loss(temperature = T, alpha = 0.8)
optimizer = optim.Adam(student.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 70], gamma=0.1)
train(student, kdloader, testloader, optimizer, criterion, 75, writer=None, scheduler=scheduler)
torch.save(student.state_dict(), "./models/distilvgg_vgg13d_t{}.pt".format(T))
del teacher, kdtrain, kdloader

HBox(children=(FloatProgress(value=0.0, max=75.0), HTML(value='')))

Epoch 1 accuracy = 57.73%
Epoch 2 accuracy = 63.05%
Epoch 3 accuracy = 69.56%
Epoch 4 accuracy = 73.32%
Epoch 5 accuracy = 73.68%
Epoch 6 accuracy = 75.74%
Epoch 7 accuracy = 76.31%
Epoch 8 accuracy = 75.88%
Epoch 9 accuracy = 75.45%
Epoch 10 accuracy = 76.67%
Epoch 11 accuracy = 75.61%
Epoch 12 accuracy = 75.56%
Epoch 13 accuracy = 76.87%
Epoch 14 accuracy = 76.04%
Epoch 15 accuracy = 77.34%
Epoch 16 accuracy = 77.46%
Epoch 17 accuracy = 79.00%
Epoch 18 accuracy = 78.93%
Epoch 19 accuracy = 78.52%
Epoch 20 accuracy = 78.59%
Epoch 21 accuracy = 78.51%
Epoch 22 accuracy = 78.26%
Epoch 23 accuracy = 78.06%
Epoch 24 accuracy = 79.27%
Epoch 25 accuracy = 80.23%
Epoch 26 accuracy = 81.16%
Epoch 27 accuracy = 81.38%
Epoch 28 accuracy = 81.43%
Epoch 29 accuracy = 81.55%
Epoch 30 accuracy = 81.65%
Epoch 31 accuracy = 81.74%
Epoch 32 accuracy = 81.70%
Epoch 33 accuracy = 81.73%
Epoch 34 accuracy = 81.78%
Epoch 35 accuracy = 81.77%
Epoch 36 accuracy = 81.79%
Epoch 37 accuracy = 81.80%
Epoch 38 a

### Compress VGG16 into VGG13 into VGG11 into DistilVGG

In [22]:
def distill_models(teacher_model, student_model, teacher_dir, student_dir, temperature):
    torch.manual_seed(0)
    teacher = VGG(teacher_model)
    teacher.load_state_dict(torch.load("./models/{}.pt".format(teacher_dir)))
    student = VGG(student_model)
    if torch.cuda.is_available():
        teacher.cuda()
        student.cuda()

    logits = extract_logits(teacher, trainloader).cpu()
    kdtrain = torch.utils.data.TensorDataset(data, labels, logits)
    kdloader = torch.utils.data.DataLoader(kdtrain, batch_size=batch_size, shuffle=False, num_workers=2)

    criterion = kd_ce_loss(temperature = T, alpha = 0.8)
    optimizer = optim.Adam(student.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 70], gamma=0.1)
    train(student, kdloader, testloader, optimizer, criterion, 75, writer=None, scheduler=scheduler)
    torch.save(student.state_dict(), "./models/{}.pt".format(student_dir))
    del teacher, kdtrain, kdloader
    return student

T = 20

In [23]:
student = distill_models("VGG16", "VGG13", "vgg16", "vgg13_vgg16d_t{}".format(T), T)

HBox(children=(FloatProgress(value=0.0, max=75.0), HTML(value='')))

Epoch 1 accuracy = 35.77%
Epoch 2 accuracy = 56.03%
Epoch 3 accuracy = 63.57%
Epoch 4 accuracy = 67.64%
Epoch 5 accuracy = 70.22%
Epoch 6 accuracy = 71.03%
Epoch 7 accuracy = 76.10%
Epoch 8 accuracy = 76.54%
Epoch 9 accuracy = 79.04%
Epoch 10 accuracy = 77.40%
Epoch 11 accuracy = 78.66%
Epoch 12 accuracy = 81.06%
Epoch 13 accuracy = 81.81%
Epoch 14 accuracy = 81.50%
Epoch 15 accuracy = 80.50%
Epoch 16 accuracy = 82.16%
Epoch 17 accuracy = 82.67%
Epoch 18 accuracy = 81.10%
Epoch 19 accuracy = 81.68%
Epoch 20 accuracy = 81.40%
Epoch 21 accuracy = 82.71%
Epoch 22 accuracy = 81.73%
Epoch 23 accuracy = 80.48%
Epoch 24 accuracy = 83.06%
Epoch 25 accuracy = 83.31%
Epoch 26 accuracy = 84.83%
Epoch 27 accuracy = 84.94%
Epoch 28 accuracy = 85.03%
Epoch 29 accuracy = 85.19%
Epoch 30 accuracy = 85.31%
Epoch 31 accuracy = 85.39%
Epoch 32 accuracy = 85.45%
Epoch 33 accuracy = 85.48%
Epoch 34 accuracy = 85.48%
Epoch 35 accuracy = 85.49%
Epoch 36 accuracy = 85.48%
Epoch 37 accuracy = 85.56%
Epoch 38 a

In [24]:
student = distill_models("VGG13", "VGG11", "vgg13_vgg16d_t{}".format(T), "vgg11_vgg13_vgg16d_t{}".format(T), T)

HBox(children=(FloatProgress(value=0.0, max=75.0), HTML(value='')))

Epoch 1 accuracy = 43.94%
Epoch 2 accuracy = 58.39%
Epoch 3 accuracy = 65.74%
Epoch 4 accuracy = 68.89%
Epoch 5 accuracy = 71.39%
Epoch 6 accuracy = 70.19%
Epoch 7 accuracy = 75.72%
Epoch 8 accuracy = 70.01%
Epoch 9 accuracy = 75.26%
Epoch 10 accuracy = 77.73%
Epoch 11 accuracy = 74.14%
Epoch 12 accuracy = 77.09%
Epoch 13 accuracy = 79.22%
Epoch 14 accuracy = 77.84%
Epoch 15 accuracy = 78.76%
Epoch 16 accuracy = 78.93%
Epoch 17 accuracy = 79.21%
Epoch 18 accuracy = 79.83%
Epoch 19 accuracy = 78.33%
Epoch 20 accuracy = 79.52%
Epoch 21 accuracy = 79.87%
Epoch 22 accuracy = 77.69%
Epoch 23 accuracy = 78.86%
Epoch 24 accuracy = 80.35%
Epoch 25 accuracy = 80.27%
Epoch 26 accuracy = 81.94%
Epoch 27 accuracy = 82.19%
Epoch 28 accuracy = 82.30%
Epoch 29 accuracy = 82.37%
Epoch 30 accuracy = 82.43%
Epoch 31 accuracy = 82.53%
Epoch 32 accuracy = 82.59%
Epoch 33 accuracy = 82.67%
Epoch 34 accuracy = 82.69%
Epoch 35 accuracy = 82.79%
Epoch 36 accuracy = 82.87%
Epoch 37 accuracy = 83.12%
Epoch 38 a

In [25]:
student = distill_models("VGG11", "DistilVGG", "vgg11_vgg13_vgg16d_t{}".format(T), "distillvgg_vgg11_vgg13_vgg16d_t{}".format(T), T)

HBox(children=(FloatProgress(value=0.0, max=75.0), HTML(value='')))

Epoch 1 accuracy = 59.01%
Epoch 2 accuracy = 65.24%
Epoch 3 accuracy = 68.02%
Epoch 4 accuracy = 72.48%
Epoch 5 accuracy = 67.12%
Epoch 6 accuracy = 74.01%
Epoch 7 accuracy = 73.25%
Epoch 8 accuracy = 72.41%
Epoch 9 accuracy = 73.50%
Epoch 10 accuracy = 77.66%
Epoch 11 accuracy = 77.15%
Epoch 12 accuracy = 75.55%
Epoch 13 accuracy = 77.62%
Epoch 14 accuracy = 77.34%
Epoch 15 accuracy = 77.43%
Epoch 16 accuracy = 78.16%
Epoch 17 accuracy = 77.89%
Epoch 18 accuracy = 77.67%
Epoch 19 accuracy = 76.18%
Epoch 20 accuracy = 76.86%
Epoch 21 accuracy = 77.20%
Epoch 22 accuracy = 78.01%
Epoch 23 accuracy = 77.42%
Epoch 24 accuracy = 78.29%
Epoch 25 accuracy = 77.75%
Epoch 26 accuracy = 80.11%
Epoch 27 accuracy = 80.15%
Epoch 28 accuracy = 80.23%
Epoch 29 accuracy = 80.30%
Epoch 30 accuracy = 80.35%
Epoch 31 accuracy = 80.35%
Epoch 32 accuracy = 80.47%
Epoch 33 accuracy = 80.54%
Epoch 34 accuracy = 80.54%
Epoch 35 accuracy = 80.63%
Epoch 36 accuracy = 80.62%
Epoch 37 accuracy = 80.69%
Epoch 38 a