In [1]:
from tensorflow.keras.models import Model, Sequential 
from tensorflow.keras.layers import *

In [2]:
class ResNetBlock(Layer):
    def __init__(self, out_channels, first_stride=1):
        super().__init__()

        first_padding = 'same'
        if first_stride != 1:
            first_padding = 'valid'

        self.conv_sequence = Sequential([
            Conv2D(out_channels, 3, first_stride, padding=first_padding),
            BatchNormalization(),
            ReLU(),

            Conv2D(out_channels,3,1,padding='same'),
            BatchNormalization(),
            ReLU(),
        ])

    def call(self, inputs):
        x = self.conv_sequence(inputs) 

        if x.shape == inputs.shape:
            x = x + inputs  # Skip Connection 
        
        return x 

layer = ResNetBlock(4)
print(layer)
        


<__main__.ResNetBlock object at 0x0000020BC5CC0890>


In [25]:
class ResNet(Model):
    def __init__(self):
        super(ResNet, self).__init__()

        self.conv_1 = Sequential([
            Conv2D(64,7,2),
            ReLU(),
            MaxPooling2D(3,2),
        ])

        self.resnet_chains = Sequential(
            [
                ResNetBlock(64), 
                ResNetBlock(64),
            ]
            +
            [
                ResNetBlock(128,2),
                ResNetBlock(128),
            ]
            +
            [
                ResNetBlock(256,2),
                ResNetBlock(256),
            ]
            +
            [
                ResNetBlock(512,2),
                ResNetBlock(512),
            ])
        
        self.out = Sequential([
            GlobalAveragePooling2D(),
            Dense(1, activation='sigmoid')
        ])

    def call(self,x):
        x = self.conv_1(x)
        x = self.resnet_chains(x)
        x = self.out(x)

        return x
    
model = ResNet()
print(model)

<__main__.ResNet object at 0x0000020BD117E9D0>


In [26]:
model.build(input_shape=(1, 224,224,3))

In [27]:
model.summary()

Model: "res_net_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_45 (Sequential)  (1, 54, 54, 64)           9472      
                                                                 
 sequential_54 (Sequential)  (1, 5, 5, 512)            11004672  
                                                                 
 sequential_55 (Sequential)  (1, 1)                    513       
                                                                 
Total params: 11014657 (42.02 MB)
Trainable params: 11006977 (41.99 MB)
Non-trainable params: 7680 (30.00 KB)
_________________________________________________________________
