In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, AveragePooling2D, Dense, Dropout, Input, Flatten, MaxPooling2D
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

from functools import partial

from albumentations import (Compose, HorizontalFlip, Rotate, OneOf, ToGray, VerticalFlip,
    RandomScale, ChannelShuffle, ChannelDropout, ToSepia, RandomBrightnessContrast
)

import random
import os

from os.path import isfile
import pickle
import lzma

AUTOTUNE = tf.data.experimental.AUTOTUNE


In [None]:
def get_mkl_enabled_flag():
    mkl_enabled = False
    major_version = int(tf.__version__.split(".")[0])
    minor_version = int(tf.__version__.split(".")[1])
    if major_version >= 2:
        if minor_version < 5:
            from tensorflow.python import _pywrap_util_port
        else:
            from tensorflow.python.util import _pywrap_util_port
            onednn_enabled = int(os.environ.get('TF_ENABLE_ONEDNN_OPTS', '0'))
        mkl_enabled = _pywrap_util_port.IsMklEnabled() or (onednn_enabled == 1)
    else:
        mkl_enabled = tf.pywrap_tensorflow.IsMklEnabled()
    return mkl_enabled


print("We are using Tensorflow version", tf.__version__)
print("MKL enabled :", get_mkl_enabled_flag())


In [None]:
cifar100_labels = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 
    'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 
    'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 
    'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 
    'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 
    'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 
    'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 
    'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 
    'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]


In [None]:
IMAGE_SHAPE = (32, 32, 3)
AUG_PROBA = lambda x=0: random.uniform(0.35, 0.75)
NUM_EPOCHS = 64
BATCH_SIZE = 256
VAL_BATCH_SIZE = 32
NUM_CLASSES = 10
INITIAL_LR = 0.01

In [None]:
def view_image(ds):
    image, label = next(iter(ds)) 
    image = image.numpy()
    label = label.numpy()

    fig = plt.figure(figsize=(16, 16))

    for i in range(20):
        ax = fig.add_subplot(4, 5, i+1, xticks=[], yticks=[])
        ax.imshow(image[i])
        label_idx = np.argmax(label[i], axis=0)
        ax.set_title(f"Label: {cifar100_labels[label_idx]}")


In [None]:
aug_transforms = Compose([
        OneOf([
                Rotate(limit=15, p=AUG_PROBA()),
                Rotate(limit=45, p=AUG_PROBA()),
                Rotate(limit=90, p=AUG_PROBA())
            ], p=AUG_PROBA()),
        OneOf([
                HorizontalFlip(p=AUG_PROBA()),
                VerticalFlip(p=AUG_PROBA())
            ], p=AUG_PROBA()),
        #RandomBrightnessContrast(p=AUG_PROBA()),
        RandomScale(scale_limit=0.05, p=AUG_PROBA()),
        # OneOf([
        #        ChannelShuffle(p=AUG_PROBA()),
        #        ChannelDropout(p=AUG_PROBA()),
        #        ToGray(p=AUG_PROBA()),
        #        ToSepia(p=AUG_PROBA())
        #    ], p=AUG_PROBA()),
    ]
)


