# Imports

In [None]:
import sys
sys.path.append('../')
import tensorflow as tf
import tensorflow_datasets as tfds
from helper import visualize_classification_image_samples, visualize_classification_predictions, plot_confusion_matrix
from helper import fast_benchmark, set_model_config
from helper import plot_loss
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt

# Global variables
model_config = set_model_config('cifar_10')

'''Create a random seed generator for randomized TF ops'''
rng = tf.random.Generator.from_seed(123, alg='philox')

# Load dataset and show information

In [None]:
# Load the CIFAR-10 dataset
(ds_train, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

# Access and print dataset information
print("CIFAR-10 dataset information:")
print(f"Number of classes: {ds_info.features['label'].num_classes}")
print(f"Class names: {ds_info.features['label'].names}")
print(f"Number of training examples: {ds_info.splits['train'].num_examples}")
print(f"Dataset splits: {list(ds_info.splits.keys())}")
print(f"Dataset description: {ds_info.description}")

In [None]:
ds_info

# Iterate the dataset and visualize some samples

In [None]:
# Iterate the dataset 
iterator = iter(ds_train.take(3))

for i in range(3):
    image, label = next(iterator)
    print(f'Sample {i} tensor shape: {image.shape}')
    print(f'Sample {i} class label: {label}')

In [None]:
# Visualize a grid of samples from the training set
with plt.style.context('dark_background'):
    visualize_classification_image_samples(ds_train, 10, ds_info, g_shape=(2,5))

# Apply pre-processing and augmentations before training

In [None]:
# Pre-processing
def normalize_image(image: tf.Tensor,label: tf.Tensor)-> (tf.Tensor, tf.Tensor):
    '''Define a normalization function that rescales 
    the pixel values from [0,255] uint8 to float32 [0,1]
    '''
    return tf.cast(image, tf.float32) / 255., label

# Augmentations
def augment_image(image_label: tuple, seed)-> (tf.Tensor, tf.Tensor):
    '''Apply basic augmentations on the training dataset samples
    in order to induce extra variance to our dataset. This can help
    our model generalize in a better way. Augmentations applied are 
    random horizontal flip, random crop and random rotation by 45 degrees.'''

    image, label = image_label
    new_seed = tf.random.split(seed, num=1)[0, :]
    image = tf.image.stateless_random_flip_left_right(image, new_seed)

    angle = tf.random.uniform(shape=(), minval=-45, maxval=45, dtype=tf.float32)
    image = tf.image.rot90(image, k=tf.cast(angle / 90, dtype=tf.int32))

    image = tf.image.stateless_random_crop(value= image, size= (32,32,3), seed= new_seed)

    return image, label


def random_wrapper(image: tf.Tensor, label: tf.Tensor)-> (tf.Tensor, tf.Tensor):
    '''Wrapper function for our augmentations to generate a new random 
    seed on each call. This way we can indeed have random augmentations 
    for each sample.'''

    seed = rng.make_seeds(2)[0]
    image, label = augment_image((image, label), seed)
    return image,label

# Generate the validation set from the training set
validation_size = int(model_config['val_size'] * ds_info.splits['train'].num_examples) 

# Create the validation dataset
ds_val = ds_train.take(validation_size)
ds_train = ds_train.skip(validation_size)

# Print the number of examples in the training and validation sets
print("Number of images in training set:", ds_info.splits['train'].num_examples - validation_size)
print("Number of images in validation set:", validation_size)
print("Number of images in test set:", ds_info.splits['test'].num_examples)

# Pipelining pre-processing, augmentations, batching and caching of the training dataset.
''''Order of things: 1. Caching Dataset into memory
                     2. Shuffling dataset to introduce randomness on each training epoch
                     3. Apply augmentations to the shuffled dataset using parallel mapping
                     4. Batch dataset and normalize the images using vectorized mapping 
                     5. Prefetch dataset for performance optimization'''

ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples, reshuffle_each_iteration=True)
ds_train = ds_train.map(random_wrapper, 
                        num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(model_config['batch_size']).map(normalize_image)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# Prepare the validation set
ds_val = ds_val.cache()
ds_val = ds_val.shuffle(validation_size, reshuffle_each_iteration=True)
ds_val = ds_val.batch(model_config['batch_size']).map(normalize_image)
ds_val = ds_val.prefetch(tf.data.AUTOTUNE)


# Prepare the test set 
ds_test = ds_test.batch(model_config['batch_size']).map(normalize_image)
ds_test = ds_test.cache()
ds_test = ds_test.shuffle(ds_info.splits['test'].num_examples)
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)


