# The normalization methods available in Deep learning (Pytorch)
In this notebook, we write code to verify Pytorch APIs for normalization methods.

## 1. Batch normalization (per channel over one mini-batch)
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
$$
y = \frac{x-\mathrm{E}(x)}{\sqrt{\mathrm{Var}(x)+\epsilon}}\cdot \gamma + \beta 
$$
The mean and standard-deviation are calculated **per-dimension over the mini-batches** and $\gamma$ and $\beta$ are learnable parameter vectors of size $C$ (where $C$ is the channel dimension). 

The Pytorch API for BN is

`torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)`

### Statistics 
NLP: $N\times L\times C\rightarrow C$

CV: $N\times C\times H\times W \rightarrow C$

In [19]:
import torch
import torch.nn as nn

# We write code to verify Pytorch normalization APIs.
# generate data
num_channels = 3
data = torch.randn(2,num_channels,4,4) # N * C * H * W
# With Learnable Parameters
torch_BN = nn.BatchNorm2d(num_channels)
# Without Learnable Parameters
torch_BN = nn.BatchNorm2d(num_channels, affine=False)

# using Pytorch BN method
output_torch_BN = torch_BN(data)

In [20]:
# implement BN based on the formula in original BN paper 
channel_mean = torch.mean(data, keepdim=True, dim=(0,2,3)) # size: number of channels
channel_var = torch.var(data, dim=(0,2,3), unbiased=False, keepdim=True) # size: number of channels
our_BN = (data - channel_mean)/torch.sqrt(channel_var+1e-5)

torch.allclose(our_BN, output_torch_BN)

True

## 2. Layer normalization (per sample, per layer)
### Statistics 
NLP: $N\times L\times C\rightarrow N\times L$

CV: $N\times C\times H\times W \rightarrow N$

In [58]:
batch_size = 2
time_steps = 3
embedding_dim = 4
input_data = torch.randn(batch_size, time_steps, embedding_dim) # N*L*C
# call pytorhc api torch.LayerNorm
pytorch_api_layerNorm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
pytorch_api_layerNorm_output = pytorch_api_layerNorm(input_data)
# print(pytorch_api_layerNorm_output)
# write code based on the formula
input_mean = torch.mean(input_data, dim=-1, keepdim=True)
input_var = torch.var(input_data, dim=-1, keepdim=True, unbiased=False)
layerNorm_output = (input_data-input_mean)/torch.sqrt(input_var+1e-5)
print(input_mean.shape)
torch.allclose(layerNorm_output, pytorch_api_layerNorm_output) # verify if the results from two methods are same

torch.Size([2, 3, 1])


True

In [56]:
# Image Example
N, C, H, W = 20, 5, 10, 10
input = torch.randn(N, C, H, W)
# Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
# as shown in the image below
layer_norm = nn.LayerNorm([C, H, W])
output = layer_norm(input)
output.shape

torch.Size([20, 5, 10, 10])

## 3. Instance normalization (per sample, per channel)
### Statistics 
NLP: $N\times L\times C\rightarrow N\times C$

CV: $N\times C\times H\times W \rightarrow N\times C$

In [59]:
# call pytorch api
pytorch_instance_normalization = nn.InstanceNorm1d(embedding_dim) # Input: (N, C, L) or (C, L), output: N*C*L or C*L
pytorch_instance_normalization_output = pytorch_instance_normalization(input_data.transpose(-1,-2)).transpose(-1,-2)

# implement the formula
input_mean = torch.mean(input_data, dim=1, keepdim=True)
input_var = torch.var(input_data, dim=1, keepdim=True, unbiased=False)
instance_normalization_math = (input_data-input_mean)/torch.sqrt(input_var+1e-5)
print(input_mean.shape) # statistics
print(torch.allclose(pytorch_instance_normalization_output, instance_normalization_math))

torch.Size([2, 1, 4])
True


## 4. Group normalization (per sample, per group)
NLP: $N\times G L\times C//G\rightarrow N\times G$

CV: $N\times G\times G//C\times H\times W \rightarrow N\times G$

In [38]:
num_groups = 2
# call pytorch api
pytorch_group_norm = nn.GroupNorm(num_groups, embedding_dim, affine=False) # input: num_groups, num_channels
pytorch_group_norm_output = pytorch_group_norm(input_data.transpose(-1,-2)).transpose(-1,-2)
# print(pytorch_group_norm_output)

# implement the formula
grouped_input = torch.split(input_data, num_groups, dim=-1)
output_list = []
for group in grouped_input:
    input_mean = torch.mean(group, dim=(1,2), keepdim=True)
    input_var = torch.var(group, dim=(1,2), keepdim=True, unbiased=False)
    group_norm_output = (group-input_mean)/torch.sqrt(input_var+1e-5)
    output_list.append(group_norm_output)
    
group_norm_output = torch.cat(output_list, dim=-1)
# print(group_norm_output)

torch.allclose(pytorch_group_norm_output, group_norm_output)

True

## 5. Weight normalization (decompose weight into magnitude and direction)
Reference: https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html

