<a href="https://colab.research.google.com/github/constantinpape/dl-teaching-resources/blob/main/exercises/classification/5_data_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Augmentation on CIFAR10

In this exercise we will use data augmentation to increase the available training data and thus improve the network training performance. We will use the same network architecture as in the previous exercise.

## Preparation

In [None]:
# load tensorboard extension
%load_ext tensorboard

In [None]:
# import torch and other libraries
import os
import numpy as np
import sklearn.metrics as metrics
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

In [None]:
!pip install cifar2png

In [None]:
# check if we have gpu support
# colab offers free gpus, however they are not activated by default.
# to activate the gpu, go to 'Runtime->Change runtime type'. 
# Then select 'GPU' in 'Hardware accelerator' and click 'Save'
have_gpu = torch.cuda.is_available()
# we need to define the device for torch, yadda yadda
if have_gpu:
    print("GPU is available")
    device = torch.device('cuda')
else:
    print("GPU is not available, training will run on the CPU")
    device = torch.device('cpu')

In [None]:
# run this in google colab to get the utils.py file
!wget https://raw.githubusercontent.com/constantinpape/training-deep-learning-models-for-vison/master/day1/utils.py 

In [None]:
# we will reuse the training function, validation function and
# data preparation from the previous notebook
import utils

In [None]:
cifar_dir = './cifar10'
!cifar2png cifar10 cifar10

In [None]:
categories = os.listdir('./cifar10/train')
categories.sort()

In [None]:
images, labels = utils.load_cifar(os.path.join(cifar_dir, 'train'))
(train_images, train_labels,
 val_images, val_labels) = utils.make_cifar_train_val_split(images, labels)

## Data Augmentation

The goal of data augmentation is to increase the amount of training data by transforming the input images in a way that they still resemble realistic images. Popular transformations used in data augmentation include rotations, image flips, color jitter or additive noise.
Here, we will start with two transformations:
- random flips along the vertical centerline
- random color jitters

In [None]:
# define random augmentations
import skimage.color as color

def random_flip(image, target, probability=.5):
    """ Randomly mirror the image across the vertical axis.
    """
    if np.random.rand() < probability:
      image = np.array([np.fliplr(im) for im in image])
    return image, target


def random_color_jitter(image, target, probability=.5):
  """ Randomly jitter the saturation, hue and brightness of the image.
  """
  if np.random.rand() > probability:
    # skimage expects WHC instead of CHW
    image = image.transpose((1, 2, 0))
    # transform image to hsv color space to apply jitter
    image = color.rgb2hsv(image)
    # compute jitter factors in range 0.66 - 1.5  
    jitter_factors = 1.5 * np.random.rand(3)
    jitter_factors = np.clip(jitter_factors, 0.66, 1.5)
    # apply the jitter factors, making sure we stay in correct value range
    image *= jitter_factors
    image = np.clip(image, 0, 1)
    # transform back to rgb and CHW
    image = color.hsv2rgb(image)
    image = image.transpose((2, 0, 1))
  return image, target

In [None]:
# create training dataset with augmentations
from functools import partial
train_trafos = [
    utils.to_channel_first,
    utils.normalize,
    random_color_jitter,
    random_flip,
    utils.to_tensor
]
train_trafos = partial(utils.compose, transforms=train_trafos)

train_dataset = utils.DatasetWithTransform(train_images, train_labels,
                                            transform=train_trafos)

# we don't use data augmentations for the validation set
val_dataset = utils.DatasetWithTransform(val_images, val_labels,
                                          transform=utils.get_default_cifar_transform())

In [None]:
# sample augmentations
def show_image(ax, image):
    # need to go back to numpy array and WHC axis order
    image = image.numpy().transpose((1, 2, 0))
    ax.imshow(image)

n_samples = 8
image_id = 0
fig, ax = plt.subplots(1, n_samples, figsize=(18, 4))
for sample in range(n_samples):
    image, _ = train_dataset[0]
    show_image(ax[sample], image)

In [None]:
# we reuse the model from the previous exercise
# if you want you can also use a different CNN architecture that
# you have designed in the tasks part of that exercise
model = utils.SimpleCNN(10)
model = model.to(device)

