# Preprocessing

In [1]:
IMG_SIZE = 160 # All images will be resized to 160x160

def format_image(image, label):
    """Formatting image raw data."""
    image = tf.cast(image, tf.float32) # Casting to tensor data type.
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) # Resize image.
    return image, label


def formatting(dataset, img_size=160, batch_size=32, shuffle_buffer_size=1000, log=False):
    """Formatting image processing data. 
    
    Arguments
    ---------
    dataset             : tf.data.Dataset
    img_size            : int
    batch_size          : int
    shuffle_buffer_size : int
    
    Return
    ------
    image_batch : class tensorflow.python.framework.ops.EagerTensor
    train_batch : class tensorflow.python.framework.ops.EagerTensor

    Example
    -------
    (raw_train, raw_validation, raw_test), metadata = tfds.load(
        'cats_vs_dogs',
        split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
        with_info=True,
        as_supervised=True,
    )
    formatting(dataset=raw_train, img_size=160, batch_size=32, shuffle_buffer_size=1000)
    """
    
    # Each data in dataset has to be formatted in 160x160 image. 
    # as it is defined in tf.keras.applications.MobileNetV2 input shape = 160. 
    def format_example(image, label):
        image = tf.cast(image, tf.float32)
        image = (image/127.5) - 1
        image = tf.image.resize(image, (img_size, img_size))
        return image, label

    # Apply above function to each data in dataset. 
    # map function apply given function in each data. 
    train = dataset.map(format_example)
    if log:
        print("train : ", train)

    # This process extract ramdam data and slices in small batches. 
    # *Noted: After the batch(), data shape will be different from (160, 160, 3) to (None, 160, 160, 3)
    train_batches = train.shuffle(shuffle_buffer_size).batch(batch_size)
    if log:
        print("train_batches : ", train_batches)
    
    # Get one batch size. 
    for image_batch, label_batch in train_batches.take(1):
        pass
    
    if log:
        print("image_batch : ", image_batch.shape)
        print("label_batch : ", label_batch.shape)
        print(type(image_batch), type(label_batch))
        
    return image_batch, label_batch
