## Residual Connection

In [1]:
import tensorflow as tf
from tensorflow import keras

In [11]:
class Residual(keras.Model):
  def __init__(self, num_channels, use_1x1conv=False, strides=1, **kwargs):
    super(Residual, self).__init__(**kwargs)
    self.conv1 = keras.layers.Conv2D(num_channels,
                                     padding='same',
                                     kernel_size=3,
                                     strides=strides)
    self.conv2 = keras.layers.Conv2D(num_channels,
                                     padding='same',
                                     kernel_size=3)
    
    if use_1x1conv:
      self.conv3 = keras.layers.Conv2D(num_channels,
                                       kernel_size=1,
                                       strides=strides)
    else:
      self.conv3 = None

    self.bn1 = keras.layers.BatchNormalization()
    self.bn2 = keras.layers.BatchNormalization()

  def call(self, X):
    Y = tf.nn.relu(self.bn1(self.conv1(X)))
    Y = self.bn2(self.conv2(Y))

    if self.conv3:
      X = self.conv3(X)

    return tf.nn.relu(Y + X)

In [12]:
block = Residual(3)
X = tf.random.uniform((2, 224, 224, 3))
block(X).shape

TensorShape([2, 224, 224, 3])

In [13]:
block = Residual(6, use_1x1conv=True, strides=2)
block(X).shape

TensorShape([2, 112, 112, 6])