# Data Augmentation with Keras and TensorFlow

In [None]:
%%capture
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

from keras import layers
import keras

In [None]:
%%capture
(train_ds, val_ds, test_ds), metadata = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

In [None]:
num_classes = metadata.features['label'].num_classes
print(num_classes)

In [None]:
get_label_name = metadata.features['label'].int2str
train_iter = iter(train_ds)
fig = plt.figure(figsize=(7, 8))
for x in range(4):
    image, label = next(train_iter)
    fig.add_subplot(1, 4, x+1)
    plt.imshow(image)
    plt.axis('off')
    plt.title(get_label_name(label))

## Resize and rescale

In [None]:
IMG_SIZE = 180

resize_and_rescale = keras.Sequential([
  layers.Resizing(IMG_SIZE, IMG_SIZE),
  layers.Rescaling(1./255)
])

result = resize_and_rescale(image)
plt.axis('off')
plt.imshow(result)

## Random rotate and flip

In [None]:
data_augmentation = keras.Sequential([
  layers.RandomFlip("horizontal_and_vertical"),
  layers.RandomRotation(0.4),
])


plt.figure(figsize=(8, 7))
for i in range(6):
    augmented_image = data_augmentation(image)
    ax = plt.subplot(2, 3, i + 1)
    plt.imshow(augmented_image.numpy()/255)
    plt.axis("off")

# Opcion 1:

## Directly adding to the model layer 

In [None]:
model = keras.Sequential([
    # Add the preprocessing layers you created earlier.
    resize_and_rescale,
    data_augmentation,
    # Add the model layers
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(1,activation='sigmoid')
])

# Opcion 2:

## Applying the augmentation function using .map

In [None]:
aug_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y))

## Data pre-processing 

The function will:

1. Apply resize and rescale to the entire dataset.
2. If shuffle is True, it will shuffle the dataset.
3. Convert the data into batches using 32 batch size. 
4. If the augment is True, it will apply the data argumentation function on all datasets. 
5. Finally, use Dataset.prefetch to overlap the training of your model on the GPU with data processing.

In [None]:
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds, shuffle=False, augment=False):
  # Resize and rescale all datasets.
    ds = ds.map(lambda x, y: (resize_and_rescale(x), y),
              num_parallel_calls=AUTOTUNE)

    if shuffle:
    ds = ds.shuffle(1000)

    # Batch all datasets.
    ds = ds.batch(batch_size)

    # Use data augmentation only on the training set.
    if augment:
    ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y),
                num_parallel_calls=AUTOTUNE)

    # Use buffered prefetching on all datasets.
    return ds.prefetch(buffer_size=AUTOTUNE)


train_ds = prepare(train_ds, shuffle=True, augment=True)
val_ds = prepare(val_ds)
test_ds = prepare(test_ds)

## Model building

In [None]:
model = keras.Sequential([
    layers.Conv2D(32, (3, 3), input_shape=(180,180,3), padding='same', activation='relu'),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(32, activation='relu'),
    layers.Dense(1,activation='softmax')
])

## Training and evaluation

In [None]:
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])
epochs=1
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
    )

In [None]:
loss, acc = model.evaluate(test_ds)

# Opcion 3

## Data Augmentation using tf.image

### Data Loading

In [None]:
%%capture
(train_ds, val_ds, test_ds), metadata = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

In [None]:
image, label = next(iter(train_ds))
plt.imshow(image)
plt.title(get_label_name(label));

### Flip left to right

In [None]:
def visualize(original, augmented):
    fig = plt.figure()
    plt.subplot(1,2,1)
    plt.title('Original image')
    plt.imshow(original)
    plt.axis("off")
 
    plt.subplot(1,2,2)
    plt.title('Augmented image')
    plt.imshow(augmented)
    plt.axis("off")

In [None]:
flipped = tf.image.flip_left_right(image)
visualize(image, flipped)

### Grayscale

In [None]:
grayscaled = tf.image.rgb_to_grayscale(image)
visualize(image,  tf.squeeze(grayscaled))

### Adjusting the saturation

In [None]:
saturated = tf.image.adjust_saturation(image, 3)
visualize(image, saturated)

### Adjusting the brightness

In [None]:
bright = tf.image.adjust_brightness(image, 0.4)
visualize(image, bright)

### Central Crop

In [None]:
cropped = tf.image.central_crop(image, central_fraction=0.5)
visualize(image, cropped)

### 90-degree rotation

In [None]:
rotated = tf.image.rot90(image)
visualize(image, rotated)

### Applying random brightness

In [None]:
for i in range(3):
    seed = (i, 0)  # tuple of size (2,)
    stateless_random_brightness = tf.image.stateless_random_brightness(
      image, max_delta=0.95, seed=seed)
    visualize(image, stateless_random_brightness)

### Applying the augmentation function

In [None]:
def augment(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    image = (image / 255.0)
    image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])
    image = tf.image.random_brightness(image, max_delta=0.5)
    return image, label


train_ds = (
    train_ds
    .shuffle(1000)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)