In [1]:
import tensorflow as tf

In [2]:
class Identity_Block(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(Identity_Block, self).__init__()
        self.n_filters = n_filters
        self.Deep = True if len(n_filters) == 3 else False
        
        self.Conv1 = tf.keras.layers.Conv2D(filters = self.n_filters[0],
                                            kernel_size = (1,1) if self.Deep else (3,3),
                                            padding = 'same'
                                           )
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.Conv2 = tf.keras.layers.Conv2D(filters = self.n_filters[1],
                                            kernel_size = (3,3),
                                            padding = 'same'
                                           )
        self.BN2 = tf.keras.layers.BatchNormalization()
        self.Conv3 = tf.keras.layers.Conv2D(filters = self.n_filters[2],
                                            kernel_size = (1,1),
                                            padding = 'same'
                                           ) if self.Deep else None
        self.BN3 = tf.keras.layers.BatchNormalization() if self.Deep else None
        
    def call(self, X):
        y = self.Conv1(X)
        y = self.BN1(y)
        y = tf.nn.relu(y)
        
        y = self.Conv2(y)
        y = self.BN2(y)
        if self.Deep:
            y = tf.nn.relu(y)
            
            y = self.Conv3(y)
            y = self.BN3(y)
            
        y = y + X
        y = tf.nn.relu(y)
        return y
        
        
class Convolutional_Block(tf.keras.layers.Layer):
    def __init__(self, n_filters, strides:int):
        super(Convolutional_Block, self).__init__()
        self.n_filters = n_filters
        self.strides = strides
        self.Deep = True if len(self.n_filters) == 3 else False
        
        self.Conv1 = tf.keras.layers.Conv2D(filters = self.n_filters[0],
                                            kernel_size = (1,1) if self.Deep else (3,3),
                                            strides = (self.strides,self.strides)
                                           )
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.Conv2 = tf.keras.layers.Conv2D(filters = self.n_filters[1],
                                            kernel_size = (3,3),
                                            padding = 'same'
                                           )
        self.BN2 = tf.keras.layers.BatchNormalization()
        self.Conv3 = tf.keras.layers.Conv2D(filters = self.n_filters[2],
                                            kernel_size = (1,1),
                                           ) if self.Deep else None
        self.BN3 = tf.keras.layers.BatchNormalization() if self.Deep else None
        self.ConvS = tf.keras.layers.Conv2D(filters = self.n_filters[2] if self.Deep else self.n_filters[1],
                                                        kernel_size = (1,1),
                                                        strides = (self.strides,self.strides)
                                                       )
        self.BNS = tf.keras.layers.BatchNormalization()
        
    def call(self, X):
        y = self.Conv1(X)
        y = self.BN1(y)
        y = tf.nn.relu(y)
        
        y = self.Conv2(y)
        y = self.BN2(y)
        if self.Deep:
            y = tf.nn.relu(y)
            
            y = self.Conv3(y)
            y = self.BN3(y)
            
        X = self.ConvS(X)
        X = self.BNS(X)
        
        y = y + X
        y = tf.nn.relu(y)
        return y
    
        
class Stage(tf.keras.layers.Layer):
    def __init__(self, n_filters, n_layers:int, double_channel:bool=True):
        super(Stage, self).__init__()
        self.double_channel = double_channel
        self.strides = 2 if self.double_channel else 1
        self.n_filters = n_filters
        self.n_layers = n_layers
        
        self.Convolution_Block_ = Convolutional_Block(self.n_filters, self.strides)
        self.Identity_Blocks = [Identity_Block(self.n_filters) for _ in range(self.n_layers - 1)]
        
    def call(self, X):
        y = self.Convolution_Block_(X)
        for IB in self.Identity_Blocks:
            y = IB(y)
        return y

In [3]:
class ResNet(tf.keras.models.Model):
    def __init__(self, n_filters, n_layers, n_labels:int, last_activation:str):
        super(ResNet, self).__init__()
        self.n_filters = n_filters
        self.double = [True] * 4 if self.n_filters[-1] == 64 else [False] + [True] * 3
        self.n_layers = n_layers
        self.n_labels = n_labels
        self.last_activation = last_activation
        
        self.Conv1 = tf.keras.layers.Conv2D(filters = 64,
                                            kernel_size = (7,7),
                                            strides = 2,
                                            padding = 'same',
                                            use_bias = False
                                           )
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.MP1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                             strides = (2,2)
                                            )
        self.Stages = tf.keras.Sequential()
        for i in range(4):
            self.Stages.add(Stage([x * (2 ** i) for x in self.n_filters], self.n_layers[i], self.double[i])) 
        self.Classifier = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(self.n_labels, activation = self.last_activation)
        ])
        
    def call(self, X):
        y = self.Conv1(X)
        y = self.BN1(y)
        y = tf.nn.relu(y)
        
        y = self.MP1(y)
        y = self.Stages(y)
        
        y = self.Classifier(y)
        return y

In [4]:
resnet = ResNet([64,64,256], [3,4,6,3], 1000, 'softmax')

In [5]:
resnet.build([16,224,224,3])

In [6]:
resnet.summary()

Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  9408      
_________________________________________________________________
batch_normalization (BatchNo multiple                  256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
sequential (Sequential)      (16, 7, 7, 2048)          23577984  
_________________________________________________________________
sequential_1 (Sequential)    (16, 1000)                2049000   
Total params: 25,636,648
Trainable params: 25,583,528
Non-trainable params: 53,120
_________________________________________________________________



|<div style = "width:200px">**ModelSize**</div>|<div style = "width:200px">**n_filters**</div>|<div style = "width:200px">**n_layers**</div>|
|:-|:-:|:-:|
|ResNet18|[64,64]|[2,2,2,2]|
|ResNet34|[64,64]|[2,4,6,3]|
|ResNet50|[64,64,256]|[3,4,6,3]|
|ResNet101|[64,64,256]|[3,4,23,3]|
|ResNet152|[64,64,256]|[3,8,36,3]|