### Paper https://arxiv.org/pdf/1512.03385.pdf

In [None]:
import tensorflow as tf

In [None]:
BATCH_SIZE = 64

In [None]:
cifar10 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0

val_images = train_images[-10000:]
val_labels = train_labels[-10000:]
train_images = train_images[:-10000]
train_labels = train_labels[:-10000]

TRAIN_IMG_COUNT = train_labels.shape[0]
VAL_IMG_COUNT = val_labels.shape[0]
print(train_labels.shape)
print(val_labels.shape)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE ## Auto tune tf.data hyper parameters for 

def prepare_for_training(ds, cache=False, shuffle_buffer_size=1000):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str): ## if caching to a file
            ds = ds.cache(cache)
        else: ## cache in memory
            ds = ds.cache()
    ## shuffle data, otherwise, data will be cached with the same patten (not random)
    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever (Reinitialise the dataset after each training iteration)
    ds = ds.repeat()

    ds = ds.batch(BATCH_SIZE)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=AUTOTUNE)

    return ds

In [None]:
train_labels_onehot =  tf.keras.utils.to_categorical(train_labels, 10)
val_labels_onehot = tf.keras.utils.to_categorical(val_labels, 10)
print(train_labels.shape)
train_ds_tensor = tf.data.Dataset.from_tensor_slices((train_images, train_labels_onehot))
val_ds_tensor = tf.data.Dataset.from_tensor_slices((val_images, val_labels_onehot))

train_ds = prepare_for_training(train_ds_tensor)
val_ds = prepare_for_training(val_ds_tensor)

for d in train_ds.take(1):
    print(d[1].shape)

<img src="https://d2l.ai/_images/resnet-block.svg"/>

