# Image augmentations

In this notebook we will apply popular augmentation techniques and reduce overfitting

**Goal.** The goal of this notebook is to develop the basic skills in image processing.

You need the following extra libraries beyond PyTorch:
* torchvision
* opencv-python (cv2)
* Pillow (PIL)
* (optional) albumentations

In [None]:
# Uncomment to install PyTorch Lightning.
# ! pip install pytorch_lightning

In [None]:
import random
import pytorch_lightning as pl
import torch
import torchvision
import numpy as np

import cv2

from matplotlib import pyplot as plt
from PIL import Image

# Below are helper tools. You can skip this block.

DATA_ROOT = "cifar10"

def show_images_dataset(dataset, n=5, collate_fn=lambda pair: pair[0]):
    images = [collate_fn(random.choice(dataset)) for _ in range(n)]
    grid = torchvision.utils.make_grid(images)
    grid -= grid.min()
    grid /= grid.max()
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

def show_augmenter_results(augmenter):
    trainset_notransform = torchvision.datasets.CIFAR10(DATA_ROOT, train=True, download=True)
    pil_image = random.choice(trainset_notransform)[0]
    transform = torchvision.transforms.Compose([
        augmenter,
        torchvision.transforms.ToTensor()
    ])
    show_images_dataset([pil_image], collate_fn=transform)

# Simple training

In [None]:
class Data(pl.LightningDataModule):
    def __init__(self, num_workers=4, batch_size=32, augmenter=None):
        super().__init__()
        self.num_workers = num_workers
        self.batch_size = batch_size
        basic_transforms = [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ]
        # Apply augmentation to the training set only.
        train_transforms = [augmenter] if augmenter is not None else []
        self.train_transform = torchvision.transforms.Compose(train_transforms + basic_transforms)
        self.test_transform = torchvision.transforms.Compose(basic_transforms)
        
    def train_dataloader(self):
        dataset = torchvision.datasets.CIFAR10(root=DATA_ROOT, train=True, download=True,
                                               transform=self.train_transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            drop_last=True,  # Drop the truncated last batch during training.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

    # Validate on test.
    def val_dataloader(self):
        return self.test_dataloader()

    def test_dataloader(self):
        dataset = torchvision.datasets.CIFAR10(root="cifar10", train=False, download=True,
                                               transform=self.test_transform)
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,  # The number of images in the batch.
            num_workers=self.num_workers,  # The number of concurrent readers and preprocessors.
            pin_memory=torch.cuda.is_available(),  # Optimize CUDA data transfer.
        )

data_module = Data()
x, y = next(iter(data_module.test_dataloader()))  # Test loader.

**Model.** We will use a standard ResNet implementation from torchvision. Other popular repositories with vision models include:
1. [transformers](https://huggingface.co/docs/transformers/index)
2. [pretrained-models](https://github.com/cadene/pretrained-models.pytorch)

In [None]:
from torch import nn
from collections import defaultdict

class Module(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.model = torchvision.models.resnet18()
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
        self.reset_metrics()
        
    def reset_metrics(self):
        self.metrics = defaultdict(lambda: defaultdict(list))

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def common_step(self, batch, split="train"):
        x, y = batch
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.metrics["loss"][split].append(loss.item())
        with torch.no_grad():
            predictions = logits.argmax(1)  # (B).
            correct = (predictions == y).sum().item()
            accuracy = correct / y.numel()
            self.metrics["accuracy"][split].append(accuracy)
        return loss

    # Process batch and compute the loss.
    def training_step(self, batch):
        return self.common_step(batch)

    def validation_step(self, batch):
        self.common_step(batch, split="val")

    # Logging tools.
    def common_log(self, split="train"):
        for metric in self.metrics:
            self.logger.experiment.add_scalars(metric, {split: np.mean(self.metrics[metric][split])}, self.global_step)
            del self.metrics[metric][split]

    def on_train_epoch_end(self):
        self.common_log()

    def on_validation_epoch_end(self):
        self.common_log(split="val")


model = Module()
data = Data()
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="default"))
trainer.fit(model, data)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

**Analysis**
1. Validation accuracy reaches plateau, but training accuracy continues to grow.
2. Validation loss starts to grow after a few epochs.
3. We see *overfitting*.

# Augmentation with OpenCV

OpenCV is an opensource library with a variety of tools for image processing and feature extraction. We will use its image processing routines.

We will manually implement the following augmentations:
1. Random horizontal flip
2. Rotation / Scale / Offset
3. Random crop
4. Brightness / Contrast
5. Blur

# Assignment 1: Random horizontal flip

Please, implement the transformation which flips an input image with a probability of $0.5$.

In [None]:
class FlipAugmenter(object):
    def __call__(self, image):
        image = np.array(image)  # PIL -> Numpy.
        h, w, c = image.shape
        assert c == 3

        # Please, flip the original image with a probability of 50%.
        #
        # The beggining of your code.
        new_image = ...
        # The end of your code.
        
        return Image.fromarray(new_image)  # Numpy -> PIL.

show_augmenter_results(FlipAugmenter())

# Random rotation

**Affine transformation**

An affine transformation linearly moves each pixel of an image. A linear transformation is defined as:
```
y = Ax + b,
```
where $A$ is a square affine transformation matrix and $b$ is an offset. The matrix $A$ affects rotation, shear transform, and reflection. The bias vector $b$ defines an offset.

For example, the following transform:

$
A = \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix},
$

