# Custom Layers and Models via Subclassing

In [1]:
# Import packages
import tensorflow as tf
from tensorflow.keras import layers

tf.__version__ # 2.x

'2.3.0'

## The Layer Class

One of the central abstraction in TF is the Layer class.
A layer encapsulates both a state (the layer's "weights") and a transformation from inputs to outputs (a "call", the layer's forward pass).

Weights are created using `Variable` method and then passing a initializer to it.

In [2]:
# Linear dense (without any activation) layer implementation
class Linear(tf.keras.layers.Layer):
    
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        # Initialize weights
        w_init = tf.random_uniform_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"), trainable=True,
        )
        # Initialize bias
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )
    
    def call(self, inputs):
        # Forward pass
        return tf.matmul(inputs, self.w) + self.b

In [3]:
Linear()

<__main__.Linear at 0x7f719c2bbb80>

You would use a layer by calling it on some input tensor, much like a Python function.

In [4]:
linear_layer = Linear(input_dim=2)
y = linear_layer(tf.ones((2, 2)))

print("Output:", y)

Output: tf.Tensor(
[[ 0.0171706   0.0980462  -0.05334902 -0.04461349  0.01196626  0.07726953
   0.00890495 -0.05230092  0.06306716 -0.01405647  0.03983072 -0.00656961
   0.06502621 -0.07911715 -0.01575241  0.03177715 -0.04458169  0.02453979
   0.02974322 -0.02328845  0.03419514 -0.0200019  -0.00210092 -0.03813468
  -0.0244395   0.01576917  0.00891878  0.05301408 -0.04351387  0.0524676
   0.00835221 -0.09076227]
 [ 0.0171706   0.0980462  -0.05334902 -0.04461349  0.01196626  0.07726953
   0.00890495 -0.05230092  0.06306716 -0.01405647  0.03983072 -0.00656961
   0.06502621 -0.07911715 -0.01575241  0.03177715 -0.04458169  0.02453979
   0.02974322 -0.02328845  0.03419514 -0.0200019  -0.00210092 -0.03813468
  -0.0244395   0.01576917  0.00891878  0.05301408 -0.04351387  0.0524676
   0.00835221 -0.09076227]], shape=(2, 32), dtype=float32)


Better way to create weights for each layer is by using `add_weight`.

In [5]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer=tf.random_uniform_initializer(), trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer=tf.zeros_initializer(), trainable=True)

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [6]:
Linear()

<__main__.Linear at 0x7f7127e81940>

In [7]:
linear_layer = Linear(input_dim=2)
y = linear_layer(tf.ones((2, 2)))

print("Output:", y)

