# Getting started with TensorFlow's `Dataset` API

A very simple input pipeline for data augmentation

In [1]:
import numpy as np
import tensorflow as tf

In [2]:
def data_generator():
    """Generate fake image data with fale labels"""
    for i in range(10):
        yield np.random.random((24, 24, 3)), (np.random.randint(10),)
               #(np.random.random((1,)) * 10).astype(np.int32))
        

def data_augmentation(image, label):
    """Create distortions from a given image"""
    # Create distortions
    flip_up_down = tf.image.flip_up_down(image)
    flip_left_right = tf.image.flip_left_right(image)
    random_brightness = tf.image.random_brightness(image, max_delta=63/255.0)
    random_contrast = tf.image.random_contrast(image, lower=0.2, upper=1.8)
    
    # Put all distortions together
    image_batch = tf.concat([image,
                             flip_up_down,
                             flip_left_right,
                             random_brightness,
                             random_contrast], axis=0)
    
    # Repeat (tile) the label so all images have one 
    label_batch_shape = (tf.shape(image_batch)[0] // tf.shape(label)[0], tf.shape(label)[1])
    return image_batch, tf.tile(label, label_batch_shape)

In [3]:
dataset = tf.data.Dataset.from_generator(data_generator, output_types=(tf.float32, tf.int32))
dataset = dataset.batch(2)
dataset = dataset.map(data_augmentation)
iterator = dataset.make_one_shot_iterator()
next_item = iterator.get_next()

In [4]:
with tf.Session() as sess:
    x, y = sess.run(next_item)
    print(x.shape, y.shape)

(10, 24, 24, 3) (10, 1)