$
b = \begin{bmatrix} 2 \\ 2 \end{bmatrix},
$

will swap $x$ and $y$ coordinates of each pixel, i.e. reflects an image over a diagonal, and moves an image by 2 pixels to the bottom right.

Another example is scaling. The following operator scales an image by a factor of 2:

$
A = \begin{bmatrix} 2 & 0 \\ 0 & 2 \end{bmatrix}.
$

**Homogeneous coordinates**

In practice the both matrix $A$ and vector $b$ are stored in a single 2 x 3 matrix:

$
A' = \begin{bmatrix} A; b \end{bmatrix},
$

and it is assumed that an input vector contains an extra element equal to $1$:

$
x' = \begin{bmatrix} x \\ 1 \end{bmatrix}.
$

This way, both transforms are identical:

$
A x + b = A' x'
$

**Example transform with OpenCV**

We use the ```warpAffine``` method to apply the affine transform to the image.

In [None]:
A = np.array([
    [0, 1, 2],
    [1, 0, 2]
]).astype(np.float32)
image = data_module.test_dataloader().dataset.data[10]
image2 = cv2.warpAffine(image, A, (32, 32))
plt.imshow(np.concatenate([image, image2], 1))

# Assignment 2: Random rotation

Please, fill the following block and implement a random rotation augmenter.

Cheat sheet:
```python
cv2.getRotationMatrix2D(center, angle, scale)  # center: (x, y), angle: degrees, scale: scalar.

cv2.warpAffine(src, M, dsize[, dst[, flags[, borderMode[, borderValue]]]])
```

Where `dsize` is tuple (w, h) with the desired output image size.

In [None]:
class AffineAugmenter(object):
    def __init__(self, max_angle=45, min_scale=0.9, max_scale=1.1, max_offset=0.1):
        self._max_angle = max_angle
        self._min_scale = min_scale
        self._max_scale = max_scale
        self._max_offset = max_offset
    
    def __call__(self, image):
        image = np.array(image)  # PIL -> Numpy.
        h, w, c = image.shape
        assert c == 3
        
        angle = random.random() * 2 * self._max_angle - self._max_angle
        scale = self._min_scale + random.random() * (self._max_scale - self._min_scale)
        x_offset = random.randint(-int(self._max_offset * w), int(self._max_offset * w))
        y_offset = random.randint(-int(self._max_offset * h), int(self._max_offset * h))

        # Use OpenCV to transform the image with the angle, scale, and offset, defined above.
        # It is prefered to use the gray background.
        #
        # The beggining of your code.

        new_image = ...
        
        # The end of your code.
        
        return Image.fromarray(new_image)  # Numpy -> PIL.


show_augmenter_results(AffineAugmenter())

# Assignment 3: Random crop

Random crop extracts a small fragment of the original image. Please, fill the following code.

Cheat sheet:
```(python)
cv2.resize(src, dsize[, dst[, fx[, fy[, interpolation]]]])
```

Where `dsize` is tuple (w, h) with the desired output image size.

