In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
from torch.nn import functional as F
from torchvision.utils import make_grid

from swadist.data.loader import get_dataloaders
from swadist.utils import Trainer, show_imgs, LinearPolyLR
from swadist.models.resnet import ResNet

plt.rcParams["savefig.bbox"] = 'tight'
torch.multiprocessing.set_sharing_strategy('file_system')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device}')

# load CIFAR-10
train_loader, valid_loader = get_dataloaders(
    dataset="cifar10", num_workers=4, batch_size=16, cuda=device=='cuda', test=False
)

Using cuda
Files already downloaded and verified


### SGD training

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) via SGD with Nesterov momentum and a linearly-decreasing learning rate schedule.

In [2]:
# initial lr, scaling factor, momentum
lr0, alpha, gamma = 2**-8.5, 0.25, 0.97

# epochs, lr scaling epochs
epochs, decay_epochs = 15, 8

# model
resnet8 = ResNet(in_kernel_size=3, stack_sizes=[1, 1, 1], n_classes=10, batch_norm=False)

optimizer = torch.optim.SGD(resnet8.parameters(), lr=lr0, momentum=gamma, nesterov=True)
scheduler = LinearPolyLR(optimizer, alpha=alpha, decay_epochs=decay_epochs, verbose=True)

trainer = Trainer(resnet8, train_loader, valid_loader, F.cross_entropy, optimizer, scheduler=scheduler, device=device, log=True, name='sgd')

# begin training
trainer(epochs=epochs, validations_per_epoch=4)

Adjusting learning rate of group 0 to 2.7621e-03.
Starting 15-epoch training loop...

Train epoch: 1 -- Accuracy: 0.357825 -- Avg. loss (cross_entropy): 1.702656 -- Batch: 2813/2813 (100%) -- Total steps: 2813
Validation accuracy: 0.487021 -- Avg. loss (cross_entropy): 1.389354 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.5032e-03.
Train epoch: 2 -- Accuracy: 0.534838 -- Avg. loss (cross_entropy): 1.281296 -- Batch: 2813/2813 (100%) -- Total steps: 5626
Validation accuracy: 0.601238 -- Avg. loss (cross_entropy): 1.117795 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 2.2442e-03.
Train epoch: 3 -- Accuracy: 0.611425 -- Avg. loss (cross_entropy): 1.084190 -- Batch: 2813/2813 (100%) -- Total steps: 8439
Validation accuracy: 0.588458 -- Avg. loss (cross_entropy): 1.135230 -- Batch: 313/313 (100%)

Adjusting learning rate of group 0 to 1.9853e-03.
Train epoch: 4 -- Accuracy: 0.659061 -- Avg. loss (cross_entropy): 0.962223 -- Batch: 2813/2813 (100%) -- T

In [9]:
classes = np.array(['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = resnet8(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 75%
Accuracy for class: plane is 78.9 %
Accuracy for class: car   is 89.5 %
Accuracy for class: bird  is 66.6 %
Accuracy for class: cat   is 60.5 %
Accuracy for class: deer  is 65.3 %
Accuracy for class: dog   is 60.1 %
Accuracy for class: frog  is 78.7 %
Accuracy for class: horse is 84.1 %
Accuracy for class: ship  is 88.0 %
Accuracy for class: truck is 83.6 %


col_0,bird,car,cat,deer,dog,frog,horse,plane,ship,truck
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
bird,319,4,24,30,21,24,22,29,3,3
car,1,437,4,2,2,1,2,12,6,21
cat,25,2,314,23,76,35,22,16,2,4
deer,39,2,31,310,11,15,51,12,2,2
dog,24,2,93,19,318,19,48,4,1,1
frog,26,0,36,19,8,384,6,3,1,5
horse,12,2,23,18,14,3,408,3,0,2
plane,29,3,8,7,2,6,8,389,27,14
ship,5,9,7,1,1,2,3,28,468,8
truck,1,34,11,2,2,3,7,14,10,428
