In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Sequential
import numpy as np
import tqdm
import itertools
import random
import uuid
import json
import os

In [None]:
BATCH_SIZE = 32
IMG_HEIGHT = 128
IMG_WIDTH = 128
NUM_CLASSES = 716
RANDOM_SEED = 0

DATA_PATH = "../../data/derived_data/data_augmented/"
RESULTS_PATH = "./hps_search_results/"

In [None]:
random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [None]:
hps_grid = dict(
    conv2d_1_filters = [8, 16, 32, 64],
    conv2d_1_kernel = [3, 5],
    conv2d_2_filters = [8, 16, 32, 64],
    conv2d_2_kernel = [3, 5],
    conv2d_3_filters = [8, 16, 32, 64],
    conv2d_3_kernel = [3, 5],
    dense_1_units = [64, 128, 256],
    pooling = [layers.MaxPooling2D, layers.AveragePooling2D],
    optimizer = [
        tf.keras.optimizers.Adam,
        tf.keras.optimizers.RMSprop,
        tf.keras.optimizers.SGD
    ],
    learning_rate = [0.01, 0.001, 0.0001]
)

In [None]:
params = list(itertools.product(*hps_grid.values()))
hps_combs = [dict(zip(hps_grid.keys(),params_sample)) for params_sample in params]

In [None]:
random.shuffle(hps_combs)

In [None]:
def define_model(hps_comb):

    pooling = hps_comb['pooling']

    model = Sequential([
        layers.Rescaling(1./255, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
        layers.Conv2D(hps_comb['conv2d_1_filters'], hps_comb['conv2d_1_kernel'], padding='same', activation='relu'),
        pooling(),
        layers.Conv2D(hps_comb['conv2d_2_filters'], hps_comb['conv2d_2_kernel'], padding='same', activation='relu'),
        pooling(),
        layers.Conv2D(hps_comb['conv2d_3_filters'], hps_comb['conv2d_3_kernel'], padding='same', activation='relu'),
        pooling(),
        layers.Flatten(),
        layers.Dense(hps_comb['dense_1_units'], activation='relu'),
        layers.Dense(NUM_CLASSES)
    ])

    model.compile(
        optimizer=hps_comb['optimizer'](learning_rate=hps_comb['learning_rate']),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
    )

    return model


def train_model(
        model,
        train_ds,
        val_ds,
        test_ds,
        checkpoint_filepath=None, 
        epochs=1000
    ):

    # Define callbacks
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
    )

    
    # model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    #     filepath=checkpoint_filepath,
    #     save_weights_only=True,
    #     monitor='val_accuracy',
    #     save_best_only=True
    # )
    
    # Train model
    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=epochs,
        callbacks=[
            early_stopping,
            # model_checkpoint,
        ],
        verbose='auto',
    )
    
    results = history.history

    test_loss, test_acc = model.evaluate(test_ds)

    results['test_loss'] = test_loss
    results['test_acc'] = test_acc
    
    return results


def run_hps_search(train_ds, val_ds, test_ds, output_path):
    hps_search_id = uuid.uuid4()

    output_path = f'{output_path}/{hps_search_id}/'
    os.makedirs(output_path, exist_ok=True)

    for hps_comb in tqdm.tqdm(hps_combs):
        hps_comb_id = uuid.uuid4()

        model = define_model(hps_comb)

        results = train_model(
            model, 
            train_ds=train_ds,
            val_ds=val_ds,
            test_ds=test_ds,
        )

        for k,v in hps_comb:
            results[k] = v

        with open(f'{output_path}/{hps_comb_id}.json', 'w') as f:
            json.dump(results, f)


In [None]:
train_ds = tf.keras.utils.image_dataset_from_directory(
    DATA_PATH,
    labels="inferred",
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    validation_split=0.3,
    subset="training",
    image_size=(IMG_WIDTH, IMG_HEIGHT),
    shuffle=True,
    seed=0,
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    DATA_PATH,
    labels="inferred",
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    validation_split=0.3,
    subset="validation",
    image_size=(IMG_WIDTH, IMG_HEIGHT),
    shuffle=True,
    seed=0,
)

val_batches = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batches // 2)
val_ds = val_ds.skip(val_batches // 2)

In [None]:
run_hps_search(train_ds, val_ds, test_ds, RESULTS_PATH)