In [1]:
# VGG-19: 85.48%
# VGG-16: 85.71%
# VGG-13: 85.52%
# VGG-11: 82.70%
# DistilVGG: 79.68%

In [2]:
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("model", help="VGG Model to train",
#                     type=str)
# parser.add_argument("device", help="Device to train on",
#                     type=str)
# args = parser.parse_args()

import os
if os.getcwd().split('/')[-1] == "notebooks":
    os.chdir('..')

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from src.train import train
from src.vgg import VGG
torch.manual_seed(0)
batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# vgg_model = args.model
for vgg_model in ["DistilVGG", "VGG13"]:
    model=VGG(vgg_model)
    if torch.cuda.is_available():
        model.cuda()
    #     model.to('cuda:{}'.format(args.device))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [25, 70, 105], gamma=0.1)
    train(model, trainloader, testloader, optimizer, criterion, 120, writer=None, scheduler=scheduler)
    torch.save(model.state_dict(), "./models/{}.pt".format(vgg_model.lower()))

Files already downloaded and verified
Files already downloaded and verified


HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))

Epoch 1 accuracy = 58.92%
Epoch 2 accuracy = 70.06%
Epoch 3 accuracy = 70.69%
Epoch 4 accuracy = 71.48%
Epoch 5 accuracy = 71.87%
Epoch 6 accuracy = 70.78%
Epoch 7 accuracy = 70.82%
Epoch 8 accuracy = 71.99%
Epoch 9 accuracy = 74.12%
Epoch 10 accuracy = 74.08%
Epoch 11 accuracy = 72.63%
Epoch 12 accuracy = 75.25%
Epoch 13 accuracy = 75.11%
Epoch 14 accuracy = 73.93%
Epoch 15 accuracy = 74.70%
Epoch 16 accuracy = 75.86%
Epoch 17 accuracy = 75.09%
Epoch 18 accuracy = 74.83%
Epoch 19 accuracy = 75.46%
Epoch 20 accuracy = 75.56%
Epoch 21 accuracy = 76.75%
Epoch 22 accuracy = 75.87%
Epoch 23 accuracy = 76.53%
Epoch 24 accuracy = 75.60%
Epoch 25 accuracy = 76.74%
Epoch 26 accuracy = 79.13%
Epoch 27 accuracy = 79.18%
Epoch 28 accuracy = 79.20%
Epoch 29 accuracy = 79.19%
Epoch 30 accuracy = 79.26%
Epoch 31 accuracy = 79.27%
Epoch 32 accuracy = 79.26%
Epoch 33 accuracy = 79.36%
Epoch 34 accuracy = 79.49%
Epoch 35 accuracy = 79.48%
Epoch 36 accuracy = 79.55%
Epoch 37 accuracy = 79.54%
Epoch 38 a

HBox(children=(FloatProgress(value=0.0, max=120.0), HTML(value='')))

Epoch 1 accuracy = 38.60%
Epoch 2 accuracy = 53.34%
Epoch 3 accuracy = 64.56%
Epoch 4 accuracy = 74.03%
Epoch 5 accuracy = 73.92%
Epoch 6 accuracy = 76.80%
Epoch 7 accuracy = 77.47%
Epoch 8 accuracy = 77.42%
Epoch 9 accuracy = 76.69%
Epoch 10 accuracy = 80.59%
Epoch 11 accuracy = 80.74%
Epoch 12 accuracy = 78.49%
Epoch 13 accuracy = 80.62%
Epoch 14 accuracy = 81.20%
Epoch 15 accuracy = 81.45%
Epoch 16 accuracy = 81.59%
Epoch 17 accuracy = 82.20%
Epoch 18 accuracy = 82.42%
Epoch 19 accuracy = 82.36%
Epoch 20 accuracy = 82.27%
Epoch 21 accuracy = 82.22%
Epoch 22 accuracy = 80.91%
Epoch 23 accuracy = 82.71%
Epoch 24 accuracy = 83.48%
Epoch 25 accuracy = 82.00%
Epoch 26 accuracy = 85.15%
Epoch 27 accuracy = 85.30%
Epoch 28 accuracy = 85.35%
Epoch 29 accuracy = 85.36%
Epoch 30 accuracy = 85.45%
Epoch 31 accuracy = 85.51%
Epoch 32 accuracy = 85.48%
Epoch 33 accuracy = 85.47%
Epoch 34 accuracy = 85.56%
Epoch 35 accuracy = 85.56%
Epoch 36 accuracy = 85.57%
Epoch 37 accuracy = 85.59%
Epoch 38 a