```torch.nn.utils.weight_norm(module, name='weight', dim=0)``` is a function instead of a class.
Its formula is
$$
\mathbf{w} = g\frac{\mathbf{v}}{\|\mathbf{v}\|}
$$
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by name (e.g. `weight`) with two parameters: one specifying the magnitude (e.g. `weight_g`) and one specifying the direction (e.g. `weight_v`). 

By default, with `dim=0`, the norm is computed independently per output channel/plane.

In [54]:
# define a module and apply weight normalization to it
linear = nn.Linear(embedding_dim, 3, bias=False)
pytorch_wn_linear = nn.utils.weight_norm(linear) # the output is still a module
pytorch_wn_linear_output = pytorch_wn_linear(input_data)
# print(pytorch_wn_linear_output)

# write weight normalization according to its principle
# print(linear.weight.data.shape)
direction = linear.weight/torch.norm(linear.weight.data, dim=1, keepdim=True)
magnitude = linear.weight_g # weight_g is obtained through randomization
wn_weight = magnitude * direction
wn_linear_output = input_data @ wn_weight.transpose(-1,-2)
# print(wn_linear_output)
torch.allclose(pytorch_wn_linear_output, wn_linear_output)

True

In [69]:
# more code about weight normalization in greater detail
batch_size = 2
feat_dim = 3
hid_dim = 4
data = torch.randn(batch_size, feat_dim)
linear = nn.Linear(feat_dim, hid_dim, bias=False)
wn_linear = nn.utils.weight_norm(linear)
print(linear.weight.shape)
print(linear.weight)

torch.Size([4, 3])
tensor([[ 0.3019, -0.1327, -0.1910],
        [-0.1707,  0.0177, -0.1221],
        [-0.0801, -0.4307, -0.1502],
        [ 0.0355, -0.3091, -0.3491]], grad_fn=<WeightNormInterfaceBackward0>)


In [73]:
weight_magnitude = torch.tensor([linear.weight[i,:].norm() for i in torch.arange(linear.weight.data.shape[0])], dtype=torch.float32)
weight_magnitude = weight_magnitude.reshape(linear.weight.shape[0], 1)
print("weight_magnitude", weight_magnitude.shape)
print(weight_magnitude)
weight_direction = linear.weight/weight_magnitude
print("weight_direction", weight_direction.shape)
print(weight_direction)
print("norm of weight direction:", (weight_direction**2).sum(dim=-1))

weight_magnitude torch.Size([4, 1])
tensor([[0.3811],
        [0.2106],
        [0.4631],
        [0.4676]])
weight_direction torch.Size([4, 3])
tensor([[ 0.7923, -0.3481, -0.5011],
        [-0.8105,  0.0842, -0.5796],
        [-0.1730, -0.9300, -0.3242],
        [ 0.0760, -0.6610, -0.7465]], grad_fn=<DivBackward0>)
norm of weight direction: tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [74]:
weight_direction*weight_magnitude

tensor([[ 0.3019, -0.1327, -0.1910],
        [-0.1707,  0.0177, -0.1221],
        [-0.0801, -0.4307, -0.1502],
        [ 0.0355, -0.3091, -0.3491]], grad_fn=<MulBackward0>)

In [75]:
linear(data)

tensor([[-0.2420,  0.0019, -0.2964, -0.3439],
        [-0.3751,  0.3547,  0.1179,  0.1043]], grad_fn=<MmBackward0>)

In [76]:
wn_linear(data)

tensor([[-0.2420,  0.0019, -0.2964, -0.3439],
        [-0.3751,  0.3547,  0.1179,  0.1043]], grad_fn=<MmBackward0>)

In [83]:
data @ (weight_direction*weight_magnitude).T

tensor([[-0.2420,  0.0019, -0.2964, -0.3439],
        [-0.3751,  0.3547,  0.1179,  0.1043]], grad_fn=<MmBackward0>)

In [84]:
# parameters of wn_linear
for n,p in wn_linear.named_parameters():
    print(n,p)

weight_g Parameter containing:
tensor([[0.3811],
        [0.2106],
        [0.4631],
        [0.4676]], requires_grad=True)
weight_v Parameter containing:
tensor([[ 0.3019, -0.1327, -0.1910],
        [-0.1707,  0.0177, -0.1221],
        [-0.0801, -0.4307, -0.1502],
        [ 0.0355, -0.3091, -0.3491]], requires_grad=True)


In [87]:
# construct the weights of the linear
weight_v_norm = torch.tensor([wn_linear.weight_v[i,:].norm() for i in torch.arange(wn_linear.weight_v.shape[0])], dtype=torch.float32).unsqueeze(-1)
wn_linear.weight_g * wn_linear.weight_v/weight_v_norm

tensor([[ 0.3019, -0.1327, -0.1910],
        [-0.1707,  0.0177, -0.1221],
        [-0.0801, -0.4307, -0.1502],
        [ 0.0355, -0.3091, -0.3491]], grad_fn=<DivBackward0>)

In [88]:
weight_v_norm

tensor([[0.3811],
        [0.2106],
        [0.4631],
        [0.4676]])

In [89]:
wn_linear.weight_g

Parameter containing:
tensor([[0.3811],
        [0.2106],
        [0.4631],
        [0.4676]], requires_grad=True)