In [1]:
%load_ext autoreload
%autoreload 2

import datetime

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid
from torch.optim.swa_utils import SWALR

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
print(f'seed: {seed}')

Using cuda
seed: 2246400


### Two-phase training: SGD + SWA

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) for phase 1 (SGD w/ Nesterov momentum). 

For the SWA phase, we follow [Izmailov et al.](http://arxiv.org/abs/1803.05407) and use a constant learning rate that is 10x smaller than the initial learning rate in the first phase.

In [2]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 15

# swa_scheduler
swa_lr = lr0 / 10

# training epochs
epochs_sgd, epochs_swa = 10, 5

# loaders
train_loader, valid_loader = get_dataloaders(dataset="cifar10", 
                                             batch_size=batch_size,
                                             num_workers=4,  
                                             test=False,
                                             pin_memory=cuda)

# model
resnet8 = ResNet(in_kernel_size=3, 
                 stack_sizes=[1, 1, 1], 
                 n_classes=10, 
                 batch_norm=False).to(device)

# optimizer and schedulers
optimizer = torch.optim.SGD(resnet8.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=True)

scheduler = LinearPolyLR(optimizer, 
                         alpha=alpha, 
                         decay_epochs=decay_epochs)

# instantiating the SWALR too early seems to have a negative effect on SGD phase (?)
# swa_scheduler = SWALR(optimizer, 
#                       swa_lr=swa_lr, 
#                       anneal_strategy='linear', 
#                       anneal_epochs=0)

swalr_kwargs = {
    'swa_lr': swa_lr, 
    'anneal_strategy': 'cos', 
    'anneal_epochs': 3
}

trainer = Trainer(resnet8, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer,
                  scheduler=scheduler,
                  # swa_scheduler=swa_scheduler,
                  swa_scheduler=swalr_kwargs,
                  device=device,
                  name='swa',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd, 
        epochs_swa=epochs_swa)

Using RandomSampler
Number of training samples: 45000
Number of training batches: 176

Starting 15-epoch training loop...
Random seed: 10951562873798627444

SGD epochs: 10 | Codistillation epochs: 0 | SWA epochs: 5
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.989805 <> acc=0.255346 | Batch (size 200): 176/176 (100%) | Total steps: 176
Validation (batch mean) |  cross_entropy=1.670629 <> accuracy=0.379216 | Batch: 20/20 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.528005 <> acc=0.440542 | Batch (size 200): 176/176 (100%) | Total steps: 352
Validation (batch mean) |  cross_entropy=1.420598 <> accuracy=0.486340 | Batch: 20/20 (100%)

Train epoch: 3 | Metrics (epoch mean): cross_entropy=1.252594 <> acc=0.553341 | Batch (size 200): 176/176 (100%) | Total steps: 528
Validation (batch mean) |  cross_entropy=1.137977 <> accuracy=0.596381 | Batch: 20/20 (100%)

Train epoch: 4 | Metrics (epoch mean): cross_entrop

### Validation accuracy by target class

In [None]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion