This involves normalizing the inputs across the feature dimension instead of the batch dimension. This ensures that training remains stable

In [23]:
import torch
import torch.nn as nn
import numpy as np

## Initializing Inputs

In [5]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
batch, seq_len, emb_dim = inputs.size()
inputs = inputs.view(seq_len, batch, emb_dim)
inputs.size()

torch.Size([2, 1, 3])

## Creating learnable parameters

In [10]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))
parameter_shape

torch.Size([1, 3])

## Computing mean and standard Deviation

In [16]:
dims = [-(i+1) for i in range(len(parameter_shape))]
# We want mean over -2 and -1 the last two dimensions
means = inputs.mean(dim = dims, keepdim=True)
var = ((inputs-means)**2).mean(dim = dims, keepdim= True)
epsilon = 1e-5
std = (var+epsilon).sqrt()

In [18]:
means

tensor([[[0.2000]],

        [[0.2333]]])

In [19]:
std

tensor([[[0.0817]],

        [[0.1886]]])

## Normalizing the input

In [20]:
y = (inputs - means)/std

In [22]:
out = gamma*y + beta

## Putting this in a class

In [39]:
class LayerNormalization:
    def __init__(self, parameter_shape, eps = 1e-5):
        self.parameter_shape = parameter_shape
        self.eps = eps
        self.gamma = torch.ones(self.parameter_shape)
        self.beta = torch.zeros(self.parameter_shape)

    def forward(self, input):
        dims = [-(i+1) for i in range(len(self.parameter_shape))]
        mean = input.mean(dim = dims, keepdim = True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((input-mean)**2).mean(dim = dims, keepdim = True)
        std = (var+self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (input-mean)/std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma*y + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [40]:
batch_size = 3
sentence_length = 5
embedding_dim = 8

inputs = torch.randn(sentence_length, batch_size, embedding_dim)
print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-5.5680e-01, -1.1699e+00,  1.1403e+00,  4.0501e-01, -1.3402e+00,
           1.8886e-01, -1.8179e-01,  2.3194e+00],
         [ 1.2173e-01, -1.2331e+00,  3.5195e-01, -1.2103e+00, -1.2475e+00,
          -4.6125e-01, -1.4173e-01,  6.9999e-01],
         [-5.4593e-01, -9.4886e-01,  4.0205e-04,  1.0998e+00, -4.0626e-01,
          -2.1154e+00,  7.7490e-01, -1.4720e+00]],

        [[ 3.7019e-01, -5.8448e-01, -1.4146e+00, -6.5040e-01,  1.6822e+00,
          -9.6505e-01, -9.2295e-01,  2.5152e-01],
         [-7.5900e-01, -1.2131e+00,  1.3730e+00, -1.7024e-01, -2.7530e-01,
          -5.8667e-02,  2.3098e+00,  6.9863e-01],
         [ 1.5008e-02,  1.6094e+00,  2.9958e-01, -4.1022e-01, -4.3226e-01,
           1.1714e-01,  7.9998e-02,  2.4723e+00]],

        [[-5.7140e-01,  1.6102e+00, -4.2432e-01, -2.3240e+00, -1.0647e+00,
          -1.1813e+00, -1.6541e+00, -1.0832e+00],
         [ 2.2693e-01,  1.0672e+00,  1.3820e+00, -7.8603e-01,  5.0952e-01,
          

In [41]:
inputs.size()[-1:]

torch.Size([8])

In [42]:
layer_norm = LayerNormalization(inputs.size()[-1:]) # Normalize only across embedding dimension
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.1006],
         [-0.3900],
         [-0.4517]],

        [[-0.2792],
         [ 0.2381],
         [ 0.4689]],

        [[-0.8366],
         [ 0.5096],
         [-0.2611]],

        [[ 0.1934],
         [ 0.6887],
         [ 0.2054]],

        [[-0.2870],
         [-0.0373],
         [ 0.4659]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1351],
         [0.7234],
         [1.0114]],

        [[0.9324],
         [1.0846],
         [0.9624]],

        [[1.0805],
         [0.6199],
         [1.1094]],

        [[0.7478],
         [1.1591],
         [1.1932]],

        [[0.7404],
         [0.7177],
         [0.7749]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-5.7919e-01, -1.1194e+00,  9.1597e-01,  2.6818e-01, -1.2694e+00,
           7.7748e-02, -2.4879e-01,  1.9548e+00],
         [ 7.0746e-01, -1.1654e+00,  1.0257e+00, -1.1339e+00, -1.1854e+00,
          -9.8447e-02,  3.4324e-01,  1.5068e+00],
         [-9.3195e-02, -4.9159e