# BatchNorm

In this demo, we make use of the BatchNorm (BN) to improve a two layer MLP. We first start by loading necessary packages and data. We use the Fashion MNIST dataset in this demo.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from timeit import default_timer as timer

torch.manual_seed(0)
np.random.seed(0)

# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# get the training and test datasets
trainset = datasets.FashionMNIST(root='../data', train=True, download=False, transform=transform)
testset = datasets.FashionMNIST(root='../data', train=False, download=False, transform=transform)

We can visualize a few samples to get a feeling about the dataset as follows

In [None]:
train_loader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=False)
images, _ = next(iter(train_loader))


print(images.shape)

#display 10 images in batch

grid = torchvision.utils.make_grid(images, nrow = 10)
plt.figure(figsize= (15,15))
plt.imshow(np.transpose(grid, (1,2,0)))

Below, we define our MLP. For fully connected (fc) layers, we use the PyTorch layer BatchNorm1d. If you want to apply the BN to convolutional layers, you need to use BatchNorm2d. Study the code below. Note that we put the BN before the activation function, following general practice.

In [None]:
class MLP(nn.Module):
    def __init__(self, use_batch_norm=True,  hidden_dim=256):

        super(MLP, self).__init__() # init super
        
        # Default layer sizes
        self.input_size = 784 # (28*28 images)
        self.hidden_dim = hidden_dim
        # Keep track of whether or not this network uses batch normalization.
        self.use_batch_norm = use_batch_norm
        
        # define hidden linear layers, with optional batch norm on their outputs
        # layers with batch_norm applied have no bias term
        if use_batch_norm:
            self.fc1 = nn.Linear(self.input_size, hidden_dim*2, bias=False)
            self.bn1 = nn.BatchNorm1d(hidden_dim*2)
        else:
            self.fc1 = nn.Linear(self.input_size, hidden_dim*2)
            
        # define *second* hidden linear layers, with optional batch norm on their outputs
        if use_batch_norm:
            self.fc2 = nn.Linear(hidden_dim*2, hidden_dim, bias=False)
            self.bn2 = nn.BatchNorm1d(hidden_dim)
        else:
            self.fc2 = nn.Linear(hidden_dim*2, hidden_dim)
        
        # third and final, fully-connected layer
        self.fc3 = nn.Linear(hidden_dim, 10)
        
        
    def forward(self, x):
        # flatten image
        x = x.view(-1, self.input_size)
        # all hidden layers + optional batch norm + relu activation
        x = self.fc1(x)
        if self.use_batch_norm:
            x = self.bn1(x)
        x = F.relu(x)
        # second layer
        x = self.fc2(x)
        if self.use_batch_norm:
            x = self.bn2(x)
        x = F.relu(x)
        # third layer, no batch norm or activation
        x = self.fc3(x)
        return x

The two functions below are used for training and evaluating the models.

