## Class

In [1]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, inputs):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)  
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y  + self.beta
        return out

In [2]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.0184, -2.2429, -1.3283,  0.6622, -1.6736,  0.3939,  0.0571,
          -0.7898],
         [ 0.5892, -0.8378,  0.2522, -0.0825, -0.8788,  0.0326,  0.2211,
           0.7102],
         [-0.1947, -0.1960, -0.4763,  0.8812,  1.6012,  1.2519, -0.7151,
          -1.1179]],

        [[ 0.4329,  1.3957, -0.8784,  1.9706,  1.7568,  0.3309, -0.2608,
           2.1087],
         [-1.1645, -1.5364, -2.0194, -1.4861, -0.3010,  0.9819,  1.3854,
           0.6967],
         [-0.5434, -0.7746,  0.1574,  1.0256,  0.7936, -0.1484, -0.9630,
          -0.8112]],

        [[-0.0872, -0.8679,  0.7315,  0.6990, -0.1256, -0.3359,  0.2288,
           0.2636],
         [-0.3613, -0.4230, -0.5867,  0.7410,  0.0177, -0.0479,  1.7613,
           1.2320],
         [ 0.5863, -0.3138,  1.0993,  1.2000,  0.1440,  1.2297, -1.2679,
           0.9455]],

        [[ 0.6666,  0.0877,  0.1341,  0.5566, -2.3257,  2.5611, -0.7490,
           0.0487],
         [ 2.5620,  0.6886, 

In [3]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [4]:
out = layer_norm.forward(inputs)

In [5]:
out[0].mean(), out[0].std()

(tensor(2.7319e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))