Output: tf.Tensor(
[[ 7.2281502e-02  2.8301883e-02 -1.1361875e-02  7.1414031e-02
  -3.5958383e-02 -8.3475634e-02  3.4645200e-05 -1.0341788e-02
   7.7435069e-02 -5.4266773e-02  7.4835859e-02 -2.7254235e-02
   7.8199305e-02 -2.4382040e-02  1.5154637e-02  5.2757360e-02
  -5.1789571e-02 -2.0522142e-02  1.8788723e-02 -1.3652921e-02
   1.1162795e-03  5.7604589e-02 -7.6549388e-03 -3.4927223e-03
  -2.8593633e-02  6.5219291e-03 -2.7138244e-02  4.5478206e-02
  -5.9064329e-02  2.7465262e-03  1.2195241e-02 -7.5143687e-02]
 [ 7.2281502e-02  2.8301883e-02 -1.1361875e-02  7.1414031e-02
  -3.5958383e-02 -8.3475634e-02  3.4645200e-05 -1.0341788e-02
   7.7435069e-02 -5.4266773e-02  7.4835859e-02 -2.7254235e-02
   7.8199305e-02 -2.4382040e-02  1.5154637e-02  5.2757360e-02
  -5.1789571e-02 -2.0522142e-02  1.8788723e-02 -1.3652921e-02
   1.1162795e-03  5.7604589e-02 -7.6549388e-03 -3.4927223e-03
  -2.8593633e-02  6.5219291e-03 -2.7138244e-02  4.5478206e-02
  -5.9064329e-02  2.7465262e-03  1.2195241e-02 -7.

Layers can have non-trainable weights as well. Set `trainable` to False for such weights.

Such weights are ignored during backpropagation, when you are training the layer.

In [8]:
class ComputeSum(tf.keras.layers.Layer):
    def __init__(self, input_dim=32):
        super(ComputeSum, self).__init__()
        self.total = self.add_weight(
            shape=(input_dim,), initializer=tf.zeros_initializer(), trainable=False
        )
    
    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return self.total

In [9]:
ComputeSum()

<__main__.ComputeSum at 0x7f719c2b5d30>

In [10]:
sum_layer = ComputeSum(input_dim=2)
y = sum_layer(tf.ones((2, 2)))

print("Output:", y)

Output: <tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>


Like trainable weights it is part of layer.weights, but it gets categorized as a non-trainable weight:

In [11]:
print("Weights:", sum_layer.weights)
print("Non-trainable weights:", sum_layer.non_trainable_weights)
print("Trainable_weights:", sum_layer.trainable_weights)

Weights: [<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>]
Non-trainable weights: [<tf.Variable 'Variable:0' shape=(2,) dtype=float32, numpy=array([2., 2.], dtype=float32)>]
Trainable_weights: []


In [12]:
print("Weights:", linear_layer.weights)
print("Non-trainable weights:", linear_layer.non_trainable_weights)
print("Trainable_weights:", linear_layer.trainable_weights)

Weights: [<tf.Variable 'Variable:0' shape=(2, 32) dtype=float32, numpy=
array([[ 0.02798979,  0.03203218,  0.00520792,  0.04720721,  0.01008133,
        -0.04932186,  0.01481643,  0.01806558,  0.0388497 , -0.03178346,
         0.04331701, -0.03406661,  0.02953893, -0.03907155,  0.00883643,
         0.02043456, -0.04409995,  0.00569988, -0.02441934, -0.02273005,
         0.01470264,  0.01518423,  0.02338674, -0.02884055, -0.01572325,
         0.04448576,  0.00829986,  0.03743643, -0.0178116 , -0.03102672,
         0.00658418, -0.02935393],
       [ 0.04429171, -0.0037303 , -0.01656979,  0.02420682, -0.04603971,
        -0.03415378, -0.01478178, -0.02840737,  0.03858537, -0.02248331,
         0.03151885,  0.00681237,  0.04866037,  0.01468951,  0.00631821,
         0.0323228 , -0.00768962, -0.02622203,  0.04320807,  0.00907713,
        -0.01358636,  0.04242035, -0.03104168,  0.02534783, -0.01287038,
        -0.03796383, -0.03543811,  0.00804178, -0.04125273,  0.03377325,
         0.005611

In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.

Use `build(self, inputs_shape)` method in such scenarios.

In [13]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        # Initialize weights here rather than in __init__
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_uniform",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_uniform", trainable=False
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [14]:
Linear()

<__main__.Linear at 0x7f7127e99070>

The __call__() method of your layer will automatically run build the first time it is called.

In [15]:
# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(units=32)

# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(tf.ones((100, 10)))

print(y)

tf.Tensor(
[[ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]
 [ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]
 [ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]
 ...
 [ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]
 [ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]
 [ 0.04997818  0.10885429  0.09778702 ... -0.08879324 -0.0251091
   0.03213795]], shape=(100, 32), dtype=float32)


Layers are recursively composable:

- If you assign a Layer instance as attribute of another Layer, the outer layer will start tracking the weights of the inner layer.

We recommend creating such sublayers in the __init__() method (since the sublayers will typically have a build method, they will be built when the outer layer gets built).

In [16]:
class MLPBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        y = self.linear_3(x)
        return y

In [17]:
perceptron = MLPBlock()
y = perceptron(tf.ones(shape=(100, 64)))

print("Output:", y)
print("Weights:", perceptron.weights)

Output: tf.Tensor(
[[-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.02349304]
 [-0.0

Layers developed in such fashion (via subclassing) are not serializable.

If you need your custom layers to be serializable, you can optionally implement a `get_config()` method.

In [18]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return config

In [19]:
layer = Linear(64)
print("Layers:", layer)

config = layer.get_config()
print("Config:", config)

new_layer = Linear.from_config(config)
print("New Layer from config:", new_layer)

Layers: <__main__.Linear object at 0x7f712454b5e0>
Config: {'name': 'linear_9', 'trainable': True, 'dtype': 'float32', 'units': 64}
New Layer from config: <__main__.Linear object at 0x7f7127e99730>


### Differently Behaving Layers

Some layers, in particular the `BatchNormalization` layer and the `Dropout` layer, have different behaviors during training and inference.

For such layers, it is standard practice to expose a `training` (boolean) argument in the `call()` method.

In [20]:
# Sample dropout wrapper
class CustomDropout(tf.keras.layers.Layer):
    def __init__(self, rate, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = rate

    def call(self, inputs, training=None):
        if training:
            return tf.nn.dropout(inputs, rate=self.rate)
        return inputs

In [21]:
class MLPBlock(tf.keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)
        self.dp = CustomDropout(rate=0.2)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        x = self.dp(x)
        y = self.linear_3(x)
        return y

In [22]:
perceptron = MLPBlock()
y = perceptron(tf.ones(shape=(100, 64)))

print("Weights:", perceptron.weights)

Weights: [<tf.Variable 'mlp_block_1/linear_10/Variable:0' shape=(64, 32) dtype=float32, numpy=
array([[-0.02030033,  0.00489099,  0.04979159, ...,  0.00241364,
        -0.01420754,  0.01192173],
       [ 0.07205828,  0.02961976,  0.01500473, ..., -0.00526279,
        -0.05345894, -0.03853558],
       [ 0.02431008,  0.00453468, -0.03738853, ..., -0.00299568,
        -0.03084896,  0.02869447],
       ...,
       [ 0.00225633,  0.03113948,  0.00864202, ...,  0.04274569,
         0.02017399,  0.00202116],
       [ 0.08807379,  0.04904206, -0.02548023, ...,  0.03532203,
         0.00554532, -0.10428514],
       [ 0.03652779, -0.02834969, -0.04760692, ...,  0.03837006,
        -0.03646067, -0.06823771]], dtype=float32)>, <tf.Variable 'mlp_block_1/linear_10/Variable:0' shape=(32,) dtype=float32, numpy=
array([-0.0288556 , -0.05529082,  0.10748927, -0.07141435,  0.02683858,
        0.06418931, -0.00532317, -0.0473565 ,  0.03618989, -0.00610597,
       -0.03847222,  0.00894331,  0.07503864,  0.

The other privileged argument supported by call() is the `mask_zero` argument.

TF automatically passes the correct boolean in suhch cases to __call__() for layers that support it.

## The Model Class

The Model class has the same API as Layer, with the following differences:

- It exposes built-in training, evaluation, and prediction loops (model.fit(), model.evaluate(), model.predict()).
- It exposes the list of its inner layers, via the model.layers property.
- It exposes saving and serialization APIs (save(), save_weights()...)

In general, you will use the Layer class to define inner computation blocks, and will use the Model class to define the outer model -- the object you will train.

In [23]:
class SampleModel(tf.keras.Model):
    
    def __init__(self, num_classes):
        super(SampleModel, self).__init__()
        self.b1 = MLPBlock()
        self.b2 = MLPBlock()
        self.classifier = tf.keras.layers.Dense(num_classes)
    
    def call(self, model_input):
        x = self.b1(model_input)
        x = self.b2(x)
        return self.classifier(x)

In [24]:
model = SampleModel(num_classes=2)
print(model)

<__main__.SampleModel object at 0x7f7127e99700>


## Putting it all together: an end-to-end example

- A Layer encapsulate a state (created in __init__() or build()) and some computation (defined in call()).
- Layers can be recursively nested to create new, bigger computation blocks.
- Layers can create and track losses (typically regularization losses) as well as metrics, via add_loss() and add_metric()
- The outer container, the thing you want to train, is a Model. A Model is just like a Layer, but with added training and serialization utilities.

In [25]:
from tensorflow.keras import layers


class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""

    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z


class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)

In [26]:
class VariationalAutoEncoder(tf.keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(
        self,
        original_dim,
        intermediate_dim=64,
        latent_dim=32,
        name="autoencoder",
        **kwargs
    ):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss.
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

In [27]:
# Simple training loop on MNIST data
original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()

loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs.
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # Compute reconstruction loss
            loss = mse_loss_fn(x_batch_train, reconstructed)
            loss += sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

Start of epoch 0
step 0: mean loss = 0.3147
step 100: mean loss = 0.1235
step 200: mean loss = 0.0982
step 300: mean loss = 0.0885
step 400: mean loss = 0.0838
step 500: mean loss = 0.0805
step 600: mean loss = 0.0784
step 700: mean loss = 0.0769
step 800: mean loss = 0.0757
step 900: mean loss = 0.0747
Start of epoch 1
step 0: mean loss = 0.0745
step 100: mean loss = 0.0738
step 200: mean loss = 0.0733
step 300: mean loss = 0.0729
step 400: mean loss = 0.0726
step 500: mean loss = 0.0722
step 600: mean loss = 0.0719
step 700: mean loss = 0.0716
step 800: mean loss = 0.0714
step 900: mean loss = 0.0711


Also, it can be trained using the built-in training loops like any other model.

In [28]:
vae = VariationalAutoEncoder(784, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f7124530940>