## Batch Normalization vs Layer Normalization

Batch Normalization (taken from karpathy's earlier lectures)

In [None]:
import torch

class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True

        # params
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

        # buffers
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self, x):
        # forward pass
        if self.training:
            xmean = x.mean(0, keepdim=True)  # batch mean
            xvar = x.var(0, keepdim=True)    # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var

        # normalize to unit variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)

        # scale and shift
        out = self.gamma * xhat + self.beta

        # update the buffers
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar

        return out

    def parameters(self):
        return [self.gamma, self.beta]



torch.manual_seed(1337)
module = BatchNorm1d(100)
x = torch.randn(32, 100)  # batch size 32 of 100-dimensional vectors
x = module(x)
print(x.shape)  # [32, 100]


In [None]:
# mean and std along batch dimension
x[:,0].mean(), x[:,1].std() 

In [None]:
# mean and std along feature dimension
x[0,:].mean(), x[1,:].std() 

Layer Normalization
- take mean and std across feature dimension
- no need to maintain running variance and mean and phases

In [None]:
import torch

class LayerNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        
        # params
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # forward pass
        xmean = x.mean(1, keepdim=True)  # feature mean
        xvar = x.var(1, keepdim=True)    # feature variance

        # normalize to unit variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)

        # scale and shift
        out = self.gamma * xhat + self.beta

        return out

    def parameters(self):
        return [self.gamma, self.beta]


torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100)  # batch size 32 of 100-dimensional vectors
x = module(x)
print(x.shape)  # [32, 100]

In [None]:
# mean and std along batch dimension
x[:,0].mean(), x[:,1].std() 

In [None]:
# mean and std along feature dimension
x[0,:].mean(), x[1,:].std() 