# Custom Layers and Models via Subclassing

In [1]:
# Import packages
import tensorflow as tf

from tensorflow.keras import layers

tf.__version__ # 2.x

'2.3.0'

## The Layer Class

One of the central abstraction in TF is the Layer class.
A layer encapsulates both a state (the layer's "weights") and a transformation from inputs to outputs (a "call", the layer's forward pass).

Weights are created using `Variable` method and then passing a initializer to it.

In [2]:
# Linear dense (without any activation) layer implementation
class Linear(tf.keras.layers.Layer):
    
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        # Initialize weights
        w_init = tf.random_uniform_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"), trainable=True,
        )
        # Initialize bias
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )
    
    def call(self, inputs):
        # Forward pass
        return tf.matmul(inputs, self.w) + self.b

In [3]:
Linear()

<__main__.Linear at 0x7f4b985e2790>

You would use a layer by calling it on some input tensor, much like a Python function.

In [4]:
linear_layer = Linear(input_dim=2)
y = linear_layer(tf.ones((2, 2)))

print("Output:", y)

Output: tf.Tensor(
[[ 0.01658548  0.05650462 -0.00680071  0.01357207  0.01159811 -0.04614431
   0.00019255 -0.01403505 -0.05581309 -0.03841535 -0.06635642 -0.01357264
  -0.01232324 -0.06170757 -0.08078814  0.00687404 -0.01288105 -0.00847041
  -0.00017538  0.02897899 -0.0028989   0.01902173  0.02617068  0.01602283
  -0.00087059 -0.05656181 -0.06385028 -0.06543048  0.01861737 -0.00904342
   0.04498125 -0.04483186]
 [ 0.01658548  0.05650462 -0.00680071  0.01357207  0.01159811 -0.04614431
   0.00019255 -0.01403505 -0.05581309 -0.03841535 -0.06635642 -0.01357264
  -0.01232324 -0.06170757 -0.08078814  0.00687404 -0.01288105 -0.00847041
  -0.00017538  0.02897899 -0.0028989   0.01902173  0.02617068  0.01602283
  -0.00087059 -0.05656181 -0.06385028 -0.06543048  0.01861737 -0.00904342
   0.04498125 -0.04483186]], shape=(2, 32), dtype=float32)


Better way to create weights for each layer is by using `add_weight`.

