# IMPORTS #

In [None]:
# import the necessary packages
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import GlobalMaxPooling2D
from tensorflow.keras.layers import AveragePooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import concatenate
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import multiply
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Permute
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import add

from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.callbacks import *
from tensorflow.keras.regularizers import l2, Regularizer
from tensorflow.keras.models import Model
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import backend as K
from tensorflow import keras

from tensorflow.keras.losses import CategoricalCrossentropy, sparse_categorical_crossentropy
from tensorflow.keras.optimizers import SGD

from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import tempfile
import json
import time
import cv2
import os

In [None]:
! mkdir models

# CUSTOM LAYERS #

In [None]:
class Mish(Layer):
    '''
    Mish Activation Function.
    .. math::
        mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
    Shape:
        - Input: Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a model.
        - Output: Same shape as the input.
    Examples:
        >>> X_input = Input(input_shape)
        >>> X = Mish()(X_input)
    '''

    def __init__(self, **kwargs):
        super(Mish, self).__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        return inputs * K.tanh(K.softplus(inputs))

    def get_config(self):
        config = super(Mish, self).get_config()
        return config

    def compute_output_shape(self, input_shape):
        return input_shape

# PREPROCESSORS #

In [None]:
class PadPreprocessor:
    def __init__(self, pad):
        # initialize the instance variables
        self.pad = pad
    
    def preprocess(self, img):
        # return the padded image
        return np.pad(img, ((self.pad, self.pad), (self.pad, self.pad), (0, 0)))

In [None]:
class ReflectionPadPreprocessor:
    def __init__(self, pad):
        # initialize the instance variables
        self.pad = pad
    
    def preprocess(self, img):
        # zero pad the image
        img = np.pad(img, ((self.pad, self.pad), (self.pad, self.pad), (0, 0)))

        # reflect pad the image
        for i, j in zip(range(self.pad), range(self.pad)):
            xstart = self.pad
            xend = img.shape[1] - self.pad - 1
            ystart = self.pad
            yend = img.shape[0] - self.pad - 1

            img[:, xstart - i - 1] = img[:, xstart + i + 1]
            img[:, xend + i + 1] = img[:, xend - i - 1]
            img[ystart - j - 1, :] = img[ystart + j + 1, :]         
            img[yend + j + 1, :] = img[yend - j - 1, :]   
        
        # return the processed image
        return img

In [None]:
class FlipPreprocessor:
    def __init__(self, prob):
        # initialize the instance variables 
        self.prob = prob
    
    def preprocess(self, img):
        p = np.random.uniform(size = (1,))

        # check to see if the image is to be flipped
        if p <= self.prob:
            img = cv2.flip(img, 1)
        
        # return the processed image
        return img

In [None]:
class PatchPreprocessor:
    def __init__(self, height, width):
        # initialize the instance variables - target height and width
        self.height = height
        self.width = width
    
    def preprocess(self, img):
        # extract a random crop from the image and return it
        return extract_patches_2d(img, (self.height, self.width), max_patches = 1)[0]

In [None]:
class MeanPreprocessor:
    def __init__(self, mean, std, normalize = True):
        # initialize the instance variables
        self.mean = mean
        self.std = std
        self.normalize = normalize
    
    def preprocess(self, img):
        # if the image is to be normalized, normalize it
        if self.normalize:
            img = img.astype("float") / 255.0

        # return the processed image
        return ((img - self.mean) / self.std)

In [None]:
class ImageToArrayPreprocessor:
    def __init__(self, data_format = None):
        # initialize the instance variables
        self.data_format = data_format

    def preprocess(self, img):
        # apply the keras utility function that correctly rearranges the dimensions of the image
        return img_to_array(img, data_format = self.data_format)

# DATA GENERATOR #