In [None]:
# Residual block (Identity)
def identity_block(x, filters, kernel_size, kernel_regularizer, kernel_initializer):
    x_id = x
    filter1, filter2, filter3 = filters # unpack the list of filters
    x = tf.keras.layers.Conv2D(
                                filter1, (1,1), strides=(1, 1), padding='valid',
                                data_format=None, dilation_rate=(1, 1), groups=1, activation=None,
                                use_bias=True, kernel_initializer=kernel_initializer,
                                bias_initializer='zeros', kernel_regularizer=kernel_regularizer,
                                bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,
                                bias_constraint=None
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.relu(x)
    
    x = tf.keras.layers.Conv2D(
                                filter2, (kernel_size, kernel_size), padding='same',
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.relu(x)
    
    x = tf.keras.layers.Conv2D(
                                filter3, (1, 1),
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Add()([x, x_id])
    x = tf.keras.activations.relu(x)
    
    return x

In [None]:
# Residual block (Identity)
def conv_block(x, filters, kernel_size, kernel_regularizer, kernel_initializer, strides=(2,2)):
    filter1, filter2, filter3 = filters # unpack the list of filters
    x_id = tf.keras.layers.Conv2D(
                                filter3, (1,1), strides=strides,
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer,
                            )(x)
    

    x = tf.keras.layers.Conv2D(
                                filter1, (1,1), strides=strides,
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer,
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.relu(x)
    
    x = tf.keras.layers.Conv2D(
                                filter2, (kernel_size, kernel_size), padding='same',
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.relu(x)
    
    x = tf.keras.layers.Conv2D(
                                filter3, (1, 1),
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer
                            )(x)
    x = tf.keras.layers.BatchNormalization()(x)
    
    x = tf.keras.layers.Add()([x, x_id])
    x = tf.keras.activations.relu(x)
    
    return x

In [None]:
# Function test of the blocks
inp = tf.random.normal((1, 32, 32, 256))
out = identity_block(inp, [64,64,256], 3, tf.keras.regularizers.L2(), tf.keras.initializers.HeNormal())
print(out.shape)

inp = tf.random.normal((1, 32, 32, 3))
out = conv_block(inp, [64,64,256], 3, tf.keras.regularizers.L2(), tf.keras.initializers.HeNormal())
print(out.shape)

In [None]:
def resnet50(input_shape, 
             kernel_regularizer=tf.keras.regularizers.L2, 
             kernel_initializer=tf.keras.initializers.HeNormal,
             classes=10):
    # Input layer
    inputs = tf.keras.layers.Input(input_shape)
    '''
    x = tf.keras.layers.Lambda( 
                        lambda image: tf.image.resize( 
                            image, 
                            (224, 224), 
                            method = tf.image.ResizeMethod.BICUBIC,
                            preserve_aspect_ratio = True
                            )
                        )(inputs)
    '''
    # First level of feature extraction
    x = tf.keras.layers.Conv2D(
                                64, (7, 7), strides=(2, 2),
                                kernel_initializer=kernel_initializer,
                                kernel_regularizer=kernel_regularizer
                            )(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.activations.relu(x)
    x = tf.keras.layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
    
    # Resnet blocks
    x = conv_block(x, [64, 64, 256], 3, kernel_regularizer, kernel_initializer, strides=(1, 1))
    x = identity_block(x, [64, 64, 256], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [64, 64, 256], 3, kernel_regularizer, kernel_initializer)

    x = conv_block(x, [128, 128, 512], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [128, 128, 512], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [128, 128, 512], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [128, 128, 512], 3, kernel_regularizer, kernel_initializer)

    x = conv_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [256, 256, 1024], 3, kernel_regularizer, kernel_initializer)

    x = conv_block(x, [512, 512, 2048], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [512, 512, 2048], 3, kernel_regularizer, kernel_initializer)
    x = identity_block(x, [512, 512, 2048], 3, kernel_regularizer, kernel_initializer)
        
    # x = tf.keras.layers.AveragePooling2D((7, 7))(x)
    
    x = tf.keras.layers.Flatten()(x)
    
    x = tf.keras.layers.Dense(classes, activation="softmax")(x)

    model = tf.keras.Model(inputs, x)
    
    return model

In [None]:
inp = tf.random.normal((1, 32, 32, 3))
model = resnet50((32, 32, 3), tf.keras.regularizers.L1(1e-2), tf.keras.initializers.HeNormal())
out = model(inp)
print(out.shape)
print(model.summary())

In [None]:
METRICS = [
    tf.keras.metrics.CategoricalAccuracy(),
    tf.keras.metrics.Precision(name='precision'),
    tf.keras.metrics.Recall(name='recall')
]

In [None]:
opt = tf.keras.optimizers.SGD(learning_rate=1e-3)

# Prepare the model for training
model.compile(optimizer=opt, # Stochastic gradient descent optimiser
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False), # Crossentropy loss
              metrics=METRICS
             ) # Accuracy measure

In [None]:
## To prevent the model from becoming worst (e.g. overfitting) stop the training before the issues start using the Early stopping callback
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,
                                                     restore_best_weights=True)

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")

history = model.fit(
    train_ds,
    steps_per_epoch=TRAIN_IMG_COUNT // BATCH_SIZE,
    epochs=100,
    validation_data=val_ds,
    validation_steps=VAL_IMG_COUNT // BATCH_SIZE,
    # callbacks=[early_stopping_cb]
    callbacks=[tensorboard_callback]
)

In [None]:
import matplotlib.pyplot as plt
## visualize the performance
fig, ax = plt.subplots(1, 4, figsize=(20, 3))
ax = ax.ravel()
print(history.history.keys())

for i, met in enumerate(['precision', 'recall', 'categorical_accuracy', 'loss']):
    ax[i].plot(history.history[met])
    ax[i].plot(history.history['val_' + met])
    ax[i].set_title('Model {}'.format(met))
    ax[i].set_xlabel('epochs')
    ax[i].set_ylabel(met)
    ax[i].legend(['train', 'val'])
