# 0 Imports

In [None]:
import os, warnings
import shutil
from collections import namedtuple
import numpy as np
import math
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
import pandas as pd
import json

import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, TensorBoard, CSVLogger, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.applications.vgg16 import VGG16

import ipywidgets as widgets
from ipywidgets import interact, interact_manual

from scripts import read_saves, write_saves, record_saves

def set_seed(seed=31415):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # os.environ['TF_CUDNN_DETERMINISTIC'] = 'true'
    # os.environ['TF_DETERMINISTIC_OPS'] = '1'
set_seed()

plt.rc('figure', autolayout=True)
plt.rc('axes', labelweight='bold', labelsize='large',
       titleweight='bold', titlesize=18, titlepad=10)
plt.rc('image', cmap='magma')
warnings.filterwarnings("ignore")

print("-----------------------------------------")
if tf.test.gpu_device_name():
    print(f"GPU used: {tf.test.gpu_device_name()}")
else:
    print(f"GPU not used")
print("-----------------------------------------")

from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

***
# 1 Config

In [None]:
config = {
    "exploration": False,
    "data_augmentation": False,
    "custom": False,
    "transfer": {
        "vgg16": False,
        "xception": False
    }
}

In [None]:
annotations_dir = "data/Annotations"
images_dir = "data/Images"
target_dir = "data/Targets"
batch_size = 16

SAVE_PATH = "data/saves.pkl"
SAVE_NB = 5

***
# 2 Dataset exploration

## 2.0 Utils

In [None]:
def convert_to_float(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    return image, label

In [None]:
def breeds_distribution(dataset, figsize=(25, 5), display_mean=True):
    plt.figure(figsize=figsize)

    sns.barplot(x=fre_df["Breed"], y=fre_df["Population"])

    if display_mean:
        mean = fre_df.mean(axis=0)[0]
        plt.axhline(mean, linestyle="--", linewidth=1, color="r")

    plt.title("Breeds distribution", size=20)
    plt.xticks(size=10, rotation=45, ha="right")
    plt.yticks(size=10)
    plt.ylabel("Count", size=16)

    if display_mean:
        return mean

***
## 2.1 Breeds distribution

How many breeds are present in the dataset ?

In [None]:
breed_subdir = os.listdir(images_dir)
print(f"There is {len(breed_subdir)} different dog breeds")

Construct a pandas dataset with all breeds and each of their respective population

In [None]:
breed_full: list = []
breed: list = []
population: list = []

for sub_d in breed_subdir:
    breed_full.append(sub_d)
    breed.append(sub_d.split("-")[1])
    population.append(len([obs for obs in os.listdir(images_dir + "/" + sub_d)]))

fre_df = pd.DataFrame(data={"Breed_full": breed_full, "Breed": breed, "Population": population})
fre_df.head()

Plot the breeds populations and the mean

In [None]:
if config["exploration"]:
    mean = breeds_distribution(fre_df)

Let's plot again bu after a sort on Population field

In [None]:
if config["exploration"]:
    fre_df.sort_values(by="Population", inplace=True, ascending=False, axis=0)
    mean = breeds_distribution(fre_df)

Let's take a look at the top three breeds, see if they are sufficiently differents.

In [None]:
if config["exploration"]:

    fix, axs = plt.subplots(3, 3, figsize=(20, 10))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    # Malteses
    plt.subplot(3, 3, 1)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_37.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 2)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_66.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 3)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_233.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    # Afghan
    plt.subplot(3, 3, 4)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_231.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 5)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_251.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 6)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_272.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    # Scottish deerhound
    plt.subplot(3, 3, 7)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_3.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 8)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_198.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 9)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_86.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

We can see that the images don't have the same size

In [None]:
fre_df = fre_df.iloc[:10,:]
fre_df

***
# 3 Dataset preparation

## 3.0 Utils

In [None]:
def sync_dataset_directory(breeds: list, source_dir: str = "data/Images", target_dir: str = "data/Targets"):
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.mkdir(target_dir)
    for breed in breeds:
        source = source_dir + "/" + breed
        target = target_dir + "/" + breed
        shutil.copytree(source, target)

***
## 3.1 Without augmentation

In [None]:
if not config["data_augmentation"]:

    sync_dataset_directory(fre_df["Breed_full"])

    ds_train_ = image_dataset_from_directory(
        target_dir,
        labels="inferred",
        label_mode="categorical",
        image_size=[224, 224],
        interpolation="nearest",
        batch_size=batch_size,
        seed=0,
        shuffle=True,
        validation_split=0.8,
        subset="training"
    )

    ds_valid_ = image_dataset_from_directory(
        target_dir,
        labels="inferred",
        label_mode="categorical",
        image_size=[224, 224],
        interpolation="nearest",
        batch_size=batch_size,
        seed=0,
        shuffle=True,
        validation_split=0.2,
        subset="validation"
    )

    AUTOTUNE = tf.data.experimental.AUTOTUNE
    ds_train = (
        ds_train_
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size=AUTOTUNE)
    )
    ds_valid = (
        ds_valid_
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size=AUTOTUNE)
    )