In [None]:
class CifarGenerator:
    def __init__(self, x_train, y_train, batch_size, preprocessors = None, aug = None):
        # initialize the cifar data
        self.x_train = x_train
        self.y_train = y_train

        # initialize the instance variables
        self.bs = batch_size
        self.preprocessors = preprocessors
        self.aug = aug
        self.num_images = self.x_train.shape[0]
        self.lb = LabelBinarizer()
        self.lb.fit(y_train)
    
    def generator(self, passes = np.inf):
        # initialize a variable to keep a count on the epochs
        epochs = 0

        # loop through the dataset indefinitely
        while(epochs < passes):
            # initialize the indices
            indices = list(range(self.num_images))
            np.random.shuffle(indices)

            # loop through the dataset in batches
            for i in range(0, self.num_images, self.bs):
                # extract the current indices
                cur_indices = sorted(indices[i : i + self.bs])

                # grab the current batch
                x, y = self.x_train[cur_indices], self.y_train[cur_indices]

                # if any preprocessors are supplied, apply them
                if self.preprocessors is not None:
                    # loop through the images
                    proc_x = []
                    for img in x:
                        # loop through the preprocessors
                        for p in self.preprocessors:
                            img = p.preprocess(img)

                        proc_x.append(img)
                
                    # update the images
                    x = np.array(proc_x)
                
                # preprocess the labels
                y = self.lb.transform(y)

                # if any augmentation is supplied, apply it
                if self.aug is not None:
                    x, y = next(self.aug.flow(x, y, batch_size = bs))
                
                # yield the current batch
                yield x, y

In [None]:
class MixUpCifarGenerator:
    def __init__(self, x_train, y_train, batch_size, alpha = 0.4, preprocessors = None, aug = None):
        # initialize the cifar data
        self.x_train = x_train
        self.y_train = y_train

        # initialize the instance variables
        self.bs = batch_size
        self.preprocessors = preprocessors
        self.aug = aug
        self.alpha = alpha
        self.num_images = self.x_train.shape[0]
        self.lb = LabelBinarizer()
        self.lb.fit(y_train)
    
    def generator(self, passes = np.inf):
        # initialize a variable to keep a count on the epochs
        epochs = 0

        # loop through the dataset indefinitely
        while(epochs < passes):
            # initialize the indices
            indices = list(range(self.num_images))
            np.random.shuffle(indices)

            # loop through the dataset in batches
            for i in range(0, self.num_images, self.bs):
                # extract the current indices
                cur_indices = sorted(indices[i : i + self.bs])

                # initialize the other batch of indices
                if i + self.bs < self.num_images:
                    oth_indices = list(range(i, i + self.bs))
                else:
                    oth_indices = list(range(i, self.num_images))

                # grab the data batches
                x1, y = self.x_train[cur_indices], self.y_train[cur_indices]
                x2 = self.x_train[oth_indices]

                # if any preprocessors are supplied, apply them
                if self.preprocessors is not None:
                    # loop through the images
                    proc_x1 = []
                    proc_x2 = []
                    for img1, img2 in zip(x1, x2):
                        # loop through the preprocessors
                        for p in self.preprocessors:
                            img1 = p.preprocess(img1)
                            img2 = p.preprocess(img2)

                        proc_x1.append(img1)
                        proc_x2.append(img2)
                
                    # update the images
                    x1 = np.array(proc_x1)
                    x2 = np.array(proc_x2)
                
                # randomly sample the lambda value from beta distribution.
                lamb = np.random.beta(self.alpha + 1, self.alpha, x1.shape[0])

                # remove possible duplicates
                lamb = np.maximum(lamb, 1 - lamb)

                # reshape the parameter to a suitable shape
                xlamb = lamb.reshape(-1, 1, 1, 1)

                # perform the mixup
                x = (xlamb * x1) + ((1 - xlamb) * x2)

                # preprocess the labels
                y = self.lb.transform(y)

                # if any augmentation is supplied, apply it
                if self.aug is not None:
                    x, y = next(self.aug.flow(x, y, batch_size = bs))
                
                # yield the current batch
                yield x, y

# MODELS #