# Run benchmarks on the data pipeline

In [None]:
""" Benchmark our training dataset for 2 epochs to test our input pipeline's efficiency 
    using parallel mapping """ 
print('Run before caching training dataset...')
fast_benchmark(ds_train)
print('Second run after caching training dataset...')
fast_benchmark(ds_train)

In [None]:
""" Benchmark our val dataset for 2 epochs to test our input pipeline's efficiency 
    using vectorized mapping """ 
print('Run before caching validation dataset...')
fast_benchmark(ds_val)
print('Second run after caching validation dataset...')
fast_benchmark(ds_val)

# Create a simple image classification model using the Sequential API from Keras

In [None]:
# Create a model using the Keras Sequential API
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam

# Define our model's architecture
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', 
           data_format= 'channels_last', 
           input_shape=(32, 32, 3)),
    MaxPooling2D((2, 2)),
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D((2, 2)),
    # Conv2D(256, (3, 3), activation='relu'),
    # MaxPooling2D((2, 2)),
    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    # Dense(128, activation='relu'),
    Dense(model_config['n_classes'], activation = None)
], name='cifar10_model_1.0')

# Compile and configure the model for training

# NOTE: Feel free to replace this with your own optimizer
optimizer = Adam(learning_rate=model_config['learning_rate'])

# Set Early Stopping strategy after 5 epochs of no improvement for validation set
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)


model.compile(optimizer=optimizer, 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='Accuracy'),
                       tf.keras.metrics.SparseTopKCategoricalAccuracy(name= 'TopKAccuracy')]
              )

# Get a detailed view of how the defined model looks
model.summary()

# Train the model

In [None]:
""" Train the model, inspect the loss curve during training and experiment 
    with different architectures. Take inspiration from the idea below or 
    try something one of your own...!
    
    Experiment idea: 1. Redifine the model's architecture with an extra convolutional
                     or dense layers and compare the models 
                     2. Try setting a different learning rate and optimizer in the model's settings
                     """
# Train model and plot losses
history = model.fit(ds_train, epochs= int(model_config['training_epochs']), 
                    validation_data= ds_val, 
                    callbacks = [callback])

# Plot with dark backgorund
with plt.style.context('dark_background'):
    plot_loss(history, model_type = 'classification')

# Log metrics and print confusion matrix

In [None]:
"""Evaluate the model on the test set
   and inspect the results. 
   How reliable is our model? How do the test metrics compare with your training 
   or validation metrics ? """

from prettytable import PrettyTable


evaluation_result = model.evaluate(ds_test, verbose= 0)

# Print and log the evaluation metrics
test_metrics = evaluation_result[1:]

# Create a PrettyTable
table = PrettyTable()
table.field_names = ["Metric", "Value"]

# Add rows
table.add_row(["Accuracy", f"{test_metrics[0]:.4f}"])
table.add_row(["TopKAccuracy", f"{test_metrics[1]:.4f}"])

# Print the table
print("--------Test dataset metrics--------")
print(table)
print("----------------------------")

# Plot confusion matrix with default background
with plt.style.context('default'):
    plot_confusion_matrix(model, ds_test, ds_info.features['label'].names)

# Load the trained model and visualize predictions

In [None]:
""" Run inference on the test set for 10 samples and visualize 
    predictions versus true labels """

from keras.models import load_model

test_iterator = iter(ds_test.take(1))
single_batch = next(test_iterator)
images = single_batch[0][:15]
true_labels = list(single_batch[1][:15].numpy())

# Load a trained model and visualize predictions
trained_model = load_model('computer_vision/trained_models/cifar10_model')
predictions = trained_model.predict(images, verbose= 0)

# Plot with dark background
with plt.style.context('dark_background'):
    visualize_classification_predictions(images, true_labels, predictions, ds_info, num_samples= 15)