# CNN image classification (CIFAR10)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import sys
sys.path.append('..')

In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchutils import \
    mean_std_over_dataset, tensor2image, \
    Classification, confusion_matrix

## Data import

In [None]:
#%% preliminary import
data_path = pathlib.Path.home() / 'Data'
train_set = datasets.CIFAR10(data_path,
                             train=True,
                             transform=transforms.ToTensor(),
                             download=True)

In [None]:
#%% mean and std.
mean, std = mean_std_over_dataset(train_set, channel_wise=True)

#%% transformations
preprocessor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
#%% data sets
train_set = datasets.CIFAR10(data_path,
                             train=True,
                             transform=preprocessor,
                             download=True)
test_set = datasets.CIFAR10(data_path,
                            train=False,
                            transform=preprocessor,
                            download=True)
print('No. train images:', len(train_set))
print('No. test images:', len(test_set))

In [None]:
#%% data loaders
batch_size = 128
train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_set,
                         batch_size=batch_size,
                         shuffle=False)
print('No. train batches:', len(train_loader))
print('No. test batches:', len(test_loader))

In [None]:
#%% example images
images, labels = next(iter(train_loader))
print('Images shape:', images.shape)
print('Labels shape:', labels.shape)

In [None]:
#%% plot: example images
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5,3))
for idx, ax in enumerate(axes.ravel()):
    array = tensor2image(images[idx]) * std + mean
    ax.imshow(array.clip(0,1))
    ax.set_title(train_set.classes[labels[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()
fig.show()

## Model training

In [None]:
#%% model (small CNN)
model = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5,5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(in_features=8*8*32, out_features=512),
    nn.LeakyReLU(),
    nn.Linear(in_features=512, out_features=10)
)
print(model)

In [None]:
#%% problem specification
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
classifier = Classification(model,
                            criterion,
                            optimizer,
                            train_loader,
                            test_loader)

In [None]:
#%% training
history = classifier.training(no_epochs=10, log_interval=10)

In [None]:
#%% plot: training history
fig, ax = plt.subplots(figsize=(6,4))
ax.plot(np.array(history['train_loss']), label='training', alpha=0.7)
ax.plot(np.array(history['test_loss']), label='testing', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim([0, history['no_epochs']])
ax.legend()
ax.grid(b=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()
fig.show()

In [None]:
#%% final loss/accuracy
train_loss, train_acc = classifier.test(train_loader)
test_loss, test_acc = classifier.test(test_loader)
print('Train loss: {:.4f}'.format(train_loss))
print('Test loss: {:.4f}'.format(test_loss))
print('Train acc.: {:.4f}'.format(train_acc))
print('Test acc.: {:.4f}'.format(test_acc))

In [None]:
#%% confusion matrix
confmat = confusion_matrix(classifier, test_loader)
print('Confusion matrix:\n{}'.format(confmat))