<a href="https://colab.research.google.com/github/ayushksingh28/transformers_scratch/blob/main/Layer_Normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn

In [None]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [None]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [None]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [None]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [None]:
dims

[-1, -2]

In [None]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [None]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [None]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [None]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [None]:
out = gamma * y + beta

In [None]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

#Class

In [None]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [None]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.5132, -0.9555, -0.3850,  1.4630, -0.0091,  1.3365,  0.7020,
           1.2015],
         [ 1.0270,  1.4306, -3.1770, -0.0610,  3.5120, -1.4810, -0.5528,
          -1.5328],
         [-2.0176,  0.4930, -0.2202,  0.9158,  0.6347, -1.1098,  1.5739,
           1.1471]],

        [[-0.4023,  0.8491,  1.5738, -1.0746,  0.7685, -0.9458,  0.6143,
          -0.1341],
         [-1.3037,  0.2526,  0.0357,  0.1983,  0.1162,  0.0347,  0.0358,
          -0.3344],
         [ 0.0660,  0.3007, -0.2052,  2.9540,  0.1326,  0.5147,  0.8439,
          -0.5573]],

        [[ 1.1804,  0.4248,  1.4208,  2.0492, -0.7309,  1.6413, -1.2175,
           0.8234],
         [ 0.3516,  1.0422,  0.4247,  0.3434, -0.2263, -2.0540, -0.0428,
           0.1775],
         [ 0.0206,  0.6769, -0.2768, -1.3291,  0.2528, -1.1714,  0.6798,
           1.0746]],

        [[ 1.2992,  0.7476, -0.3813,  1.2765, -0.1336,  0.1039, -1.4050,
          -0.7291],
         [-0.1854, -0.0963, 

In [None]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [None]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.3550],
         [-0.1044],
         [ 0.1771]],

        [[ 0.1561],
         [-0.1206],
         [ 0.5062]],

        [[ 0.6989],
         [ 0.0020],
         [-0.0091]],

        [[ 0.0973],
         [-0.0308],
         [-0.4872]],

        [[-0.2696],
         [ 0.4218],
         [ 0.0180]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.8789],
         [1.9460],
         [1.1390]],

        [[0.8803],
         [0.4764],
         [1.0074]],

        [[1.0766],
         [0.8523],
         [0.8181]],

        [[0.9000],
         [0.5809],
         [0.7405]],

        [[0.8176],
         [0.8589],
         [0.8164]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.9879, -1.4912, -0.8420,  1.2606, -0.4143,  1.1167,  0.3948,
           0.9632],
         [ 0.5814,  0.7888, -1.5790,  0.0223,  1.8584, -0.7074, -0.2304,
          -0.7341],
         [-1.9269,  0.2774, -0.3489,  0.6485,  0.4018, -1.1299,  1.2264,
           0.8517]],

