In [None]:
import tensorflow as tf
from keras.preprocessing.image import ImageDataGenerator
import os
import scipy
from skimage import io
import multiprocessing as mp
from itertools import repeat

In [None]:
datagen = ImageDataGenerator(
    rotation_range=20, #rotate between 0 and 20 degree
    width_shift_range=0.1, # shift width
    height_shift_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest', # constant, nearest, reflect
    cval=125,
    brightness_range=[0.7,1.0])



In [None]:
dir = 'downscaledImg/'
augmentedPath = 'augmentedImg/'

# how many images should be produced per source image
iterations = 5;

# read all classes
classes = os.listdir(dir)

In [None]:
# process an image
def processImage(srcDir, cls, img, targetPath):
    i = 0
    try:
        path = srcDir+cls+'/'+img
        print(path)
        x = io.imread(path)
        x = x.reshape((1, ) + x.shape)
        for batch in datagen.flow(x, batch_size=16,
                          save_to_dir=targetPath,
                          save_prefix='aug-'+img,
                          save_format='png'):
            i+=1
            if i >= iterations: break;
    except BaseException as err:
        print(type(err))
        if(type(err) == KeyboardInterrupt):
            raise Exception("stopped process")
        print(f"Unexpected {err=}, {type(err)=}")

def processImageClass(cls, subThreads):
    print("starting with: "+cls)
    pool = mp.Pool(subThreads)
    targetPath = augmentedPath+cls
    images = os.listdir(dir+cls)
    if not(os.path.exists(targetPath)): os.mkdir(targetPath)
    #for img in images:
    pool.starmap(processImage, zip(repeat(dir), repeat(cls), images, repeat(targetPath)))
    #pool.starmap(processImage, [dir, cls, (images), targetPath])
        #processImage(dir, cls, img, 'hello', targetPath)
    pool.close()
            
    print("done with: "+cls)

In [None]:
#define pool for parallel processing
processors = mp.cpu_count()
if not(os.path.exists(augmentedPath)): os.mkdir(augmentedPath)
for cls in classes:
    if(os.path.isdir(dir+cls)):
        processImageClass(cls, processors)

print('DONE!!!')
