In [None]:
# External package imports
import tensorflow as tf
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import json

In [None]:
tf.config.list_physical_devices('GPU')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Import the custom data loader
from data_loader import create_data_loader

# Import the VGG model creator
from vgg_initializer import initialize_vgg, initialize_vgg_3d, initialize_vgg_lstm

In [None]:
# Get a list of the training record files
tfrecord_file = "/content/drive/MyDrive/ratsi_data.tfrecord"
metadata_file = "/content/drive/MyDrive/ratsi_data.metadata.json"

In [None]:
class ResetStatesCallback(tf.keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        self.model.reset_states()

In [None]:
def train_model(name, model, batch_size, seq_size=1, lr=1e-3):
    print(name)

    # Initialize training and validation datasets
    dataset_train, dataset_valid = create_data_loader(
        tfrecord_file,
        metadata_file,
        valid_size=0.5,
        batch_size=batch_size,
        n_channels=3,
        seq_size=seq_size
    )

    with open(metadata_file, "r") as f:
        metadata = json.load(f)

    tf.keras.backend.clear_session()

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        metrics=["accuracy"]
    )

    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        mode="max",
        patience=4,
        verbose=0,
        restore_best_weights=True
    )

    reset_states = ResetStatesCallback()

    history = model.fit(
        x=dataset_train,
        epochs=40,
        validation_data=dataset_valid,
        callbacks=[early_stop, reset_states]
    )

    fig, axs = plt.subplots(ncols=2, figsize=(10,3))
    axs = axs.flatten()

    axs[0].plot(history.history["val_loss"], color="tab:red", label="Validation")
    axs[0].plot(history.history["loss"], color="tab:blue", label="Training")
    axs[0].legend()
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")

    axs[1].plot(history.history["val_accuracy"], color="tab:red", label="Validation")
    axs[1].plot(history.history["accuracy"], color="tab:blue", label="Training")
    axs[1].legend()
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    plt.tight_layout()

    trainable_params = np.sum([np.prod(v.get_shape()) for v in model.trainable_weights])
    nontrainable_params = np.sum([np.prod(v.get_shape()) for v in model.non_trainable_weights])
    total_params = trainable_params + nontrainable_params

    res = {
        "name": name,
        "history": history.history,
        "n_parameters": total_params
    }

    with open(f'drive/MyDrive/{name}_result.dict', 'wb') as f:
        pickle.dump(res, f)

In [None]:
with open(metadata_file, "r") as f:
    metadata = json.load(f)

models = [
    #{
    #    "name": "VGG11",
    #    "model": initialize_vgg(2, None, (*(metadata["img_size"][:2]), 3), dropout=0.01),
    #    "batch_size": 128
    #},
    #{
    #    "name": "VGG14",
    #    "model": initialize_vgg(3, None, (*(metadata["img_size"][:2]), 3), dropout=0.005),
    #    "batch_size": 128
    #},
    #{
    #    "name": "VGG17",
    #    "model": initialize_vgg(4, None, (*(metadata["img_size"][:2]), 3), dropout=0.005),
    #    "batch_size": 64
    #},
    #{
    #    "name": "VGG20",
    #    "model": initialize_vgg(5, None, (*(metadata["img_size"][:2]), 3), dropout=0.005),
    #    "batch_size": 64
    #},
    #{
    #    "name": "VGG11-3D",
    #    "model": initialize_vgg_3d(2, None, (*(metadata["img_size"][:2]), 3), seq_size=128, filter_reduction_fac=3, dropout=0.01),
    #    "batch_size": 1,
    #    "seq_size": 128
    #},
    #{
    #    "name": "VGG14-3D",
    #    "model": initialize_vgg_3d(3, None, (*(metadata["img_size"][:2]), 3), seq_size=128, filter_reduction_fac=3, dropout=0.01),
    #    "batch_size": 1,
    #    "seq_size": 128
    #},
    #{
    #    "name": "VGG17-3D",
    #    "model": initialize_vgg_3d(4, None, (*(metadata["img_size"][:2]), 3), seq_size=64, filter_reduction_fac=3, dropout=0.1),
    #    "batch_size": 1,
    #    "seq_size": 64
    #},
    #{
    #    "name": "VGG20-3D",
    #    "model": initialize_vgg_3d(5, None, (*(metadata["img_size"][:2]), 3), seq_size=128, filter_reduction_fac=3, dropout=0.01),
    #    "batch_size": 1,
    #    "seq_size": 128
    #},
    #{
    #    "name": "VGG11-LSTM",
    #    "model": initialize_vgg_lstm(2, 64, (*(metadata["img_size"][:2]), 3), seq_size=2, filter_reduction_fac=8, dropout=0.05),
    #    "batch_size": 64,
    #    "seq_size": 2,
    #},
    #{
    #    "name": "VGG14-LSTM",
    #    "model": initialize_vgg_lstm(3, 64, (*(metadata["img_size"][:2]), 3), seq_size=2, filter_reduction_fac=8, dropout=0.001),
    #    "batch_size": 64,
    #    "seq_size": 2,
    #},
    #{
    #    "name": "VGG17-LSTM",
    #    "model": initialize_vgg_lstm(4, 32, (*(metadata["img_size"][:2]), 3), seq_size=2, filter_reduction_fac=8, dropout=0.0001),
    #    "batch_size": 32,
    #    "seq_size": 2,
    #},
    #{
    #    "name": "VGG20-LSTM",
    #    "model": initialize_vgg_lstm(5, 32, (*(metadata["img_size"][:2]), 3), seq_size=2, filter_reduction_fac=8, dropout=0.05),
    #    "batch_size": 32,
    #    "seq_size": 2,
    #},
]
for model in models:
    train_model(**model)