In [1]:
import tensorflow as tf

In [2]:
tf.test.is_gpu_available()

Instructions for updating:
Use `tf.config.list_physical_devices('GPU')` instead.


True

## Implementing custom layers

- `__init__` , where you can do all input-independent initialization
- `build`, where you know the shapes of the input tensors and can do the rest of the initialization
- `call`, where you do the forward computation

In [31]:
class MyDenseLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(MyDenseLayer, self).__init__()
        self.num_outputs = num_outputs
        
    def build(self, input_shape):
        self.kernel = self.add_weight('kernel', shape=[input_shape[-1], self.num_outputs])
        self.bias = self.add_weight('bias', shape=[self.num_outputs])
    
    def call(self, inputs):
        return inputs @ self.kernel + self.bias

In [35]:
layer = MyDenseLayer(10)
layer(tf.zeros([10, 5])).shape

TensorShape([10, 10])

## Models: Composing layers

In [38]:
class ResnetIdentityBlock(tf.keras.Model):
    def __init__(self, kernel_size, filters):
        super(ResnetIdentityBlock, self).__init__(name='ResnetIdentityBlock')
        filters1, filters2, filters3 = filters
        
        self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
        self.bn2a = tf.keras.layers.BatchNormalization()
        
        self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
        self.bn2b = tf.keras.layers.BatchNormalization()
        
        self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
        self.bn2c = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs, training=False):
        x = self.bn2a(self.conv2a(inputs), training=training)
        x = tf.nn.relu(x)
        
        x = self.bn2b(self.conv2b(x), training=training)
        x = tf.nn.relu(x)
        
        x = self.bn2c(self.conv2c(x), training=training)
        
        x += inputs
        return tf.nn.relu(x)

In [39]:
block = ResnetIdentityBlock(kernel_size=(1,1), filters = [64, 128, 64])

In [45]:
block(tf.zeros([1, 500, 500, 64])).shape

TensorShape([1, 500, 500, 64])