***
## 3.2 With augmentation

In [None]:
if config["data_augmentation"]:

    sync_dataset_directory(fre_df["Breed_full"])

    train_datagen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        zca_whitening=True,
        rotation_range=90,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        validation_split=0.2,
        rescale=1./224
        )

    valid_datagen = ImageDataGenerator(
        validation_split=0.2,
        rescale=1./224
        )

    train_aug = train_datagen.flow_from_directory(
        target_dir,
        target_size=(224, 224),
        subset="training",
        batch_size=16,
        shuffle=True
        )

    val_aug = valid_datagen.flow_from_directory(
        target_dir,
        target_size=(224, 224),
        subset="validation",
        batch_size=4,
        shuffle=True
        )

***
# 4 Custom CNN

## 4.0 Utils

In [None]:
def visualize_history(history, figsize=(20, 10), metrics: str = "categorical_accuracy"):
    fix, axs = plt.subplots(2, 1, figsize=figsize, sharex=True)

    plt.subplot(2, 1, 1)
    plt.title("Loss")
    sns.lineplot(data=history, x=history.index, y="loss", label="loss")
    sns.lineplot(data=history, x=history.index, y="val_loss", label="val_loss")
    plt.xlabel("epochs")
    plt.tick_params(labelright=True)
    plt.legend()
    plt.grid()

    plt.subplot(2, 1, 2)
    plt.title("Accuracy")
    sns.lineplot(data=history, x=history.index, y=metrics, label=metrics)
    sns.lineplot(data=history, x=history.index, y="val_" + metrics, label="val_" + metrics)
    plt.xlabel("epochs")
    plt.tick_params(labelright=True)
    plt.legend()
    plt.grid()

***
## 4.1 Neural Network

In [None]:
if config["custom"]:

    model = keras.Sequential([
        layers.InputLayer(input_shape=[224, 224, 3]),

        # >>> base <<<
        # > block 1 <
        layers.Conv2D(
            filters=16,
            kernel_size=3,
            strides=1,
            padding="same",
            use_bias=False
            ),
        layers.BatchNormalization(scale=False),
        layers.Activation(activation="relu"),
        layers.MaxPool2D(
            pool_size=4,
            # strides=2,
            padding="same"
        ),
        layers.Dropout(rate=0.2),
        # > block 1: end <

        # > block 2 <
        layers.Conv2D(
            filters=32,
            kernel_size=3,
            strides=1,
            padding="same",
            use_bias=False
            ),
        layers.BatchNormalization(scale=False),
        layers.Activation(activation="relu"),
        layers.MaxPool2D(
            pool_size=4,
            # strides=2,
            padding="same"
        ),
        layers.Dropout(rate=0.2),
        # > block 2: end <

        # > block 3 <
        layers.Conv2D(
            filters=64,
            kernel_size=3,
            strides=1,
            padding="same",
            use_bias=False
            ),
        layers.BatchNormalization(scale=False),
        layers.Activation(activation="relu"),
        layers.GlobalAveragePooling2D(),
        # > block 3: end <
        # >>> base: end <<<

        # >>> head <<<
        # layers.Flatten(),

        layers.Dense(units=128, activation="relu"),

        layers.Dense(units=fre_df.shape[0], activation="softmax")
        # >>> head: end <<<
    ])

    # model.summary()

In [None]:
if config["custom"]:

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["categorical_accuracy"]
    )

    checkpoint = ModelCheckpoint(
        './base.model',
        monitor='val_loss',
        verbose=1,
        save_best_only=True,
        mode='min',
        save_weights_only=False,
        period=1
    )

    early_stopping = EarlyStopping(
        min_delta=0.001,
        patience=5,
        restore_best_weights=True
    )

    tensorboard = TensorBoard(
        log_dir="./logs",
        histogram_freq=0,
        write_graph=True,
        write_images=False
    )

    csv_logger = CSVLogger(
        filename="training_csv.log",
        separator=",",
        append=False
    )

    reduce = ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.1,
        patience=3,
        verbose=1, 
        mode='auto'
    )

    if not config["data_augmentation"]:
        print("Fit without aug")
        history = model.fit(
            ds_train,
            validation_data=ds_valid,
            epochs=50,
            workers=8,
            # callbacks=[early_stopping],
            # verbose=0
        )
    else:
        print("Fit with aug")
        history = model.fit_generator(
            train_aug,
            validation_data=val_aug,
            epochs=50,
            steps_per_epoch=100,
            workers=8
            # callbacks=[early_stopping]
            # verbose=0
        )

