# A closer look at the practice of Data Augmentation

First, some preparation (import and data loading)

In [None]:
%matplotlib inline

# General imports
import matplotlib.pyplot as plt
import pathlib
import skimage
import skimage.transform
import skimage.viewer
import pandas as pd
import numpy as np
import time
import keras.utils.np_utils
from tqdm import tqdm

# Setup to show interactive jupyter widgets
from IPython.display import Image, display
from ipywidgets import interact, fixed
import ipywidgets as widgets
def imgplotList(i,data):
    plt.figure(figsize=(10,10))
    plt.imshow(data[i],interpolation="nearest")
    plt.show()

In [None]:
# Define where datasets are located
dataset_directory = pathlib.Path("..")/"datasets"/"final split"

# Define which datasets we should consider.
# Each dataset is a directory withing dataset_directory
# and must contain three subdirectories: (c0, c1, c2) for (rock, paper, scissors).
# dnames = ["testing.small","Dpolimi","D1.small","D2.small","D3.small","D4.small","D5.small","D6.small"]
dnames = ["Dpolimi"]


# Now check the data
ddirs=[dataset_directory/dn for dn in dnames] # directories of the dataset
cdirs={}
for ddir in ddirs:
    cdirs.update({ddir/"c0":0,
                  ddir/"c1":1,
                  ddir/"c2":2})
names = ["rock", "paper", "scissors"]
for cdir,cdir_class in cdirs.items():
    assert(cdir.exists())
    print("Found directory {} containing class {}".format(cdir,names[cdir_class]))

In [None]:
imagesize = 200

dataset=[]

import warnings

