# **Load Packages**

In [None]:
# Install tensorflow-addons to have access for some optimizers like AdamW
! pip install -U tensorflow-addons

In [None]:
import tensorflow as tf
import keras
from keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import regularizers
from tensorflow.keras.datasets import cifar100, cifar10
import tensorflow_datasets as tfds
from IPython.display import clear_output
from tensorflow.keras.utils import to_categorical
from scipy.stats import norm
import random
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from numpy.random import rand
from pylab import figure
import tensorflow_hub as hub
from sklearn.utils import shuffle
import cv2
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback
from IPython.display import Image, display
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn import datasets
from sklearn import manifold
from sklearn.decomposition import PCA
import pickle
from tensorflow.keras.metrics import SparseCategoricalAccuracy, SparseTopKCategoricalAccuracy, Precision, Recall
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
from keras.models import load_model
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score
from sklearn.metrics import precision_recall_fscore_support
# Import the context normalization layer
import os
import sys
package_dir = os.getcwd()
root_dir = os.path.dirname(package_dir)
sys.path.append(root_dir)
from normalization.layers import ContextNormalization


# **Define some functions**

In [None]:
# Estimate the mean and variance in a given image dataset along channels axis
def compute_mean_std(dataset):
    data_r = np.dstack([dataset[i][:, :, 0] for i in range(len(dataset))])
    data_g = np.dstack([dataset[i][:, :, 1] for i in range(len(dataset))])
    data_b = np.dstack([dataset[i][:, :, 2] for i in range(len(dataset))])
    mean = np.mean(data_r), np.mean(data_g), np.mean(data_b)
    std = np.std(data_r), np.std(data_g), np.std(data_b)
    return mean, std

In [None]:
# Define a Data Augmentation Layer for CIFAR-10 and CIFAR-100
data_augmentation = tf.keras.Sequential(
    [
        tf.keras.layers.Resizing(72, 72),
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(factor=0.02),
        tf.keras.layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
  )

In [None]:
# Define functions to save lists and load files
def write_list(a_list, file_name):
    with open(file_name, 'wb') as fp:
        pickle.dump(a_list, fp)
        print('Done writing list into a binary file')

def read_list(file_name):
    with open(file_name, 'rb') as fp:
        n_list = pickle.load(fp)
        return n_list

# **CIFAR-10**

In [None]:
# Define constant parameters
class CFG:
    batch_size = 256
    learning_rate = 0.001
    weight_decay = 1e-4
    num_classes = 10
    num_contexts = 3
    num_epochs=100

In [None]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Encod labels with one-hot representation
y_train_sparse = to_categorical(y_train, num_classes=CFG.num_classes)
y_test_sparse = to_categorical(y_test, num_classes=CFG.num_classes)

# Standardize the dataset using the defined function computer_mean_std
mean, std = compute_mean_std(x_train)
x_train = x_train.astype('float32')
x_train = (x_train - mean) / std
x_test = x_test.astype('float32')
x_test = (x_test - mean) / std

In [None]:
# Load labels generated by the GMM algorithms
gmm_train_labels = read_list("gmm/gmm_cifar10_tr_labels")
gmm_test_labels = read_list("gmm/gmm_cifar10_ts_labels")

context_train = [[0]*CFG.num_contexts for _ in range(len(gmm_train_labels))]
context_test = [[0]*CFG.num_contexts for _ in range(len(gmm_test_labels))]

for iter, label in enumerate(gmm_train_labels):
    context_train[iter][label] = 1

for iter, label in enumerate(gmm_test_labels):
    context_test[iter][label] = 1

context_train = np.array(context_train)
context_test = np.array(context_test)

In [None]:
# Build the ConvNet model
def build_cnn(num_classes=CFG.num_classes, num_contexts=CFG.num_contexts, learning_rate=CFG.learning_rate, weight_decay=CFG.weight_decay):
    input_image = tf.keras.layers.Input(shape=(32,32,3))
    context_id = tf.keras.layers.Input(shape=(num_contexts,), dtype='int32')
    augmented = data_augmentation(input_image)

    conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5,5), strides=(1,1),  padding="same")(augmented)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)
    conv1 = tf.keras.layers.ReLU()(conv1)

    pool1 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv1)

    conv2 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool1)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)
    conv2 = tf.keras.layers.ReLU()(conv2)

    pool2 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv2)

    conv3 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool2)
    conv3 = ContextNormalization()([conv3, context_id])
    conv3 = tf.keras.layers.ReLU()(conv3)

    pool3 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2, 2), padding="same")(conv3)

    conv4 = tf.keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1),  padding="same")(pool3)
    conv4 = tf.keras.layers.BatchNormalization()(conv4)
    conv4 = tf.keras.layers.ReLU()(conv4)

    pool4 = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(conv4)
    flattened = tf.keras.layers.Flatten()(pool4)
    outputs = tf.keras.layers.Dense(num_classes)(flattened)

    model = tf.keras.models.Model([input_image, context_id], outputs)

    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
            Precision(name="precision"),
            Recall(name="recall"),
            tfa.metrics.F1Score(num_classes=num_classes, name="f1-score")

        ],
    )
    return model

