# 0 Imports

In [None]:
import os, warnings
import shutil
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing import image_dataset_from_directory

def set_seed(seed=31415):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
set_seed()

plt.rc('figure', autolayout=True)
plt.rc('axes', labelweight='bold', labelsize='large',
       titleweight='bold', titlesize=18, titlepad=10)
plt.rc('image', cmap='magma')
warnings.filterwarnings("ignore")

print("-----------------------------------------")
if tf.test.gpu_device_name():
    print(f"GPU used: {tf.test.gpu_device_name()}")
else:
    print(f"GPU not used")
print("-----------------------------------------")

from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

***
# 1 Config

In [None]:
config = {
    "exploration": False,
    "custom": True,
    "transfer": False
}

In [None]:
annotations_dir = "data/Annotations"
images_dir = "data/Images"
target_dir = "data/Targets"
batch_size = 16

***
# 2 Dataset exploration

## 2.0 Utils

In [None]:
def convert_to_float(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.float32)
    return image, label

In [None]:
def breeds_distribution(dataset, figsize=(25, 5), display_mean=True):
    plt.figure(figsize=figsize)

    sns.barplot(x=fre_df["Breed"], y=fre_df["Population"])

    if display_mean:
        mean = fre_df.mean(axis=0)[0]
        plt.axhline(mean, linestyle="--", linewidth=1, color="r")

    plt.title("Breeds distribution", size=20)
    plt.xticks(size=10, rotation=45, ha="right")
    plt.yticks(size=10)
    plt.ylabel("Count", size=16)

    if display_mean:
        return mean

***
## 2.1 Breeds distribution

How many breeds are present in the dataset ?

In [None]:
breed_subdir = os.listdir(images_dir)
print(f"There is {len(breed_subdir)} different dog breeds")

Construct a pandas dataset with all breeds and each of their respective population

In [None]:
breed_full: list = []
breed: list = []
population: list = []

for sub_d in breed_subdir:
    breed_full.append(sub_d)
    breed.append(sub_d.split("-")[1])
    population.append(len([obs for obs in os.listdir(images_dir + "/" + sub_d)]))

fre_df = pd.DataFrame(data={"Breed_full": breed_full, "Breed": breed, "Population": population})
fre_df.head()

Plot the breeds populations and the mean

In [None]:
if config["exploration"]:
    mean = breeds_distribution(fre_df)

Let's plot again bu after a sort on Population field

In [None]:
if config["exploration"]:
    fre_df.sort_values(by="Population", inplace=True, ascending=False, axis=0)
    mean = breeds_distribution(fre_df)

Let's take a look at the top three breeds, see if they are sufficiently differents.

In [None]:
if config["exploration"]:

    fix, axs = plt.subplots(3, 3, figsize=(20, 10))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)

    # Malteses
    plt.subplot(3, 3, 1)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_37.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 2)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_66.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 3)
    image = tf.io.read_file("data/Images/n02085936-Maltese_dog/n02085936_233.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    # Afghan
    plt.subplot(3, 3, 4)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_231.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 5)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_251.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 6)
    image = tf.io.read_file("data/Images/n02088094-Afghan_hound/n02088094_272.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    # Scottish deerhound
    plt.subplot(3, 3, 7)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_3.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 8)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_198.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

    plt.subplot(3, 3, 9)
    image = tf.io.read_file("data/Images/n02092002-Scottish_deerhound/n02092002_86.jpg")
    image = tf.io.decode_jpeg(image, channels=3)
    img = tf.squeeze(image).numpy()
    plt.imshow(img)
    plt.axis('off')

We can see that the images don't have the same size

In [None]:
fre_df = fre_df.iloc[:10,:]
fre_df

***
# 3 Custom CNN

## 3.0 Utils

In [None]:
def sync_dataset_directory(breeds: list, source_dir: str = "data/Images", target_dir: str = "data/Targets"):
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.mkdir(target_dir)
    for breed in breeds:
        source = source_dir + "/" + breed
        target = target_dir + "/" + breed
        shutil.copytree(source, target)

***
## 3.1 Dataset preparation

In [None]:
if config["custom"]:

    sync_dataset_directory(fre_df["Breed_full"])

    ds_train_ = image_dataset_from_directory(
        target_dir,
        labels="inferred",
        label_mode="categorical",
        image_size=[224, 224],
        interpolation="nearest",
        batch_size=batch_size,
        seed=0,
        shuffle=True,
        validation_split=0.8,
        subset="training"
    )

    ds_valid_ = image_dataset_from_directory(
        target_dir,
        labels="inferred",
        label_mode="categorical",
        image_size=[224, 224],
        interpolation="nearest",
        batch_size=batch_size,
        seed=0,
        shuffle=True,
        validation_split=0.2,
        subset="validation"
    )

    AUTOTUNE = tf.data.experimental.AUTOTUNE
    ds_train = (
        ds_train_
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size=AUTOTUNE)
    )
    ds_valid = (
        ds_valid_
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size=AUTOTUNE)
    )

***
## 3.2 Neural Network

In [None]:
if config["custom"]:

    model = keras.Sequential([
        # >>> preprocessing <<<

        # >>> base <<<
        layers.Conv2D(
            input_shape=[224, 224, 3],
            filters=32,
            kernel_size=3,
            activation="relu",
            padding="same"
            ),
        layers.MaxPool2D(),
        
        # >>> head <<<
        layers.Flatten(),
        # layers.Dropout(rate=0.3),
        # layers.BatchNormalization(),
        layers.Dense(units=6, activation="relu"),
        layers.Dense(units=fre_df.shape[0], activation="softmax")
    ])

    model.summary()

In [None]:
if config["custom"]:
    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["categorical_accuracy"]
    )

    early_stopping = EarlyStopping(
        min_delta=0.001,
        patience=5,
        restore_best_weights=True
    )

    history = model.fit(
        ds_train,
        validation_data=ds_valid,
        epochs=20,
        # callbacks=[early_stopping]
        # verbose=0
    )

***
## 3.3 Results

In [None]:
if config["custom"]:
    history_frame = pd.DataFrame(history.history)
    history_frame.loc[:, ("loss", "val_loss")].plot()
    history_frame.loc[:, ("categorical_accuracy", "val_categorical_accuracy")].plot()

***
# 4 Transfer learning CNN

In [None]:
if config["transfer"]:

    pass