<a href="https://colab.research.google.com/github/o-beckley/machine_learning/blob/main/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from functools import partial

DefaultConv2D = partial(layers.Conv2D, kernel_size=3, strides=1,
                        padding='same', kernel_initializer='he_normal',
                        use_bias=False)
class ResidualUnit(tf.keras.layers.Layer):
  def __init__(self, filters, strides=1, activation='relu', **kwargs):
    super().__init__(**kwargs)
    self.filters = filters
    self.strides = strides
    self.activation = tf.keras.activations.get(activation)
    self.base = [
        DefaultConv2D(filters, strides=strides),
        layers.BatchNormalization(),
        self.activation
    ]
    # self.base = tf.keras.Sequential([
    #     DefaultConv2D(filters, strides=strides),
    #     layers.BatchNormalization(),
    #     self.activation
    # ])
    self.se_block = [
        layers.GlobalAvgPool2D(input_shape=(None,None,filters)),
        layers.Dense(int((filters/16)), activation='relu'),
        layers.Dense((filters), activation='softmax'),
    ]
    # self.se_block = tf.keras.Sequential([
    #     layers.GlobalAvgPool2D(input_shape=(None,None,filters)),
    #     layers.Dense(int((filters/16)), activation='relu'),
    #     layers.Dense((filters), activation='softmax'),
    # ], name='se_block')

    self.skip_layers = []
    if strides > 1:
      self.skip_layers = [
          DefaultConv2D(filters, kernel_size=1, strides=strides),
          layers.BatchNormalization()]



  def call(self, inputs):
    z = inputs
    for layer in self.base:
      z = layer(z)

    se = z
    for layer in self.se_block:
      se = layer(se)

    skip = inputs
    for layer in self.skip_layers:
      skip = layer(skip)

    return self.activation(z*se + skip)

In [None]:
model = tf.keras.Sequential([
    DefaultConv2D(filters=64, kernel_size=7, strides=2, input_shape=(224, 224, 3)),
    layers.BatchNormalization(),
    layers.Activation('relu'),
    layers.MaxPool2D(pool_size=3, strides=2),
])

prev_filters = 64

for filters in [64]*3 + [128]*4 + [256]*6 + [512]*3:
  if filters == prev_filters:
    model.add(ResidualUnit(filters))
  else:
    model.add(ResidualUnit(filters, strides=2))
  prev_filters = filters
model.add(layers.GlobalAvgPool2D())
model.add(layers.Dense(1000, activation='softmax'))
model.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_28 (Conv2D)          (None, 112, 112, 64)      9408      
                                                                 
 batch_normalization_28 (Bat  (None, 112, 112, 64)     256       
 chNormalization)                                                
                                                                 
 activation_4 (Activation)   (None, 112, 112, 64)      0         
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 55, 55, 64)       0         
 2D)                                                             
                                                                 
 residual_unit_21 (ResidualU  (None, 55, 55, 64)       37700     
 nit)                                                            
                                                      

<__main__.ResidualUnit at 0x7aa822d92680>