In [None]:
class XResNet:
    @staticmethod
    def residual_module(data, K, stride, chan_dim, red = False, reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9, bottleneck = True, name = "res_block"):
        # shortcut branch
        shortcut = data

        if bottleneck:
            # first bottleneck block - 1x1
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Activation("relu", name = name + "_relu1")(bn1)
            conv1 = Conv2D(int(K * 0.25), (1, 1), use_bias = False, kernel_regularizer = l2(reg), kernel_initializer = "he_normal", name = name + "_conv1")(act1)

            # conv block - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Activation("relu", name = name + "_relu2")(bn2)            
            conv2 = Conv2D(int(K * 0.25), (3, 3), strides = stride, padding = "same", use_bias = False, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)
            
            # second bottleneck block - 1x1
            bn3 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn3")(conv2)
            act3 = Activation("relu", name = name + "_relu3")(bn3)
            conv3 = Conv2D(K, (1, 1), use_bias = False, kernel_regularizer = l2(reg), kernel_initializer = "he_normal", name = name + "_conv3")(act3)

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red:
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride, padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)
            
            # add the shortcut and final conv
            x = add([conv3, shortcut], name = name + "_add")
        
        else:
            # conv block 1 - 3x3
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Activation("relu", name = name + "_relu1")(bn1)            
            conv1 = Conv2D(K, (3, 3), strides = stride, padding = "same", use_bias = False, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv1")(act1)

            # conv block 2 - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Activation("relu", name = name + "_relu2")(bn2)
            conv2 = Conv2D(K, (3, 3), padding = "same", use_bias = False,
                        kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red and stride != (1, 1):
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride, padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)

            # add the shortcut and final conv
            x = add([conv2, shortcut], name = name + "_add")      

        # return the addition as the output of the residual block
        return x

    @staticmethod
    def build(height, width, depth, classes, stages, filters, stem_type = "imagenet", bottleneck = True, reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9):
        # set the input shape
        if K.image_data_format() == "channels_last":
            input_shape = (height, width, depth)
            chan_dim = -1
        else:
            input_shape = (depth, height, width)
            chan_dim = 1

        # initialize a counter to keep count of the total number of layers in the model
        n_layers = 0
        
        # input block
        inputs = Input(shape = input_shape)

        # stem
        if stem_type is "imagenet":
            x = Conv2D(filters[0], (3, 3), strides = (2, 2), use_bias = False, padding = "same", 
                    kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv1")(inputs)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same", 
                    kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv2")(x)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same", 
                    kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv3")(x)
            x = MaxPooling2D(pool_size = (3, 3), strides = (2, 2), padding = "same", name = "stem_max_pool")(x)
        elif stem_type is "cifar":
            x = Conv2D(filters[0], (3, 3), use_bias = False, padding = "same", kernel_initializer = "he_normal", 
                       kernel_regularizer = l2(reg), name = "stem_conv")(inputs)

        # increment the number of layers
        n_layers += 1

        # modify the stages to suit bottleck
        if bottleneck:
            stages = [int(np.floor(st / 3)) for st in stages]
        else:
            stages = [int(np.floor(st / 2)) for st in stages]

        # loop through the stages
        for i in range(0, len(stages)):
            # set the stride value
            stride = (1, 1) if i == 0 else (2, 2)

            name = f"stage{i + 1}_res_block1"
            x = XResNet.residual_module(x, filters[i + 1], stride, chan_dim, reg = reg, red = True, bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # loop through the number of layers in the stage
            for j in range(0, stages[i] - 1):
                # apply a residual module
                name = f"stage{i + 1}_res_block{j + 2}"
                x = XResNet.residual_module(x, filters[i + 1], (1, 1), chan_dim, reg = reg, bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # increment the number of layers
            if bottleneck:
                n_layers += (3 * stages[i])
            else:
                n_layers += (2 * stages[i])
        
        # BN => RELU -> POOL
        x = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = "final_bn")(x)
        x = Activation("relu", name = "final_relu")(x)
        x1 = GlobalAveragePooling2D(name = "global_avg_pooling")(x)
        x2 = GlobalMaxPooling2D(name = "global_max_pooling")(x)
        x = concatenate([x1, x2], axis = -1, name = "concatenate")

        # softmax classifier
        sc = Dense(classes, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "classifier")(x)
        sc = Activation("softmax", name = "softmax")(sc)

        # increment the number of layers
        n_layers += 1

        print(f"[INFO] {__class__.__name__}{n_layers} built successfully!")

        # return the constructed network architecture
        return Model(inputs = inputs, outputs = sc, name = f"{__class__.__name__}{n_layers}")

In [None]:
class MXResNet:
    @staticmethod
    def residual_module(data, K, stride, chan_dim, red = False, reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9,
                        bottleneck = True, name = "res_block"):
        # shortcut branch
        shortcut = data

        if bottleneck:
            # first bottleneck block - 1x1
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Mish(name = name + "_mish1")(bn1)
            conv1 = Conv2D(int(K * 0.25), (1, 1), use_bias = False, kernel_regularizer = l2(reg),
                           kernel_initializer = "he_normal", name = name + "_conv1")(act1)

            # conv block - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Mish(name = name + "_mish2")(bn2)
            conv2 = Conv2D(int(K * 0.25), (3, 3), strides = stride, padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)

            # second bottleneck block - 1x1
            bn3 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn3")(conv2)
            act3 = Mish(name = name + "_mish3")(bn3)
            conv3 = Conv2D(K, (1, 1), use_bias = False, kernel_regularizer = l2(
                reg), kernel_initializer = "he_normal", name = name + "_conv3")(act3)

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red:
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride,
                                            padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal",
                                  kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)

            # add the shortcut and final conv
            x = add([conv3, shortcut], name = name + "_add")

        else:
            # conv block 1 - 3x3
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Mish(name = name + "_mish1")(bn1)
            conv1 = Conv2D(K, (3, 3), strides = stride, padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv1")(act1)

            # conv block 2 - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Mish(name = name + "_mish2")(bn2)
            conv2 = Conv2D(K, (3, 3), padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red and stride != (1, 1):
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride,
                                            padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal",
                                  kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)

            # add the shortcut and final conv
            x = add([conv2, shortcut], name = name + "_add")

        # return the addition as the output of the residual block
        return x

    @staticmethod
    def build(height, width, depth, classes, stages, filters, stem_type = "imagenet", bottleneck = True,
              reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9):
        # set the input shape
        if K.image_data_format() == "channels_last":
            input_shape = (height, width, depth)
            chan_dim = -1
        else:
            input_shape = (depth, height, width)
            chan_dim = 1

        # initialize a counter to keep count of the total number of layers in the model
        n_layers = 0

        # input block
        inputs = Input(shape = input_shape)

        # stem
        if stem_type == "imagenet":
            x = Conv2D(filters[0], (3, 3), strides = (2, 2), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv1")(inputs)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv2")(x)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv3")(x)
            x = MaxPooling2D(pool_size = (3, 3), strides = (2, 2), padding = "same", name = "stem_max_pool")(x)
        elif stem_type == "cifar":
            x = Conv2D(filters[0], (3, 3), use_bias = False, padding = "same", kernel_initializer = "he_normal",
                       kernel_regularizer = l2(reg), name = "stem_conv")(inputs)

        # increment the number of layers
        n_layers += 1

        # loop through the stages
        for i in range(0, len(stages)):
            # set the stride value
            stride = (1, 1) if i == 0 else (2, 2)

            name = f"stage{i + 1}_res_block1"
            x = MXResNet.residual_module(x, filters[i + 1], stride, chan_dim, reg = reg, red = True,
                                         bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # loop through the number of layers in the stage
            for j in range(0, stages[i] - 1):
                # apply a residual module
                name = f"stage{i + 1}_res_block{j + 2}"
                x = MXResNet.residual_module(x, filters[i + 1], (1, 1), chan_dim, reg = reg,
                                             bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # increment the number of layers
            if bottleneck:
                n_layers += (3 * stages[i])
            else:
                n_layers += (2 * stages[i])

        # BN => RELU -> POOL
        x = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = "final_bn")(x)
        x = Mish(name = "final_mish")(x)
        x1 = GlobalAveragePooling2D(name = "global_avg_pooling")(x)
        x2 = GlobalMaxPooling2D(name = "global_max_pooling")(x)
        x = concatenate([x1, x2], axis = -1, name = "concatenate")

        # softmax classifier
        sc = Dense(classes, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "classifier")(x)
        sc = Activation("softmax", name = "softmax")(sc)

        # increment the number of layers
        n_layers += 1

        print(f"[INFO] {__class__.__name__}{n_layers} built successfully!")

        # return the constructed network architecture
        return Model(inputs = inputs, outputs = sc, name = f"{__class__.__name__}{n_layers}")

In [None]:
class SEMXResNet:
    @staticmethod
    def squeeze_excite_block(tensor, ratio = 16, name = "se_block"):
        init = tensor
        channel_axis = 1 if K.image_data_format() == "channels_first" else -1
        filters = init.shape[channel_axis]
        se_shape = (1, 1, filters)

        se = GlobalAveragePooling2D(name = name + "_gap")(init)
        se = Reshape(se_shape, name = name + "_reshape")(se)
        se = Dense(filters // ratio, kernel_initializer = 'he_normal', use_bias = False, name = name + "_squeeze")(se)
        se = Activation("relu", name = name + "_squeeze_relu")(se)
        se = Dense(filters, kernel_initializer = 'he_normal', use_bias = False, name = name + "_excite")(se)
        se = Activation("sigmoid", name = name + "_excite_sigmoid")(se)

        if K.image_data_format() == 'channels_first':
            se = Permute((3, 1, 2))(se)

        x = multiply([init, se], name = name + "_scale")
        return x

    @staticmethod
    def residual_module(data, K, stride, chan_dim, red = False, reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9,
                        bottleneck = True, name = "res_block"):
        # shortcut branch
        shortcut = data

        if bottleneck:
            # first bottleneck block - 1x1
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Mish(name = name + "_mish1")(bn1)
            conv1 = Conv2D(int(K * 0.25), (1, 1), use_bias = False, kernel_regularizer = l2(reg),
                           kernel_initializer = "he_normal", name = name + "_conv1")(act1)

            # conv block - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Mish(name = name + "_mish2")(bn2)
            conv2 = Conv2D(int(K * 0.25), (3, 3), strides = stride, padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)

            # second bottleneck block - 1x1
            bn3 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn3")(conv2)
            act3 = Mish(name = name + "_mish3")(bn3)
            conv3 = Conv2D(K, (1, 1), use_bias = False, kernel_regularizer = l2(
                reg), kernel_initializer = "he_normal", name = name + "_conv3")(act3)

            # se module
            conv3 = SEMXResNet.squeeze_excite_block(conv3, name = name + "_se_block")

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red:
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride,
                                            padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal",
                                  kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)

            # add the shortcut and final conv
            x = add([conv3, shortcut], name = name + "_add")

        else:
            # conv block 1 - 3x3
            bn1 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn1")(data)
            act1 = Mish(name = name + "_mish1")(bn1)
            conv1 = Conv2D(K, (3, 3), strides = stride, padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv1")(act1)

            # conv block 2 - 3x3
            bn2 = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = name + "_bn2")(conv1)
            act2 = Mish(name = name + "_mish2")(bn2)
            conv2 = Conv2D(K, (3, 3), padding = "same", use_bias = False,
                           kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = name + "_conv2")(act2)

            # se module
            conv2 = SEMXResNet.squeeze_excite_block(conv2, name = name + "_se_block")

            # if dimensions are to be reduced, apply a conv layer to the shortcut
            if red and stride != (1, 1):
                shortcut = AveragePooling2D(pool_size = (2, 2), strides = stride,
                                            padding = "same", name = name + "_avg_pool")(act1)
                shortcut = Conv2D(K, (1, 1), strides = (1, 1), use_bias = False, kernel_initializer = "he_normal",
                                  kernel_regularizer = l2(reg), name = name + "_red")(shortcut)
                shortcut = BatchNormalization(name = name + "_red_bn")(shortcut)

            # add the shortcut and final conv
            x = add([conv2, shortcut], name = name + "_add")

        # return the addition as the output of the residual block
        return x

    @staticmethod
    def build(height, width, depth, classes, stages, filters, stem_type = "imagenet", bottleneck = True,
              reg = 1e-4, bn_eps = 2e-5, bn_mom = 0.9):
        # set the input shape
        if K.image_data_format() == "channels_last":
            input_shape = (height, width, depth)
            chan_dim = -1
        else:
            input_shape = (depth, height, width)
            chan_dim = 1

        # initialize a counter to keep count of the total number of layers in the model
        n_layers = 0

        # input block
        inputs = Input(shape = input_shape)

        # stem
        if stem_type == "imagenet":
            x = Conv2D(filters[0], (3, 3), strides = (2, 2), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv1")(inputs)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv2")(x)
            x = Conv2D(filters[0], (3, 3), strides = (1, 1), use_bias = False, padding = "same",
                       kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "stem_conv3")(x)
            x = MaxPooling2D(pool_size = (3, 3), strides = (2, 2), padding = "same", name = "stem_max_pool")(x)
        elif stem_type == "cifar":
            x = Conv2D(filters[0], (3, 3), use_bias = False, padding = "same", kernel_initializer = "he_normal",
                       kernel_regularizer = l2(reg), name = "stem_conv")(inputs)

        # increment the number of layers
        n_layers += 1

        # loop through the stages
        for i in range(0, len(stages)):
            # set the stride value
            stride = (1, 1) if i == 0 else (2, 2)

            name = f"stage{i + 1}_res_block1"
            x = SEMXResNet.residual_module(x, filters[i + 1], stride, chan_dim, reg = reg, red = True,
                                           bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # loop through the number of layers in the stage
            for j in range(0, stages[i] - 1):
                # apply a residual module
                name = f"stage{i + 1}_res_block{j + 2}"
                x = SEMXResNet.residual_module(x, filters[i + 1], (1, 1), chan_dim, reg = reg,
                                               bn_eps = bn_eps, bn_mom = bn_mom, bottleneck = bottleneck, name = name)

            # increment the number of layers
            if bottleneck:
                n_layers += (3 * stages[i])
            else:
                n_layers += (2 * stages[i])

        # BN => RELU -> POOL
        x = BatchNormalization(axis = chan_dim, epsilon = bn_eps, momentum = bn_mom, name = "final_bn")(x)
        x = Mish(name = "final_mish")(x)
        x1 = GlobalAveragePooling2D(name = "global_avg_pooling")(x)
        x2 = GlobalMaxPooling2D(name = "global_max_pooling")(x)
        x = concatenate([x1, x2], axis = -1, name = "concatenate")

        # softmax classifier
        sc = Dense(classes, kernel_initializer = "he_normal", kernel_regularizer = l2(reg), name = "classifier")(x)
        sc = Activation("softmax", name = "softmax")(sc)

        # increment the number of layers
        n_layers += 1

        print(f"[INFO] {__class__.__name__}{n_layers} built successfully!")

        # return the constructed network architecture
        return Model(inputs = inputs, outputs = sc, name = f"{__class__.__name__}{n_layers}")

In [None]:
MODELS = {
    "xresnet20" : XResNet.build(32, 32, 3, 10, [6, 6, 6], [16, 16, 32, 64], stem_type = "cifar", bottleneck = False),
    # "mxresnet20" : MXResNet.build(32, 32, 3, 10, [6, 6, 6], [16, 16, 32, 64], stem_type = "cifar", bottleneck = False),
    # "se-mxresnet20" : SEMXResNet.build(32, 32, 3, 10, [6, 6, 6], [16, 16, 32, 64], stem_type = "cifar", bottleneck = False, reg = 1e-4)
    # "xresnet44" : XResNet.build(32, 32, 3, 10, [14, 14, 14], [16, 16, 32, 64], stem_type = "cifar", bottleneck = False),
    # "xresnet56" : XResNet.build(32, 32, 3, 10, [18, 18, 18], [16, 16, 32, 64], stem_type = "cifar", bottleneck = False)
}

# CALLBACKS #

In [None]:
class TrainingMonitor(BaseLogger):
    def __init__(self, fig_path, json_path = None, start_at = 0):
        # store the output path for the figure, the path to the JSON serialized file, and the starting epoch
        super(TrainingMonitor, self).__init__()
        self.fig_path = fig_path
        self.json_path = json_path
        self.start_at = start_at

    def on_train_begin(self, logs = {}):
        # initialize the history dictionary
        self.H = {}

        # if the JSON history path exists, load the training history
        if self.json_path is not None:
            if os.path.exists(self.json_path):
                self.H = json.loads(open(self.json_path).read())

                # check to see if a starting epoch was supplied
                if self.start_at > 0:
                    # loop over the entries in the history log and trim any entries that are past the starting epoch
                    for key in self.H.keys():
                        self.H[key] = self.H[key][:self.start_at]

    def on_epoch_end(self, epoch, logs = {}):
        # loop over the logs and update the loss, accuracy etc, for the entire training process
        for (k, v) in logs.items():
            l = self.H.get(k, [])
            l.append(v)
            self.H[k] = l

        # check to see if the training history should be serialized to file
        if self.json_path is not None:
            f = open(self.json_path, "w")
            f.write(json.dumps(self.H))
            f.close()

        # ensure atleast two epochs have passed before plotting
        if len(self.H["loss"]) > 1:
            # plot the training loss and accuracy
            N = np.arange(0, len(self.H["loss"]))
            plt.figure()
            plt.style.use("ggplot")
            plt.plot(N, self.H["loss"], label = "train_loss")
            plt.plot(N, self.H["val_loss"], label = "val_loss")
            plt.plot(N, self.H["accuracy"], label = "acc")
            plt.plot(N, self.H["val_accuracy"], label = "val_acc")
            plt.title("Training Loss [Epoch {}]".format(len(self.H["loss"])))
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.legend()

            # save the figure
            plt.savefig(self.fig_path)
            plt.close()

# TRAINING #

In [None]:
# initialize the dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# split the dataset into the train and validation splits
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train,
                                                  test_size = 0.1, random_state = 42,
                                                  stratify = y_train)

In [None]:
# define training constants
epochs = 180
bs = 128
steps_per_epoch = np.ceil(x_train.shape[0] / bs)
validation_steps = np.ceil(x_val.shape[0] / bs)
test_steps = np.ceil(x_test.shape[0] / bs)
model_name = "xresnet20"
init_lr = 0.1

In [None]:
# initialize the preprocessors
# rpp = ReflectionPadPreprocessor(4)
pp = PadPreprocessor(4)
fp = FlipPreprocessor(0.5)
patchp = PatchPreprocessor(32, 32)
mp = MeanPreprocessor([0.4914, 0.4822, 0.4465], [0.247, 0.2435, 0.2616])
iap = ImageToArrayPreprocessor()

# initialize the data generators
train_datagen = CifarGenerator(x_train, y_train, bs, preprocessors = [pp, fp, patchp, mp, iap]).generator()
val_datagen = CifarGenerator(x_val, y_val, bs, preprocessors = [mp, iap]).generator()

In [None]:
# step decay learning rate scheduler
def lr_sched(epoch):
    lr = init_lr

    if epoch < 1:
        lr = init_lr / 10
    elif epoch < 90:
        lr = init_lr
    elif epoch < 135:
        lr = init_lr / 10
    else:
        lr = init_lr / 100
    
    return lr

# # cosine decay learning rate scheduler
# class CosineScheduler(Callback):
#     def __init__(self, max_lr, steps_per_epoch, tot_epochs, warmup = 5):
#         # parent class constructor
#         super(CosineScheduler, self).__init__()

#         # initialize the instance variables
#         self.max_lr = max_lr
#         self.warm_steps = steps_per_epoch * warmup
#         self.reg_steps = steps_per_epoch * (tot_epochs - warmup)
#         self.history = {"lrs" : []}
    
#     def on_train_begin(self, logs = None):
#         # initialize a counter to keep track of the number of batches seen
#         self.iterations = 0
    
#     def on_batch_begin(self, batch, logs = None):
#         # increment the number of iterations
#         self.iterations += 1

#         # calculate the learning rate
#         if self.iterations <= self.warm_steps:
#             lr = (self.iterations / self.warm_steps) * self.max_lr
#         else:
#             lr = (self.max_lr / 2.0) * (1 + np.cos(((self.iterations - self.warm_steps) / self.reg_steps) * np.pi))
        
#         # update the learning rate
#         K.set_value(self.model.optimizer.lr, lr)

#         # add the current learning rate to the history dictionary
#         self.history["lrs"].append(lr)

In [None]:
# initialize the callbacks
mc = ModelCheckpoint(os.path.sep.join(["models", model_name + "_{epoch:03d}.h5"]))
tm = TrainingMonitor(f"{model_name}.png", f"{model_name}.json")
lr = LearningRateScheduler(lr_sched)
# cs = CosineScheduler(init_lr, steps_per_epoch, epochs)
callbacks = [mc, tm, lr]

In [None]:
# initialize the model and compile it
model = MODELS[model_name]
opt = SGD(lr = init_lr, momentum = 0.9)
# loss = CategoricalCrossentropy(label_smoothing = 0.1)
loss = CategoricalCrossentropy()
model.compile(optimizer = opt, loss = loss, metrics = ["accuracy"])

In [None]:
# train the model
model.fit_generator(train_datagen, steps_per_epoch = steps_per_epoch, epochs = epochs,
                    validation_data = val_datagen, validation_steps = validation_steps,
                    callbacks = callbacks)

# INFERENCE #

In [None]:
# load the trained model
test_model = load_model(f"models/{model_name}_{epochs}.h5", custom_objects = {"Downsample" : Downsample, "Mish" : Mish})

# initialize the data generator
test_gen = CifarGenerator(x_test, y_test, bs, preprocessors = [mp, iap]).generator()

# evaluate the model
H = test_model.evaluate_generator(test_gen, steps = test_steps)
print(H)