# Implementing Batchnormalization from scratch

In [1]:
import torch
import torch.nn as nn


## Logic behind the batch normalization layer
* check if mode is training or inference
* If in inference mode then use pre determined mean and variance
* If in training mode then detect thee type of the layer
* If fully connected layer then calculate the mean and variance over dimension of the input
* If Conv layer then calculate mean and variance over channel dimension for each feature map

In [2]:

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # to determine whether the current mode is training mode or
    # inference mode, when grad not enable then inference mode
    if not torch.is_grad_enabled():
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # fully connected layer
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1).
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used for the
        # standardization
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance of the moving average
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean, moving_var