In [1]:
import tensorflow as tf

In [2]:
class Identity_Block(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(Identity_Block, self).__init__()
        self.n_filters = n_filters
        self.kernel_size = [(1,1),(3,3),(1,1)] if len(self.n_filters) == 3 else [(3,3),(3,3)]
        
        self.Residual = tf.keras.Sequential([])
        for i in range(len(n_filters)):
            self.Residual.add(tf.keras.layers.Conv2D(filters = self.n_filters[i],
                                                     kernel_size = self.kernel_size[i],
                                                     use_bias = False,
                                                     padding = 'same'
                                                    )
                             )
            self.Residual.add(tf.keras.layers.BatchNormalization())
            self.Residual.add(tf.keras.layers.ReLU())
        
    def call(self, X):
        y = self.Residual(X) + X
        y = tf.nn.relu(y)
        return y
        
        
class Convolutional_Block(tf.keras.layers.Layer):
    def __init__(self, n_filters, downsampling):
        super(Convolutional_Block, self).__init__()
        self.n_filters = n_filters
        self.downsampling = downsampling
        self.kernel_size = [(1,1),(3,3),(1,1)] if len(self.n_filters) == 3 else [(3,3),(3,3)]
        
        self.Residual = tf.keras.Sequential([])
        for i in range(len(self.n_filters)):
            self.Residual.add(tf.keras.layers.Conv2D(filters = self.n_filters[i],
                                                     kernel_size  = self.kernel_size[i],
                                                     strides = (2,2) if (i == 0)&(self.downsampling) else (1,1),
                                                     padding = 'valid' if (i == 0)&(self.downsampling) else 'same',
                                                     use_bias = False
                                                    ))
            self.Residual.add(tf.keras.layers.BatchNormalization())
            self.Residual.add(tf.keras.layers.ReLU())
            
        self.linear_projection = tf.keras.Sequential([
            tf.keras.layers.Conv2D(filters = self.n_filters[2],
                                   kernel_size = (1,1),
                                   strides = (2,2) if self.downsampling else (1,1),
                                   padding = 'valid',
                                   use_bias = False
                                  ),
            tf.keras.layers.BatchNormalization()
        ])
        
    def call(self, X):
        y = self.Residual(X) + self.linear_projection(X)
        y = tf.nn.relu(y)
        return y
    
        
class Stage(tf.keras.layers.Layer):
    def __init__(self, n_filters, n_layers:int, downsampling):
        super(Stage, self).__init__()
        self.downsampling = downsampling
        self.n_filters = n_filters
        self.n_layers = n_layers
        
        self.Blocks = tf.keras.Sequential([
            Convolutional_Block(self.n_filters, downsampling) if (self.n_filters[0] != self.n_filters[-1])|(self.downsampling) else Identity_Block(self.n_filters)
        ] + [
            Identity_Block(self.n_filters) for _ in range(1, self.n_layers - 1)
        ])
               
    def call(self, X):
        y = self.Blocks(X)
        return y

In [3]:
class ResNet(tf.keras.models.Model):
    def __init__(self, n_filters, n_layers, n_labels:int, last_activation:str):
        super(ResNet, self).__init__()
        self.n_filters = n_filters
        self.n_layers = n_layers
        self.n_labels = n_labels
        self.last_activation = last_activation
        
        self.Conv1 = tf.keras.layers.Conv2D(filters = 64,
                                            kernel_size = (7,7),
                                            strides = 2,
                                            padding = 'same',
                                            use_bias = False
                                           )
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.MP1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                             strides = (2,2),
                                             padding = 'same'
                                            )
        self.Stages = tf.keras.Sequential([
            Stage([x * (2 ** i) for x in self.n_filters], self.n_layers[i], False if i == 0 else True) for i in range(4)
        ])
        self.Classifier = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(), # No dropout in resnet
            tf.keras.layers.Dense(self.n_labels, activation = self.last_activation)
        ])
        
    def call(self, X):
        y = self.Conv1(X)
        y = self.BN1(y)
        y = tf.nn.relu(y)
        
        y = self.MP1(y)
        y = self.Stages(y)
        
        y = self.Classifier(y)
        return y

In [4]:
resnet = ResNet([64,64,256], [3,4,6,3], 1000, 'softmax')

In [5]:
resnet.build([16,224,224,3])

In [6]:
resnet.summary()

Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  9408      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
sequential_20 (Sequential)   (16, 7, 7, 2048)          17609728  
_________________________________________________________________
sequential_21 (Sequential)   (16, 1000)                2049000   
Total params: 19,668,392
Trainable params: 19,626,792
Non-trainable params: 41,600
_________________________________________________________________



|<div style = "width:200px">ModelSize</div>|<div style = "width:200px">n_filters</div>|<div style = "width:200px">n_layers</div>|
|:-|:-:|:-:|
|ResNet18|[64,64]|[2,2,2,2]|
|ResNet34|[64,64]|[2,4,6,3]|
|ResNet50|[64,64,256]|[3,4,6,3]|
|ResNet101|[64,64,256]|[3,4,23,3]|
|ResNet152|[64,64,256]|[3,8,36,3]|