# 3.3.2 Image Augmentation: Cutout, MixUp and CutMix
By Zac Todd

This tutorials covers the image augmenations included in the DeVries and Taylors work [Cutout](https://arxiv.org/abs/1708.04552), Zhang et al work [MixUp](https://arxiv.org/abs/1710.09412) and Yun et al work [CutMix](https://arxiv.org/abs/1905.04899). 

In [None]:
import os
import numpy as np
import cv2
from PIL import Image, ImageOps, ImageEnhance

IMAGES_DIR = f'{os.getcwd()}/resources'
IMAGE_1 = f'{IMAGES_DIR}/dog.jpg'
IMAGE_2 = f'{IMAGES_DIR}/cat.jpg'
IMAGE_3 = f'{IMAGES_DIR}/cat_on_dog.jpg'

Wrapper for enabling functions that take np.ndarray to take PIL.Image as input.

In [None]:
def _PIL_NUMPY(func):
    def wrapper(*args, **kwargs):
        new_args = [np.asarray(arg) if isinstance(arg, type(Image.Image())) else arg for arg in args]
        new_kwargs = {k: (np.asarray(arg) if isinstance(arg, type(Image.Image())) else arg) for k, arg in kwargs.items()}
        out_array = func(*new_args, **new_kwargs)
        out_image = Image.fromarray(np.uint8(out_array))
        return out_image
    return wrapper

## Cutout
Cutoout removes random propostion of images 

In [None]:
@_PIL_NUMPY
def cutout(image, holes, length):
    output = image.copy()
    h, w, _ = output.shape
    for _ in range(holes):
        x0, y0 = np.random.randint(w - length), np.random.randint(h - length)
        output[y0: y0 + length, x0:x0 + length] = 0
    return output
    
img = Image.open(IMAGE_1)
cutout_image = cutout(img, 10, 500)
cutout_image

Now instead of cuting out with black space try cuting out the image with unifrom noise. 
Hint look at *np.random.randint*.

In [None]:
def noisy_cutout(image, holes, length):
    output = image.copy()
    h, w, _ = output.shape
    for _ in range(holes):
        x0, y0 = np.random.randint(w - length), np.random.randint(h - length)
        output[y0:y0 + length, x0:x0 + length] = ...
    return output
    
img = Image.open(IMAGE_2)
img = np.asarray(img)
noisy_cutout_image = noisy_cutout(img, 10, 500)
noisy_cutout_image

## Mixup
Mix up in pratise works on both the images and onehot encoding. However, for the purposes of this tutorial we will only be looking at the mixup process for the images though the same process is appied to the lables.
Mixup works by sampling from Beta(alpha, alpha) and using sampled probabilty to determine the weights of the sum of the images and their labels being mixed up.

Run the cell below a few times and change the play around with the alpha value.

In [None]:
@_PIL_NUMPY
def mixup(image1, image2, alpha):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    output = lam * image1 + (1 - lam) * image2
    return output

img1 = Image.open(IMAGE_1)
img2 = Image.open(IMAGE_2)
img = np.asarray(img)
mixup_image = mixup(img1, img2, 0.5)
mixup_image

In the mixup function you will see that it only works on image of the same size rewrite the function so it works on images of different sizes.

In [None]:
@_PIL_NUMPY
def resized_mixup(image1, image2, alpha):
    resized_image1 = ...
    resized_image2 = ...
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1
    output = lam * resized_image1 + (1 - lam) * resized_image2
    return output

img1 = Image.open(IMAGE_1)
img3 = Image.open(IMAGE_3)
img = np.asarray(img)
resized_mixup = mixup(img1, img3, 0.5)
resized_mixup

## CutMix
CutMix is like Cutout as it remove a section from and image is like Mixup as it uses two images to make new smaple with lables. 

In [None]:
@_PIL_NUMPY
def cutmix(image1, image2):
    resized_image1 = ...
    resized_image2 = ...

    h, w, _ = image1.shape
    
    lam = np.random.uniform()
    width_factor = np.sqrt(1 - lam)
    xl, yl = w * width_factor, h * width_factor
    x0, y0 = np.random.randint(w - xl),  np.random.randint(h - yl)
    
    output = ...
    output[...] = ...
    return output

img1 = Image.open(IMAGE_1)
img2 = Image.open(IMAGE_2)
cutmix_img = cutmix(img1, img2)
cutmix_img

Make sure that your implmentation works for image of different sizes.

In [None]:
img1 = Image.open(IMAGE_1)
img3 = Image.open(IMAGE_3)
cutout_img = mixup(img1, img3)
cutout_img