In [1]:
%load_ext autoreload
%autoreload 2

import datetime

import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
torch.manual_seed(seed)
print(f'seed: {seed}')

Using cuda
seed: 777600


### SGD training

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) via SGD with Nesterov momentum and a linearly-decreasing learning rate schedule.

In [3]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-5., 0.975

# scheduler
alpha, decay_epochs = 0.25, 15

# training epochs
epochs_sgd = 15

# loaders
train_loader, valid_loader = get_dataloaders(dataset="cifar10",
                                             batch_size=batch_size, 
                                             num_workers=4, 
                                             test=False,
                                             pin_memory=cuda)

# keep starting params consistent across runs
torch.manual_seed(seed)

# model
resnet8 = ResNet(in_kernel_size=3, 
                 stack_sizes=[1, 1, 1], 
                 n_classes=10, 
                 batch_norm=False)

# optimizer and scheduler
optimizer = torch.optim.SGD(resnet8.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=True)

scheduler = LinearPolyLR(optimizer, 
                         alpha=alpha, 
                         decay_epochs=decay_epochs)

# trainer
trainer = Trainer(resnet8, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer, 
                  scheduler=scheduler, 
                  device=device,
                  name='sgd',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd)

Using RandomSampler
Number of training samples: 45000
Number of training batches: 176

Starting 15-epoch training loop...
Random seed: 4991803466732119831

SGD epochs: 15 | Codistillation epochs: 0 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=1.924011 <> acc=0.286610 | Batch (size 200): 176/176 (100%) | Total steps: 176
Validation | cross_entropy=1.663748 <> accuracy=0.402309 | Batch: 20/20 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=1.462216 <> acc=0.467975 | Batch (size 200): 176/176 (100%) | Total steps: 352
Validation | cross_entropy=1.563266 <> accuracy=0.463971 | Batch: 20/20 (100%)

Train epoch: 3 | Metrics (epoch mean): cross_entropy=1.210520 <> acc=0.565708 | Batch (size 200): 176/176 (100%) | Total steps: 528
Validation | cross_entropy=1.096468 <> accuracy=0.612672 | Batch: 20/20 (100%)

Train epoch: 4 | Metrics (epoch mean): cross_entropy=1.034033 <> acc=0.632164 | Batch (size 20

### Validation accuracy by target class

In [3]:
resnet8.to(device)

classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 73%
Accuracy for class: plane is 77.7 %
Accuracy for class: car   is 83.2 %
Accuracy for class: bird  is 54.5 %
Accuracy for class: cat   is 58.4 %
Accuracy for class: deer  is 68.8 %
Accuracy for class: dog   is 64.8 %
Accuracy for class: frog  is 77.7 %
Accuracy for class: horse is 78.1 %
Accuracy for class: ship  is 88.2 %
Accuracy for class: truck is 86.3 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,261,1,28,63,30,27,18,38,6,7
car,0,406,4,2,6,0,3,13,14,40
cat,17,4,303,28,100,21,15,17,6,8
deer,20,1,35,327,22,17,36,12,2,3
dog,16,0,106,14,343,11,30,6,1,2
frog,19,0,28,25,25,379,4,2,4,2
horse,8,0,25,37,22,0,379,6,4,4
plane,14,4,9,8,4,3,7,383,42,19
ship,3,9,6,2,2,3,2,27,469,9
truck,2,23,10,2,3,2,2,12,14,442
