# Data Augmentation

 - Keras Documentation - ImageDataGenerator class: https://keras.io/preprocessing/image/

In [1]:
import glob
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing.image import img_to_array, load_img, save_img


VALIDATION_SPLIT = 0.1
SEED = 1

Using TensorFlow backend.


In [2]:
data_gen_args = dict(
    rotation_range=20,
    zoom_range=0.15,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.15,
    horizontal_flip=True,
    fill_mode='constant',
    validation_split=VALIDATION_SPLIT
)
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

seed = 1
images = []
masks = []

for image in sorted(glob.glob('256x256/*_image.jpg')):
    images.append(img_to_array(load_img(image)))
for mask in sorted(glob.glob('256x256/*_mask.jpg')):
    masks.append(img_to_array(load_img(mask)))

assert len(images) == len(masks)
input_length = len(images)
print("Found %d images and masks" % input_length)

images = np.array(images)
masks = np.array(masks)

train_generator = zip(
    image_datagen.flow(images, batch_size=1, seed=SEED, subset='training'),
    mask_datagen.flow(masks, batch_size=1, seed=SEED, subset='training')
)
val_generator = zip(
    image_datagen.flow(images, batch_size=1, seed=SEED, subset='validation'),
    mask_datagen.flow(masks, batch_size=1, seed=SEED, subset='validation')
)

# Enlarging dataset by a factor of 10.
output_length = input_length * 10
val_length = int(output_length * VALIDATION_SPLIT)
train_length = output_length - val_length

i = 0
for images, masks in train_generator:
    for image, mask in zip(images, masks):
        save_img('aug/train/' + str(i) + '_image.jpg', image, quality=95)
        save_img('aug/train/' + str(i) + '_mask.jpg', mask, quality=95)
        i += 1
    
    if i >= train_length:
        break

print("Created %d training images" % train_length)

i = 0
for images, masks in val_generator:
    for image, mask in zip(images, masks):
        save_img('aug/val/' + str(i) + '_image.jpg', image, quality=95)
        save_img('aug/val/' + str(i) + '_mask.jpg', mask, quality=95)
        i += 1
    
    if i >= val_length:
        break

print("Created %d validation images" % val_length)

Found 62 images and masks
Created 558 training images
Created 62 validation images
