# CNN image classification (CIFAR10)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from torchutils import (
    mean_std_over_dataset,
    tensor2image,
    Classification,
    confusion_matrix
)

## Data import

In [None]:
data_path = pathlib.Path.home() / 'Data'

train_set = datasets.CIFAR10(
    data_path,
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

In [None]:
mean, std = mean_std_over_dataset(train_set, channel_wise=True)

preprocessor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [None]:
train_set = datasets.CIFAR10(
    data_path,
    train=True,
    transform=preprocessor,
    download=True
)

val_set = datasets.CIFAR10(
    data_path,
    train=False,
    transform=preprocessor,
    download=True
)

print('No. train images:', len(train_set))
print('No. val. images:', len(val_set))

In [None]:
batch_size = 128

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False
)

print('No. train batches:', len(train_loader))
print('No. val. batches:', len(val_loader))

In [None]:
images, labels = next(iter(train_loader))

print('Images shape:', images.shape)
print('Labels shape:', labels.shape)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    array = tensor2image(images[idx]) * std + mean
    ax.imshow(array.clip(0, 1))
    ax.set_title(train_set.classes[labels[idx]])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Model training

In [None]:
model = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5, 5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(5, 5), padding=2),
    nn.LeakyReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(in_features=8*8*32, out_features=512),
    nn.LeakyReLU(),
    nn.Linear(in_features=512, out_features=10)
)

In [None]:
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

classifier = Classification(
    model,
    criterion,
    optimizer,
    train_loader,
    val_loader
)

In [None]:
history = classifier.training(num_epochs=100, log_interval=50)

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(np.array(history['train_loss']), label='train', alpha=0.7)
ax.plot(np.array(history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim((0, history['num_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

In [None]:
train_loss, train_acc = classifier.test(train_loader)
val_loss, val_acc = classifier.test(val_loader)

print('Train loss: {:.4f}'.format(train_loss))
print('Val. loss: {:.4f}'.format(val_loss))
print('Train acc.: {:.4f}'.format(train_acc))
print('Val. acc.: {:.4f}'.format(val_acc))

In [None]:
confmat = confusion_matrix(classifier, val_loader)
print('Confusion matrix:\n{}'.format(confmat))