# Supervised Fine-Tuning Cifar10

Code: [github:lucasdavid/experiments/.../supervised/fine-tuning/cifar10](https://github.com/lucasdavid/experiments/blob/main/notebooks/supervised/fine-tuning/cifar10/cifar10.ipynb)  
Dataset: Cifar10  
Docker image: `tensorflow/tensorflow:latest-gpu-jupyter`  

In [None]:
from time import time
import tensorflow as tf

class RC:
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    seed = 5131

class DC:
    batch_size = 64
    image_size = (32, 32)
    channels = 3
    input_shape = (batch_size, *image_size, channels)

class TC:
    epochs = 200
    learning_rate = .001
    
    epochs_fine_tuning = 0
    learning_rate_fine_tuning = .0005

    validation_split = '30%'
    reduce_lr_on_plateau_factor = .5

    splits = [f'train[{validation_split}:]', f'train[:{validation_split}]', 'test']
    
    augment = False

class LogConfig:
    tensorboard = (f'/tf/logs/d:cifar100 '
                   f'e:{TC.epochs} b:{DC.batch_size} v:{TC.validation_split} '
                   f'm:mobilenetv2 aug:{TC.augment} sd:{RC.seed}'
                   f'/{int(time())}')
    
class Config:
    run = RC
    data = DC
    training = TC
    log = LogConfig

## Setup

In [None]:
import os
import pathlib
from math import ceil

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

from tensorflow.keras import Model, Sequential, Input
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, BatchNormalization,
                                     Activation, Lambda)

In [None]:
def plot(y, titles=None, rows=1, i0=0):
    for i, image in enumerate(y):
        if image is None:
            plt.subplot(rows, ceil(len(y) / rows), i0+i+1)
            plt.axis('off')
            continue

        t = titles[i] if titles else None
        plt.subplot(rows, ceil(len(y) / rows), i0+i+1, title=t)
        plt.imshow(image)
        plt.axis('off')

In [None]:
sns.set()

## Loading Dataset

In [None]:
import tensorflow_datasets as tfds

In [None]:
class Data:
    (train_ds, val_ds, test_ds), info = tfds.load(
    'cifar100',
    split=Config.training.splits,
    shuffle_files=True,
    as_supervised=True,
    with_info=True)

    class_names = np.asarray(info.features['label'].names)

In [None]:
print(Data.info.citation)

## Augmentation Policy

In [None]:
batchwise_augmentation = Sequential([
    tf.keras.layers.experimental.preprocessing.RandomZoom((-.3, .3)),
    tf.keras.layers.experimental.preprocessing.RandomFlip(),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
], name='batch_aug')

def augment_fn(image, label):
    image = samplewise_augmentation(image)
    image = tf.clip_by_value(image, 0, 255)
    return image, label

def prepare(ds):
    ds = ds.batch(Config.data.batch_size, drop_remainder=True)
    return ds.prefetch(buffer_size=Config.run.AUTOTUNE)

In [None]:
train_ds = prepare(Data.train_ds)
val_ds = prepare(Data.val_ds)
test_ds = prepare(Data.test_ds)

In [None]:
for x, y in train_ds:
    print('Shapes:', x.shape, 'and', y.shape)
    print("Labels: ", y.numpy())

    plt.figure(figsize=(16, 9))
    plot(x.numpy().astype(int), rows=4)
    plt.tight_layout()
    break

## Model Definition

In [None]:
from tensorflow.keras.applications import mobilenet_v2

encoder = mobilenet_v2.MobileNetV2(include_top=False, pooling='avg',
                                   input_shape=Config.data.input_shape[1:])
encoder = Model(encoder.input, encoder.get_layer('block_9_add').output)

In [None]:
def encoder_pre(x):
    return Lambda(mobilenet_v2.preprocess_input, name='pre_inception')(x)

In [None]:
from tensorflow.keras.layers import GlobalAveragePooling2D

def dense_block(x, units, activation='relu', name=None):
    y = Dense(units, name=f'{name}_fc', use_bias=False)(x)
    y = BatchNormalization(name=f'{name}_bn')(y)
    y = Activation(activation, name=f'{name}_relu')(y)
    return y
    
