In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential
import numpy as np
import tqdm
import itertools
import random
import uuid
import json
import os

In [None]:
# Check if GPU available
tf.config.list_physical_devices('GPU')

In [None]:
BATCH_SIZE = 32
IMG_HEIGHT = 128
IMG_WIDTH = 128
NUM_CLASSES = 716
RANDOM_SEED = 0

DATA_PATH = "../data/derived_data/data_augmented/"
RESULTS_PATH = "./hps_search_results/"

In [None]:
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [None]:
hps_grid = dict(
    conv2d_1_filters = [16],
    conv2d_1_kernel = [3],
    conv2d_2_filters = [32],
    conv2d_2_kernel = [3],
    conv2d_3_filters = [64],
    conv2d_3_kernel = [3],
    dense_1_units = [128],
    pooling = [layers.MaxPooling2D],
    optimizer = [
        tf.keras.optimizers.Adam,
    ],
    learning_rate = [0.001]
)

In [None]:
params = list(itertools.product(*hps_grid.values()))
hps_combs = [dict(zip(hps_grid.keys(),params_sample)) for params_sample in params]

In [None]:
random.shuffle(hps_combs)

In [None]:
def define_model(hps_comb):

    model = Sequential([
        layers.Rescaling(1./255, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        layers.Conv2D(hps_comb['conv2d_1_filters'], hps_comb['conv2d_1_kernel'], padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(hps_comb['conv2d_2_filters'], hps_comb['conv2d_2_kernel'], padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(hps_comb['conv2d_3_filters'], hps_comb['conv2d_3_kernel'], padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(hps_comb['dense_1_units'], activation='relu'),
        layers.Dense(NUM_CLASSES)
    ])

    model.compile(
        optimizer=hps_comb['optimizer'](learning_rate=hps_comb['learning_rate']),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
    )

    return model


def train_model(
        model,
        train_ds,
        val_ds,
        test_ds,
        checkpoint_filepath=None,
        save_model_path=None,
        epochs=15
    ):

    # Define callbacks
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
    )

    
    # model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    #     filepath=checkpoint_filepath,
    #     save_weights_only=True,
    #     monitor='val_accuracy',
    #     save_best_only=True
    # )
    
    # Train model
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[
            # early_stopping,
            # model_checkpoint,
        ],
        verbose='auto',
    )
    
    if save_model_path:
        model.save(save_model_path)
    
    results = history.history

    test_loss, test_acc = model.evaluate(test_ds)

    results['test_loss'] = test_loss
    results['test_acc'] = test_acc
    
    return results


def run_hps_search(train_ds, val_ds, test_ds, output_path):
    hps_search_id = uuid.uuid4()

    output_path = f'{output_path}/{hps_search_id}/'
    os.makedirs(output_path, exist_ok=True)

    for hps_comb in tqdm.tqdm(hps_combs):
        print(hps_comb)
        hps_comb_id = uuid.uuid4()

        model = define_model(hps_comb)

        results = train_model(
            model, 
            train_ds=train_ds,
            val_ds=val_ds,
            test_ds=test_ds,
            save_model_path=f'{output_path}/{hps_comb_id}.keras'
        )

        for k,v in hps_comb.items():
            results[k] = str(v)

        with open(f'{output_path}/{hps_comb_id}.json', 'w') as f:
            json.dump(results, f)


In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    DATA_PATH,
    labels="inferred",
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    validation_split=0.3,
    subset="training",
    image_size=(IMG_WIDTH, IMG_HEIGHT),
    shuffle=True,
    seed=0,
    interpolation='nearest'
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    DATA_PATH,
    labels="inferred",
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    validation_split=0.3,
    subset="validation",
    image_size=(IMG_WIDTH, IMG_HEIGHT),
    shuffle=True,
    seed=0,
    interpolation='nearest'
)

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 2)
val_ds = val_ds.skip(val_batches // 2)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

In [None]:
run_hps_search(train_ds, val_ds, test_ds, RESULTS_PATH)

# Analyse predictions

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [None]:
# Load model
model = tf.keras.models.load_model('hps_search_results/2da1ec8e-7a2f-4150-80bc-c4745fdb8c86/a01928f8-7a00-4e7a-a249-c0edacefecf4.keras')

In [None]:
# Eval model for sanity checking
model.evaluate(test_ds)

In [None]:
# Compute model predictions on test set
y_pred = model.predict(test_ds)
y_pred.shape

In [None]:
# Model outputs (logits) have to be passed through the softmax function to normalize them to a probability distribution
y_pred = tf.nn.softmax(y_pred)
y_pred.shape

In [None]:
# The class assigned to each sample is the position of the max value
y_pred = np.argmax(y_pred, axis=1)
y_pred.shape

In [None]:
# Get classes from the Dataset object
y_true = [y for x, y in test_ds]
y_true = (np.array(y_true)).flatten()
y_true.shape

In [None]:
# Computing Confusion Matrix to evaluate accuracy of classification
c_m = confusion_matrix(y_true, y_pred)

In [None]:
labels = train_ds.class_names

In [None]:
# Setting default size of the plot
# Setting default fontsize used in the plot
plt.rcParams['figure.figsize'] = (35.0, 35.0)
plt.rcParams['font.size'] = 20

# Implementing visualization of Confusion Matrix
display_c_m = ConfusionMatrixDisplay(c_m[:50,:50], display_labels=labels[:50])
display_c_m.plot(cmap='OrRd', xticks_rotation=90)
plt.show()