# BatchNorm

BatchNorm, short for batch normalization, is a simple layer that has one task in life: 
using two learned parameters (meaning that it will be trained along with the rest of the network) to try to ensure that each minibatch that goes through the network has a mean centered around zero with a variance of 1.    

You might ask why we need to do this when we’ve already normalized our input by using the transform. For smaller networks, BatchNorm is indeed less useful, but as they get larger, the effect of any layer on another, say 20 layers down, can be vast because of repeated multiplication, and you may end up with either vanishing or exploding gradients, both of which are fatal to the training process. 

The BatchNorm layers make sure that even if you use a model such as ResNet-152, the multiplications inside your network don’t get out of hand.You might be wondering: if we have BatchNorm in our network, why are we normalizing the input at all in the training loop’s transformation chain? After all, shouldn’t BatchNorm do the work for us? And the answer here is yes, you could do that! But it’ll take longer for the network to learn how to get the inputs under control, as they’ll have to discover the initial transform themselves, which will make training longer.

In [10]:
from torch import nn
from torch.nn import functional as F
import torch
from utils.metrics import calculate_acc

In [14]:
class BnLayer(nn.Module):
    def __init__(self, ni, nf, stride=2, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride,bias=False, padding=1)
        
        self.a = nn.Parameter(torch.zeros(nf,1,1))
        self.m = nn.Parameter(torch.ones(nf,1,1))
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x_chan = x.transpose(0,1).contiguous().view(x.size(1), -1)
        # c_chan -> torch.Size([3, 12544]) for [1,3,224,224]

        if self.training:
            self.means = x_chan.mean(1)[:,None,None]
            self.stds  = x_chan.std(1)[:,None,None]          

        return (x-self.means) / self.stds * self.m + self.a

In [3]:
k = torch.rand(1, 3, 224, 224)

In [15]:
BN = BnLayer(3, 3)

In [17]:
BN(k)

tensor([[[[-3.1043, -0.7105, -0.1559,  ..., -1.4714, -1.1432, -2.0205],
          [-2.7250, -0.9292,  0.9126,  ...,  1.2126,  0.1625, -1.3427],
          [-3.2150, -0.9238,  1.9144,  ...,  0.1427, -0.3079,  0.1637],
          ...,
          [-2.1894,  0.0657, -0.3283,  ...,  0.7765, -0.2901, -0.9284],
          [-0.1304,  1.0492, -0.4338,  ...,  1.3788,  1.4572,  0.0774],
          [-2.5184,  1.5575,  0.1685,  ...,  0.4032, -0.4163,  0.3095]],

         [[ 1.9979,  1.6444,  2.2727,  ...,  2.6831,  0.7054,  3.2051],
          [-0.8301,  2.2846, -0.9196,  ..., -0.3816, -0.9196, -0.8902],
          [-0.2964,  1.3931, -0.9196,  ...,  0.1738,  0.8101, -0.9196],
          ...,
          [-0.7531, -0.3444, -0.9196,  ..., -0.9196, -0.8844, -0.9196],
          [ 2.7149,  0.0839, -0.9196,  ..., -0.1110, -0.0543, -0.9191],
          [ 0.0519,  0.0141,  0.8943,  ..., -0.9196, -0.9196, -0.9196]],

         [[ 0.0881,  2.0197, -0.0351,  ...,  1.2356,  0.4678,  1.3606],
          [-0.4601,  0.5780,  