# Training a ShallowNet model

__Objective:__ train a ShallowNet CNN to classify image data.

We experiment with:
- Loading data into tensors.
- Loading data into Tensorflow `Dataset`s.
- Defining the model via Keras' `Sequential` API.
- Defining the model as a Keras `Layer` object using the functional API and then putting it into a `Model` object.

## Load data

In [None]:
training_data_dir = '../data/Dataset/'
test_data_dir = '../data/Dataset_test/'

In [None]:
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array

In [None]:
def load_images_from_dir(dir_path, color_mode='grayscale', target_size=(32, 32)):
    """
    Given the path to a directory containing images, loads the images
    in a tensor using the provided `color_mode` and `target_size` options.
    """
    image_tensors = tf.concat(
        [
            img_to_array(load_img(
                os.path.join(dir_path, image_name),
                color_mode=color_mode,
                target_size=target_size,
                keep_aspect_ratio=False if target_size is None else True
            ))[tf.newaxis]
            for image_name in os.listdir(dir_path)
        ],
        axis=0
    )

    return image_tensors


def load_image_dataset(dataset_dir, color_mode='grayscale', target_size=(32, 32), shuffle=True):
    """
    Loads a dataset of images for the given directory. The data directory is
    assumed to be strctured in sub-directories, each named as a numeric class
    label and containing all the images belonging to the corresponding class.
    """
    classes = sorted([
        int(c) for c in os.listdir(dataset_dir)
        if c.isnumeric()
    ])

    x = []
    y = []

    for c in classes:
        x.append(load_images_from_dir(
            os.path.join(dataset_dir, f'{c}/'),
            color_mode=color_mode,
            target_size=target_size
        ))

        y.append(c * tf.ones(shape=x[-1].shape[0]))

    x = tf.concat(x, axis=0)
    y = tf.concat(y, axis=0)

    # Normalize pixel values.
    pixel_normalization = tf.reduce_max(x)

    x /= pixel_normalization

    # Shuffle data if required.
    if shuffle:
        shuffled_indices = tf.random.shuffle(tf.range(x.shape[0]))

        x = tf.gather(
            x,
            shuffled_indices,
            axis=0
        )
        
        y = tf.gather(
            y,
            shuffled_indices,
            axis=0
        )

    return x, y

Load training and test data.

In [None]:
x_train, y_train = load_image_dataset(training_data_dir)
x_test, y_test = load_image_dataset(test_data_dir)

Print some random images.

In [None]:
random_indices = tf.random.uniform(shape=(3,), minval=0, maxval=x_train.shape[0], dtype=tf.int32)

random_images = tf.gather(x_train, random_indices)
random_images_classes = tf.gather(y_train, random_indices)

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(14, 6))

for i, ax in enumerate(axs):
    axs[i].imshow(random_images[i, ...].numpy(), cmap='gray')

    plt.sca(axs[i])
    plt.title(f'Class: {int(random_images_classes[i])}')

### Alternative data loading: Tensorflow `Dataset`s

In [None]:
from tensorflow.keras.utils import image_dataset_from_directory

In [None]:
training_dataset = image_dataset_from_directory(
    training_data_dir,
    labels="inferred",
    label_mode="int",
    color_mode="grayscale",
    batch_size=32,
    image_size=(32, 32),
    shuffle=True,
    crop_to_aspect_ratio=True
)

# Normalize pixel values.
training_dataset = training_dataset.map(
    lambda x, y: (x / 255., y)
)

test_dataset = image_dataset_from_directory(
    test_data_dir,
    labels="inferred",
    label_mode="int",
    color_mode="grayscale",
    batch_size=32,
    image_size=(32, 32),
    shuffle=True,
    crop_to_aspect_ratio=True
)

# Normalize pixel values.
test_dataset = test_dataset.map(
    lambda x, y: (x / 255., y)
)

## Define and train a model

### Model definition via the Sequential API

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Activation, Flatten, Dense
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme()

In [None]:
def plot_training_history(training_history, metrics=None):
    """
    Plots loss and additional metrics (if any) across training
    epochs. Additional metrics must be provided as a list of
    strings.
    """
    nrows = len(metrics) + 1 if metrics is not None else 1
    
    fig, axs = plt.subplots(nrows=nrows, ncols=1, figsize=(14, 10), sharex=True)

    sns.lineplot(
        x=range(len(training_history.history['loss'])),
        y=training_history.history['loss'],
        ax=axs[0],
        label='Training'
    )

    if 'val_loss' in training_history.history.keys():
        sns.lineplot(
            x=range(len(training_history.history['val_loss'])),
            y=training_history.history['val_loss'],
            ax=axs[0],
            label='Validation'
        )
        
    plt.sca(axs[0])
    plt.title('Loss')

    if metrics is not None:
        for i, metric in enumerate(metrics):
            sns.lineplot(
                x=range(len(training_history.history[metric])),
                y=training_history.history[metric],
                ax=axs[i + 1],
                label='Training'
            )
        
            if f'val_{metric}' in training_history.history.keys():
                sns.lineplot(
                    x=range(len(training_history.history[f'val_{metric}'])),
                    y=training_history.history[f'val_{metric}'],
                    ax=axs[i + 1],
                    label='Validation'
                )
    
            plt.sca(axs[i + 1])
            plt.title(metric)

    plt.xlabel('Epoch')

