Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Example of folding keras conv layer with batch norm

## Imports

In [147]:
import tensorflow as tf
import numpy as np
np.random.seed(123)

## Model with convolution and batch norm layer definition

In [148]:

epsilon=0.001
inputs = tf.keras.Input(shape=(50, 32, 5), batch_size=4)
net = inputs
net = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3))(net)
net = tf.keras.layers.BatchNormalization(epsilon=epsilon)(net)
net = tf.keras.layers.ReLU()(net)
net = tf.keras.layers.Flatten()(net)
model = tf.keras.Model(inputs, net)
model.summary()

Model: "functional_27"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_18 (InputLayer)        [(4, 50, 32, 5)]          0         
_________________________________________________________________
conv2d_16 (Conv2D)           (4, 48, 30, 2)            92        
_________________________________________________________________
batch_normalization_13 (Batc (4, 48, 30, 2)            8         
_________________________________________________________________
re_lu_1 (ReLU)               (4, 48, 30, 2)            0         
_________________________________________________________________
flatten_14 (Flatten)         (4, 2880)                 0         
Total params: 100
Trainable params: 96
Non-trainable params: 4
_________________________________________________________________


In [149]:
model.layers

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f76b40d9cc0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f76b40d9dd8>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f76b40db048>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x7f76b40db7b8>,
 <tensorflow.python.keras.layers.core.Flatten at 0x7f76b40db470>]

## Initialize all model weights with random numbers

In [150]:
# we will set all weights to randmom numbers so that even bias will be non zero 
# it will help to validated numerical correctness of conv and batch norm fusion
all_weights = model.get_weights()
for i in range(len(all_weights)):
  all_weights[i] = np.random.random(all_weights[i].shape)
model.set_weights(all_weights)

## Dims of conv layer weights

In [151]:
ind_conv_layer = 1
assert(isinstance(model.layers[ind_conv_layer], tf.keras.layers.Conv2D))
conv_weights = model.layers[ind_conv_layer].get_weights()
print("conv weights shape " + str(conv_weights[0].shape))
print("conv bias shape " + str(conv_weights[1].shape))

conv weights shape (3, 3, 5, 2)
conv bias shape (2,)


## Dims of batch norm layer weights

In [152]:
ind_batch_norm_layer = 2
assert(isinstance(model.layers[ind_batch_norm_layer], tf.keras.layers.BatchNormalization))
bn_weights = model.layers[ind_batch_norm_layer].get_weights()

In [153]:
gamma = bn_weights[0]
print("gamma shape " + str(gamma.shape))
betta = bn_weights[1]
print("betta shape " + str(gamma.shape))
mean = bn_weights[2]
print("mean shape " + str(gamma.shape))
variance = bn_weights[3]
print("variance shape " + str(gamma.shape))

gamma shape (2,)
betta shape (2,)
mean shape (2,)
variance shape (2,)


## Fuse conv and batch norm weights

In [135]:
new_conv_weights = np.multiply(conv_weights[0], gamma) / np.sqrt(variance + epsilon)
new_bias = betta + np.multiply((conv_weights[1] - mean), gamma) / np.sqrt(variance + epsilon)

## Model with folded/fused convolution and batch norm layers

In [136]:
inputs = tf.keras.Input(shape=(50, 32, 5), batch_size=4)
net = inputs
net = tf.keras.layers.Conv2D(filters=2, kernel_size=(3,3))(net)
net = tf.keras.layers.Flatten()(net)
model_fused = tf.keras.Model(inputs, net)
model_fused.summary()

Model: "functional_25"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_17 (InputLayer)        [(4, 50, 32, 5)]          0         
_________________________________________________________________
conv2d_15 (Conv2D)           (4, 48, 30, 2)            92        
_________________________________________________________________
flatten_13 (Flatten)         (4, 2880)                 0         
Total params: 92
Trainable params: 92
Non-trainable params: 0
_________________________________________________________________


## Initialize model_fused with fused weights

In [142]:
all_weights_fused = model_fused.get_weights()
all_weights_fused[0] = new_conv_weights
all_weights_fused[1] = new_bias
model_fused.set_weights(all_weights_fused)

## Validate that model and model_fused produce the same outputs

In [143]:
input_data = np.random.random(inputs.shape)

In [144]:
outputs = model.predict(input_data)

In [145]:
outputs_fused = model_fused.predict(input_data)

In [146]:
np.allclose(outputs, outputs_fused)

True