In [None]:
# Train our model and validate it per 25 epochs
class CustomValidationCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % 25 == 0:
            self.model.load_weights("model.ckpt")
            _, accuracy, top_5_accuracy, precision, recall, f1 = self.model.evaluate([x_test, context_test], y_test_sparse)
            print(f"Test accuracy at epoch {epoch + 1}: {round(accuracy * 100, 2)}%")
            print(f"Test top 5 accuracy at epoch {epoch + 1}: {round(top_5_accuracy * 100, 2)}%")
            print(f"Precision at epoch {epoch + 1}: {round(precision * 100, 2)}%")
            print(f"Recall at epoch {epoch + 1}: {round(recall * 100, 2)}%")
            print(f"F1-score at epoch {epoch + 1}: {f1}%")

def run_model(model, filepath, batch_size=CFG.batch_size, num_epochs=CFG.num_epochs):
    checkpoint_filepath = filepath
    checkpoint_callback = ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=[x_train, context_train],
        y=y_train_sparse,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback, CustomValidationCallback()],
    )

    return history


with tf.device('/device:GPU:0'):
  model = build_cnn()
  history = run_model(model, "model.ckpt")

In [None]:
# Save metrics list on files
write_list(history.history['accuracy'], 'accuracy')
write_list(history.history['val_accuracy'], 'val_accuracy')
write_list(history.history['loss'], 'loss')
write_list(history.history['val_loss'], 'val_loss')
write_list(history.history['precision'], 'precision')
write_list(history.history['val_precision'], 'val_precision')
write_list(history.history['recall'], 'recall')
write_list(history.history['val_recall'], 'val_recall')
write_list(history.history['val_f1-score'], 'val_f1')
write_list(history.history['f1-score'], 'f1')

In [None]:
# See model generalization (loss and validation loss)
loss = history.history['loss']
validation_loss = history.history['val_loss']
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, validation_loss, 'r', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

# **CIFAR-100**

In [None]:
# Define constant parameters
class CFG:
    batch_size = 256
    learning_rate = 0.001
    weight_decay = 1e-4
    num_classes = 100
    num_contexts = 3
    num_epochs=100

In [None]:
# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = cifar100.load_data()

# Encod labels with one-hot representation
y_train_sparse = to_categorical(y_train, num_classes=CFG.num_classes)
y_test_sparse = to_categorical(y_test, num_classes=CFG.num_classes)

# Standardize the dataset using the defined function computer_mean_std
mean, std = compute_mean_std(x_train)
x_train = x_train.astype('float32')
x_train = (x_train - mean) / std
x_test = x_test.astype('float32')
x_test = (x_test - mean) / std

In [None]:
# Load labels generated by the GMM algorithms
gmm_train_labels = read_list("gmm/gmm_cifar100_tr_labels")
gmm_test_labels = read_list("gmm/gmm_cifar100_ts_labels")

context_train = [[0]*CFG.num_contexts for _ in range(len(gmm_train_labels))]
context_test = [[0]*CFG.num_contexts for _ in range(len(gmm_test_labels))]

for iter, label in enumerate(gmm_train_labels):
    context_train[iter][label] = 1

for iter, label in enumerate(gmm_test_labels):
    context_test[iter][label] = 1

