# Layer Norm

In [1]:
import torch
import torch.nn as nn


eps = 1e-10

In [2]:
def custom_layer_norm(input_tensor, eps):
    
    initial_shape = input_tensor.shape
    
    if len(input_tensor.shape) >= 4:
        input_tensor = input_tensor.reshape(input_tensor.shape[0], input_tensor.shape[1], -1)
    elif len(input_tensor.shape) == 2:
        input_tensor = input_tensor.unsqueeze(-1)

    normed_tensor = torch.empty(input_tensor.shape)
    
    gamma = torch.ones(input_tensor.shape[1], input_tensor.shape[2])
    beta = torch.zeros(input_tensor.shape[1], input_tensor.shape[2])

    for layer in range(input_tensor.shape[0]):
        current_slice = input_tensor[layer]
        mean, var = current_slice.reshape(-1).mean(dim=0), current_slice.reshape(-1).var(dim=0, unbiased=False)
        normed_tensor[layer] = (current_slice - mean)/torch.sqrt(var + eps) * gamma + beta
        
    return normed_tensor.reshape(initial_shape)

In [3]:
all_correct = True
for dim_count in range(3, 9):
    input_tensor = torch.randn(*list(range(3, dim_count + 2)), dtype=torch.float)
    layer_norm = nn.LayerNorm(input_tensor.size()[1:], elementwise_affine=False, eps=eps)

    norm_output = layer_norm(input_tensor)
    custom_output = custom_layer_norm(input_tensor, eps)

    all_correct &= torch.allclose(norm_output, custom_output, 1e-2)
    all_correct &= norm_output.shape == custom_output.shape
print(all_correct)

True


# Instance Norm

In [4]:
import torch
import torch.nn as nn

In [5]:
eps = 1e-3

batch_size = 5
input_channels = 2
input_length = 30

In [6]:
instance_norm = nn.InstanceNorm1d(input_channels, affine=False, eps=eps)

input_tensor = torch.randn(batch_size, input_channels, input_length, dtype=torch.float)

In [7]:
def custom_instance_norm1d(input_tensor, eps):
    
    normed_tensor = torch.empty(input_tensor.shape)
    
    gamma = torch.ones(input_tensor.shape[2])
    beta = torch.zeros(input_tensor.shape[2])
    
    for image in range(input_tensor.shape[0]):
        for layer in range(input_tensor.shape[1]):
            
            current_slice = input_tensor[image][layer]
            mean, var = current_slice.reshape(-1).mean(), current_slice.reshape(-1).var(unbiased=False)
            
            normed_tensor[image][layer] = (current_slice - mean)/torch.sqrt(var + eps) * gamma + beta

    return normed_tensor

In [8]:
norm_output = instance_norm(input_tensor)
custom_output = custom_instance_norm1d(input_tensor, eps)
print(torch.allclose(norm_output, custom_output) and norm_output.shape == custom_output.shape)

True


# Group Norm

In [15]:
import torch
import torch.nn as nn

In [16]:
channel_count = 6
eps = 1e-3
batch_size = 20
input_size = 2

input_tensor = torch.randn(batch_size, channel_count, input_size)

In [17]:
groups = 3
group_norm = nn.GroupNorm(groups, channel_count, affine=False)
input_tensor.shape

torch.Size([20, 6, 2])

In [23]:
def custom_group_norm(input_tensor, groups, eps):
    
    normed_tensor = torch.empty(input_tensor.shape)
    
    gamma = torch.ones(input_tensor.shape[1]//groups, input_tensor.shape[2])
    beta = torch.zeros(input_tensor.shape[1]//groups, input_tensor.shape[2])
    
    for layer in range(input_tensor.shape[0]):
        
        channel = torch.zeros(input_tensor.shape[1], input_tensor.shape[2])
        
        for sub in range(0, input_tensor.shape[1], input_tensor.shape[1]//groups):
            
            current_slice = input_tensor[layer][sub:sub+input_tensor.shape[1]//groups]           
            mean, var = current_slice.reshape(-1).mean(), current_slice.reshape(-1).var(unbiased=False)
            channel[sub:sub + input_tensor.shape[1]//groups] = (current_slice - mean)/torch.sqrt(var + eps) * gamma + beta

        normed_tensor[layer] = channel
        
    return normed_tensor

In [24]:
all_correct = True
for groups in [1, 2, 3, 6]:
    group_norm = nn.GroupNorm(groups, channel_count, eps=eps, affine=False)
    all_correct &= torch.allclose(group_norm(input_tensor), custom_group_norm(input_tensor, groups, eps), 1e-3)
print(all_correct)

True
