In [1]:
%load_ext autoreload
%autoreload 2

import datetime

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim.swa_utils import SWALR

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()
    
seed = int(datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
torch.manual_seed(seed)
print(f'seed: {seed}')

Using cuda
seed: 172800


### Two-phase training: SGD + SWA

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) for phase 1 (SGD w/ Nesterov momentum). 

For the SWA phase, we follow [Izmailov et al.](http://arxiv.org/abs/1803.05407) and use a constant learning rate that is 10x smaller than the initial learning rate in the first phase.

In [2]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 10

# swa_scheduler
swa_lr = lr0 / 10

# training epochs
epochs_sgd, epochs_swa = 10, 5

# loaders
train_loader, valid_loader = get_dataloaders(dataset="cifar10", 
                                             batch_size=batch_size,
                                             num_workers=4,  
                                             test=False,
                                             pin_memory=cuda)

# model
resnet8 = ResNet(in_kernel_size=3, 
                 stack_sizes=[1, 1, 1], 
                 n_classes=10, 
                 batch_norm=False).to(device)

# optimizer and schedulers
optimizer = torch.optim.SGD(resnet8.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=True)

scheduler = LinearPolyLR(optimizer, 
                         alpha=alpha, 
                         decay_epochs=decay_epochs)

# instantiating the SWALR too early seems to have a negative effect on SGD phase (?)
# swa_scheduler = SWALR(optimizer, 
#                       swa_lr=swa_lr, 
#                       anneal_strategy='linear', 
#                       anneal_epochs=0)

swalr_kwargs = {
    'swa_lr': swa_lr, 
    'anneal_strategy': 'linear', 
    'anneal_epochs': 0
}

trainer = Trainer(resnet8, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer,
                  scheduler=scheduler,
                  # swa_scheduler=swa_scheduler,
                  swa_scheduler=swalr_kwargs,
                  device=device,
                  name='swa',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd, 
        epochs_swa=epochs_swa)

Files already downloaded and verified
Using RandomSampler
Number of training batches: 176

tensor([[[-0.0474,  0.0416,  0.0171],
         [ 0.0755,  0.0522, -0.0476],
         [ 0.0778,  0.2243,  0.0656]],

        [[ 0.0514, -0.0160,  0.1759],
         [ 0.1962, -0.1677,  0.1313],
         [ 0.1330,  0.0804, -0.0392]],

        [[ 0.1299,  0.0665,  0.1465],
         [ 0.0598,  0.0056, -0.0290],
         [ 0.0085, -0.2033,  0.0930]]], grad_fn=<ToCopyBackward0>)
Starting 15-epoch training loop...
Random seed: 16738633956411264589

SGD epochs: 10 | Codistillation epochs: 0 | SWA epochs: 5
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 -- Accuracy: 0.282267 -- Avg. loss (cross_entropy): 1.919128 -- Batch: 176/176 (100%) -- Total steps: 176
Validation accuracy: 0.415407 -- Avg. loss (cross_entropy): 1.609617 -- Batch: 20/20 (100%)

Train epoch: 2 -- Accuracy: 0.478725 -- Avg. loss (cross_entropy): 1.445861 -- Batch: 176/176 (100%) -- Total steps: 352
Validation accu

In [3]:
# difference in validation accuracy between end of SWA and SGD phases
print(f'SWA validation accuracy gain: {trainer.valid_acc - trainer.valid_accs[epochs_sgd - 1]:.4f}')

SWA validation accuracy gain: 0.0158


### Validation accuracy by target class

In [4]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 74%
Accuracy for class: plane is 76.9 %
Accuracy for class: car   is 89.3 %
Accuracy for class: bird  is 64.7 %
Accuracy for class: cat   is 54.1 %
Accuracy for class: deer  is 72.4 %
Accuracy for class: dog   is 65.2 %
Accuracy for class: frog  is 80.7 %
Accuracy for class: horse is 76.7 %
Accuracy for class: ship  is 84.0 %
Accuracy for class: truck is 85.0 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,310,2,24,44,24,25,13,29,4,4
car,2,436,4,1,4,0,5,9,2,25
cat,33,6,281,41,89,31,13,13,2,10
deer,35,1,22,344,14,22,24,11,1,1
dog,16,0,86,23,345,15,32,6,3,3
frog,25,3,21,25,16,394,1,2,0,1
horse,15,2,20,38,24,5,372,3,2,4
plane,27,12,9,10,2,2,4,379,31,17
ship,2,11,5,3,1,2,3,44,447,14
truck,3,34,7,6,4,1,4,11,7,435
