In [4]:
import torch
from torch import nn
from labml_helpers.module import Module

- BN depends on a basic concept called **whitening** 
    - It is known that whitening improves speed and convergence. 
    - Here, we linearly transform inputs to have zero mean, unit variance, and be uncorrelated
    - But it can be computationally expensive because you need to de-correlate and the gradients must flow through the full whitening calculation

In [5]:
class BatchNorm(Module):
    """
        channels: is the num of features in the input
        eps: is used for numerical stability (avoiding zero sqrt)
        momentum: is the momentum in taking the exponential moving average
        affine: is whether to scale and shift the normalized value
        track_running_stats: is whether to calculate the moving averages or mean and variance
    """
    def __init__(self, channels: int, *, eps: float=1e-5, momentum: float=0.1, affine: bool=True, track_running_stats: bool=True):
        super().__init__()
        self.channels = channels
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        """
            create params gamma and beta for scale and shift 
        """    
        if self.affine:
            self.scale = nn.Parameter(torch.ones(channels))
            self.shift = nn.Parameter(torch.zeros(channels))
        
        """
            create buffers to store exponential moving averages of mean and variance 
        """
        if self.track_running_stats:
            self.register_buffer('exp_mean', torch.zeros(channels))
            self.register_buffer('exp_var', torch.ones(channels))
    
    def forward(self, x: torch.Tensor):
        """
            x is a tensor of shape [batch_size, channels, *].
            * denotes any number of (possibly 0) dimensions.
            E.g., in an image (2D) convolution this will be [batch_size, channels, height, width]
        """
        x_shape = x.shape
        batch_size = x_shape[0]
        assert self.channels == x.shape[1]
        x = x.view(batch_size, self.channels, -1) # NOTE: reshape into [batch_size, channels, n]
        
        if self.training or not self.track_running_stats:

            mean = x.mean(dim=[0, 2])
            mean_x2 = (x ** 2).mean(dim=[0, 2]) # NOTE: calculate the squared mean across first and last dimension
            var = mean_x2 - mean**2

            if self.training and self.track_running_stats:
                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * mean
                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * var
        
        else:
            mean = self.exp_mean
            var = self.exp_var
        
        x_norm = (x - mean.view(1, -1, 1)) / torch.sqrt(var + self.eps).view(1, -1, 1) # NOTE: normalize

        if self.affine:
            x_norm = self.scale.view(1, -1, 1) * x_norm + self.shift.view(1, -1, 1) # NOTE: scale and shift
        
        return x_norm.view(x_shape)

In [6]:
def _test():
    from labml.logger import inspect
    
    x = torch.zeros([2, 3, 2, 4])
    inspect(x.shape)
    bn = BatchNorm(3)

    x = bn(x)
    inspect(x.shape)
    inspect(bn.exp_var.shape)

In [None]:
_test()