In [None]:
import tensorflow as tf

### ResNEt-34 for image classification

In [None]:
class ResidualUnit(tf.keras.layers.Layer):
    def __init__(self, filters, strides=1, activation="relu", **kwargs):
        super().__init__(**kwargs)
        self.activation = tf.keras.activations.get(activation)
        self.main_layers = [
            tf.keras.layers.Conv2D(
                filters, 
                3, 
                strides=strides, 
                padding="same",
                use_bias=False
            ),
            tf.keras.layers.BatchNormalization(),
            self.activation,
            tf.keras.layers.Conv2D(
                filters,
                3,
                strides=strides,
                padding="same",
                use_bias=False
            ),
            tf.keras.layers.BatchNormalization()
        ]
        # just if the skip needs a reduction
        self.skip_layers = []
        if strides > 1:
            self.skip_layers = [
                tf.keras.layers.Conv2D(
                    filters,
                    1,
                    strides=strides,
                    padding="same",
                    use_bias=False
                ),
                tf.keras.layers.BatchNormalization()
            ]
        
    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        
        return self.activation(Z + Z_skip)

In [None]:
model = tf.models.Sequential()
model.add(
    tf.keras.layers.Conv2D(
        64,
        7,
        strides=2,
        input_shape=[224, 224, 3],
        padding="same",
        use_bias=False
    )
)
model.add(tf.keras.layers.BatchNormalization())
model.add(
    tf.keras.layers.MaxPool2D(
        pool_size=3,
        strides=2,
        padding="same"
    )
)
prev_filters = 64
for filters in [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3:
    strides = 1 if filters == prev_filters else 2
    model.add(
        ResidualUnit(
            filters,
            strides=strides
        )
    )
    prev_filters = filters
model.add(tf.keras.layers.GlovalAvgPool2D())
model.add(tf.keras.layers.Flatten())
model.add(
    tf.keras.layers.Dense(
        10,
        activation="softmax"
    )
)