Reference: https://d2l.ai/chapter_convolutional-modern/resnet.html

<img src="images/resnet.png">

In [21]:
import tensorflow as tf

# If use_1x1conv is False, num_channels should be equal to X.shape[-1].
# In this case, output.shape = input.shape

# If use_1x1conv is True, num_channels is not necessarily equal to X.shape[-1].
# In this case, output.shape = (b, w, h, num_channels) and X.shape = (b, w, h, c).

class Residual(tf.keras.Model):  #@save
    """The Residual block of ResNet."""
    def __init__(self, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(num_channels, 3, strides, 'same')
        self.conv2 = tf.keras.layers.Conv2D(num_channels, 3, padding='same')
        self.conv3 = None
        if use_1x1conv:
            self.conv3 = tf.keras.layers.Conv2D(num_channels, 1, strides)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.bn2 = tf.keras.layers.BatchNormalization()

    def call(self, X):                                             # X.shape = (b, w, h, c) and Assume strides=1
        Y = tf.keras.activations.relu(self.bn1(self.conv1(X)))     # Y.shape = (b, w, h, num_channels)
        Y = self.bn2(self.conv2(Y))                                # Y.shape = (b, w, h, num_channels)
        if self.conv3 is not None:
            X = self.conv3(X)                                      # X.shape = (b, w, h, num_channels)
        Y += X
        return tf.keras.activations.relu(Y)
    
# Generally, if X.shape = (b, w, h, c), then 
# output.shape = (b, ceil(w/strides), ceil(h/strides), num_channels)

<img src="images/resnet18.png">

In [24]:
class ResnetBlock(tf.keras.layers.Layer):
    def __init__(self, num_channels, num_residuals, first_block=False, **kwargs):
        super(ResnetBlock, self).__init__(**kwargs)
        self.residual_layers = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                self.residual_layers.append(Residual(num_channels, use_1x1conv=True, strides=2))
            else:
                self.residual_layers.append(Residual(num_channels))

    def call(self, X):
        for layer in self.residual_layers.layers:
            X = layer(X)
        return X

In [23]:
def net():
    return tf.keras.Sequential([
        # The following layers are the same as b1 that we created earlier
        tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Activation('relu'),
        tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same'),
        ResnetBlock(64, 2, first_block=True),
        ResnetBlock(128, 2),
        ResnetBlock(256, 2),
        ResnetBlock(512, 2),
        tf.keras.layers.GlobalAvgPool2D(),
        tf.keras.layers.Dense(units=10)])

In [25]:
X = tf.random.uniform(shape=(1, 224, 224, 3))
for layer in net().layers:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Conv2D output shape:	 (1, 112, 112, 64)
BatchNormalization output shape:	 (1, 112, 112, 64)
Activation output shape:	 (1, 112, 112, 64)
MaxPooling2D output shape:	 (1, 56, 56, 64)
ResnetBlock output shape:	 (1, 56, 56, 64)
ResnetBlock output shape:	 (1, 28, 28, 128)
ResnetBlock output shape:	 (1, 14, 14, 256)
ResnetBlock output shape:	 (1, 7, 7, 512)
GlobalAveragePooling2D output shape:	 (1, 512)
Dense output shape:	 (1, 10)
