In [1]:
import tensorflow as tf

In [2]:
class Conv(tf.keras.layers.Layer):
    def __init__(self, n_channels, kernel_size, Downsampling = False):
        super(Conv, self).__init__()
        self.n_channels = n_channels
        self.Downsampling = Downsampling
        self.kernel_size = kernel_size
        
        self.BN = tf.keras.layers.BatchNormalization()
        self.Conv1 = tf.keras.layers.Conv2D(filters = n_channels,
                                            kernel_size = self.kernel_size,
                                            strides = (2, 2) if self.Downsampling else (1, 1),
                                            padding = 'valid' if self.Downsampling else 'same',
                                            activation = 'linear',
                                            use_bias = False
                                           )
    
    def call(self, X):
        y = self.BN(X)
        y = tf.nn.relu(y)
        y = self.Conv1(y)
        return(y)
        
class ResdualUnit(tf.keras.layers.Layer):
    def __init__(self, n_channels, increase_channel = False, Downsampling = False):
        super(ResdualUnit, self).__init__()
        self.n_channels = n_channels
        self.increase_channel = increase_channel
        self.Downsampling = Downsampling
        
        self.F = tf.keras.Sequential([
            Conv(self.n_channels[0], (1, 1), Downsampling = self.Downsampling),
            Conv(self.n_channels[1], (3, 3)),
            Conv(self.n_channels[2], (1, 1))
        ])
        if self.increase_channel:
            self.h = tf.keras.layers.Conv2D(filters = self.n_channels[2],
                                            kernel_size = (1,1),
                                            strides = (2, 2) if self.Downsampling else (1, 1),
                                            activation = 'linear'
                                           )
        else:
            self.h = tf.keras.layers.Layer()
        
    def call(self, X):
        y = self.F(X)
        X = self.h(X)
        return y + X
    
class Stage(tf.keras.layers.Layer):
    def __init__(self, n_channels, n_layers, Downsampling=True):
        super(Stage, self).__init__()
        self.n_channels = n_channels
        self.n_layers = n_layers
        self.Downsampling = Downsampling
        
        self.ResidualUnits = tf.keras.Sequential([
            ResdualUnit(self.n_channels, increase_channel = True, Downsampling = self.Downsampling)
        ] + [
            ResdualUnit(self.n_channels) for _ in range(self.n_layers - 1)
        ])
    
    def call(self, X):
        return self.ResidualUnits(X)

In [3]:
class ResNet1k(tf.keras.models.Model):
    def __init__(self, ):
        super(ResNet1k, self).__init__()
        
        self.Conv = tf.keras.layers.Conv2D(filters = 64,
                                            kernel_size = (7,7),
                                            strides = (2,2),
                                            activation  = 'relu'
                                           )
        self.MP = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                            strides = (2,2)
                                           )
        self.Stages = tf.keras.Sequential([
            Stage([64, 64, 256], 111, Downsampling = False),
            Stage([128, 128, 512], 111),
            Stage([256, 256, 1024], 111),
            Stage([512, 512, 2048], 111)
        ])
        
        self.Clasifier = tf.keras.Sequential([
            tf.keras.layers.GlobalAvgPool2D(),
            tf.keras.layers.Dense(1000, activation = 'softmax')
        ])
        
    def call(self, X):
        y = self.Conv(X)
        y = self.MP(y)
        y = self.Stages(y)
        y = self.Clasifier(y)
        return y

In [4]:
resnet1k = ResNet1k()

In [5]:
resnet1k.build([None,224,224,3])

In [6]:
resnet1k.summary()

Model: "res_net1k"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  9472      
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
sequential_448 (Sequential)  (None, 7, 7, 2048)        661599744 
_________________________________________________________________
sequential_449 (Sequential)  (None, 1000)              2049000   
Total params: 663,658,216
Trainable params: 662,383,464
Non-trainable params: 1,274,752
_________________________________________________________________
