In [5]:
import tensorflow as tf

class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
    self.bn2b = tf.keras.layers.BatchNormalization()

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)


block = ResnetIdentityBlock(1, [1, 2, 3])
_ = block(tf.zeros([1, 2, 3, 3])) 
block.summary()

Model: "resnet_identity_block_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_12 (Conv2D)           multiple                  4         
_________________________________________________________________
batch_normalization_12 (Batc multiple                  4         
_________________________________________________________________
conv2d_13 (Conv2D)           multiple                  4         
_________________________________________________________________
batch_normalization_13 (Batc multiple                  8         
_________________________________________________________________
conv2d_14 (Conv2D)           multiple                  9         
_________________________________________________________________
batch_normalization_14 (Batc multiple                  12        
Total params: 41
Trainable params: 29
Non-trainable params: 12
______________________________________________

In [None]:
import math
print(math.degrees(2.118172),math.degrees(2.010722))
