# Batch Normalization

In this lab session we are going to familiarize with the usage of *Batch Normalization (BN) layers* in our networks. BN is designed in order to **standardize** each feature within a mini-batch, in such a way as to have 0 mean and unit variance. It then scales and shifts the standardized activations with learnable parameters. BN is known for 

*   faster convergence properties
*   improved performance

More details can be found in the original [paper](https://arxiv.org/abs/1502.03167).

$BN(x_{i, k}) = \gamma_{k} \frac{x_{i, k} - \mu_{B, k}}{\sqrt{\sigma^{2}_{B,k} + \epsilon}} + \beta_{k}$

The intuitive idea behind BN is as follows: a neural network is trained using mini-batches, and the distribution of inputs **varies** from one batch to the other. Difference in distributions between mini-batches can cause the training to be **unstable** and heavily **dependant on the initial weights** of the network. Therefore, this kind of transformation (transforming the inputs to have mean 0 and unit variance) guarantees that input distribution of each layer remains **unchanged across mini-batches**.

More interestingly, we will learn how to code BN layer from scratch using PyTorch. Let's start by importing the necessary libraries, as usual.


In [None]:
import torch
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

## BatchNorm1D
This is the implementation of batch normalization for fully connected hidden layers

In [None]:
"""
Applies Batch Normalization over a 1D input (or 2D tensor)

Shape:
  Input: (N, C)
  Output: (N, C)

Input Parameters:
  in_features: number of features of the input activations
  track_running_stats: whether to keep track of running mean and std. (default: True)
  affine: whether to scale and shift the normalized activations. (default: True)
  momentum: the momentum value for the moving average. (default: 0.9)

Usage:
  >>> # with learable parameters
  >>> bn = BatchNorm1d(4)
  >>> # without learable parameters
  >>> bn = BatchNorm1d(4, affine=False)
  >>> input = torch.rand(10, 4)
  >>> out = bn(input)
"""


class BatchNorm1d(torch.nn.Module):
    def __init__(
        self, in_features, track_running_stats=True, affine=True, momentum=0.9
    ):
        super().__init__()

        self.in_features = in_features
        self.track_running_stats = track_running_stats
        self.affine = affine
        self.momentum = momentum

        if self.affine:
            self.gamma = torch.nn.Parameter(torch.ones(self.in_features, 1))
            self.beta = torch.nn.Parameter(torch.zeros(self.in_features, 1))

        if self.track_running_stats:
            # register_buffer registers a tensor as a buffer that will be saved as part of the model
            # but which does not require to be trained, differently from nn.Parameter
            self.register_buffer("running_mean", torch.zeros(self.in_features, 1))
            self.register_buffer("running_std", torch.ones(self.in_features, 1))

    def forward(self, x):
        # transpose (N, C) to (C, N)
        x = x.transpose(0, 1).contiguous().view(x.shape[1], -1)

        # calculate batch mean
        mean = x.mean(dim=1).view(-1, 1)

        # calculate batch std
        std = x.std(dim=1).view(-1, 1)

        # during training keep running statistics (moving average of mean and std)
        if self.training and self.track_running_stats:
            # no computational graph is necessary to be built for this computation
            with torch.no_grad():
                self.running_mean = (
                    self.momentum * self.running_mean + (1 - self.momentum) * mean
                )
                self.running_std = (
                    self.momentum * self.running_std + (1 - self.momentum) * std
                )

        # during inference time
        if not self.training and self.track_running_stats:
            mean = self.running_mean
            std = self.running_std

        # normalize the input activations
        x = (x - mean) / std

        # scale and shift the normalized activations
        if self.affine:
            x = x * self.gamma + self.beta

        return x.transpose(0, 1)

## BatchNorm2D
BN module for convolutional layers

In [None]:
"""
Applies Batch Normalization over a 2D or 3D input (4D tensor)

Shape:
  Input: (N, C, H, W)
  Output: (N, C, H, W)

Input Parameters:
  in_features: number of features of the input activations
  track_running_stats: whether to keep track of running mean and std. (default: True)
  affine: whether to scale and shift the normalized activations. (default: True)
  momentum: the momentum value for the moving average. (default: 0.9)

Usage:
  >>> # with learable parameters
  >>> bn = BatchNorm2d(4)
  >>> # without learable parameters
  >>> bn = BatchNorm2d(4, affine=False)
  >>> input = torch.rand(10, 4, 5, 5)
  >>> out = bn(input)
"""


class BatchNorm2d(torch.nn.Module):
    def __init__(
        self, in_features, track_running_stats=True, affine=True, momentum=0.9
    ):
        super().__init__()

        self.in_features = in_features
        self.track_running_stats = track_running_stats
        self.affine = affine
        self.momentum = momentum

        if self.affine:
            self.gamma = torch.nn.Parameter(torch.ones(self.in_features, 1))
            self.beta = torch.nn.Parameter(torch.zeros(self.in_features, 1))

        if self.track_running_stats:
            # register_buffer registers a tensor as a buffer that will be saved as part of the model
            # but which does not require to be trained, differently from nn.Parameter
            self.register_buffer("running_mean", torch.zeros(self.in_features, 1))
            self.register_buffer("running_std", torch.ones(self.in_features, 1))

    def forward(self, x):
        # transpose (N, C, H, W) to (C, N, H, W)
        x = x.transpose(0, 1)

        # store the shape
        c, bs, h, w = x.shape

        # collapse all dimensions except the 'channel' dimension
        x = x.contiguous().view(c, -1)

        # calculate batch mean
        mean = x.mean(dim=1).view(-1, 1)

        # calculate batch std
        std = x.std(dim=1).view(-1, 1)

        # keep running statistics (moving average of mean and std)
        if self.training and self.track_running_stats:
            with torch.no_grad():
                self.running_mean = (
                    self.momentum * self.running_mean + (1 - self.momentum) * mean
                )
                self.running_std = (
                    self.momentum * self.running_std + (1 - self.momentum) * std
                )

        # during inference time
        if not self.training and self.track_running_stats:
            mean = self.running_mean
            std = self.running_std

        # normalize the input activations
        x = (x - mean) / std

        # scale and shift the normalized activations
        if self.affine:
            x = x * self.gamma + self.beta

        return x.view(c, bs, h, w).transpose(0, 1)

## LeNet-5 with bach normalization
Here we will use BN layers for the LeNet-5 network. These layers are added right after the convolutional and fully connected layers, except the output ones.

In [None]:
class LeNet(torch.nn.Module):
    def __init__(self, norm=False):
        super().__init__()
        self.norm = norm

        # input channel = 3, output channels = 6, kernel size = 5
        # input image size = (32, 32), image output size = (28, 28)
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=6, kernel_size=(5, 5))
        if self.norm:
            self.bn1 = BatchNorm2d(6)

        # input channel = 6, output channels = 16, kernel size = 5
        # input image size = (14, 14), output image size = (10, 10)
        self.conv2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5))
        if self.norm:
            self.bn2 = BatchNorm2d(16)

        # input dim = 5 * 5 * 16 ( H x W x C), output dim = 120
        self.fc3 = torch.nn.Linear(in_features=5 * 5 * 16, out_features=120)
        if self.norm:
            self.bn3 = BatchNorm1d(120)

        # input dim = 120, output dim = 84
        self.fc4 = torch.nn.Linear(in_features=120, out_features=84)
        if self.norm:
            self.bn4 = BatchNorm1d(84)

        # input dim = 84, output dim = 10
        self.fc5 = torch.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        if self.norm:
            x = self.bn1(x)
        x = F.relu(x)
        # Max Pooling with kernel size = 2
        # output size = (14, 14)
        x = F.max_pool2d(x, kernel_size=2)

        x = self.conv2(x)
        if self.norm:
            x = self.bn2(x)
        x = F.relu(x)
        # Max Pooling with kernel size = 2
        # output size = (5, 5)
        x = F.max_pool2d(x, kernel_size=2)

        # flatten the feature maps into a long vector
        x = x.view(x.shape[0], -1)

        x = self.fc3(x)
        if self.norm:
            x = self.bn3(x)
        x = F.relu(x)

        x = self.fc4(x)
        if self.norm:
            x = self.bn4(x)
        x = F.relu(x)

        x = self.fc5(x)

        return x

