In [195]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
tf.__version__

'2.3.1'

In [196]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
    tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=.7, width_factor=0.7),    
])

In [197]:
def load_img(img_path):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, size=(224,224))
    return img

directories = [
    "./dataset-raw/Black & White",
    "./dataset-raw/Color",
]

for directory in directories:

    glob_pattern = "{}/*.jpg".format(directory)
    dataset = tf.data.Dataset.list_files(glob_pattern)
    images = dataset.map(load_img)

    print('Augmenting {} images in {}..'.format(len(images), directory))
    
    index = 0
    batch_size = len(images)
    output_dir = '{}-augmented'.format(directory)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    augmentations = 9
    
    for batch in images.batch(batch_size):
        print('Batch {} {}'.format(index + 1, batch.shape))

        aug_index = 0
        for i in range(augmentations):
            augmented_batch = data_augmentation(batch)
            
            img_index = 0
            for aug_img in augmented_batch:

                aug_img_path = '{}/{:04}-{:02}.jpg'.format(output_dir, img_index, aug_index)
                aug_img = tf.image.convert_image_dtype(aug_img/255.0, dtype=tf.uint8)
                aug_img = tf.image.encode_jpeg(aug_img)
                tf.io.write_file(aug_img_path, aug_img)                
                
                img_index += 1
            aug_index += 1  
        index += 1
        
!open {output_dir}

Augmenting 30 images in ./dataset-raw/Black & White..
Batch 1 (30, 224, 224, 3)
Augmenting 32 images in ./dataset-raw/Color..
Batch 1 (32, 224, 224, 3)


In [198]:
# amount = 9
# for i in range(amount):
#     augmented_image = data_augmentation(input)
#     ax = plt.subplot(3, 3, i + 1)
#     plt.imshow(augmented_image[0], cmap='gray')
#     plt.axis('off')