In [None]:
class CropAugmenter(object):
    def __init__(self, min_scale=0.8):
        self._min_scale = min_scale
    
    def __call__(self, image):
        image = np.array(image)  # PIL -> Numpy.
        h, w, c = image.shape
        assert c == 3
        scale = self._min_scale + random.random() * (1 - self._min_scale)
        new_w = int(scale * w)
        new_h = int(scale * h)
        x = random.randint(0, w - new_w)
        y = random.randint(0, h - new_h)

        # Please, create a fragment of the original image,
        # defined by the offset (x, y) and size (new_w, new_h).
        #
        # The beggining of your code.

        new_image = ...

        # The end of your code.
        
        return Image.fromarray(new_image)  # Numpy -> PIL.

show_augmenter_results(CropAugmenter())

# Assignment 4: Changing brightness & contrast

Apply the following trasform (like in lecture slides):

$
f(x) = c (x - 128) + 128 + b
$

Keep the brightness and parameters ranges in mind. Also pay attention to data types.

In [None]:
class BrightnessContrastAugmenter(object):
    def __init__(self, brightness=0.3, contrast=0.3):
        self._brightness = brightness
        self._contrast = contrast
    
    def __call__(self, image):
        image = np.array(image)  # PIL -> Numpy.
        h, w, c = image.shape
        assert c == 3
        brightness = 2 * (random.random() - 0.5) * self._brightness  # In the range [-1, 1].
        contrast = 1 + 2 * (random.random() - 0.5) * self._contrast  # In the range [0, 2].

        # Apply the brightness and contrast defined above.
        #
        # The beggining of your code.
        
        new_image = ...
        
        # The end of your code.
        
        assert new_image.dtype == np.uint8
        return Image.fromarray(new_image)  # Numpy -> PIL.
    
show_augmenter_results(BrightnessContrastAugmenter())

# Assignment 5: Gaussian blur

Cheat sheet:
```(python)
cv2.GaussianBlur(src, ksize, sigmaX[, dst[, sigmaY[, borderType]]])

# ksize: (w, h)
# sigmaX: scalar
```

In [None]:
class BlurAugmenter(object):
    def __init__(self, max_kernel=5):
        self._max_kernel = max_kernel
    
    def __call__(self, image):
        kernel = random.randint(0, self._max_kernel // 2) * 2 + 1
        if kernel == 1:
            return image
        image = np.array(image)  # PIL -> Numpy.
        h, w, c = image.shape
        assert c == 3
        
        # The beggining of your code.
        
        new_image = ...
        
        # The end of your code.
        
        return Image.fromarray(new_image)  # Numpy -> PIL.
    
show_augmenter_results(BlurAugmenter())

# Augmented training

In [None]:
class RandomAugmentation(object):
    def __init__(self, *augmenters):
        self._augmenters = list(augmenters)
        
    def __call__(self, image):
        augmenter = random.choice(self._augmenters)
        return augmenter(image)
    
augmenter = RandomAugmentation(FlipAugmenter(),
                               AffineAugmenter(),
                               CropAugmenter(),
                               BrightnessContrastAugmenter(),
                               BlurAugmenter())

show_augmenter_results(augmenter)

In [None]:
model = Module()
data = Data(augmenter=augmenter)
trainer = pl.Trainer(max_epochs=10,
                     logger=pl.loggers.TensorBoardLogger("lightning_logs", name="augmented"))
trainer.fit(model, data)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

**Discussion**
* The training set accuracy decreased and became similar to the validation quality.
* There is no overfitting (because of BatchNorm, validation loss and accuracy are even better).
* Validation accuracy increased, compared to the default training. Augmentations improve the final accuracy.

# Simple augmentation with Albumentations

Albumentations library provides a ready-to-use set of popular image augmentations.

Albumentations architecture slightly differs from torchvision:
* In torchvision each transformation accepts an image and returns an image.
* In albumentations each transformation accepts kwargs and returns a dictionary.
* This way albumentations can also process labels, i.e., segmentation maps.
* Albumentations work with Numpy tensors rather than PIL images.

In [None]:
import albumentations

augmenter = albumentations.Compose([
    albumentations.ShiftScaleRotate(rotate_limit=0.25, p=0.7),
    albumentations.RandomBrightnessContrast(p=0.4),
    albumentations.RandomGamma(p=0.4),
    albumentations.Blur(blur_limit=3, p=0.1),
    albumentations.GaussNoise((10, 100), p=0.2),
    albumentations.HorizontalFlip(p=0.5)
])

show_augmenter_results(lambda image: augmenter(image=np.array(image))["image"])

# Homework (optional)
* Try to add the Gamma correction augmenter and additive noise
* Adjust hyperparameters to achieve a better test set quality