## Cost function and optimizer
Similarly to the standard cases seen so far, we will employ cross-entropy loss and a SGD optimizer.

In [None]:
def get_cost_function():
    cost_function = torch.nn.CrossEntropyLoss()
    return cost_function


def get_optimizer(net, lr, wd, momentum):
    optimizer = torch.optim.SGD(
        net.parameters(), lr=lr, weight_decay=wd, momentum=momentum
    )
    return optimizer

## Training and test steps
Let's defined our loops for the training procedure

In [None]:
def test_step(net, data_loader, cost_function, device="cuda:0"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # strictly needed if network contains layers which behave differently between the training and test steps
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            # load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            # forward pass
            outputs = net(inputs)

            # apply the loss
            loss = cost_function(outputs, targets)

            # print statistics
            samples += inputs.shape[0]
            cumulative_loss += (
                loss.item()
            )  # Note: the .item() is needed to extract scalars from tensors
            _, predicted = outputs.max(1)
            cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100


def training_step(net, data_loader, optimizer, cost_function, device="cuda:0"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    # strictly needed if network contains layers which behave differently between the training and test steps
    net.train()

    for batch_idx, (inputs, targets) in enumerate(data_loader):
        # load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # forward pass
        outputs = net(inputs)

        # apply the loss
        loss = cost_function(outputs, targets)

        # backward pass
        loss.backward()

        # update parameters
        optimizer.step()

        # zero the gradient
        optimizer.zero_grad()

        # print statistics
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100

## Data loading
Let's now defined our data loading utility

In [None]:
def get_data(batch_size, test_batch_size=256, dataset="mnist"):
    # prepare data transformations and then combine them sequentially
    if dataset == "mnist":
        transform = list()
        transform.append(T.ToTensor())  # convert Numpy to Pytorch Tensor
        transform.append(
            T.Lambda(lambda x: F.pad(x, (2, 2, 2, 2), "constant", 0))
        )  # pad zeros to make MNIST 32 x 32
        transform.append(
            T.Lambda(lambda x: x.repeat(3, 1, 1))
        )  # make MNIST RGB instead of grayscale
        transform.append(
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        )  # normalize the Tensors between [-1, 1]
        transform = T.Compose(transform)  # compose the above transformations into one
    elif dataset == "svhn":
        transform = list()
        transform.append(T.ToTensor())  # convert Numpy to Pytorch Tensor
        transform.append(
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        )  # normalize the Tensors between [-1, 1]
        transform = T.Compose(transform)  # compose the above transformations into one

    # prepare dataset
    if dataset == "mnist":
        full_training_data = torchvision.datasets.MNIST(
            "./data/mnist", train=True, transform=transform, download=True
        )
        test_data = torchvision.datasets.MNIST(
            "./data/mnist", train=False, transform=transform, download=True
        )
    elif dataset == "svhn":
        full_training_data = torchvision.datasets.SVHN(
            "./data/svhn", split="train", transform=transform, download=True
        )
        test_data = torchvision.datasets.SVHN(
            "./data/svhn", split="test", transform=transform, download=True
        )

    # create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int(num_samples * 0.8 + 1)
    validation_samples = num_samples - training_samples

    training_data, validation_data = torch.utils.data.random_split(
        full_training_data, [training_samples, validation_samples]
    )

    # initialize dataloaders
    train_loader = torch.utils.data.DataLoader(
        training_data, batch_size, shuffle=True, drop_last=True
    )
    val_loader = torch.utils.data.DataLoader(
        validation_data, test_batch_size, shuffle=False
    )
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

## Wrap it up
Let's now put everything together into a training procedure!

In [None]:
"""
Input arguments
  batch_size: Size of a mini-batch
  device: GPU where you want to train your network
  weight_decay: Weight decay co-efficient for regularization of weights
  momentum: Momentum for SGD optimizer
  epochs: Number of epochs for training the network
  visualization_name: name of the tensorboard folder
  dataset: which dataset to train
  norm: whether to use batch normalization
"""


def main(
    batch_size=128,
    device="cuda:0",
    learning_rate=0.01,
    weight_decay=0.000001,
    momentum=0.9,
    epochs=50,
    visualization_name="mnist",
    dataset="mnist",
    norm=False,
):
    # creates a logger for the experiment
    writer = SummaryWriter(log_dir=f"runs/{visualization_name}")

    train_loader, val_loader, test_loader = get_data(
        batch_size=batch_size, test_batch_size=batch_size, dataset=dataset
    )

    net = LeNet(norm=norm).to(device)

    optimizer = get_optimizer(net, learning_rate, weight_decay, momentum)

    cost_function = get_cost_function()

    print("Before training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)

    print(
        "\tTraining loss {:.5f}, Training accuracy {:.2f}".format(
            train_loss, train_accuracy
        )
    )
    print(
        "\tValidation loss {:.5f}, Validation accuracy {:.2f}".format(
            val_loss, val_accuracy
        )
    )
    print("\tTest loss {:.5f}, Test accuracy {:.2f}".format(test_loss, test_accuracy))
    print("-----------------------------------------------------")

    # add values to plots
    writer.add_scalar("train/loss", train_loss, 0)
    writer.add_scalar("val/loss", val_loss, 0)
    writer.add_scalar("train/accuracy", train_accuracy, 0)
    writer.add_scalar("val/accuracy", val_accuracy, 0)

    for e in range(epochs):
        train_loss, train_accuracy = training_step(
            net, train_loader, optimizer, cost_function
        )
        val_loss, val_accuracy = test_step(net, val_loader, cost_function)
        print("Epoch: {:d}".format(e + 1))
        print(
            "\tTraining loss {:.5f}, Training accuracy {:.2f}".format(
                train_loss, train_accuracy
            )
        )
        print(
            "\tValidation loss {:.5f}, Validation accuracy {:.2f}".format(
                val_loss, val_accuracy
            )
        )
        print("-----------------------------------------------------")

        # Add values to plots
        writer.add_scalar("train/loss", train_loss, e + 1)
        writer.add_scalar("val/loss", val_loss, e + 1)
        writer.add_scalar("train/accuracy", train_accuracy, e + 1)
        writer.add_scalar("val/accuracy", val_accuracy, e + 1)

    print("After training:")
    train_loss, train_accuracy = test_step(net, train_loader, cost_function)
    val_loss, val_accuracy = test_step(net, val_loader, cost_function)
    test_loss, test_accuracy = test_step(net, test_loader, cost_function)

    print(
        "\tTraining loss {:.5f}, Training accuracy {:.2f}".format(
            train_loss, train_accuracy
        )
    )
    print(
        "\tValidation loss {:.5f}, Validation accuracy {:.2f}".format(
            val_loss, val_accuracy
        )
    )
    print("\tTest loss {:.5f}, Test accuracy {:.2f}".format(test_loss, test_accuracy))
    print("-----------------------------------------------------")

    # Closes the logger
    writer.close()

## Run
Let's make it happen! First on MNIST without BN

In [None]:
main(visualization_name="mnist", dataset="mnist")

Now on MNIST with BN layers

In [None]:
main(visualization_name="mnist_bn", dataset="mnist", norm=True)

SVHN without BN

In [None]:
main(visualization_name="svhn_bn", dataset="svhn", norm=True)

SVHN with BN

In [None]:
main(visualization_name="svhn_bn", dataset="svhn", norm=True)

In [None]:
%load_ext tensorboard
%tensorboard --logdir=runs