In [None]:
# instantiate loaders and optimizer and start tensorboard
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=25)
optimizer = Adam(model.parameters(), lr=1.e-3)
%tensorboard --logdir runs

In [None]:
# we have moved all the boilerplate for the full training procedure to utils now
n_epochs = 10
utils.run_cifar_training(model, optimizer,
                         train_loader, val_loader,
                         device=device, name='da1', 
                         n_epochs=n_epochs)

In [None]:
# evaluate the model on test data
test_dataset = utils.make_cifar_test_dataset(cifar_dir)
test_loader = DataLoader(test_dataset, batch_size=25)
predictions, labels = utils.validate(model, test_loader, nn.NLLLoss(),
                                     device, step=0, tb_logger=None)

In [None]:
print("Test accuracy:")
accuracy = metrics.accuracy_score(labels, predictions)
print(accuracy)

fig, ax = plt.subplots(1, figsize=(8, 8))
utils.make_confusion_matrix(labels, predictions, categories, ax)

## Normalization layers

In addition to convolutional layers and pooling layers, another important part of neural networks are normalization layers.

These layers keep their input normalized using a learned normalization. The first type of normalization introduced has been [BatchNorm](https://arxiv.org/abs/1502.03167), which we will now add to the CNN architecture from the previous exercise.

In [None]:
import torch.nn.functional as F

class CNNBatchNorm(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes

        # the convolutions
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3)
        # the pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        # the normalization layers
        self.bn1 = nn.BatchNorm2d(12)
        self.bn2 = nn.BatchNorm2d(24)

        # the fully connected part of the network
        # after applying the convolutions and poolings, the tensor
        # has the shape 24 x 6 x 6, see below
        self.fc = nn.Sequential(
            nn.Linear(24 * 6 * 6, 120),
            nn.ReLU(),
            nn.Linear(120, 60),
            nn.ReLU(),
            nn.Linear(60, self.n_classes)
        )
        self.activation = nn.LogSoftmax(dim=1)

    def apply_convs(self, x):
      # input image has shape 3 x  32 x 32
      x = self.pool(F.relu(self.bn1(self.conv1(x))))
      # shape after conv: 12 x 28 x 28
      # shape after pooling: 12 x 14 X 14
      x = self.pool(F.relu(self.bn2(self.conv2(x))))
      # shape after conv: 24 x 12 x 12
      # shape after pooling: 24 x 6 x 6
      return x
    
    def forward(self, x):
        x = self.apply_convs(x)
        x = x.view(-1, 24 * 6 * 6)
        x = self.fc(x)
        x = self.activation(x)
        return x

In [None]:
# instantiate model and optimizer
model = CNNBatchNorm(10)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=1.e-3)

In [None]:
n_epochs = 10
utils.run_cifar_training(model, optimizer,
                         train_loader, val_loader,
                         device=device, name='batch-norm', 
                         n_epochs=n_epochs)

In [None]:
model = utils.load_checkpoin("best_checkpoint_batch-norm.tar", model, optimizer)[0]

In [None]:
predictions, labels = utils.validate(model, test_loader, nn.NLLLoss(),
                                     device, step=0, tb_logger=None)

print("Test accuracy:")
accuracy = metrics.accuracy_score(labels, predictions)
print(accuracy)

fig, ax = plt.subplots(1, figsize=(8, 8))
utils.make_confusion_matrix(labels, predictions, categories, ax)

## Tasks and Questions

Tasks:
- Implement one or two additional augmentations and train the model again using these. You can use [the torchvision transformations](https://pytorch.org/docs/stable/torchvision/transforms.html) for inspiration.

Questions:
- Compare the model results in this exercise.
- Can you think of any transformations that make use of symmetries/invariances not present here but present in other kinds of images (e.g. biomedical images)?

Advanced:
- Check out the other [normalization layers available in pytorch](https://pytorch.org/docs/stable/nn.html#normalization-layers). Which layers could be beneficial to BatchNorm here? Try training with them and see if this improves performance further.