In [1]:
%load_ext autoreload
%autoreload 2

import datetime

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()
    
seed = (datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds()
torch.manual_seed(seed)
print(f'seed: {int(seed)}')

Using cuda
seed: 172800


### SGD training

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) via SGD with Nesterov momentum and a linearly-decreasing learning rate schedule.

In [2]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 10

# swa_scheduler
swa_lr = lr0 / 10

# training epochs
epochs_sgd = 15

# loaders
train_loader, valid_loader = get_dataloaders(dataset="cifar10",
                                             batch_size=batch_size, 
                                             num_workers=4, 
                                             test=False,
                                             pin_memory=cuda)

# keep starting params consistent across runs
torch.manual_seed(seed)

# model
resnet8 = ResNet(in_kernel_size=3, 
                 stack_sizes=[1, 1, 1], 
                 n_classes=10, 
                 batch_norm=False)

# optimizer and scheduler
optimizer = torch.optim.SGD(resnet8.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=True)

scheduler = LinearPolyLR(optimizer, 
                         alpha=alpha, 
                         decay_epochs=decay_epochs)

# trainer
trainer = Trainer(resnet8, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer, 
                  scheduler=scheduler, 
                  device=device,
                  name='sgd',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd)

Files already downloaded and verified
Using RandomSampler
Number of training batches: 176

tensor([[[-0.0474,  0.0416,  0.0171],
         [ 0.0755,  0.0522, -0.0476],
         [ 0.0778,  0.2243,  0.0656]],

        [[ 0.0514, -0.0160,  0.1759],
         [ 0.1962, -0.1677,  0.1313],
         [ 0.1330,  0.0804, -0.0392]],

        [[ 0.1299,  0.0665,  0.1465],
         [ 0.0598,  0.0056, -0.0290],
         [ 0.0085, -0.2033,  0.0930]]], grad_fn=<SelectBackward0>)
Starting 15-epoch training loop...
Random seed: 1129665589277747793

SGD epochs: 15 | Codistillation epochs: 0 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 -- Accuracy: 0.292054 -- Avg. loss (cross_entropy): 1.902774 -- Batch: 176/176 (100%) -- Total steps: 176
Validation accuracy: 0.400954 -- Avg. loss (cross_entropy): 1.589182 -- Batch: 20/20 (100%)

Train epoch: 2 -- Accuracy: 0.470129 -- Avg. loss (cross_entropy): 1.453287 -- Batch: 176/176 (100%) -- Total steps: 352
Validation accur

### Validation accuracy by target class

In [3]:
resnet8.to(device)

classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 75%
Accuracy for class: plane is 78.5 %
Accuracy for class: car   is 86.3 %
Accuracy for class: bird  is 64.5 %
Accuracy for class: cat   is 48.7 %
Accuracy for class: deer  is 71.4 %
Accuracy for class: dog   is 72.6 %
Accuracy for class: frog  is 76.6 %
Accuracy for class: horse is 79.6 %
Accuracy for class: ship  is 87.6 %
Accuracy for class: truck is 88.3 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,309,0,26,39,31,23,16,26,8,1
car,0,421,3,0,1,3,3,6,14,37
cat,34,1,253,26,126,34,19,8,9,9
deer,33,1,17,339,23,11,39,8,1,3
dog,27,1,63,18,384,12,17,2,3,2
frog,29,3,19,29,19,374,3,3,4,5
horse,14,0,16,30,28,1,386,2,4,4
plane,23,2,8,12,3,8,7,387,29,14
ship,5,8,7,2,1,1,3,26,466,13
truck,2,24,4,5,2,1,4,11,7,452