In [None]:
def get_dataset(ds_name: str):
    def aug_fn(image):
        data = {"image": image}
        aug_data = aug_transforms(**data)
        aug_img = aug_data["image"]
        aug_img = tf.image.resize(aug_img, size=[IMAGE_SHAPE[0], IMAGE_SHAPE[1]])
        return aug_img

    def process_data(image, label):
        aug_img = tf.numpy_function(func=aug_fn, inp=[image], Tout=tf.float32)
        return aug_img, label

    def normilize_func(image):
        return tf.cast(image/255.0, tf.float32)

    def normilize_data(image, label):
        image = tf.numpy_function(func=normilize_func, inp=[image], Tout=tf.float32)
        return image, tf.one_hot(label, depth=NUM_CLASSES)

    train_ds, test_ds = tfds.load(name=ds_name, split=['train[:85%]','test'], as_supervised=True, with_info=False, shuffle_files=True) 

    train_ds = train_ds.map(partial(process_data), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)
    train_ds = train_ds.map(partial(normilize_data), num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

    test_ds = test_ds.map(partial(normilize_data), num_parallel_calls=AUTOTUNE).batch(VAL_BATCH_SIZE).prefetch(AUTOTUNE)

    return train_ds, test_ds

In [None]:
# Learning rate exp decay
def lr_exp_decay(epoch, lr):
    k = 0.1
    return INITIAL_LR * np.exp(-k * epoch)


In [None]:
def build_lenet_5(activation_f: str):
    model = Sequential(
        [
            Conv2D(filters=6, kernel_size=(5, 5), activation=activation_f, strides=(1, 1), padding='valid', input_shape=IMAGE_SHAPE),
            AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),

            Conv2D(filters=16, kernel_size=(5, 5), activation=activation_f, strides=(1, 1)),
            AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),

            Flatten(),

            Dense(units=120, activation=activation_f),
            Dense(units=84, activation=activation_f),

            Dense(units=NUM_CLASSES, activation='softmax')
        ]
    )

    adam_optimizer = Adam(learning_rate=INITIAL_LR)

    model.compile(optimizer=adam_optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

    return model

In [None]:
def train_model(model_name: str, model=None, train=None, test=None):
    model_history = {}
    model_history_fn = "./data/models/" + model_name + "-history.xz"
    model_fn = './data/models/' + model_name + '-best_model.hdf5'
    ret_model = model

    if isfile(model_history_fn):
        with lzma.open(model_history_fn, "rb") as m_file:
            model_history = pickle.load(m_file)
            ret_model = load_model(model_fn)
    else:
        # Early stop callback
        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2, mode='min', verbose=1)

        # Save best model
        checkpoint = ModelCheckpoint(filepath=model_fn,
            monitor='val_loss',
            verbose=1,
            save_best_only=True,
            mode='min'
        )

        model_history = ret_model.fit(
            train, 
            batch_size=BATCH_SIZE, 
            epochs=NUM_EPOCHS,
            validation_data=test,
            validation_batch_size=VAL_BATCH_SIZE,
            callbacks=[early_stop, LearningRateScheduler(lr_exp_decay, verbose=1), checkpoint]
        )

        with lzma.open(model_history_fn, "wb") as m_file:
            pickle.dump(model_history, m_file)

    return model_history, ret_model


In [None]:
train_ds, test_ds = get_dataset('cifar10')
view_image(train_ds)

In [None]:
relu_lenet5_model = build_lenet_5(activation_f='relu')
relu_lenet5_model_history, relu_lenet5_model = train_model(model_name='cifar_10-relu_lenet_5', model=relu_lenet5_model, train=train_ds, test=test_ds)

In [None]:
tanh_lenet5_model = build_lenet_5(activation_f='tanh')
tanh_lenet5_model_history, tanh_lenet5_model = train_model(model_name='cifar_10-than_lenet_5', model=tanh_lenet5_model, train=train_ds, test=test_ds)

In [None]:
NUM_CLASSES = 100
train_ds, test_ds = get_dataset('cifar100')

In [None]:
relu_cifar_100_lenet5_model = build_lenet_5(activation_f='relu')
relu_cifar_100_lenet5_model_history, relu_cifar_100_lenet5_model = train_model(model_name='cifar_100-relu_lenet_5', model=relu_cifar_100_lenet5_model, train=train_ds, test=test_ds)

In [None]:
tanh_cifar_100_lenet5_model = build_lenet_5(activation_f='tanh')
tanh_cifar_100_lenet5_model_history, tanh_cifar_100_lenet5_model = train_model(model_name='cifar_100-tanh_lenet_5', model=tanh_cifar_100_lenet5_model, train=train_ds, test=test_ds)