def discriminator():
    y = x = Input(shape=Config.data.input_shape[1:], name='inputs')
    if Config.training.augment:
        y = batchwise_augmentation(y)
    y = encoder_pre(y)
    y = encoder(y)
    y = GlobalAveragePooling2D(name='avg')(y)
    y = Dense(len(Data.class_names), name='predictions')(y)
    return tf.keras.Model(x, y, name='author_disc')

disc = discriminator()
disc.summary()

In [None]:
disc.get_layer('model').trainable = False

In [None]:
tf.keras.utils.plot_model(disc, show_shapes=True, show_dtype=True)

In [None]:
from tensorflow.keras import losses, metrics, optimizers

disc.compile(
    optimizer=optimizers.Adam(lr=Config.training.learning_rate),
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[
        metrics.SparseCategoricalAccuracy(),
        metrics.SparseTopKCategoricalAccuracy()
    ]
)

## Training

### Initial Training for Final Classification Layer

The final layer --- currently containing random values --- must be first adjusted to match the the encoder's layers' current state.

In [None]:
from tensorflow.keras import callbacks

callbacks = [
    callbacks.TerminateOnNaN(),
    callbacks.ModelCheckpoint(Config.log.tensorboard + '/weights.h5',
                              save_best_only=True,
                              save_weights_only=True,
                              verbose=1),
    callbacks.ReduceLROnPlateau(patience=Config.training.epochs // 2,
                                factor=Config.training.reduce_lr_on_plateau_factor),
    callbacks.EarlyStopping(patience=Config.training.epochs // 3),
    callbacks.TensorBoard(Config.log.tensorboard, histogram_freq=1)
]

In [None]:
disc.fit(
    train_ds,
    validation_data=val_ds,
    epochs=Config.training.epochs,
    initial_epoch=0,
    callbacks=callbacks,
);

### Fine-Tuning All Layers

In [None]:
if Config.training.epochs_fine_tuning:
    disc.get_layer('model').trainable = True

    disc.fit(
        train_ds,
        validation_data=val_ds,
        initial_epoch=disc.history.epoch[-1] + 1,
        epochs=len(disc.history.epoch) + Config.training.epochs_fine_tuning,
        callbacks=callbacks,
    )

## Testing

In [None]:
disc.get_layer('model').trainable = False

disc.load_weights(Config.log.tensorboard + '/weights.h5')

In [None]:
from sklearn import metrics as skmetrics

def labels_and_predictions(model, ds):
    labels, predictions = [], []
    
    for x, y in ds:
        p = model(x).numpy()
        p = p.argmax(axis=1)
        
        labels.append(y.numpy())
        predictions.append(p)
    
    labels, predictions = np.concatenate(labels), np.concatenate(predictions)
    labels, predictions = Data.class_names[labels], Data.class_names[predictions]
    return labels, predictions

def evaluate(model, ds):
    labels, predictions = labels_and_predictions(model, ds)
    
    print('balanced acc:', skmetrics.balanced_accuracy_score(labels, predictions))
    print('accuracy    :', skmetrics.accuracy_score(labels, predictions))
    print('Classification report:')
    print(skmetrics.classification_report(labels, predictions))

#### Training Report

In [None]:
evaluate(disc, train_ds)

#### Validation Report

In [None]:
evaluate(disc, val_ds)

#### Test Report

In [None]:
evaluate(disc, test_ds)

In [None]:
labels, predictions = labels_and_predictions(disc, test_ds)

In [None]:
cm = skmetrics.confusion_matrix(labels, predictions)
sorted_by_most_accurate = (cm / cm.sum(axis=1, keepdims=True)).diagonal().argsort()[::-1]
cm = cm[sorted_by_most_accurate][:, sorted_by_most_accurate]

plt.figure(figsize=(12, 12))
sns.heatmap(cm,
            cmap='RdPu', annot=False, cbar=False,
            yticklabels=Data.class_names[sorted_by_most_accurate],
            xticklabels=False);

In [None]:
def plot_predictions(model, ds, take=1):
    figs, titles = [], []
    
    plt.figure(figsize=(16, 12))
    for ix, (x, y) in enumerate(ds.take(take)):
        p = model.predict(x).argmax(axis=-1)
        
        figs.append(x.numpy())
        titles.append([f'{a} {b}' for a, b in zip(y, p)])
        
    plot(np.concatenate(figs),
         titles=sum(titles, []),
         rows=6)
    plt.tight_layout()

plot_predictions(disc, train_ds)