<a href="https://colab.research.google.com/github/dominiksakic/NETworkingMay/blob/main/11_residual_connections.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
# Example of residual connections to fight the vanishing gradients

from tensorflow import keras
from tensorflow.keras import layers

# Input -> Block -> Input + Block result -> Forward
inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(32, 3, activation="relu")(inputs)
residual = x # Set aside
x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
residual = layers.Conv2D(64, 1)(residual) # Project the residual to the correct shape to add it
x = layers.add([x, residual])

In [11]:
# Case max-pooling
inputs = keras.Input(shape=(32, 32, 3))
x = layers.Conv2D(32, 3, activation="relu")(inputs)
residual = x

x = layers.Conv2D(64, 3, activation="relu", padding="same")(x)
x = layers.MaxPooling2D(2, padding="same")(x)

residual = layers.Conv2D(64, 1, strides=2)(residual) # You have to match the ouput size of the prev layer before adding
x = layers.add([x, residual])

In [12]:
# simple convnet implementation

# Util function that scales the residual block correctly
def residual_block(x, filters, pooling=False):
  residual = x
  x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
  x = layers.Conv2D(filters, 3, activation="relu", padding="same")(x)
  if pooling:
    x = layers.MaxPooling2D(2, padding="same")(x)
    residual = layers.Conv2D(filters, 1, strides=2)(residual)
  elif filters != residual.shape[-1]: # Case change in channels
    residual = layers.Conv2D(filters, 1)(residual)
  x = layers.add([x, residual])
  return x

inputs = keras.Input(shape=(28, 28, 1))
x = layers.Rescaling(1./255)(inputs)
x = residual_block(x, filters=28, pooling=True)
x = residual_block(x, filters=32, pooling=True)
x = residual_block(x, filters=64, pooling=False)
x = layers.GlobalAveragePooling2D()(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.summary()