# Normalization



In this notebook, I implement weight norm, batch norm and layer norm.

[Weight normalization](https://arxiv.org/pdf/1602.07868.pdf) is a method of reparameterizing the weight vector `w` in the standard linear neural network layer

```
y = φ(w · x + b)
```

 as

```
w = (v*g)/(||v||)
```
where `v` is the learnable weight vector and `g` is a learnable scalar. This reparameterization has the effect of decoupling the magnitude and the direction of the weight vector.

Both [layer](https://arxiv.org/abs/1607.06450) and [batch](https://arxiv.org/abs/1502.03167) normalization are methods of reducing covariate shift (change in input distribution) by normalizing the input to each layer. They involve normalizing the input

```
input_normalized = (input-input_mean)/input_std
```

which is then transformed using learnable parameters a and b

```
final_input = a * input_normalized + b
```
This results in each layer in the network receiving input with a consistent distribution, allowing it to train faster and become more accurate.

The difference between batch norm and layer norm lies in how `input_mean` and `input_std` are calculated. In batch norm, they are typically calculated across each minibatch, while in layer norm they are calculated individually for each input vector to each layer.

Layer norm was created to remove the relationship between batch size and the effect of normalization.




## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter
import math

import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Get Data

In [2]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=128, shuffle=True)

## Weight Norm


The below code is an edited version of PyTorch's `Linear` layer with weight normalization added.

In [3]:
class WeightNormalizedLinear(nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.g = Parameter(torch.empty((out_features,in_features),**factory_kwargs))
        self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.g, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.g*(self.weight)/torch.norm(self.weight,dim=0), self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

## Batch Norm

In [4]:
class BatchNorm(nn.Module):
  def __init__(self):
        super(BatchNorm, self).__init__()
        self.beta = Parameter(torch.empty(1,device=device))
        self.gamma = Parameter(torch.empty(1,device=device))
        torch.nn.init.uniform_(self.beta)
        torch.nn.init.uniform_(self.gamma)

  def forward(self, x):
    mean = torch.mean(x)
    diff = x - mean
    stdev = torch.mean(diff*diff)
    z = (x-mean)/torch.sqrt(stdev+0.000001)
    return self.gamma*z + self.beta

## Layer Norm

In [5]:
class LayerNorm(nn.Module):
  def __init__(self):
        super(LayerNorm, self).__init__()
        self.beta = Parameter(torch.empty(1,device=device))
        self.gamma = Parameter(torch.empty(1,device=device))
        torch.nn.init.uniform_(self.beta)
        torch.nn.init.uniform_(self.gamma)

  def forward(self, x):
    mean = torch.mean(x,dim=1)
    diff = x - mean.reshape(-1,1)
    stdev = torch.mean(diff*diff,dim=1)
    z = (diff)/torch.sqrt(stdev.reshape(-1,1)+0.000001)
    return self.gamma*z + self.beta

## Model

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.p1 = WeightNormalizedLinear(784,512,device=device)
        self.ln1 = LayerNorm()
        self.p2 = WeightNormalizedLinear(512,512,device=device)
        self.bn2 = BatchNorm()
        self.p3 = WeightNormalizedLinear(512,256,device=device)
        self.ln3 = LayerNorm()
        self.p4 = WeightNormalizedLinear(256,10,device=device)

    def forward(self, x):

        x = F.relu(self.p1(x))
        x = self.ln1(x)
        x = F.relu(self.p2(x))
        x = self.bn2(x)
        x = F.relu(self.p3(x))
        x = self.ln3(x)
        x = self.p4(x)

        return x

In [7]:
model = Net()
model.to(device)
optimizer = optim.Adam(model.parameters(),lr = 0.001)

## Train/Test

In [8]:
def train_one_epoch(model):
    loss_fn = nn.CrossEntropyLoss()
    model.to(device)
    for inputs, labels in train_dataloader:
      optimizer.zero_grad()
      outputs = model.forward(inputs.to(device).reshape(-1,784))
      loss = loss_fn(outputs, labels.to(device))
      loss.backward()
      optimizer.step()

In [9]:
def test(model, dataloader = test_dataloader):

    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, labels in dataloader:
            outputs = model.forward(inputs.to(device).reshape(-1,784))
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()
    return correct/total

In [11]:
for epoch in range(5):
  train_one_epoch(model)
  test_acc = test(model)
  print(f"epoch : {epoch+1}, test acc : {test_acc}")

epoch : 1, test acc : 0.9806
epoch : 2, test acc : 0.9805
epoch : 3, test acc : 0.9815
epoch : 4, test acc : 0.9799
epoch : 5, test acc : 0.9814
