In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, MaxPooling2D, GlobalAvgPool2D, Dense
from tensorflow.keras.activations import relu

In [2]:
class Residual(tf.keras.Model):
    def __init__(self, num_channels, use_1x1_conv=False, strides=1):
        super().__init__()
        self.conv1 = Conv2D(num_channels, kernel_size=3, strides=strides, padding='same')
        self.conv2 = Conv2D(num_channels, kernel_size=3, padding='same')
        self.conv3 = None
        if use_1x1_conv:
            self.conv3 = Conv2D(num_channels, kernel_size=1, strides=strides)
        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()

    def call(self, X):
        Y = relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3 is not None:
            X = self.conv3(X)
        Y += X
        return relu(Y)

In [3]:
class ResnetBlock(tf.keras.layers.Layer):
    def __init__(self, num_channels, num_residuals, first_block=False, **kwargs):
        super(ResnetBlock, self).__init__(**kwargs)
        self.residual_layers = []
        for i in range(num_residuals):
            if i == 0 and not first_block:
                self.residual_layers.append(Residual(num_channels, use_1x1_conv=True, strides=2))
            else:
                self.residual_layers.append(Residual(num_channels))

    def call(self, X):
        for layer in self.residual_layers.layers:
            X = layer(X)
        return X

In [15]:
def ResNet_50(input_shape):
    model =  Sequential([
        Conv2D(64, kernel_size=7, strides=2, padding='same', input_shape = input_shape),
        BatchNormalization(),
        ReLU(),
        MaxPooling2D(pool_size=3, strides=2, padding='same'),

        ResnetBlock(64, 3, first_block=True),
        ResnetBlock(128, 4),
        ResnetBlock(256, 6),
        ResnetBlock(512, 3),

        GlobalAvgPool2D(),
        Dense(units=10)])
    return model

def ResNet_101(input_shape):
    model =  Sequential([
        Conv2D(64, kernel_size=7, strides=2, padding='same', input_shape = input_shape),
        BatchNormalization(),
        ReLU(),
        MaxPooling2D(pool_size=3, strides=2, padding='same'),

        ResnetBlock(64, 3, first_block=True),
        ResnetBlock(128, 4),
        ResnetBlock(256, 23),
        ResnetBlock(512, 3),

        GlobalAvgPool2D(),
        Dense(units=10)])
    return model

In [16]:
input_shape = (224, 224, 3)
model = ResNet_50(input_shape)
model.summary()

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_148 (Conv2D)          (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_133 (Bat (None, 112, 112, 64)      256       
_________________________________________________________________
re_lu_5 (ReLU)               (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 56, 56, 64)        0         
_________________________________________________________________
resnet_block_20 (ResnetBlock (None, 56, 56, 64)        223104    
_________________________________________________________________
resnet_block_21 (ResnetBlock (None, 28, 28, 128)       1119360   
_________________________________________________________________
resnet_block_22 (ResnetBlock (None, 14, 14, 256)      

In [17]:
input_shape = (224, 224, 3)
model = ResNet_101(input_shape)
model.summary()

Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_184 (Conv2D)          (None, 112, 112, 64)      9472      
_________________________________________________________________
batch_normalization_166 (Bat (None, 112, 112, 64)      256       
_________________________________________________________________
re_lu_6 (ReLU)               (None, 112, 112, 64)      0         
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 56, 56, 64)        0         
_________________________________________________________________
resnet_block_24 (ResnetBlock (None, 56, 56, 64)        223104    
_________________________________________________________________
resnet_block_25 (ResnetBlock (None, 28, 28, 128)       1119360   
_________________________________________________________________
resnet_block_26 (ResnetBlock (None, 14, 14, 256)      