context_train = np.array(context_train)
context_test = np.array(context_test)

In [None]:
# Build the ConvNet model
def build_cnn(num_classes=CFG.num_classes, num_contexts=CFG.num_contexts, learning_rate=CFG.learning_rate, weight_decay=CFG.weight_decay):
    input_image = tf.keras.layers.Input(shape=(32,32,3))
    context_id = tf.keras.layers.Input(shape=(num_contexts,), dtype='int32')
    augmented = data_augmentation(input_image)

    conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5,5), strides=(1,1),  padding="same")(augmented)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)
    conv1 = tf.keras.layers.ReLU()(conv1)

    pool1 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv1)

    conv2 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool1)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)
    conv2 = tf.keras.layers.ReLU()(conv2)

    pool2 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv2)

    conv3 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool2)
    conv3 = ContextNormalization()([conv3, context_id])
    conv3 = tf.keras.layers.ReLU()(conv3)

    pool3 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2, 2), padding="same")(conv3)

    conv4 = tf.keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1),  padding="same")(pool3)
    conv4 = tf.keras.layers.BatchNormalization()(conv4)
    conv4 = tf.keras.layers.ReLU()(conv4)

    pool4 = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(conv4)
    flattened = tf.keras.layers.Flatten()(pool4)
    outputs = tf.keras.layers.Dense(num_classes)(flattened)

    model = tf.keras.models.Model([input_image, context_id], outputs)

    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
            Precision(name="precision"),
            Recall(name="recall"),
            tfa.metrics.F1Score(num_classes=num_classes, name="f1-score")

        ],
    )
    return model

In [None]:
# Train our model and validate it per 25 epochs
class CustomValidationCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % 25 == 0:
            self.model.load_weights("model.ckpt")
            _, accuracy, top_5_accuracy, precision, recall, f1 = self.model.evaluate([x_test, context_test], y_test_sparse)
            print(f"Test accuracy at epoch {epoch + 1}: {round(accuracy * 100, 2)}%")
            print(f"Test top 5 accuracy at epoch {epoch + 1}: {round(top_5_accuracy * 100, 2)}%")
            print(f"Precision at epoch {epoch + 1}: {round(precision * 100, 2)}%")
            print(f"Recall at epoch {epoch + 1}: {round(recall * 100, 2)}%")
            print(f"F1-score at epoch {epoch + 1}: {f1}%")

def run_model(model, filepath, batch_size=CFG.batch_size, num_epochs=CFG.num_epochs):
    checkpoint_filepath = filepath
    checkpoint_callback = ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=[x_train, context_train],
        y=y_train_sparse,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback, CustomValidationCallback()],
    )

    return history


with tf.device('/device:GPU:0'):
  model = build_cnn()
  history = run_model(model, "model.ckpt")

In [None]:
# Save metrics list on files
write_list(history.history['accuracy'], 'accuracy')
write_list(history.history['val_accuracy'], 'val_accuracy')
write_list(history.history['loss'], 'loss')
write_list(history.history['val_loss'], 'val_loss')
write_list(history.history['precision'], 'precision')
write_list(history.history['val_precision'], 'val_precision')
write_list(history.history['recall'], 'recall')
write_list(history.history['val_recall'], 'val_recall')
write_list(history.history['val_f1-score'], 'val_f1')
write_list(history.history['f1-score'], 'f1')

In [None]:
# See model generalization (loss and validation loss)
loss = history.history['loss']
validation_loss = history.history['val_loss']
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, validation_loss, 'r', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

# **Tiny ImageNet**

In [None]:
# Clone ImageNet dataset
! git clone https://github.com/seshuad/IMagenet
! ls 'IMagenet/tiny-imagenet-200/'

In [None]:
# Define all parameters
class CFG:
    projection_dims = 256
    batch_size = 256
    epochs = 100
    learning_rate = 0.001
    weight_decay = 1e-4
    num_classes = 200
    num_contexts = 3
    num_epochs = 100
    image_size = 64

In [None]:
# Load dataset and split into train and test
path = 'IMagenet/tiny-imagenet-200/'

