In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim.swa_utils import SWALR

from swadist.data.loader import get_dataloaders
from swadist.utils import Trainer, show_imgs, LinearPolyLR
from swadist.models.resnet import ResNet

plt.rcParams["savefig.bbox"] = 'tight'
torch.multiprocessing.set_sharing_strategy('file_system')

torch.cuda.device_count()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}')

# load CIFAR-10
train_loader, valid_loader = get_dataloaders(
    dataset="cifar10", num_workers=4, batch_size=16, cuda=device=='cuda', test=False
)

Using cuda
Files already downloaded and verified


### Two-phase training: SGD + SWA

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) for phase 1 (SGD w/ Nesterov momentum). 

For the SWA phase, we follow [Izmailov et al.](http://arxiv.org/abs/1803.05407) and use a constant learning rate that is 10x smaller than the initial learning rate in the first phase.

In [2]:
# initial lr, scaling factor, momentum, SWA lr
lr0, alpha, gamma = 2**-8.5, 0.25, 0.97
swa_lr = lr0 / 10

# epochs, lr scaling epochs
epochs, decay_epochs, epochs_swa = 10, 8, 5

# model
resnet8 = ResNet(in_kernel_size=3, stack_sizes=[1, 1, 1], n_classes=10, batch_norm=False)

optimizer = torch.optim.SGD(resnet8.parameters(), lr=lr0, momentum=gamma, nesterov=True)
scheduler = LinearPolyLR(optimizer, alpha=alpha, decay_epochs=decay_epochs, verbose=True)
swa_scheduler = SWALR(optimizer, anneal_strategy='linear', anneal_epochs=0, swa_lr=swa_lr)

trainer = Trainer(resnet8, train_loader, valid_loader, F.cross_entropy, optimizer, scheduler=scheduler, swa_scheduler=swa_scheduler, 
                  device=device, log=True, name='swa')

# begin training
trainer(epochs=epochs, epochs_swa=epochs_swa, validations_per_epoch=4)

Adjusting learning rate of group 0 to 2.7621e-03.
Starting 15-epoch training loop...

Train epoch: 1 -- Accuracy: 0.300347 -- Avg. loss (cross_entropy): 1.857402 -- Batch: 2813/2813 (100%) -- Total steps: 2813
Validation accuracy: 0.365815 -- Avg. loss (cross_entropy): 1.682510 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.5032e-03.
Train epoch: 2 -- Accuracy: 0.432723 -- Avg. loss (cross_entropy): 1.540084 -- Batch: 2813/2813 (100%) -- Total steps: 5626
Validation accuracy: 0.542732 -- Avg. loss (cross_entropy): 1.289733 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.2442e-03.
Train epoch: 3 -- Accuracy: 0.570165 -- Avg. loss (cross_entropy): 1.193615 -- Batch: 2813/2813 (100%) -- Total steps: 8439
Validation accuracy: 0.603834 -- Avg. loss (cross_entropy): 1.101158 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 1.9853e-03.
Train epoch: 4 -- Accuracy: 0.635265 -- Avg. loss (cross_entropy): 1.032278 -- Batch: 2813/2813 (100%) -- T

In [3]:
# difference in validation accuracy between end of SWA and SGD phases
print(f'SWA validation accuracy gain: {trainer.valid_acc - trainer.valid_accs[epochs - 1]:.4f}')

SWA validation accuracy gain: 0.0258


In [5]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 75%
Accuracy for class: plane is 83.2 %
Accuracy for class: car   is 87.9 %
Accuracy for class: bird  is 61.2 %
Accuracy for class: cat   is 54.9 %
Accuracy for class: deer  is 74.3 %
Accuracy for class: dog   is 65.6 %
Accuracy for class: frog  is 82.0 %
Accuracy for class: horse is 71.5 %
Accuracy for class: ship  is 89.1 %
Accuracy for class: truck is 88.3 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,293,2,22,38,23,36,12,38,12,3
car,1,429,2,0,5,2,1,8,7,33
cat,22,1,285,26,88,57,8,11,10,11
deer,25,1,19,353,8,21,24,16,7,1
dog,27,0,75,21,347,21,28,6,3,1
frog,20,2,22,20,17,400,3,0,2,2
horse,16,1,25,54,25,3,347,6,1,7
plane,13,1,6,9,2,6,2,410,27,17
ship,5,9,2,3,0,2,1,27,474,9
truck,1,32,3,4,0,2,2,10,6,452
