# Batch Normalization

Training deep neural networks can be challenging, particularly when it comes to achieving convergence within a reasonable timeframe. In this section, we introduce batch normalization, a widely used technique that reliably speeds up the training process (Ioffe and Szegedy, 2015). Batch normalization has enabled practitioners to successfully train networks exceeding 100 layers in depth. An additional advantage of this method is its built-in regularization effect.

In [None]:
import torch
from torch import nn
from d2l import torch as d2l

## Implementation from Scratch

To see how batch normalization works in practice, we implement one from scratch below.

In [None]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
        moving_var = (1.0 - momentum) * moving_var + momentum * var
    Y = gamma * X_hat + beta  # Scale and shift
    return Y, moving_mean.data, moving_var.data

We can now implement a proper BatchNorm layer, which maintains trainable parameters for scaling (gamma) and shifting (beta), updated during training.

The layer also keeps moving averages of means and variances to be used later during inference. Leaving aside the algorithmic details, the key design principle is to separate mathematical operations from implementation overhead. Typically, the batch normalization math is defined in a dedicated function (e.g., `batch_norm`), which is then wrapped into a custom layer. 

The layer handles tasks like placing data on the correct device, initializing variables, and tracking moving averages. For simplicity, the example requires explicitly specifying the number of features rather than inferring input shapes automatically. In practice, modern deep learning frameworks provide high-level BatchNorm APIs that handle shape detection seamlessly.

In [None]:
class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # The variables that are not model parameters are initialized to 0 and
        # 1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # Save the updated moving_mean and moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.1)
        return Y

We employed momentum to control how past estimates of the mean and variance are aggregated. Although the term can be misleading, since it is unrelated to the momentum used in optimization, it has become the standard terminology. For consistency with common API conventions, we adopt the same variable name in our implementation.

## LeNet with Batch Normalization

To illustrate the use of BatchNorm in practice, we incorporate it into a standard LeNet model. Note that batch normalization is placed after convolutional or fully connected layers, but before their associated activation functions.

In [None]:
class BNLeNetScratch(d2l.Classifier):
    def __init__(self, lr=0.1, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), BatchNorm(6, num_dims=4),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), BatchNorm(16, num_dims=4),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120),
            BatchNorm(120, num_dims=2), nn.Sigmoid(), nn.LazyLinear(84),
            BatchNorm(84, num_dims=2), nn.Sigmoid(),
            nn.LazyLinear(num_classes))

As before, we will train the network using the Fashion-MNIST dataset. The code is nearly the same as the implementation used earlier for training LeNet.

In [None]:
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNetScratch(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)

Let’s examine the learned values of the scale parameter (gamma) and the shift parameter (beta) from the first batch normalization layer.

In [None]:
model.net[1].gamma.reshape((-1,)), model.net[1].beta.reshape((-1,))

## Concise Implementation

Instead of using the custom BatchNorm class we implemented, we can directly rely on the BatchNorm class provided by high-level deep learning APIs. The code remains almost the same, but with the advantage that we no longer need to manually specify arguments for handling the dimensions.

In [None]:
class BNLeNet(d2l.Classifier):
    def __init__(self, lr=0.1, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.Sequential(
            nn.LazyConv2d(6, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.LazyConv2d(16, kernel_size=5), nn.LazyBatchNorm2d(),
            nn.Sigmoid(), nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(), nn.LazyLinear(120), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(84), nn.LazyBatchNorm1d(),
            nn.Sigmoid(), nn.LazyLinear(num_classes))

Next, we train the model using the same hyperparameters. As expected, the high-level API version executes significantly faster, since its operations are compiled into C++ or CUDA, whereas our custom implementation runs through Python interpretation.

In [None]:
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128)
model = BNLeNet(lr=0.1)
model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
trainer.fit(model, data)

On a more practical note, there are a number of aspects worth remembering about batch
normalization: 
- During model training, batch normalization continuously adjusts the intermediate output of the network by utilizing the mean and standard deviation of the minibatch, so that the values of the intermediate output in each layer throughout the neural network are
more stable. 
- Batch normalization is slightly different for fully connected layers than for convolutional layers. In fact, for convolutional layers, layer normalization can sometimes be used as
an alternative. 
- Like a dropout layer, batch normalization layers have different behaviors in training mode
than in prediction mode. 
- Batch normalization is useful for regularization and improving convergence in optimizaion. By contrast, the original motivation of reducing internal covariate shift seems
not to be a valid explanation. 
- For more robust models that are less sensitive to input perturbations, consider removing
batch normalization (Wang et al., 2022).

## Exercises

1. Compare the learning rates for LeNet with and without batch normalization.
- Plot the increase in validation accuracy.
- How large can you make the learning rate before the optimization fails in both cases?

2. Do we need batch normalization in every layer? Experiment with it.

3. Can you replace dropout by batch normalization? How does the behavior change?