def get_id_dictionary():
    id_dict = {}
    for i, line in enumerate(open( path + 'wnids.txt', 'r')):
        id_dict[line.replace('\n', '')] = i
    return id_dict

def get_class_to_id_dict():
    id_dict = get_id_dictionary()
    all_classes = {}
    result = {}
    for i, line in enumerate(open( path + 'words.txt', 'r')):
        n_id, word = line.split('\t')[:2]
        all_classes[n_id] = word
    for key, value in id_dict.items():
        result[value] = (key, all_classes[key])
    return result

def get_data(id_dict):
    print('starting loading data')
    train_data, test_data = [], []
    train_labels, test_labels = [], []
    t = time.time()

    for key, value in id_dict.items():
        train_data += [nd.imread( path + 'train/{}/images/{}_{}.JPEG'.format(key, key, str(i)), pilmode='RGB') for i in range(500)]
        train_labels_ = np.array([[0]*200]*500)
        train_labels_[:, value] = 1
        train_labels += train_labels_.tolist()

    for line in open( path + 'val/val_annotations.txt'):
        img_name, class_id = line.split('\t')[:2]
        test_data.append(nd.imread( path + 'val/images/{}'.format(img_name) ,pilmode='RGB'))
        test_labels_ = np.array([[0]*200])
        test_labels_[0, id_dict[class_id]] = 1
        test_labels += test_labels_.tolist()

    print('finished loading data, in {} seconds'.format(time.time() - t))
    return np.array(train_data), np.array(train_labels), np.array(test_data), np.array(test_labels)

train_data, train_labels, test_data, test_labels = get_data(get_id_dictionary())
print( "train data shape: ",  train_data.shape )
print( "train label shape: ", train_labels.shape )
print( "test data shape: ",   test_data.shape )
print( "test_labels.shape: ", test_labels.shape )

In [None]:
# Create context using GMM predictions
component_train = read_list("gmm_imagenet_train")
component_test = read_list("gmm_imagenet_test")
context_train = [[0]*CFG.num_contexts for _ in range(len(component_train))]
context_test = [[0]*CFG.num_contexts for _ in range(len(component_test))]
context_train = np.array(context_train)
context_test = np.array(context_test)

In [None]:
# Create our data generator
class DataGenerator(keras.utils.Sequence):
    """ Helper to iterate over the data (as Numpy arrays). """

    def __init__(self, batch_size, img_size, input_img, target_img, context_img, mean, std, number_classes, num_contexts):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img = input_img
        self.target_img = target_img
        self.context_img = context_img
        self.mean = mean
        self.number_classes = number_classes
        self.std = std
        self.num_contexts = num_contexts


    def __len__(self):
        return len(self.target_img) // self.batch_size

    def __getitem__(self, idx):
        """ Returns tuple (input, target) correspond to batch #idx. """
        i = idx * self.batch_size
        batch_input_img = self.input_img[i : i + self.batch_size]
        batch_target_img = self.target_img[i : i + self.batch_size]
        batch_context_img = self.context_img[i:i+self.batch_size]


        # images
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        for j, image in enumerate(batch_input_img):
          x[j] = image
        x = x/255.0
        x = (x - self.mean)/self.std

        # labels
        y = np.zeros((self.batch_size, self.number_classes))
        for j, target in enumerate(batch_target_img):
          y[j] = target

        # Context
        contexts = np.zeros((self.batch_size,)+(self.num_contexts, ), dtype="int32")
        for j, context in enumerate(batch_context_img):
          contexts[j] = context

        return (x, contexts), y

idx = np.random.permutation(len(train_data))
x_train, y_train = train_data[idx], train_labels[idx]
context_train = context_train[idx]
idx = np.random.permutation(len(test_data))
x_test, y_test = test_data[idx], test_labels[idx]
context_test = context_test[idx]
mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
train_data, train_labels, test_data, test_labels
x_val = x_train[92000:]
y_val = y_train[92000:]
context_val = context_train[92000:]
train_generator = DataGenerator(CFG.batch_size, (CFG.image_size, CFG.image_size), x_train[:92000], y_train[:92000], context_train[:92000], mean, std, CFG.num_classes, CFG.num_contexts)
test_generator = DataGenerator(CFG.batch_size, (CFG.image_size, CFG.image_size), x_test, y_test, context_test, mean, std, CFG.num_classes, CFG.num_contexts)
val_generator = DataGenerator(CFG.batch_size, (CFG.image_size, CFG.image_size), x_val, y_val, context_val, mean, std, CFG.num_classes, CFG.num_contexts)

