In [None]:
import torch
from torch import nn

In [None]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [None]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [None]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [None]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [None]:
dims

[-1, -2]

In [None]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [None]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [None]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [None]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [None]:
out = gamma * y + beta

In [None]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

## Class

In [None]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [None]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.4989, -1.5183,  0.9349,  1.2598, -0.4364, -1.8070, -0.8136,
          -1.0286],
         [-0.8853, -0.4556, -0.0162,  1.5420,  0.4248,  1.4558, -1.0326,
           0.3080],
         [-0.5790,  1.7504,  0.6980,  0.0048, -2.1282, -0.7690,  1.0542,
          -1.3146]],

        [[ 1.0202,  0.5731,  0.5530,  0.0186, -0.8504,  0.9823,  0.7989,
           1.2140],
         [-0.6105, -1.0026, -1.0181, -0.1076,  2.1379, -0.3287, -0.5917,
           0.3563],
         [-0.3975, -0.4420,  1.0853,  0.2781,  0.9457,  1.0476,  0.4669,
          -0.6695]],

        [[-0.1404, -0.1451,  0.1926, -2.3987, -0.1363,  0.0039,  0.4464,
          -1.8458],
         [ 1.4095, -0.1229, -1.2031, -0.2409,  0.5317,  0.1864, -0.1956,
          -0.4548],
         [-0.5604,  0.1777,  0.8957,  1.6596, -0.1630,  1.0839,  0.0349,
          -0.2990]],

        [[ 1.4716,  0.4145,  2.3406,  0.7430,  0.2245, -0.9234, -0.5451,
          -0.1811],
         [ 1.5420, -0.6677, 

In [None]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [None]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.3638],
         [ 0.1676],
         [-0.1604]],

        [[ 0.5387],
         [-0.1456],
         [ 0.2893]],

        [[-0.5029],
         [-0.0112],
         [ 0.3537]],

        [[ 0.4431],
         [-0.1600],
         [-0.0611]],

        [[ 0.1583],
         [-0.6578],
         [-0.2939]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.0684],
         [0.9087],
         [1.2063]],

        [[0.6279],
         [0.9626],
         [0.6705]],

        [[0.9639],
         [0.7132],
         [0.7240]],

        [[1.0019],
         [0.8310],
         [1.3570]],

        [[0.7389],
         [1.2349],
         [0.9045]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.8075, -1.0806,  1.2155,  1.5196, -0.0680, -1.3508, -0.4210,
          -0.6222],
         [-1.1587, -0.6859, -0.2022,  1.5125,  0.2831,  1.4177, -1.3208,
           0.1545],
         [-0.3470,  1.5841,  0.7116,  0.1370, -1.6313, -0.5046,  1.0069,
          -0.9568]],



In [None]:
out[0].mean(), out[0].std()

(tensor(9.9341e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))