**Writing Custom layers and Models**<br/>
We can make custom layers and models in keras by subclassing, the layer class and the model class.

**The Layer Class** <br/>

One of the central abstarction in Keras is the Layer class. 
1. It encapsulates the weights and a computation (i.e. call, the layer's forward pass). 
2. A layer can have trainable as well as non - trainable weights. 
3. Layers can be composed recursively.
4. Some layers, in particular the BatchNormalization layer and the Dropout layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a training (boolean) argument in the call() method.

**The Model Class**<br/>
Typically, a layer defines an inner computation block and the model defines the overall architecture, i.e the obeject which we train. Here we have 3 ResNet blocks which subclass (inherit) Layer class and a single ResNetModel which contains the three ResNetBlocks. It is similar to Layer class except that it allows to use model.compile(), model.fit() etc.


Refer to https://www.tensorflow.org/guide/keras/custom_layers_and_models#the_model_class for more details

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd

In [None]:
class ResNetBlock(tf.keras.layers.Layer): # A ResNet Block comprising of two convolution blocks followed by max_pooling for the first layer.

  def __init__(self, out_channels, kernel_size,  padding = "valid", max_pool = False):
    #out_channels - the list of no. of filters to be used in the convolution operations,
    #kernel_size - the size of the kernel to be used
    #max_pool - Falsem for the second and third layers and True for teh first layer. Allows MaxPooling in the first layer.
    super(ResNetBlock, self).__init__(name = 'Block')
    self.conv1 = tf.keras.layers.Conv2D(filters = out_channels[0], kernel_size = kernel_size, padding = padding, activation = "relu")
    self.conv2 = tf.keras.layers.Conv2D(filters = out_channels[1], kernel_size = kernel_size, padding = padding, activation = "relu")
    self.max_pool = max_pool
  
  def call(self, inputs):
    x = self.conv1(inputs)
    x = self.conv2(x)
    if self.max_pool:
      x = tf.keras.layers.MaxPooling2D(3)(x)
    return x

In [None]:
class ResNetModel(tf.keras.Model): # The ResNet Model comprising of the three ResNet Blocks and the remaining operations liek adding skip connections.

  def __init__(self):
    super(ResNetModel, self).__init__()
    self.block1 = ResNetBlock(out_channels = [32, 64], kernel_size = 3, max_pool = True)
    self.block2 = ResNetBlock(out_channels = [64, 64], kernel_size = 3, padding = "same")
    self.block3 = ResNetBlock(out_channels = [64, 64], kernel_size = 3, padding = "same")
    self.add = tf.keras.layers.Add()
    self.conv = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, activation = "relu")
    self.global_average_pooling = tf.keras.layers.GlobalAveragePooling2D()
    self.dense1 = tf.keras.layers.Dense(256, activation = "relu")
    self.dropout = tf.keras.layers.Dropout(0.5)
    self.dense2 = tf.keras.layers.Dense(10, activation= "softmax")
  
  def call(self, inputs, training = False):
    block1_output = self.block1(inputs)
    block2_output = self.block2(block1_output)
    block2_output = self.add([block1_output, block2_output])
    block3_output = self.block3(block2_output)
    block3_output = self.add([block2_output, block3_output])
    x = self.conv(block3_output)
    x = self.global_average_pooling(x)
    x = self.dense1(x)
    x = self.dropout(x, training = training)
    x = self.dense2(x)
    return x

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [None]:
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10) #convert to one hot, since we will be using categorical cross entropy.
y_test = tf.keras.utils.to_categorical(y_test, 10)

In [None]:
lr = 1e-3
optimizer = tf.keras.optimizers.RMSprop(lr = lr)
loss_fn = tf.keras.losses.CategoricalCrossentropy() #Note the differnce between categorical cross entropy and sparse categorical cross entropy.
batch_size = 64

In [None]:
model = ResNetModel()
model.compile(optimizer = optimizer, loss = loss_fn, metrics = ["acc"])
history = model.fit(x_train[:15000], y_train[:15000], epochs=20, validation_split = 0.3) #Also add saving complete model, model checkpointing only stores the

In [None]:
import matplotlib.pyplot as plt
print(history.history.keys())
#  "Accuracy"
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()
# "Loss"
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()

**Writing a training loop from scratch**<br/>
It provides very low level control over training and validation. <br/>
Typical workflow of a training loop:
1. Open a loop to iterate over epochs.
2. Over each epoch, open a loop to iterate over batches (steps).
3. Open a GradientTape scope and call the model, keep training =  true, in this step and compute the loss.
4. Outside the scope retrieve the gradients w.r.t. loss and update the weights.
5. You can add metrics, and updating them appropriately in the loop.


In [None]:
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(x_train[:15000], y_train[:15000], test_size = 0.3)
model = ResNetModel() #Get the model.
x_train.shape

(10500, 32, 32, 3)

In [None]:
#Create tensorflow Datasets.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size = 1024).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
# Prepare the metrics.
train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

In [None]:
import time

epochs = 20
#iterate over the epochs
for epoch in range(epochs): 
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
      #Calling a model inside a GradientTape scope enables you to retrieve the gradients 
      #of the trainable weights of the layer with respect to a loss value. 
      #Using an optimizer instance, you can use these gradients to update these variables (which you can retrieve using model.trainable_weights).
      #It records the operations run during the forward pass and enables auto-differentiation.
        with tf.GradientTape() as tape:
            preds = model(x_batch_train, training=True) #Predcition for the mini-batch.
            #Compute the loss for this mini-batch.
            loss_value = loss_fn(y_batch_train, preds) 
        #Retieve the gradients from the gradient tape of the trainable variables.
        grads = tape.gradient(loss_value, model.trainable_weights)

        #Update the weights by running one-step of the optimizer.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * 64))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

In [None]:
lr = 1e-3
optimizer = tf.keras.optimizers.RMSprop(lr = lr)
loss_fn = tf.keras.losses.CategoricalCrossentropy() #Note the differnce between categorical cross entropy and sparse categorical cross entropy.
batch_size = 64

In [None]:
#for optimization purpose, refer to https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch for the details
#as to why this helps to improve performance.
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)



In [None]:
import time

model = ResNetModel()

epochs = 20
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train) #Run the train step.

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * 64))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val) #Run the validation step.

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))