In [None]:
def train_model(model, optimizer, criterion, train_loader, device="cpu"):
    model.train()
    running_loss = 0
    for batch_idx, (imgs, labels) in enumerate(train_loader): 
        optimizer.zero_grad()
        outputs = model(imgs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()        
        running_loss += loss.item()
        optimizer.step()
    return (running_loss / len(train_loader))




def evaluate_model(model, test_loader, device="cpu"):
    with torch.no_grad():
        model.eval()
        correct_predictions = 0
        total_predictions = 0
        for batch_idx, (data, target) in enumerate(test_loader):   
            outputs = model(data.to(device))                     
            predicted = torch.argmax(outputs, 1)
            correct_predictions += (predicted == target.to(device)).sum().item()
            total_predictions += target.shape[0]
        return (100*correct_predictions/total_predictions)

We instantiate two models below, one with and one without BN.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

MLP_plain = MLP(use_batch_norm=False, hidden_dim=256).to(device)
MLP_BN = MLP(use_batch_norm=True, hidden_dim=256).to(device)

We first train the model without BN for 10 epochs below and store the results in a csv file.

In [None]:
max_epoch = 10

# OPTIMISER PARAMETERS
lr = 0.01 
optimizer_plain = torch.optim.SGD(MLP_plain.parameters(), lr=lr)
optimizer_BN = torch.optim.SGD(MLP_BN.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()



result_file_plain = 'results/mlp_plain.csv'
model_file_plain = 'models/mlp_plain.pt'
cols       = ['epoch', 'train_loss', 'train accuracy', 'test accuracy', 'total training time']
results_df_plain = pd.DataFrame(columns=cols).set_index('epoch')


# prepare data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)
training_time_plain = 0
best_test_acc_plain = 0.0
for epoch in range(max_epoch): 

    #plain model
    start_time = timer()
    train_loss = train_model(MLP_plain, optimizer_plain, criterion, trainloader, device)
    end_time = timer()
    training_time_plain += (end_time-start_time)
    
    #evaluating
    train_acc = evaluate_model(MLP_plain, trainloader, device)
    test_acc = evaluate_model(MLP_plain, testloader, device)
    print(f'Plain MLP - Epoch:{epoch+1:3}| training loss = {train_loss:.4f}|', 
          f'Training Accuracy = {train_acc:.2f}| Test Accuracy = {test_acc:.2f}|', 
          f'Training Time (sec) = {training_time_plain:.2f}|')
        

    
    
    

    results_df_plain.loc[epoch] = [train_loss, train_acc, test_acc, training_time_plain]
    results_df_plain.to_csv(result_file_plain, float_format='%.2f')
    # Save best model
    if (test_acc > best_test_acc_plain):
            torch.save(MLP_plain.state_dict(),model_file_plain)
            best_test_acc_plain = test_acc

We now train the model with BN for 10 epochs and store the results in a csv file.

In [None]:
#========================================================================= 
result_file_bn = 'results/mlp_bn.csv'
model_file_bn = 'models/mlp_bn.pt'     
results_df_bn = pd.DataFrame(columns=cols).set_index('epoch')         
# prepare data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) 
 
training_time_bn = 0  
best_test_acc_bn = 0.0
for epoch in range(max_epoch):     
    #batchnorm model
    start_time = timer()
    train_loss = train_model(MLP_BN, optimizer_BN, criterion, trainloader, device)
    end_time = timer()
    training_time_bn += (end_time-start_time)
    
    #evaluating
    train_acc = evaluate_model(MLP_BN, trainloader, device)
    test_acc = evaluate_model(MLP_BN, testloader, device)

    print(f'MLP with BN - Epoch:{epoch+1:3}| training loss = {train_loss:.4f}|' 
          f'Training Accuracy = {train_acc:.2f}| Test Accuracy = {test_acc:.2f}|' 
          f'Training Time (sec) = {training_time_bn:.2f}|')
        

    
    
    

    results_df_bn.loc[epoch] = [train_loss, train_acc, test_acc, training_time_bn]
    results_df_bn.to_csv(result_file_bn, float_format='%.2f')
    # Save best model
    if (test_acc > best_test_acc_bn):
            torch.save(MLP_BN.state_dict(),model_file_bn)
            best_test_acc_bn = test_acc   

Time to compare and visualize.

In [None]:
clr = [ 'c',  'r']


f, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(25, 7))

plain_net_df = pd.read_csv(f'results/mlp_plain.csv')
BN_net_df = pd.read_csv(f'results/mlp_bn.csv')


ax1.axis([0, 10, 70, 100])
    

ax1.plot(plain_net_df['epoch'], plainnet_df['train accuracy'], color='yellowgreen',
                 linestyle='--', label=f'MLP - train')
ax1.plot(BN_net_df['epoch'], resnet_df['train accuracy'], color='darkmagenta',
                 linestyle='--', label=f'MLP w/ BN - train')



    
ax1.set_title('Train Accuracy')
ax1.legend(loc='lower left')
ax1.set_xlabel('epochs')
ax1.set_ylabel('accuracy (%)')
ax1.axhline(10, color='black', alpha=0.5, dashes=(10., 10.))
ax1.axhline(5, color='black', alpha=0.5, dashes=(10., 10.))
    


ax1.axis([0, 10, 70, 100])
    
ax2.plot(plain_net_df['epoch'], plainnet_df['test accuracy'], color='yellowgreen',
                 linestyle='--', label=f'MLP - test')
ax2.plot(BN_net_df['epoch'], resnet_df['test accuracy'], color='darkmagenta',
                 linestyle='--', label=f'MLP w/ BN test')



    
ax2.set_title('Test Accuracy')
ax2.legend(loc='lower left')
ax2.set_xlabel('epochs')
ax2.set_ylabel('accuracy (%)')
ax2.axhline(10, color='black', alpha=0.5, dashes=(10., 10.))
ax2.axhline(5, color='black', alpha=0.5, dashes=(10., 10.))

**Questions**
- Implement a CNN with and without BN and compare the results. 
- Check the ResNet block. Do we use BN with ResNet? 