### Necessary Imports 

In [23]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import Input
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import Flatten, add
from tensorflow.keras.layers import AveragePooling2D, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import math

### Model Parameters

In [52]:
batch_size = 32
epochs = 10
data_augmentation = True
num_classes = 10
n = 3
depth = n * 6 + 2
subtract_pixel_mean = True

### Loadind Data 

In [49]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

X_train = X_train.reshape(50000, 32, 32, 3)
X_test = X_test.reshape(10000, 32, 32, 3)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

if subtract_pixel_mean:
    X_train_mean = np.mean(X_train, axis=0)
    X_train -= X_train_mean
    X_test -= X_train_mean

num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

input_shape = X_train.shape[1:]

In [50]:
X_train.shape
#y_train.shape

(50000, 32, 32, 3)

### Model Architecture 

In [36]:
def resnet_layer(inputs,
                 kernel_size=3,
                 num_filters=16,
                 strides=1,
                 activation='relu',
                 batch_normalization=True,
                 conv_first=True):
    """2D Convolution-BatchNormalization-Activation Residual Block"""
    
    conv = Conv2D(num_filters,
                 kernel_size=kernel_size,
                 strides=strides,
                 padding='same',
                 kernel_initializer='he_normal',
                 kernel_regularizer=l2())
    x = inputs
    
    if conv_first:
        x = conv(x)
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
    
    else:
        if batch_normalization:
            x = BatchNormalization()(x)
        if activation is not None:
            x = Activation(activation)(x)
        x = conv(x)
    return x
    

def resnet_v1(input_shape, depth, num_classes):
    if (depth-2)%6 !=0:
        raise ValueError('Depth shoutld be 6n+2: eg-20, 32, 44 etc')
    
    # Model Definition
    num_filters = 16
    num_res_blocks = int((depth-2)/6)
    
    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    # Instance the stack of residual units
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack>0 and res_block==0:
                strides=2 # Downsampling
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            
            if stack>0 and res_block==0:
                # linear projection residual shortcut
                # connection to match changed dims
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = add([x, y])
            x =  Activation('relu')(x)
            
        num_filters *= 2
        
    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    kernel_initializer='he_normal')(y)
    model = Model(inputs=inputs, outputs=outputs)
    return model
            

In [37]:
model = resnet_v1(input_shape=input_shape, depth = depth, num_classes = num_classes)
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_66 (Conv2D)              (None, 32, 32, 16)   448         input_6[0][0]                    
__________________________________________________________________________________________________
batch_normalization_60 (BatchNo (None, 32, 32, 16)   64          conv2d_66[0][0]                  
__________________________________________________________________________________________________
activation_59 (Activation)      (None, 32, 32, 16)   0           batch_normalization_60[0][0]     
____________________________________________________________________________________________

### Model Compile

In [43]:
model.compile(loss='categorical_crossentropy',
              optimizer=Adam(lr=0.0001),
              metrics=['acc'])

#plot_model(model, to_file="resnet_v1.png", show_shapes=True)



### Model saving directory 

In [32]:
save_dir = os.path.join(os.getcwd(), 'saved_moddels')
model_name = 'cifar_10_resnet_v1_model.{epoch:03d}.h5'
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
filepath = os.path.join(save_dir, model_name)


### Training

In [53]:
checkpoint = ModelCheckpoint(filepath=filepath,
                             monitor='val_accuracy',
                             verbose=1,
                             save_best_only=True)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
                               cooldown=0,
                               patience=5,
                               min_lr=0.5e-6)

callbacks = [checkpoint, lr_reducer]

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(X_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(X_test, y_test),
              shuffle=True,
              callbacks=callbacks)
else:
    print('Using real-time data augmentation.')
    # this will do preprocessing and realtime data augmentation:
    datagen = ImageDataGenerator(
        # set input mean to 0 over the dataset
        featurewise_center=False,
        # set each sample mean to 0
        samplewise_center=False,
        # divide inputs by std of dataset
        featurewise_std_normalization=False,
        # divide each input by its std
        samplewise_std_normalization=False,
        # apply ZCA whitening
        zca_whitening=False,
        # randomly rotate images in the range (deg 0 to 180)
        rotation_range=0,
        # randomly shift images horizontally
        width_shift_range=0.1,
        # randomly shift images vertically
        height_shift_range=0.1,
        # randomly flip images
        horizontal_flip=True,
        # randomly flip images
        vertical_flip=False)
    
  

    # compute quantities required for featurewise normalization
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(X_train)

    steps_per_epoch =  math.ceil(len(X_train) / batch_size)
    # fit the model on the batches generated by datagen.flow().
    model.fit(x=datagen.flow(X_train, y_train, batch_size=batch_size),
              verbose=1,
              epochs=epochs,
              validation_data=(X_test, y_test),
              steps_per_epoch=steps_per_epoch,
              callbacks=callbacks)

Using real-time data augmentation.
  ...
    to  
  ['...']
Train for 1563 steps, validate on 10000 samples
Epoch 1/10


KeyboardInterrupt: 