# Custom Layers

In [1]:
import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
tfe = tf.contrib.eager

In [3]:
tf.enable_eager_execution()

In [15]:
class MyDenseLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(MyDenseLayer, self).__init__()
        self.num_outputs = num_outputs
        
    def build(self, input_shape):
        self.kernel = self.add_variable("kernel",
                                       shape=[input_shape[-1].value,
                                             self.num_outputs])
    
    def call(self, input):
        return tf.matmul(input, self.kernel)

In [16]:
layer = MyDenseLayer(10)

In [17]:
print(layer(tf.zeros([10, 5])))

tf.Tensor(
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)


In [18]:
print(layer.variables)

[<tf.Variable 'my_dense_layer/kernel:0' shape=(5, 10) dtype=float32, numpy=
array([[ 0.22661197,  0.07442695, -0.39846274, -0.49072683,  0.14715093,
         0.06607729,  0.31235397, -0.60834205,  0.06002074, -0.26701155],
       [ 0.43488556, -0.5957981 ,  0.5126317 , -0.3249684 ,  0.39465827,
        -0.2896839 ,  0.5001288 ,  0.08098161,  0.15348208,  0.12532854],
       [ 0.14555377,  0.5120905 ,  0.39332265,  0.51066417,  0.34448308,
         0.38136417,  0.6095267 ,  0.15809715, -0.6105145 , -0.24795422],
       [ 0.4501869 ,  0.5596408 ,  0.0637238 ,  0.15206844,  0.52343315,
         0.02190334, -0.17782897,  0.42002374,  0.17947423,  0.08077306],
       [-0.52035165,  0.4257428 , -0.40859246,  0.3379144 ,  0.3591876 ,
         0.32245475, -0.4137683 , -0.34311965,  0.27081972,  0.2636705 ]],
      dtype=float32)>]


In [23]:
class ResnetIdentityBlock(tf.keras.Model):
    def __init__(self, kernel_size, filters):
        super(ResnetIdentityBlock, self).__init__(name='')
        filters1, filters2, filters3 = filters
    
        self.conv2a = tf.keras.layers.Conv2D(filters1, (1,1))
        self.bn2a = tf.keras.layers.BatchNormalization()

        self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
        self.bn2b = tf.keras.layers.BatchNormalization()

        self.conv2c = tf.keras.layers.Conv2D(filters3, (1,1))
        self.bn2c = tf.keras.layers.BatchNormalization()
    
    def call(self, input_tensor, training=False):
        x = self.conv2a(input_tensor)
        x = self.bn2a(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.conv2b(x)
        x = self.bn2b(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.conv2c(x)
        x = self.bn2c(x, training=training)
        
        x += input_tensor
        return tf.nn.relu(x)
    

In [24]:
block = ResnetIdentityBlock(1, [1,2,3])

In [25]:
print(block(tf.zeros([1,2,3,3])))

tf.Tensor(
[[[[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)


In [26]:
print([x.name for x in block.variables])

['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']