***
## 3.3 Results

In [None]:
if config["custom"]:

    history_df = pd.DataFrame(data=history.history)
    visualize_history(history_df)

In [None]:
# saves = record_saves(SAVE_PATH, SAVE_NB, history, model)

In [None]:
# for i, (key, (model_json, history)) in enumerate(saves.items()):
#     if i == 0:
#         visualize_history(history)
#     else:
#         visualize_history(history, figsize=(10, 6))
#         # json_parsed = json.loads(model_json)
#         # print(json.dumps(json_parsed, indent=4, sort_keys=True))

In [None]:
# tab_contents = saves.keys()
# children = []

# for i, (key, (model_json, history)) in enumerate(saves.items()):
#     if i == 0:
#         visualize_history(history)
#     else:
#         out = widgets.Output()
#         children.append(out)
#         with out:
#             visualize_history(history, figsize=(10, 6))
#             json_parsed = json.loads(model_json)
#             print(json.dumps(json_parsed, indent=4, sort_keys=True))

# tab = widgets.Tab(children=children)
# for i, child in enumerate(children):
#     tab.set_title(i, "tab"+str(i))

# display(tab)

***
# 5 Transfer learning CNN

## 5.1 VGG16

In [None]:
if config["transfer"]["vgg16"]:

    vgg16 = VGG16(
        include_top=False,
        input_shape=(224, 224, 3),
        weights="imagenet"
    )

    for layer in vgg16.layers:
        layer.trainable = False

In [None]:
if config["transfer"]["vgg16"]:

    model_vgg16 = keras.Sequential([
        # >>> base <<<
        vgg16,

        # >>> head <<<
        layers.GlobalAveragePooling2D(),

        layers.Dropout(rate=0.5),
        layers.Dense(units=fre_df.shape[0], activation="softmax")
    ])

    model_vgg16.summary()

In [None]:
if config["transfer"]["vgg16"]:

    model_vgg16.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["categorical_accuracy"]
    )

    early_stopping = EarlyStopping(
        min_delta=0.001,
        patience=5,
        restore_best_weights=True
    )

    if not config["data_augmentation"]:
        print("Fit without aug")
        history_vgg16 = model_vgg16.fit(
            ds_train,
            validation_data=ds_valid,
            epochs=20,
            workers=8,
            # callbacks=[early_stopping],
            # verbose=0
        )
    else:
        print("Fit with aug")
        history_vgg16 = model_vgg16.fit_generator(
            train_aug,
            validation_data=val_aug,
            epochs=20,
            steps_per_epoch=100,
            workers=8
            # callbacks=[early_stopping]
            # verbose=0
        )

In [None]:
if config["transfer"]["vgg16"]:

    history_vgg16_df = pd.DataFrame(data=history_vgg16.history)
    visualize_history(history_vgg16_df)

***
### 5.1.1 Result without aug

<img src="plots/transfer_vgg16_without_aug.png" style="background-color:white">

***
### 5.1.2 Result with aug

<img src="plots/transfer_vgg16_with_aug.png" style="background-color:white">

***
## 5.2 Xception

In [None]:
if config["transfer"]["xception"]:

    xception = Xception(
        weights="imagenet",
        include_top=False,
        pooling="avg",
        input_shape=(224, 224, 3)
    )

    for layer in xception.layers:
        layer.trainable = False

In [None]:
if config["transfer"]["xception"]:

    model_xception = keras.Sequential([
        # >>> base <<<
        xception,

        # >>> head <<<
        layers.Dense(units=128, activation="relu"),
        layers.Dropout(rate=0.2),
        layers.Dense(units=fre_df.shape[0], activation="softmax")
    ])

    model_xception.summary()

In [None]:
if config["transfer"]["xception"]:

    model_xception.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["categorical_accuracy"]
    )

    early_stopping = EarlyStopping(
        min_delta=0.001,
        patience=5,
        restore_best_weights=True
    )

    if not config["data_augmentation"]:
        print("Fit without aug")
        history_xception = model_xception.fit(
            ds_train,
            validation_data=ds_valid,
            epochs=20,
            workers=8,
            # callbacks=[early_stopping],
            # verbose=0
        )
    else:
        print("Fit with aug")
        history_xception = model_xception.fit_generator(
            train_aug,
            validation_data=val_aug,
            epochs=20,
            steps_per_epoch=100,
            workers=8
            # callbacks=[early_stopping]
            # verbose=0
        )

In [None]:
if config["transfer"]["xception"]:

    history_xception_df = pd.DataFrame(data=history_xception.history)
    visualize_history(history_xception_df)

***
### 5.2.1 Result without aug

<img src="plots/transfer_xception_without_aug.png" style="background-color:white">

***
### 5.2.2 Result with aug

<img src="plots/transfer_xception_with_aug.png" style="background-color:white">