In [None]:
# Build our model
def build_cnn(num_classes=CFG.num_classes, num_contexts=CFG.num_contexts, learning_rate=CFG.learning_rate, weight_decay=CFG.weight_decay):
    input_image = tf.keras.layers.Input(shape=(64,64,3))
    context_id = tf.keras.layers.Input(shape=(num_contexts,))
    augmented = data_augmentation(input_image)

    conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(5,5), strides=(1,1),  padding="same")(augmented)
    conv1 = tf.keras.layers.BatchNormalization()(conv1)
    conv1 = tf.keras.layers.ReLU()(conv1)

    pool1 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv1)

    conv2 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool1)
    conv2 = tf.keras.layers.BatchNormalization()(conv2)
    conv2 = tf.keras.layers.ReLU()(conv2)

    pool2 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2,2), padding="same")(conv2)

    conv3 = tf.keras.layers.Conv2D(filters=128, kernel_size=(5,5), strides=(1,1),  padding="same")(pool2)
    conv3 = ContextNormalization()([input_image, context_id])
    conv3 = tf.keras.layers.ReLU()(conv3)

    pool3 = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=(2, 2), padding="same")(conv3)

    conv4 = tf.keras.layers.Conv2D(filters=256, kernel_size=(5, 5), strides=(1, 1),  padding="same")(pool3)
    conv4 = tf.keras.layers.BatchNormalization()(conv4)
    conv4 = tf.keras.layers.ReLU()(conv4)

    pool4 = tf.keras.layers.AveragePooling2D(pool_size=(4, 4), strides=(1, 1))(conv4)
    flattened = tf.keras.layers.Flatten()(pool4)
    outputs = tf.keras.layers.Dense(num_classes, name="logits")(flattened)

    model = tf.keras.models.Model([input_image, context_id], outputs)

    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[
            tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.TopKCategoricalAccuracy(5, name="top-5-accuracy"),
            Precision(name="precision"),
            Recall(name="recall"),
            tfa.metrics.F1Score(num_classes=num_classes, name="f1-score")

        ],
    )
    return model

In [None]:
# Train our model with  context normalization
class CustomValidationCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % 25 == 0:
            self.model.load_weights("model.ckpt")
            _, accuracy, top_5_accuracy, precision, recall, f1 = self.model.evaluate(test_generator)
            print(f"Test accuracy at epoch {epoch + 1}: {round(accuracy * 100, 2)}%")
            print(f"Test top 5 accuracy at epoch {epoch + 1}: {round(top_5_accuracy * 100, 2)}%")
            print(f"Precision at epoch {epoch + 1}: {round(precision * 100, 2)}%")
            print(f"Recall at epoch {epoch + 1}: {round(recall * 100, 2)}%")
            print(f"F1-score at epoch {epoch + 1}: {f1}%")

def run_model(model, filepath, batch_size=CFG.batch_size, num_epochs=CFG.num_epochs):
    checkpoint_filepath = filepath
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        train_generator,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_data = val_generator,
        callbacks=[checkpoint_callback, CustomValidationCallback()],
    )

    return history


with tf.device('/device:GPU:0'):
  model = build_cnn()
  history = run_model(model, "model.ckpt")

In [None]:
# Save metrics
write_list(history.history['accuracy'], 'accuracy')
write_list(history.history['val_accuracy'], 'val_accuracy')
write_list(history.history['loss'], 'loss')
write_list(history.history['val_loss'], 'val_loss')
write_list(history.history['precision'], 'precision')
write_list(history.history['val_precision'], 'val_precision')
write_list(history.history['recall'], 'recall')
write_list(history.history['val_recall'], 'val_recall')
write_list(history.history['val_f1-score'], 'val_f1')
write_list(history.history['f1-score'], 'f1')

In [None]:
# See model generalization (loss and validation loss)
loss = history.history['loss']
validation_loss = history.history['val_loss']
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, validation_loss, 'r', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()