In [5]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        # Add weights in constructor
        self.w = self.add_weight(
            shape=(input_dim, units), initializer=tf.random_uniform_initializer(), trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer=tf.zeros_initializer(), trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [6]:
Linear()

<__main__.Linear at 0x7f4b231c8a00>

In [7]:
linear_layer = Linear(input_dim=2)
y = linear_layer(tf.ones((2, 2)))

print("Output:", y)

Output: tf.Tensor(
[[-0.06047176 -0.04816655 -0.03154886 -0.06251568  0.00958148 -0.01282692
  -0.02240032  0.07374001 -0.0739876  -0.05883969  0.01308106 -0.01288871
   0.04340095  0.03366411 -0.04783353 -0.04165464  0.01183216  0.03819682
  -0.00598559  0.03709728 -0.06509309  0.00408354 -0.03959733 -0.07756019
   0.02486961 -0.04697967 -0.05941094 -0.00075205 -0.00912856  0.01380936
  -0.01735049  0.04662926]
 [-0.06047176 -0.04816655 -0.03154886 -0.06251568  0.00958148 -0.01282692
  -0.02240032  0.07374001 -0.0739876  -0.05883969  0.01308106 -0.01288871
   0.04340095  0.03366411 -0.04783353 -0.04165464  0.01183216  0.03819682
  -0.00598559  0.03709728 -0.06509309  0.00408354 -0.03959733 -0.07756019
   0.02486961 -0.04697967 -0.05941094 -0.00075205 -0.00912856  0.01380936
  -0.01735049  0.04662926]], shape=(2, 32), dtype=float32)


Layers can have non-trainable weights as well. Set `trainable` to False for such weights.

Such weights are ignored during backpropagation, when you are training the layer.

In [8]:
class ComputeSum(tf.keras.layers.Layer):
    def __init__(self, input_dim=32):
        super(ComputeSum, self).__init__()
        # Create non-trainable weights
        self.total = self.add_weight(
            shape=(input_dim,), initializer=tf.zeros_initializer(), trainable=False
        )
    
    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return self.total

In [9]:
ComputeSum()

<__main__.ComputeSum at 0x7f4b2330cca0>

In [10]:
sum_layer = ComputeSum(input_dim=2)
y = sum_layer(tf.ones((2, 2)))

print("Output:", y)

Output: <tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>


Like trainable weights it is part of layer.weights, but it gets categorized as a non-trainable weight:

In [11]:
print("Weights:", sum_layer.weights)
print("Non-trainable weights:", sum_layer.non_trainable_weights)
print("Trainable_weights:", sum_layer.trainable_weights)

Weights: [<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>]
Non-trainable weights: [<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>]
Trainable_weights: []


In [12]:
print("Weights:", linear_layer.weights)
print("Non-trainable weights:", linear_layer.non_trainable_weights)
print("Trainable_weights:", linear_layer.trainable_weights)

Weights: [<tf.Variable 'Variable:0' shape=(2, 32) dtype=float32, numpy=
array([[-2.3279607e-02, -2.1544887e-02, -1.7158844e-02, -4.3387759e-02,
        -9.9650621e-03,  2.9152036e-03, -2.1465087e-02,  4.6870921e-02,
        -4.3387868e-02, -1.5940737e-02,  4.7146980e-02, -2.7294362e-02,
        -3.2730326e-03, -1.1640072e-02, -4.0566910e-02, -2.0389259e-02,
         7.3510781e-03,  4.2893652e-02,  2.2132624e-02, -5.2142739e-03,
        -2.0865202e-02,  3.8359765e-02, -2.0445168e-02, -3.7553620e-02,
         8.0000386e-03, -3.0260539e-02, -2.0804977e-02, -3.7090946e-02,
        -6.4313412e-05,  3.5211667e-03,  2.5303986e-02,  4.0511042e-04],
       [-3.7192155e-02, -2.6621664e-02, -1.4390014e-02, -1.9127918e-02,
         1.9546542e-02, -1.5742123e-02, -9.3523413e-04,  2.6869085e-02,
        -3.0599738e-02, -4.2898953e-02, -3.4065917e-02,  1.4405657e-02,
         4.6673987e-02,  4.5304183e-02, -7.2666183e-03, -2.1265376e-02,
         4.4810772e-03, -4.6968348e-03, -2.8118217e-02,  4.2311

In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.

Use `build(self, inputs_shape)` method in such scenarios.

In [13]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        # Initialize weights here rather than in __init__
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_uniform",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_uniform", trainable=False
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [14]:
Linear()

<__main__.Linear at 0x7f4b231e2f70>

The __call__() method of your layer will automatically run build the first time it is called.

In [15]:
# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(units=32)

# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(tf.ones((100, 10)))

print(y)

tf.Tensor(
[[ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]
 [ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]
 [ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]
 ...
 [ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]
 [ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]
 [ 0.12824091 -0.00382189  0.03500267 ...  0.0037626  -0.1274958
  -0.03381208]], shape=(100, 32), dtype=float32)


Layers are recursively composable:

- If you assign a Layer instance as attribute of another Layer, the outer layer will start tracking the weights of the inner layer.

We recommend creating such sublayers in the __init__() method (since the sublayers will typically have a build method, they will be built when the outer layer gets built).

In [16]:
class MLPBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        y = self.linear_3(x)
        return y

In [17]:
perceptron = MLPBlock()
y = perceptron(tf.ones(shape=(100, 64)))

print("Output:", y.shape)
print("Weights:", perceptron.weights)

Output: (100, 1)
Weights: [<tf.Variable 'mlp_block/linear_6/Variable:0' shape=(64, 32) dtype=float32, numpy=
array([[-0.04294861, -0.0325537 , -0.0386548 , ...,  0.04671087,
         0.04081596, -0.01135286],
       [-0.04653119,  0.03761765,  0.01638753, ..., -0.03778819,
         0.01035241, -0.001336  ],
       [ 0.02587212,  0.02146473, -0.02824861, ...,  0.01182736,
         0.02844635,  0.00086957],
       ...,
       [ 0.02222217,  0.00842669,  0.02813211, ...,  0.00685203,
        -0.04439139,  0.04384978],
       [ 0.03513259,  0.00355319, -0.03237949, ..., -0.0266817 ,
         0.00431416, -0.03424772],
       [-0.04397074, -0.01329132, -0.02974968, ...,  0.02046299,
         0.04075642, -0.0346027 ]], dtype=float32)>, <tf.Variable 'mlp_block/linear_7/Variable:0' shape=(32, 32) dtype=float32, numpy=
array([[-0.00786608,  0.02778982, -0.04551524, ..., -0.01123388,
        -0.04614855,  0.00516076],
       [ 0.01132749,  0.00665132, -0.04467681, ...,  0.02743456,
        -0.001

Layers developed in such fashion (via subclassing) are not serializable.

If you need your custom layers to be serializable, you can optionally implement a `get_config()` method.

In [18]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return config

In [19]:
layer = Linear(64)
print("Layers:", layer)

config = layer.get_config()
print("Config:", config)

new_layer = Linear.from_config(config)
print("New Layer from config:", new_layer)

Layers: <__main__.Linear object at 0x7f4b20037190>
Config: {'name': 'linear_9', 'trainable': True, 'dtype': 'float32', 'units': 64}
New Layer from config: <__main__.Linear object at 0x7f4b20037be0>


### Differently Behaving Layers

Some layers, in particular the `BatchNormalization` layer and the `Dropout` layer, have different behaviors during training and inference.

For such layers, it is standard practice to expose a `training` (boolean) argument in the `call()` method.

In [20]:
# Sample dropout wrapper
class CustomDropout(tf.keras.layers.Layer):
    def __init__(self, rate, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = rate

    def call(self, inputs, training=None):
        if training:
            return tf.nn.dropout(inputs, rate=self.rate)
        return inputs

In [21]:
class MLPBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)
        self.dp = CustomDropout(rate=0.2)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        x = self.dp(x)
        y = self.linear_3(x)
        return y

In [22]:
perceptron = MLPBlock()
y = perceptron(tf.ones(shape=(100, 64)))

print("Weights:", perceptron.weights)

Weights: [<tf.Variable 'mlp_block_1/linear_10/Variable:0' shape=(64, 32) dtype=float32, numpy=
array([[ 0.00480832, -0.04449831, -0.01868737, ..., -0.07568541,
         0.01739467,  0.01585306],
       [ 0.07346726, -0.05165542, -0.10224849, ...,  0.0011067 ,
         0.02950659,  0.04581891],
       [ 0.06077056, -0.03540228,  0.03681133, ...,  0.06726784,
        -0.02320555, -0.00682611],
       ...,
       [ 0.02116216, -0.0135638 , -0.0766438 , ...,  0.03039101,
        -0.08805507, -0.02326634],
       [-0.00771108,  0.00853711, -0.04889629, ..., -0.00316886,
         0.02050136,  0.0009126 ],
       [ 0.0104132 , -0.06640428, -0.03596489, ..., -0.0627687 ,
         0.01927026,  0.00334143]], dtype=float32)>, <tf.Variable 'mlp_block_1/linear_10/Variable:0' shape=(32,) dtype=float32, numpy=
array([ 0.01551276, -0.09641021,  0.05902372, -0.01977053,  0.04370142,
       -0.07039469,  0.01610832, -0.0211281 , -0.0957595 ,  0.03623869,
       -0.1146992 , -0.05096183,  0.01770232,  0.

The other privileged argument supported by call() is the `mask_zero` argument.

TF automatically passes the correct boolean in suhch cases to __call__() for layers that support it.

## The Model Class

The Model class has the same API as Layer, with the following differences:

- It exposes built-in training, evaluation, and prediction loops (model.fit(), model.evaluate(), model.predict()).
- It exposes the list of its inner layers, via the model.layers property.
- It exposes saving and serialization APIs (save(), save_weights()...)

In general, you will use the Layer class to define inner computation blocks, and will use the Model class to define the outer model -- the object you will train.

In [23]:
class SampleModel(tf.keras.Model):
    
    def __init__(self, num_classes):
        super(SampleModel, self).__init__()
        self.b1 = MLPBlock()
        self.b2 = MLPBlock()
        self.classifier = tf.keras.layers.Dense(num_classes)
    
    def call(self, model_input):
        x = self.b1(model_input)
        x = self.b2(x)
        return self.classifier(x)

In [24]:
model = SampleModel(num_classes=2)
print(model)

<__main__.SampleModel object at 0x7f4b231c8790>


## Putting it all together: an end-to-end example

- A Layer encapsulate a state (created in __init__() or build()) and some computation (defined in call()).
- Layers can be recursively nested to create new, bigger computation blocks.
- Layers can create and track losses (typically regularization losses) as well as metrics, via add_loss() and add_metric()
- The outer container, the thing you want to train, is a Model. A Model is just like a Layer, but with added training and serialization utilities.

In [25]:
from tensorflow.keras import layers


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z


class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)

In [26]:
class VariationalAutoEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=64,
        latent_dim=32,
        name="autoencoder",
        **kwargs
    ):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

In [27]:
# Simple training loop on MNIST data
original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs.
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # Compute reconstruction loss
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

Start of epoch 0
step 0: mean loss = 0.3371
step 100: mean loss = 0.1247
step 200: mean loss = 0.0987
step 300: mean loss = 0.0888
step 400: mean loss = 0.0840
step 500: mean loss = 0.0807
step 600: mean loss = 0.0786
step 700: mean loss = 0.0770
step 800: mean loss = 0.0758
step 900: mean loss = 0.0748
Start of epoch 1
step 0: mean loss = 0.0746
step 100: mean loss = 0.0739
step 200: mean loss = 0.0734
step 300: mean loss = 0.0730
step 400: mean loss = 0.0726
step 500: mean loss = 0.0722
step 600: mean loss = 0.0719
step 700: mean loss = 0.0716
step 800: mean loss = 0.0714
step 900: mean loss = 0.0712


Also, it can be trained using the built-in training loops like any other model.

In [28]:
vae = VariationalAutoEncoder(784, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f4b105347f0>