# Implementing Batchnorm




In this notebook, I implement [Batchnorm](https://arxiv.org/pdf/1502.03167)

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter
import math

import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Get Data

In [2]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 265518678.56it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 101284024.94it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 188336139.99it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3725900.40it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






## Create Network


In [3]:
class BatchNorm(nn.Module):
  def __init__(self):
        super(BatchNorm, self).__init__()
        self.beta = Parameter(torch.empty(1,device=device))
        self.gamma = Parameter(torch.empty(1,device=device))
        torch.nn.init.uniform_(self.beta)
        torch.nn.init.uniform_(self.gamma)

  def forward(self, x):
    mean = torch.mean(x)
    diff = x - mean
    stdev = torch.mean(diff*diff)
    z = (x-mean)/torch.sqrt(stdev+0.000001)
    return self.gamma*z + self.beta

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #we reduce the standard deviation of initialised weights to reduce chance of gradient explosion

        self.p1 = nn.Linear(784,512,device=device)
        self.bn1 = BatchNorm()
        self.p2 = nn.Linear(512,512,device=device)
        self.bn2 = BatchNorm()
        self.p3 = nn.Linear(512,256,device=device)
        self.bn3 = BatchNorm()
        self.p4 = nn.Linear(256,10,device=device)

    def forward(self, x):

        x = F.relu(self.p1(x))
        x = self.bn1(x)
        x = F.relu(self.p2(x))
        x = self.bn2(x)
        x = F.relu(self.p3(x))
        x = self.bn3(x)
        x = self.p4(x)

        return x


## Train

In [5]:

def train_one_epoch(model):
    loss_fn = nn.CrossEntropyLoss()
    model.to(device)

    optimizer = optim.Adam(model.parameters(),lr = 0.001)

    for inputs, labels in train_dataloader:
      optimizer.zero_grad()
      inputs = inputs.to(device)
      labels = labels.to(device)
      outputs = model.forward(inputs.reshape(-1,784))
      loss = loss_fn(outputs, labels)
      loss.backward()
      optimizer.step()



## Test

In [6]:
def test(model, dataloader = test_dataloader):

    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels=labels.to(device)
            outputs = model.forward(inputs.reshape(-1,784))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total

## Test

In [7]:
model = Net()
model.to(device)


for _ in range(10):
  train_one_epoch(model)
  test_acc = test(model)
  print(f"test acc : {test_acc}")




test acc : 0.9674
test acc : 0.9737
test acc : 0.975
test acc : 0.9799
test acc : 0.979
test acc : 0.9801
test acc : 0.983
test acc : 0.9818
test acc : 0.9812
test acc : 0.9835
