In [None]:
# Importing required libraries
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import deepdish as dd

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.utils import np_utils, to_categorical, plot_model
from keras.callbacks import ReduceLROnPlateau
from keras.optimizers import Adam, RMSprop, SGD
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.models import Sequential

In [None]:
# Function definition: balance_and_augmentation_data
## Given a unbalanced dataset, this function carries out a data balance by 
## applying data augmentation. Given a number of images per class, the function
## returns a dataset with all the categories with the same number of data.

def balance_and_augmentation_data(X, Y, num_muestras_clase): 

  """
  Input params: X, Y, num_muestras_clase
  Output params: X_bal, Y_bal
  """

  X = np.array(X)
  Y = Y.to_numpy()

  # Definition of basic information: images size, number of categories and
  # inital number of samples per category.
  rows, cols, depth = X[0].shape
  num_clases = pd.DataFrame(Y).nunique()[0]
  num_muestras_init = pd.DataFrame(Y).value_counts()

  # Initializing the new dataset
  X_bal = np.zeros((num_clases*num_muestras_clase, rows, cols, depth), 
                   dtype = np.uint8)
  Y_bal = np.zeros((num_clases*num_muestras_clase,), dtype = np.uint8)

  # Definition of Data Augmentation operator
  seq = iaa.Sequential([
      #iaa.Fliplr(0.5), # horizontal flips
      #iaa.Crop(percent=(0, 0.1)), # random crops
      iaa.Sometimes(0.5,iaa.GaussianBlur(sigma=(0, 1.0))), # small gausian blur
      # Add gaussian noise.
      iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5)
      # Apply affine transformations to each image.
      # Scale/zoom them, translate/move them, rotate them and shear them.
      #iaa.Affine(
      #    scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
      #    translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
      #    rotate=(-40, 40),
      #    shear=(-8, 8),
      #    order = 0,
      #    cval=(0, 255),
      #)
  ], random_order=True)

  # Definition of another Data Augmentation operator
  img_augmentation = Sequential(
      [
          preprocessing.RandomRotation(factor=0.15),
          preprocessing.RandomTranslation(height_factor=0.1, width_factor=0.1),
          preprocessing.RandomFlip(),
          preprocessing.RandomContrast(factor=0.1),
      ],
      name="img_augmentation",
  )

  # Initializing counter
  next_idx = 0
  
  # For each category, we iterate filling the new dataset with existing data and
  # complete to the desired number of data per class with data augmentation
  # images from the previously defined operator.
  for cat in range(num_clases):

    print('Data Balance for Class {} with {} original images'.format(cat, 
                                              num_muestras_init[cat].values[0]))

    # Number images to generate with data augmentation operator.
    n_imgs_generate = num_muestras_clase - num_muestras_init[cat].values[0]
    # Indexes of original images of current category in the original dataset.
    idx_imgs_cat = np.where(Y == cat)[0]
    
    # Complete the new dataset with original data
    for idx, img in enumerate(idx_imgs_cat):
      if idx < num_muestras_clase:
        X_bal[next_idx,:,:,:] = X[img]
        Y_bal[next_idx] = Y[img]
        next_idx += 1
      else:
        break

    # Complete the new dataset with new data generated by data augmentation
    # operator.
    if n_imgs_generate > 0:
      idx_new_imgs = np.random.randint(0, num_muestras_init[cat].values[0], 
                                       n_imgs_generate)
      for idx in idx_new_imgs:
        modified_image = seq.augment_image(X[idx_imgs_cat[idx]])
        X_bal[next_idx,:,:,:] = img_augmentation(tf.expand_dims(modified_image,
                                                                axis=0))
        Y_bal[next_idx] = cat
        next_idx += 1

  return X_bal, Y_bal 