# Custom Batch Normalization:

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import numpy as np
import matplotlib.pyplot as plt

class MyBatchNormalization(layers.Layer):

  def __init__(self):
    super(MyBatchNormalization, self).__init__()

  def build(self, input_shape):   
    self.alpha = self.add_weight(shape = (input_shape[3]), initializer = 'ones', trainable = True)
    self.beta = self.add_weight(shape = (input_shape[3]), initializer = 'zeros', trainable = True)

  def call(self, inputs):
    mean = tf.math.reduce_mean(inputs, axis = [0,1,2], keepdims = True)
    stddev = tf.math.reduce_std(inputs, axis = [0,1,2], keepdims = True)
    normalized = (inputs - mean)/stddev
    return self.alpha*normalized + self.beta  


def define_model():
    inputs = keras.Input(shape=(28,28,1))

    K = 20 # number of convolution layers per block
    L = 3  # number of blocks
    x = inputs
    for i in range(0,L):
        for j in range(0,K):
            x = MyBatchNormalization()(x)
            x = layers.Conv2D(32, 3, activation="relu",padding="same")(x)
        x = layers.MaxPooling2D(3)(x)
    x = layers.GlobalMaxPooling2D()(x)
    outputs = layers.Dense(10,activation='softmax')(x)

    model = keras.Model(inputs,outputs)
    model.summary() # show model overview
    return model






In [None]:
# Load and preprocess training data (Fashion-MNIST)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0
test_images = test_images / 255.0
train_labels = tf.keras.utils.to_categorical(train_labels)
test_labels = tf.keras.utils.to_categorical(test_labels)

# Define and train model
model = define_model()
model.compile(loss=keras.losses.CategoricalCrossentropy(),optimizer=keras.optimizers.Adam(),metrics=["accuracy"])
model.fit(train_images,train_labels, batch_size=64, epochs=100)


Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 28, 28, 1)]       0         
                                                                 
 my_batch_normalization_60 (  (None, 28, 28, 1)        2         
 MyBatchNormalization)                                           
                                                                 
 conv2d_60 (Conv2D)          (None, 28, 28, 32)        320       
                                                                 
 my_batch_normalization_61 (  (None, 28, 28, 32)       64        
 MyBatchNormalization)                                           
                                                                 
 conv2d_61 (Conv2D)          (None, 28, 28, 32)        9248      
                                                                 
 my_batch_normalization_62 (  (None, 28, 28, 32)       64  

<keras.callbacks.History at 0x7f7af97df9a0>