# Fusing convolutional and batchnorm layers in TensorFlow

- Batchnorm layers can be fused with preceding 2D convolutional layers, into a single convolutional layer.
- Batchnorm layers *cannot* be fused with subseqeunt 2D convolutional layers, into a single convolutional layer, *if* that convolutional layer adds padding to its input.
- This script demos both types of fusion, and highlights how the fusion can fail in the latter case when padding is used.
- Thus, in some models implemented on specialized hardware (e.g., Edge TPU), it is desirable (necessary!) to have an API that implements a batchnorm layer, or at least the broadcasted scaling and shifting that can mimic a static batchnorm layer after training. Moreover, in the case of a sequence of (Conv2D->RELU->BN) layers, the batchnorm layer clearly cannot be fused with the convolutional layer.

Batchnorm layers scale and shift the input data, providing normalized data to the subsequent layer. In the original batchnorm [publication](https://arxiv.org/abs/1502.03167), bathnorm layers follow convolutional layers, and precede RELU layers (Conv2D->BN->RELU). Since that publication, many researchers have used batchnorm layers that follow the RELU layer (Conv2D->RELU->BN), claiming superior results.

After training, if the moving mean and moving standard deviation of the batchnorm layer are kept fix, then the scaling and shifting parameters *might* be "fused" with a preceding or subsequent convolutional layer (by scaling and shifting the parameters of the convolutional layer), thereby allowing the explicit batchnorm layer to be removed (as it will be implicitly implemented by the convolutional layer).

For a nice summary of how this fusing can be accomplished when the batchnorm layer follows the convolutional layer, see here: [https://tehnokv.com/posts/fusing-batchnorm-and-conv/](https://tehnokv.com/posts/fusing-batchnorm-and-conv/).

*However*, if the batchnorm precedes a convolutional layer, the fusion can only be accomplished if the convolutional layer does not add any padding to the input. If padding is added, attempts to fuse the batchnorm layer will results in errors at the edges of the output tensors, compared to outputs of the unfused pair of layers. This discrepency will be demo'd in this script.

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer, Input, Conv2D, BatchNormalization
from tensorflow.keras import Model

# %matplotlib notebook
%matplotlib inline
import matplotlib.pyplot as plt

### Define the architectures that will be used...

In [None]:
def build_conv_bn(input_shape, n_chan_out, padding='same'):
    # Build architecture with batchnorm after convolution
    x = Input(shape=input_shape)
    y = Conv2D(n_chan_out, (3,3), padding=padding, name='bnconv_conv_1')(x)
    y = BatchNormalization(axis=-1, name='bnconv_bn_1', fused=False,
                           beta_initializer=tf.initializers.RandomUniform(0, 1),
                           gamma_initializer=tf.initializers.RandomNormal(0, 0.5),
                           moving_mean_initializer=tf.initializers.RandomNormal(0, 0.5),
                           moving_variance_initializer=tf.initializers.RandomUniform(0, 1))(y)
    model = Model(x, y)
    return model

def build_bn_conv(input_shape, n_chan_out, padding='same'):
    # Build architecture with batchnorm before convolution
    x = Input(shape=input_shape)
    y = BatchNormalization(axis=-1, name='bnconv_bn_1', fused=False,
                           beta_initializer=tf.initializers.RandomUniform(0, 1),
                           gamma_initializer=tf.initializers.RandomNormal(0, 0.5),
                           moving_mean_initializer=tf.initializers.RandomNormal(0, 0.5),
                           moving_variance_initializer=tf.initializers.RandomUniform(0, 1))(x)
    y = Conv2D(n_chan_out, (3,3), padding=padding, name='bnconv_conv_1')(y)
    model = Model(x, y)
    return model

def build_fused(input_shape, n_chan_out, padding='same'):
    # Build fused architecture (one without the explicit batchnorm layer)
    x = Input(shape=input_shape)
    y = Conv2D(n_chan_out, (3,3), padding=padding, name='bnconv_comp_1')(x)
    model = Model(x, y)
    return model

## Define methods for setting parameters of the fused models...

In [None]:
def fuse_conv_bn_params(model_source, model_target):
    names = [layer.name for layer in model_source.layers]
    for name in names:
        if name.startswith('bnconv_bn'):
            name_bn = name
            name_conv = 'bnconv_conv' + name[len('bnconv_bn'):]
            name_comp = 'bnconv_comp' + name[len('bnconv_bn'):]

            layer_bn = model_source.get_layer(name=name_bn)
            params_bn = layer_bn.get_weights()
            gamma = params_bn[0]
            beta = params_bn[1]
            moving_mean = params_bn[2]
            moving_variance = params_bn[3]
            epsilon = layer_bn.epsilon

            m_bn = gamma / np.sqrt(moving_variance + epsilon)
            b_bn = beta - m_bn * moving_mean

            ## Compute new convolution kernel of composite layer/model
            layer_conv = model_source.get_layer(name=name_conv)
            params_conv = layer_conv.get_weights()
            w_conv = params_conv[0]
            b_conv = params_conv[1]

            b_comp = m_bn * b_conv + b_bn
            m_bn = np.reshape(m_bn, (1, 1, 1,  m_bn.size))
            w_comp = w_conv * m_bn

            ## Compute new convolution bias of composite layer/model
            # Mimic impact of convolution on the shift/offset term of the batchnorm.
            # This is slightly incorrect at the x/y spatial edges of the tensors.
            # w_conv_sum = np.sum(w_conv, axis=(0,1))    # b_bn has no x/y spatial dependence, so sum conv weights in x and y.
            # b_bnconv = np.sum(w_conv_sum * np.reshape(b_bn, (b_bn.size, 1)), axis=0) # mimic the convolution
            # b_comp = b_conv + b_bnconv

            # Set parameters for composite layer/model
            layer_comp = model_target.get_layer(name=name_comp)
            layer_comp.set_weights([w_comp, b_comp])


def fuse_bn_conv_params(source_model, target_model):
    pass

## Build and test fused models: Convolution followed by batchnorm...

In [None]:
# Set the seed, for repeatability
tf.random.set_seed(0)

input_shape = (100, 100, 3)
n_chan_out = 16

# Create random data for model testing through models and measure the output delta
n_samples = 1
x = tf.random.normal([n_samples] + list(input_shape))

# Build models with random parameters
padding = 'same' # same or valid
model_conv_bn = build_conv_bn(input_shape, n_chan_out, padding=padding)
model_fused = build_fused(input_shape, n_chan_out, padding=padding)

# Fused conv+batchnorm parameters into new conv parameters
fuse_conv_bn_params(model_conv_bn, model_fused)

# Put data through both models and measure the output delta
y_conv_bn = model_conv_bn.predict(x)
y_fused = model_fused.predict(x)

delta = y_conv_bn - y_fused
mae = np.mean(np.absolute(delta))
print('Mean Absolute Error: %0.2e' % (mae))

# Show delta
delta = delta[0]
delta = np.sum(delta, axis=2)

plt.figure(figsize=(8, 8))
plt.subplot()
plt.imshow(delta)
plt.colorbar()
plt.title('Difference between Conv2D->BN layers, versus fusion into single Conv2D layer')