# IMAGE DATABASE AUGMENTATION OF CLAHE DATABASE
### Code to build the augmented image database
#### by Luis Soenksen
#### Last Update: 01/08/2018

----------------------------

## IMAGE AUGMENTATION

In [None]:
"""
 PRO IMAGE DATABASE AUGMENTATION
 ---------------------------------
 by Luis R Soenksen
 Last Update: 2017/04/23
 Adapted from script for offline image augmentation using Keras
"""

import glob
import cv2
import os
import shutil
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from keras.preprocessing.image import ImageDataGenerator 


#IMAGE RANDOMIZATION AND AUGMENTATION HELPER FUNCTIONS
# Print iterations progress
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
    # Print New Line on Complete
    if iteration == total: 
        print()

        
#Definition of folder tree structure for converted files
inputpath = './data/single_lesion_database/clahe_data_randomized/'
outputpath ='./data/single_lesion_database/augmented_clahe_data_randomized/'

#Creation of required folders
if not os.path.isdir(outputpath):
    os.mkdir(outputpath)

for dirpath, dirnames, filenames in os.walk(inputpath):
    structure = os.path.join(outputpath, dirpath[len(inputpath):])
    if not os.path.isdir(structure):
        os.mkdir(structure)        
        
# Augmentation for training
train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range = 30,
        width_shift_range=0.1,
        height_shift_range=0.1,
        channel_shift_range = 0.1,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        #preprocessing_function=preprocess_img, # Other desired transformations are possible (more CLAHE for example)
        fill_mode='reflect') # 'nearest' copies last pixels and extends them


#Specify Augmentation setup
img_width, img_height = 299, 299  # change based on the shape/structure of your images
augm_idx = 50000 # desired number of images per class (sum of train, validation and test for each class)

#Obtain total number of images in randomized database
orig_db_n = (sum([len(files) for r, d, files in os.walk(inputpath)]))
print('Original Database: ' + inputpath +' has '+ str(orig_db_n) + ' images')
print('Selected per class augmentation (train + val + test): '+ str(augm_idx) + ' images')
print('')
print('Starting Balanced Augmentation...')
# Run through train, test and validation folders to augment all classes in a balanced way

for directory in glob.iglob(inputpath + '*', recursive=True):
    file_n = (sum([len(files) for p, d, files in os.walk(directory)]))
    class_n = len([d for p, d, files in os.walk(directory)][0])
    #Obtain the percentage of images that that directory contains of the full database
    dir_p = file_n/orig_db_n
    class_augm_idx = round(augm_idx*dir_p)
    print('')
    print('*' + directory.replace(inputpath,'') + ' dir has ' + str(file_n) + ' Images (' + str(dir_p*100) + '% of original database) divided into ' + str(class_n) + ' Classes')
    
    #Augment Loop
    for subdirectory in glob.iglob(directory + '/*', recursive=True):
        class_id = subdirectory.replace((directory+'/'),'')
        print('-->Augmenting class ' + class_id + ': ', end='')
        n = 0
        for batch in train_datagen.flow_from_directory(directory , batch_size=1, save_to_dir = subdirectory.replace(inputpath,outputpath),
                                                   target_size=(img_width, img_height), classes = [subdirectory.replace((directory+'/'),'')], 
                                                   save_prefix = 'AUG_' , save_format='png'):  
            n += 1
            if n>=class_augm_idx:
                break

print('')              
print('AUGMENTATION HAS FINISHED!')
print('Augmented/Balanced Database now has a total of ' + str(class_n*augm_idx) + ' images into ' + str(class_n) + ' classes, stored in folder:')
print(outputpath )
print('')  

------------------