<a href="https://colab.research.google.com/github/bhattarai-aavash/deep_learning/blob/main/Notebooks/Chap11/11_3_Batch_Normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 11.3: Batch normalization**

This notebook investigates the use of batch normalization in residual networks.

Work through the cells below, running each cell in turn. In various places you will see the words "TODO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
# Run this if you're in a Colab to install MNIST 1D repository
!pip install git+https://github.com/greydanus/mnist1d

Collecting git+https://github.com/greydanus/mnist1d
  Cloning https://github.com/greydanus/mnist1d to /tmp/pip-req-build-kle2hcqg
  Running command git clone --filter=blob:none --quiet https://github.com/greydanus/mnist1d /tmp/pip-req-build-kle2hcqg
  Resolved https://github.com/greydanus/mnist1d to commit 7878d96082abd200c546a07a4101fa90b30fdf7e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mnist1d
  Building wheel for mnist1d (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mnist1d: filename=mnist1d-0.0.2.post16-py3-none-any.whl size=14665 sha256=6e0829de345c2c95ab9b9c6716cd833e70501c9f6b55a8a3cb8a7ab29b68af20
  Stored in directory: /tmp/pip-ephem-wheel-cache-3220o1yf/wheels/d6/38/42/3d2112bc7d915f6195254ac85eb761d922d1b18f52817aa8e2
Successfully built mnist1d
Installing collected packages: mnist1d
Successfully i

In [2]:
import numpy as np
import os
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import mnist1d
import random

In [3]:
args = mnist1d.data.get_dataset_args()
data = mnist1d.data.get_dataset(args, path='./mnist1d_data.pkl', download=False, regenerate=False)

# The training and test input and outputs are in
# data['x'], data['y'], data['x_test'], and data['y_test']
print("Examples in training set: {}".format(len(data['y'])))
print("Examples in test set: {}".format(len(data['y_test'])))
print("Length of each example: {}".format(data['x'].shape[-1]))

Did or could not load data from ./mnist1d_data.pkl. Rebuilding dataset...
Examples in training set: 4000
Examples in test set: 1000
Length of each example: 40


In [4]:
# Load in the data
train_data_x = data['x'].transpose()
train_data_y = data['y']
val_data_x = data['x_test'].transpose()
val_data_y = data['y_test']
# Print out sizes
print("Train data: %d examples (columns), each of which has %d dimensions (rows)"%((train_data_x.shape[1],train_data_x.shape[0])))
print("Validation data: %d examples (columns), each of which has %d dimensions (rows)"%((val_data_x.shape[1],val_data_x.shape[0])))

Train data: 4000 examples (columns), each of which has 40 dimensions (rows)
Validation data: 1000 examples (columns), each of which has 40 dimensions (rows)


In [5]:
def print_variance(name, data):
  # First dimension(rows) is batch elements
  # Second dimension(columns) is neurons.
  np_data = data.detach().numpy()
  # Compute variance across neurons and average these variances over members of the batch
  neuron_variance = np.mean(np.var(np_data, axis=0))
  # Print out the name and the variance
  print("%s variance=%f"%(name,neuron_variance))

In [6]:
# He initialization of weights
def weights_init(layer_in):
  if isinstance(layer_in, nn.Linear):
    nn.init.kaiming_uniform_(layer_in.weight)
    layer_in.bias.data.fill_(0.0)

In [7]:
def run_one_step_of_model(model, x_train, y_train):
  # choose cross entropy loss function (equation 5.24 in the loss notes)
  loss_function = nn.CrossEntropyLoss()
  # construct SGD optimizer and initialize learning rate and momentum
  optimizer = torch.optim.SGD(model.parameters(), lr = 0.05, momentum=0.9)

  # load the data into a class that creates the batches
  data_loader = DataLoader(TensorDataset(x_train,y_train), batch_size=200, shuffle=True, worker_init_fn=np.random.seed(1))

  # Initialize model weights
  model.apply(weights_init)

  # Get a batch
  for i, data in enumerate(data_loader):
    # retrieve inputs and labels for this batch
    x_batch, y_batch = data
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward pass -- calculate model output
    pred = model(x_batch)
    # compute the loss
    loss = loss_function(pred, y_batch)
    # backward pass
    loss.backward()
    # SGD update
    optimizer.step()
    # Break out of this loop -- we just want to see the first
    # iteration, but usually we would continue
    break

In [8]:
# convert training data to torch tensors
x_train = torch.tensor(train_data_x.transpose().astype('float32'))
y_train = torch.tensor(train_data_y.astype('long'))

In [9]:
# This is a simple residual model with 5 residual branches in a row
class ResidualNetwork(torch.nn.Module):
  def __init__(self, input_size, output_size, hidden_size=100):
    super(ResidualNetwork, self).__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.linear2 = nn.Linear(hidden_size, hidden_size)
    self.linear3 = nn.Linear(hidden_size, hidden_size)
    self.linear4 = nn.Linear(hidden_size, hidden_size)
    self.linear5 = nn.Linear(hidden_size, hidden_size)
    self.linear6 = nn.Linear(hidden_size, hidden_size)
    self.linear7 = nn.Linear(hidden_size, output_size)

  def count_params(self):
    return sum([p.view(-1).shape[0] for p in self.parameters()])

  def forward(self, x):
    print_variance("Input",x)
    f = self.linear1(x)
    print_variance("First preactivation",f)
    res1 = f+ self.linear2(f.relu())
    print_variance("After first residual connection",res1)
    res2 = res1 + self.linear3(res1.relu())
    print_variance("After second residual connection",res2)
    res3 = res2 + self.linear4(res2.relu())
    print_variance("After third residual connection",res3)
    res4 = res3 + self.linear5(res3.relu())
    print_variance("After fourth residual connection",res4)
    res5 = res4 + self.linear6(res4.relu())
    print_variance("After fifth residual connection",res5)
    return self.linear7(res5)

In [10]:
# Define the model and run for one step
# Monitoring the variance at each point in the network
n_hidden = 100
n_input = 40
n_output = 10
model = ResidualNetwork(n_input, n_output, n_hidden)
run_one_step_of_model(model, x_train, y_train)

Input variance=1.016311
First preactivation variance=1.912032
After first residual connection variance=3.384397
After second residual connection variance=5.635843
After third residual connection variance=9.055282
After fourth residual connection variance=15.182500
After fifth residual connection variance=27.545427


Notice that the variance roughly doubles at each step so it increases exponentially as in figure 11.6b in the book.

In [15]:
# TODO Adapt the residual network below to add a batch norm operation
# before the contents of each residual link as in figure 11.6c in the book
# Use the torch function nn.BatchNorm1d
class ResidualNetworkWithBatchNorm(torch.nn.Module):
  def __init__(self, input_size, output_size, hidden_size=100):
    super(ResidualNetworkWithBatchNorm, self).__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.linear2 = nn.Linear(hidden_size, hidden_size)
    self.linear3 = nn.Linear(hidden_size, hidden_size)
    self.linear4 = nn.Linear(hidden_size, hidden_size)
    self.linear5 = nn.Linear(hidden_size, hidden_size)
    self.linear6 = nn.Linear(hidden_size, hidden_size)
    self.linear7 = nn.Linear(hidden_size, output_size)


  def count_params(self):
    return sum([p.view(-1).shape[0] for p in self.parameters()])

  def forward(self, x):
    print_variance("Input", x)

    # Initial transformation
    x = self.linear1(x)
    print_variance("First preactivation", x)

    # First residual block
    x = nn.BatchNorm1d(x.shape[1])(x)  # Proper BatchNorm initialization
    residual = self.linear2(x.relu())
    res1 = x + residual
    print_variance("After first residual connection", res1)

    # Second residual block
    res1 = nn.BatchNorm1d(res1.shape[1])(res1)
    residual = self.linear3(res1.relu())
    res2 = res1 + residual
    print_variance("After second residual connection", res2)

    # Third residual block
    res2 = nn.BatchNorm1d(res2.shape[1])(res2)
    residual = self.linear4(res2.relu())
    res3 = res2 + residual
    print_variance("After third residual connection", res3)

    # Fourth residual block
    res3 = nn.BatchNorm1d(res3.shape[1])(res3)
    residual = self.linear5(res3.relu())
    res4 = res3 + residual
    print_variance("After fourth residual connection", res4)

    # Fifth residual block
    res4 = nn.BatchNorm1d(res4.shape[1])(res4)
    residual = self.linear6(res4.relu())
    res5 = res4 + residual
    print_variance("After fifth residual connection", res5)

    return self.linear7(res5)

In [16]:
# Define the model
n_hidden = 100
n_input = 40
n_output = 10
model = ResidualNetworkWithBatchNorm(n_input, n_output, n_hidden)
run_one_step_of_model(model, x_train, y_train)

Input variance=1.000050
First preactivation variance=1.860212
After first residual connection variance=1.747567
After second residual connection variance=1.670213
After third residual connection variance=1.716545
After fourth residual connection variance=1.681681
After fifth residual connection variance=1.748113


Note that the variance now increases linearly as in figure 11.6c.