for cdir,cn in tqdm(list(cdirs.items())):
    for f in list(cdir.glob("*")):
        try:
            im=skimage.io.imread(f)
        except (OSError, ValueError) as e:
            warnings.warn("ignoring {} due to exception {}".format(f,str(e)))
            continue
            
        h,w=im.shape[0:2] # height, width
        sz=min(h,w)
        im=im[(h//2-sz//2):(h//2+sz//2),(w//2-sz//2):(w//2+sz//2),:] # defines the central square        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            im=skimage.img_as_ubyte(skimage.transform.resize(im,(imagesize,imagesize))) # resize it to 500x500, whatever the original resolution
            
        dataset.append({
            "file":f,
            "label":cn,
            "image":im})
        
print("Done")

We make a pandas dataframe for the dataset, and create a "dn" field containing the name of the dataset from which each image comes (as the name of the directory it was read from).

In [None]:
dataset=pd.DataFrame(dataset)
dataset["dn"]=dataset["file"].apply(lambda x: x.parent.parts[-2])

## A closer look to the practice of data augmentation

In [None]:
im = list(dataset.sample(1)["image"])[0]
plt.imshow(im)

In [None]:
ims = list(dataset.sample(3)["image"])

In [None]:
def prepare(im, sz, normalize, rot):
    ...
    return imr

@interact(sz=widgets.IntSlider(min=2,max=32),
          normalize = widgets.ToggleButton(),
          rot=widgets.IntSlider(min=0,max=4),
          continuous_update=False)
def f(sz,normalize,rot):
    fig,axs = plt.subplots(ncols=len(ims),figsize=(18,6))
    for im,ax in zip(ims,axs):
        ax.imshow(prepare(im,sz,normalize,rot))
        

## Let's now edit the cell above to make some basic augmentations / preprocessing
First, we visualize the effect of normalization.

Then, we implement:
- resizing
- 90 deg rotations
- flipping

In [None]:
def transform(im,sz,rot90times,flip,normalize):
    imr = skimage.transform.resize(im,(sz,sz))
    if(flip):
        imr = np.fliplr(imr)
    imr = np.rot90(imr, k=rot90times)
    if(normalize):
        imr = np.clip((imr-np.mean(imr))/np.std(imr) * 0.2 + 0.5,0,1)
    return imr

@interact(sz = widgets.IntSlider(min=1,max=60,value=30),
          rot90times = widgets.IntSlider(min=0,max=4),
          flip = widgets.ToggleButton(),
          normalize = widgets.ToggleButton(),
          continuous_update=False)
def f(sz,rot90times,flip,normalize):
    fig,axs = plt.subplots(ncols=len(ims),figsize=(18,6))
    for im,ax in zip(ims,axs):
        ax.imshow(transform(im,sz,rot90times,flip,normalize))

## Geometric transformations

Geometric transformations on images require to manipulate transformation objects

In [None]:
tf = (skimage.transform.SimilarityTransform(translation = (1,2)) +
      skimage.transform.SimilarityTransform(rotation = np.pi/3) +
      skimage.transform.SimilarityTransform(translation = (1,2)))
tf([[100,0]])

Then, we can call the transform object as a function on a list of points.

Transformations can be also chained

In [None]:
tf = (skimage.transform.SimilarityTransform(translation=(20,30)) +
      skimage.transform.SimilarityTransform(scale=0.1) +
      skimage.transform.SimilarityTransform(translation=(5,5)))
tf.params

In [None]:
tf([[0,0],[1,1]])

In [None]:
def transform(im,sz,rot90times,flip,normalize):
    imr = skimage.transform.resize(im,(sz,sz))
    if(flip):
        imr = np.fliplr(imr)
    imr = np.rot90(imr, k=rot90times)
    if(normalize):
        imr = np.clip((imr-np.mean(imr))/np.std(imr) * 0.2 + 0.5,0,1)
    
    return imr

@interact(sz = widgets.IntSlider(min=1,max=60,value=30),
          rot90times = widgets.IntSlider(min=0,max=4),
          flip = widgets.ToggleButton(),
          normalize = widgets.ToggleButton(),
          continuous_update=False)
def f(sz,rot90times,flip,normalize):
    fig,axs = plt.subplots(ncols=len(ims),figsize=(18,6))
    for im,ax in zip(ims,axs):
        ax.imshow(transform(im,sz,rot90times,flip,normalize))

## Applying a transform to an image

```skimage.transform.warp(image, inverse_map, map_args={}, output_shape=None, order=1, mode='constant', cval=0.0, clip=True, preserve_range=False)```

Try zooming and rotating an image

In [None]:
h,w = im.shape[0], im.shape[1]
@interact(continuous_update=False,
          scale = widgets.FloatSlider(min=0.1,max=11),
          rot = widgets.FloatSlider(min=-np.pi,max=+np.pi))
def f(scale,rot):
    tf = (  ... )
    imt = skimage.transform.warp(im, tf.inverse)
    fig,(ax0,ax1) = plt.subplots(ncols=2,figsize=(18,8))
    ax0.imshow(im)
    imt = skimage.transform.warp(im, tf.inverse)
    ax1.imshow(imt)

In [None]:
@interact(scale = widgets.FloatSlider(min=0.1,max=3,value=1),
          rotation = widgets.FloatSlider(min=-np.pi,max=+np.pi,value=0),
          continuous_update=False)
def f(scale,rotation):
    tf = (  skimage.transform.SimilarityTransform(translation = [-h/2,-w/2]) +
            skimage.transform.SimilarityTransform(scale = scale) +
            skimage.transform.SimilarityTransform(rotation = rotation) +
            skimage.transform.SimilarityTransform(translation = [+h/2,+w/2]))
    imt = skimage.transform.warp(im, tf.inverse)
    
    fig,(ax0,ax1) = plt.subplots(ncols=2,figsize=(18,8))
    ax0.imshow(im)
    ax1.imshow(imt)

## The piecewise affine transform for elastic deformations
![image.png](attachment:image.png)

In [None]:
@interact(xdisp = widgets.FloatSlider(min=0.1,max=0.9,value=0.5),
          ydisp = widgets.FloatSlider(min=0.1,max=0.9,value=0.5),
          continuous_update=False)
def f(xdisp,ydisp):
    tf = skimage.transform.PiecewiseAffineTransform()
    points =  np.array([[0,0],    [0,1],    [1,1],    [1,0],    [0.5,0.5]])
    tpoints=  np.array([[0,0],    [0,1],    [1,1],    [1,0],    [xdisp,ydisp]])
    tf.estimate(points * np.array([[w,h]]), tpoints * np.array([[w,h]]))
    imt = skimage.transform.warp(im, tf.inverse)
    
    fig,(ax0,ax1) = plt.subplots(ncols=2,figsize=(18,8))
    ax0.imshow(im)
    ax1.imshow(imt)

## Assignment 1

In [None]:
# 1. Download mnist
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train.shape, y_train.shape
y_test.shape

In [None]:
# 2. Sample the first 5 training elements from each digit
for i in range(10):
    print("Use only these elements for digit "+str(i), np.where(y_train == i)[0][:5])

In [None]:
# 3. Train a decent CNN with a simple lenet-like architecture.  Augmentation will help.

## Assigment 2
Apply all you have learned on the new packages dataset (see course website) which includes 2000+ images in 4 classes.  Compare a feature-based approach (from first week) to a CNN-based approach.

They key for CNN success in this context is doing proper data augmentation!