In [1]:
import tensorflow as tf

In [2]:
"""
implementing custom layers
1. extend tf.keras.Layer
2. implement __init__, where you can do all 'input independent' initialization
3. implement build, where you know the shapes of the input tensors and can do the rest of the initialization
4. implement call, where you can do the forward computation
"""

"\nimplementing custom layers\n1. extend tf.keras.Layer\n2. implement __init__, where you can do all 'input independent' initialization\n3. implement build, where you know the shapes of the input tensors and can do the rest of the initialization\n4. implement call, where you can do the forward computation\n"

In [5]:
class MyLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel", shape=[int(input_shape[-1]), self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

In [6]:
layer = MyLayer(10)

In [7]:
# calling the 'layer' var 'build()' it
_ = layer(tf.zeros([10, 5]))

In [8]:
print([var.name for var in layer.trainable_variables])

['my_layer/kernel:0']


In [9]:
# Models: Composing Layers


In [10]:
class ResnetIdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters):
    super(ResnetIdentityBlock, self).__init__(name='')
    filters1, filters2, filters3 = filters

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
    self.bn2a = tf.keras.layers.BatchNormalization()

    self.convb = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
    self.bn2b = tf.keras.layers.BatchNormalization()

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
    self.bn2c = tf.keras.layers.BatchNormalization()

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = tf.nn.relu(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x += input_tensor
    return tf.nn.relu(x)

In [11]:
block = ResnetIdentityBlock(1, [1,2,3])