In [8]:
%load_ext autoreload
%autoreload 2

import datetime
import os

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.nn import functional as F
from torchvision.utils import make_grid

from swadist.data import get_dataloaders
from swadist.train import Trainer
from swadist.optim import LinearPolyLR
from swadist.models import ResNet

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
print(f'Using {device}')

if cuda:
    torch.cuda.empty_cache()

seed = int((datetime.date.today() - datetime.date(2022, 4, 11)).total_seconds())
torch.manual_seed(seed)
print(f'seed: {seed}')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using cuda
seed: 2419200


### SGD training

We train ResNet-8, using the optimal hyperparameters given in [Shallue et al. 2019](http://arxiv.org/abs/1811.03600) via SGD with Nesterov momentum and a linearly-decreasing learning rate schedule.

In [9]:
class SimpleCNN(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv_1 = nn.Conv2d(1, 64, kernel_size=3)
        self.conv_2 = nn.Conv2d(64, 64, kernel_size=3)
        self.bn_1 = nn.BatchNorm2d(64)
        self.conv_3 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv_4 = nn.Conv2d(128, 128, kernel_size=3)
        self.bn_2 = nn.BatchNorm2d(128)
        self.conv_5 = nn.Conv2d(128, 256, kernel_size=3)
        self.bn_3 = nn.BatchNorm2d(256)
        self.fc_1 = nn.Linear(256, 512)
        self.fc_2 = nn.Linear(512, 10)
    
    def forward(self, x):
        x = F.relu(self.conv_1(x))
        x = F.relu(self.conv_2(x))
        x = F.max_pool2d(x, kernel_size=2)
        # x = self.bn_1(x)

        x = F.relu(self.conv_3(x))
        x = F.relu(self.conv_4(x))
        x = F.max_pool2d(x, kernel_size=2)
        # x = self.bn_2(x)

        x = F.relu(self.conv_5(x))
        x = F.max_pool2d(x, kernel_size=2)
        # x = self.bn_3(x)
        x = torch.flatten(x, start_dim=1)
        # print(x.size())

        x = F.relu(self.fc_1(x))
        x = F.relu(self.fc_2(x))
        x = F.softmax(x, dim=-1)

        return x

In [10]:
%%time

# whether to log training to Tensorboard
log = False

batch_size = 256

# optimizer
lr0, momentum,  = 2**-8., 0.975

# scheduler
alpha, decay_epochs = 0.25, 15

# training epochs
epochs_sgd = 25

# loaders
train_loader, valid_loader = get_dataloaders(dataset="fashionmnist",
                                             batch_size=batch_size, 
                                             num_workers=4, 
                                             test=False,
                                             pin_memory=cuda)

# keep starting params consistent across runs
torch.manual_seed(seed)

# model
simplecnn = SimpleCNN()

# optimizer and scheduler
optimizer = torch.optim.SGD(simplecnn.parameters(), 
                            lr=lr0, 
                            momentum=momentum, 
                            nesterov=False)



# trainer
trainer = Trainer(simplecnn, 
                  train_loader, 
                  valid_loader, 
                  F.cross_entropy, 
                  optimizer, 
                #   scheduler=scheduler, 
                  device=device,
                  name='sgd',
                  log=log)

# begin training
trainer(epochs_sgd=epochs_sgd)

Using RandomSampler
Number of training samples: 54000
Number of training batches: 211

Starting 25-epoch training loop...
Random seed: 6498985492641146865

SGD epochs: 25 | Codistillation epochs: 0 | SWA epochs: 0
DistributedDataParallel: False
Stopping accuracy: None

Train epoch: 1 | Metrics (epoch mean): cross_entropy=2.302523 <> acc=0.099705 | Batch (size 240): 211/211 (100%) | Total steps: 211
Validation (batch mean) |  cross_entropy=2.302389 <> accuracy=0.102562 | Batch: 24/24 (100%)

Train epoch: 2 | Metrics (epoch mean): cross_entropy=2.302165 <> acc=0.116089 | Batch (size 240): 211/211 (100%) | Total steps: 422
Validation (batch mean) |  cross_entropy=2.301881 <> accuracy=0.191964 | Batch: 24/24 (100%)

Train epoch: 3 | Metrics (epoch mean): cross_entropy=2.300951 <> acc=0.296127 | Batch (size 240): 211/211 (100%) | Total steps: 633
Validation (batch mean) |  cross_entropy=2.299179 <> accuracy=0.390625 | Batch: 24/24 (100%)

Train epoch: 4 | Metrics (epoch mean): cross_entropy

### Validation accuracy by target class

In [11]:
simplecnn.to(device)

classes = np.array([
    'shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'])

correct = 0
total = 0

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

labels = []
preds = []

with torch.inference_mode():
    for data in valid_loader:
        images, target = data
        labels.append(target.numpy())
        outputs = simplecnn(images.to(device)).cpu()
        _, predictions = torch.max(outputs, 1)
        preds.append(predictions)
        total += target.size(0)
        correct += (predictions == target).sum().item()
        # collect the correct predictions for each class
        for label, prediction in zip(target, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1

print(f'Validation accuracy: {100 * correct // total}%')
            
# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

labels = pd.Series(np.hstack(labels).astype(int), name="Labels")
preds = pd.Series(np.hstack(preds).astype(int), name="Preds")
df_confusion = pd.crosstab(classes[labels], classes[preds])
df_confusion

Validation accuracy: 86%
Accuracy for class: shirt/top is 85.9 %
Accuracy for class: Trouser is 97.9 %
Accuracy for class: Pullover is 76.2 %
Accuracy for class: Dress is 82.1 %
Accuracy for class: Coat  is 77.4 %
Accuracy for class: Sandal is 95.8 %
Accuracy for class: Shirt is 62.8 %
Accuracy for class: Sneaker is 97.5 %
Accuracy for class: Bag   is 97.6 %
Accuracy for class: Ankle boot is 95.5 %


col_0,Ankle boot,Bag,Coat,Dress,Pullover,Sandal,Shirt,Sneaker,Trouser,shirt/top
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Ankle boot,591,0,0,0,0,6,0,22,0,0
Bag,2,601,5,0,3,0,4,0,0,1
Coat,0,4,490,14,66,0,55,0,3,1
Dress,0,1,24,497,11,0,17,0,18,37
Pullover,0,12,73,2,459,1,52,0,0,3
Sandal,5,3,0,0,0,566,0,17,0,0
Shirt,0,5,40,8,67,0,355,0,3,87
Sneaker,9,1,0,0,0,4,0,541,0,0
Trouser,0,2,2,6,1,0,0,0,572,1
shirt/top,0,8,0,11,18,1,51,0,0,541
