In [2]:
%load_ext autoreload
%autoreload 2

import datetime

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim.swa_utils import SWALR

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

Using cuda
seed: 691200


### Two-phase training: SGD + SWA

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) for phase 1 (SGD w/ Nesterov momentum). 

For the SWA phase, we follow [Izmailov et al.](http://arxiv.org/abs/1803.05407) and use a constant learning rate that is 10x smaller than the initial learning rate in the first phase.

In [3]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 15

# swa_scheduler
swa_lr = lr0 / 10

# training epochs
epochs_sgd, epochs_swa = 10, 5

# loaders
train_loader, valid_loader = get_dataloaders(dataset="cifar10", 
                                             batch_size=batch_size,
                                             num_workers=4,  
                                             test=False,
                                             pin_memory=cuda)

# model
resnet8 = ResNet(in_kernel_size=3, 
                 stack_sizes=[1, 1, 1], 
                 n_classes=10, 
                 batch_norm=False).to(device)

# optimizer and schedulers
optimizer = torch.optim.SGD(resnet8.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=True)

scheduler = LinearPolyLR(optimizer, 
                         alpha=alpha, 
                         decay_epochs=decay_epochs)

# instantiating the SWALR too early seems to have a negative effect on SGD phase (?)
# swa_scheduler = SWALR(optimizer, 
#                       swa_lr=swa_lr, 
#                       anneal_strategy='linear', 
#                       anneal_epochs=0)

swalr_kwargs = {
    'swa_lr': swa_lr, 
    'anneal_strategy': 'cos', 
    'anneal_epochs': 3
}

trainer = Trainer(resnet8, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer,
                  scheduler=scheduler,
                  # swa_scheduler=swa_scheduler,
                  swa_scheduler=swalr_kwargs,
                  device=device,
                  name='swa',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd, 
        epochs_swa=epochs_swa)

Using RandomSampler
Number of training samples: 45000
Number of training batches: 176

Starting 15-epoch training loop...
Random seed: 1210021667759833501

SGD epochs: 10 | Codistillation epochs: 0 | SWA epochs: 5
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.956578 <> acc=0.269387 | Batch (size 200): 176/176 (100%) | Total steps: 176
Validation | cross_entropy=1.658363 <> accuracy=0.383479 | Batch: 20/20 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.497479 <> acc=0.449078 | Batch (size 200): 176/176 (100%) | Total steps: 352
Validation | cross_entropy=1.308733 <> accuracy=0.518359 | Batch: 20/20 (100%)

Train epoch: 3 | Metrics (epoch mean): cross_entropy=1.245030 <> acc=0.549643 | Batch (size 200): 176/176 (100%) | Total steps: 528
Validation | cross_entropy=1.159701 <> accuracy=0.591625 | Batch: 20/20 (100%)

Train epoch: 4 | Metrics (epoch mean): cross_entropy=1.095957 <> acc=0.611321 | Batch (size 20

In [4]:
# difference in validation accuracy between end of SWA and SGD phases
print(f'SWA validation accuracy gain: {trainer.valid_acc - trainer.valid_accs[epochs_sgd - 1]:.4f}')

SWA validation accuracy gain: 0.0150


### Validation accuracy by target class

In [5]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 74%
Accuracy for class: plane is 76.7 %
Accuracy for class: car   is 88.3 %
Accuracy for class: bird  is 67.0 %
Accuracy for class: cat   is 54.7 %
Accuracy for class: deer  is 68.6 %
Accuracy for class: dog   is 65.6 %
Accuracy for class: frog  is 75.8 %
Accuracy for class: horse is 74.4 %
Accuracy for class: ship  is 86.8 %
Accuracy for class: truck is 85.2 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,321,2,24,35,23,28,8,29,6,3
car,3,431,2,2,2,2,1,8,12,25
cat,35,1,284,23,95,33,18,18,3,9
deer,32,2,20,326,27,19,27,16,4,2
dog,24,2,95,19,347,14,21,4,2,1
frog,31,4,30,25,17,370,2,2,5,2
horse,17,0,25,40,25,3,361,4,2,8
plane,25,9,3,12,3,3,9,378,40,11
ship,5,8,6,1,0,2,2,34,462,12
truck,3,34,6,3,7,1,3,11,8,436
