# TensorFlow: DNN using dataset augmentation

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

In [None]:
from common import CV_DATA_DIR
DATA_DIR = CV_DATA_DIR / 'animals' / 'cats-and-dogs'
assert DATA_DIR.is_dir(),  \
    f'Dir "{DATA_DIR}" does not exists'

In [None]:
all_physical_devices = tf.config.list_physical_devices()
print("All physical devices:", all_physical_devices)

# Prepare dataset

In [None]:
train_ds, valid_ds = tf.keras.utils.image_dataset_from_directory(
    directory=DATA_DIR,
    image_size=(150, 150),
    batch_size=64,
    label_mode='binary',
    validation_split=0.1,
    subset='both',
    seed=1
)

In [None]:
BUFFER_SIZE = 1000

train_ds = (train_ds
    .cache()
    .shuffle(BUFFER_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE))

valid_ds = (train_ds
    .cache()
    .prefetch(tf.data.experimental.AUTOTUNE))

In [None]:
batch_images, batch_labels = next(iter(train_ds.take(1)))

print(f"Maximum pixel value of images: {np.max(batch_images)}\n")
print(f"Shape of batch of images: {batch_images.shape}")
print(f"Shape of batch of labels: {batch_labels.shape}")

# Define models

In [None]:
# Define train model
train_layers = tf.keras.models.Sequential([
    tf.keras.Input(shape=(150,150,3)),
    tf.keras.layers.Rescaling(scale=1./255),
    tf.keras.layers.Conv2D(32, kernel_size=(3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(64, kernel_size=(3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(128, kernel_size=(3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(128, kernel_size=(3,3), activation='relu'),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

In [None]:
FILL_MODE = 'nearest'

# Define augmentation layers
augme_layers = tf.keras.Sequential([
    tf.keras.Input(shape=(150,150,3)),
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2, fill_mode=FILL_MODE),
    tf.keras.layers.RandomTranslation(0.2, 0.2, fill_mode=FILL_MODE),
    tf.keras.layers.RandomZoom(0.2, fill_mode=FILL_MODE),
])

In [None]:
# Compose two models into one
model = tf.keras.models.Sequential([
    train_layers,
    augme_layers
])

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-4),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

In [None]:
# Check input images size for correctness
try:
    model.evaluate(batch_images, batch_labels, verbose=False)
except:
    print('Model is not compatible with dataset')
else:
    predictions = model.predict(batch_images, verbose=False)
    print(f'Predictions have shape: {predictions.shape}')

# Train model

In [None]:
class EarlyStoppingCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epochs, logs=None):
        if logs['accuracy'] >= 0.85 and logs['val_accuracy'] >= 0.8:
            self.model.stop_training = True
            print('Reached 95% train accuracy and 80% validation accuracy')

history = model.fit(
	train_ds,
	epochs=30,
	validation_data=valid_ds,
	callbacks = [EarlyStoppingCallback()],
    verbose=2
)

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

# Get number of epochs
epochs = range(len(acc))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
fig.suptitle('Training and validation accuracy')

ax[0].plot(epochs, acc, 'r', label='Training accuracy')
ax[0].plot(epochs, val_acc, 'b', label='Validation accuracy')
ax[0].set_title('Training and validation accuracy')
ax[0].set_xlabel('epochs')
ax[0].set_ylabel('accuracy')
ax[0].legend()

ax[1].plot(epochs, loss, 'r', label='Training Loss')
ax[1].plot(epochs, val_loss, 'b', label='Validation Loss')
ax[1].set_title('Training and validation loss')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('loss')
ax[1].legend()

plt.show()