In [None]:
class ShallowNet:
    @staticmethod
    def build(width, height, depth, n_classes):
        """
        """
        model = Sequential()

        input_shape = (height, width, depth)

        model.add(Conv2D(
            filters=32,
            kernel_size=(3, 3),
            padding='same',
            input_shape=input_shape
        ))
        model.add(Activation('relu'))
        model.add(Flatten())
        model.add(Dense(units=n_classes))
        model.add(Activation('softmax'))

        return model

In [None]:
sn_model = ShallowNet().build(width=32, height=32, depth=1, n_classes=10)

sn_model

In [None]:
sn_model(x_train[:1, ...])

In [None]:
optimizer = tf.keras.optimizers.SGD(learning_rate=5e-3)

sn_model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=optimizer,
    metrics=['accuracy']
)

training_history = sn_model.fit(
    x_train,
    y_train,
    batch_size=32,
    epochs=100,
    validation_data=(x_test, y_test)
)

In [None]:
plot_training_history(training_history, metrics=['accuracy'])

### Model definition via the functional API

Build a model starting from Keras `Layer` and `Model` objects and using the functional API.

In [None]:
from tensorflow.keras.layers import Layer
from tensorflow.keras import Input, Model

In [None]:
class ShallowNetLayer(Layer):
    def __init__(self, n_classes):
        super().__init__()

        self.conv = Conv2D(
            filters=32,
            kernel_size=(3, 3),
            padding='same'
        )
        self.relu = Activation('relu')
        self.flatten = Flatten()
        self.dense = Dense(units=n_classes)
        self.softmax = Activation('softmax')

    def call(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.dense(x)
        output = self.softmax(x)

        return output 

In [None]:
sn_layer = ShallowNetLayer(n_classes=10)

inputs = Input(shape=(32, 32, 1))
outputs = sn_layer(inputs)

sn_model_from_layer = Model(
    inputs=inputs,
    outputs=outputs
)

optimizer_2 = tf.keras.optimizers.SGD(learning_rate=5e-3)

sn_model_from_layer.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=optimizer_2,
    metrics=['accuracy']
)

# Training with manually loaded data.
# training_history_2 = sn_model_from_layer.fit(
#     x_train,
#     y_train,
#     batch_size=32,
#     epochs=100,
#     validation_data=(x_test, y_test)
# )

# Training with Tensorflow Datasets.
training_history_2 = sn_model_from_layer.fit(
    training_dataset,
    batch_size=None,
    epochs=100,
    validation_data=test_dataset
)

In [None]:
plot_training_history(training_history_2, metrics=['accuracy'])

## Accessing a model's intermediate output

Extract the intermediate output (from hidden layers) of a trained model by defining another model built from the trained model's layers.

__Note:__ this is easily done with the sequential if we have a sequential model and we want to get its first N layers - more complicated models require the functional API.

Isolate the names of the trained model's layers (in order).

In [None]:
sn_model.summary()

In [None]:
features_extraction_layer_names = [layer['config']['name'] for layer in sn_model.get_config()['layers']][1:3]

features_extraction_layer_names

Define a sequential model getting the list of layers of the original one by name.

In [None]:
feature_extraction_model = Sequential([
    sn_model.get_layer(layer_name)
    for layer_name in features_extraction_layer_names
])

In [None]:
intermediate_output = feature_extraction_model(x_test)

Check the intermediate output from the convolutional layer against the original image.

In [None]:
fig, axs = plt.subplots(ncols=3, nrows=3, figsize=(14, 14))

for row in range(axs.shape[0]):
    for col in range(axs.shape[1]):
        ax = axs[row, col]

        # The first row contains the original image.
        if row == 0:
            ax.imshow(x_test[0, ...].numpy(), cmap='gray')

            plt.sca(ax)
            plt.title('Original image')

            break

        # The second row contains 3 channels of the
        # intermediate output.
        if row == 1:
            ax.imshow(intermediate_output[0, ..., col].numpy(), cmap='gray')

            plt.sca(ax)
            plt.title(f'Intermediate output: channel {col}')


        # The first row contains the original image.
        if row == 2:
            ax.imshow(tf.reduce_mean(intermediate_output[:1, ...], axis=-1)[0, ...].numpy(), cmap='gray')

            plt.sca(ax)
            plt.title('Intermediate output: average over all channels')

            break