In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import math
import gc
from PIL import Image
import cv2
import keras.preprocessing.image as myKeras
from matplotlib import pyplot as plt

In [2]:
load_path = 'downloaded_images/'
save_path = "pipeline_aug_images/"
generated_id = 0

def save_batch(batch, prefix):
    global generated_id
    for image in batch:
        cv2.imwrite(save_path + str(generated_id) + "_" + prefix + ".jpeg", image)
        generated_id = generated_id + 1

In [3]:
def imageAugmentation(class_name, aug_factor=5, batch_size=10, hshift = True, hflip = True, 
                      vShift = False, vFlip = False, brightness = True, scale = True, 
                      rotation = False):
    
    
    raw_datagen = myKeras.ImageDataGenerator(  )
    
    # Reading input images in batches
    raw_iterator = raw_datagen.flow_from_directory(
            load_path, 
            shuffle = True,
            batch_size = batch_size,
            class_mode = None,
            classes=[class_name]
        )
    
    print(raw_iterator.samples)
    
    if  batch_size > raw_iterator.samples:
        print(" big_batch_size can not be greater than sample size")
        return
    
    # Number of batches 
    n_batches = int(math.ceil(raw_iterator.samples / batch_size ))
    
    
    # loop through the batches
    for i in range (n_batches): 
        
        # input batch size = 100 ---> 150 hflipped --> 50 hsfhited
        # output batch size = 500
        # out of this 500, 30% fliped, 10% shifted, 30% scaled 30% brightes 
        
        # A batch of input images
        raw_batch = raw_iterator.next()
        
        # Augment each raw_batch_size to aug_batch_size
#         aug_batch_size =  aug_factor * len(raw_batch)
        aug_batch_size = aug_factor * raw_iterator.samples
#         print()

        ############# hflip #####################
        if hflip == True:
        
            hflip_datagen = myKeras.ImageDataGenerator( horizontal_flip=True )

            # prepare iterators for each dataset
            hflip_iterator = hflip_datagen.flow(
                raw_batch, 
                shuffle = True, 
                batch_size =  int (0.3 * aug_batch_size),
            )
        
            del raw_batch
            gc.collect()

            hflip_data = hflip_iterator.next()
            save_batch(hflip_data, "_hflip")

        
        ###################### hshift #######################
        
        if hshift == True:
            hshift_datagen = myKeras.ImageDataGenerator( width_shift_range = [-40, +40] )

            # prepare iterators for each dataset
            hshift_iterator = hshift_datagen.flow(
                hflip_data,
                batch_size = int (0.1 * aug_batch_size),
                shuffle = True, 
            )

            # 10% of 500 = 50 images
            hshift_data = hshift_iterator.next()
            save_batch(hshift_data, "_hshift" )

        
        ###################### scaling ##############
        
        if scale == True:
        
            scale_datagen = myKeras.ImageDataGenerator( zoom_range = [1.0, 0.5] )

            scale_iterator = scale_datagen.flow( 
                hflip_data, # 150
                batch_size = int (0.3 * aug_batch_size),
                shuffle = True,
            )

            # 30% of 500 is 150
            scale_data = scale_iterator.next()
            save_batch(scale_data, "_scale" )        


        ############### brightness ###################
        
        if brightness == True:
            
            join_data_for_brightness = np.concatenate((scale_data, hshift_data), axis=0)  
            brightness_datagen = myKeras.ImageDataGenerator( brightness_range = [1.0, 2.0] )

            brightness_iterator = brightness_datagen.flow(
                join_data_for_brightness, #200
                batch_size = int (0.3 * aug_batch_size), #150
                shuffle = True,
            )


            bright_data = brightness_iterator.next()
            save_batch(bright_data, "_bright")   
            
            
        ############ Vertical Shift #################
            
        if vShift == True:
        
            vShift_datagen = myKeras.ImageDataGenerator( vertical_flip=True )

            # prepare iterators for each dataset
            vShift_iterator = vShift_datagen.flow(
                raw_batch, 
                shuffle = True, 
                batch_size =  int (0.1 * aug_batch_size),
            )
        
            del raw_batch
            gc.collect()

            vShift_data = vShift_iterator.next()
            save_batch(vShift_data, "_vShift")
        
        
        ############ Vertical Flip #################
            
        if vFlip == True:
        
            vFlip_datagen = myKeras.ImageDataGenerator( vertical_flip=True )

            # prepare iterators for each dataset
            vFlip_iterator = vFlip_datagen.flow(
                raw_batch, 
                shuffle = True, 
                batch_size =  int (0.1 * aug_batch_size),
            )
        
            del raw_batch
            gc.collect()

            vFlip_data = vFlip_iterator.next()
            save_batch(vFlip_data, "_vFlip")

        if rotation == True:

            # create generator
            rotation_datagen = myKeras.ImageDataGenerator( rotation_range=90 )

            # prepare iterators for each dataset
            rotation_iterator = rotation_datagen.flow(
                raw_batch,
                batch_size =  int (0.3 * aug_batch_size),
                shuffle = True,
            )

            del raw_batch
            gc.collect()

            rotated_data = rotation_iterator.next()
            save_batch(rotated_data, "_rotated")


In [4]:
imageAugmentation( "nature", batch_size=6 )

Found 42 images belonging to 1 classes.
42
