In [1]:
import torch
from torch import nn

In [3]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [4]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [5]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [6]:
dims = [-(i+1) for i in range(len(parameter_shape))]

In [8]:
dims

[-1, -2]

In [10]:
mean = inputs.mean(dim = dims,keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [11]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [12]:
var = ((inputs-mean)**2).mean(dim = dims, keepdim = True)
epsilon=1e-5
std = (var+epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [14]:
y = (inputs-mean)/std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [15]:
out = gamma * y + beta

In [16]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

# Final Coded Class

In [21]:
import torch
from torch import nn


class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))
    
    def forward(self, input):
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim = dims,keepdim=True)
        print(f"Mean\n({mean.size()}): \n{mean}")
        var = ((inputs-mean)**2).mean(dim = dims, keepdim = True)
        std = (var+self.eps).sqrt()
        print(f"Standard Deviation \n({std.size()}): \n {std}")
        y = (inputs-mean)/std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma*y+self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [22]:
batch_size = 3
sentence_len = 5
embedding_dim = 8
inputs = torch.randn(sentence_len, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.4483, -0.8446,  0.4163,  0.4035,  0.7656, -0.4665, -0.0477,
           1.6926],
         [ 1.0695, -1.0304, -1.4541, -0.1808,  1.3723, -0.4448,  0.3324,
          -0.1096],
         [-1.4080, -0.7032,  0.8333, -0.5489, -1.0976,  1.3816, -0.5604,
           1.1932]],

        [[ 0.1203,  0.3701, -0.0965, -0.3001, -0.9296,  1.9759,  1.4141,
          -0.1827],
         [-0.6961,  0.0852, -0.4433,  1.8251, -0.0195, -0.6754, -1.4085,
           0.2618],
         [-0.5855, -0.9787,  0.9353, -0.7325,  0.1625, -0.8249, -1.0062,
           0.8623]],

        [[ 1.3271,  0.3044, -1.2154, -0.7595, -0.3714,  1.0765,  0.8129,
          -0.5579],
         [-1.7047,  0.1477,  0.3521, -1.1187, -0.1087, -1.6481, -0.6429,
           1.0010],
         [ 0.1528, -0.1740,  1.5574, -1.3735,  0.0333,  0.7055, -0.8190,
           1.9306]],

        [[-0.8772, -0.0215,  0.0136,  0.4636,  1.4950,  0.1716, -1.0654,
           2.3146],
         [-1.6381, -2.1887, 

In [25]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [26]:
out = layer_norm.forward(inputs)

Mean
(torch.Size([5, 3, 1])): 
tensor([[[ 0.2959],
         [-0.0557],
         [-0.1138]],

        [[ 0.2965],
         [-0.1339],
         [-0.2709]],

        [[ 0.0771],
         [-0.4653],
         [ 0.2516]],

        [[ 0.3118],
         [-0.7740],
         [ 0.2517]],

        [[-0.5515],
         [-0.4047],
         [-0.2845]]])
Standard Deviation 
(torch.Size([5, 3, 1])): 
 tensor([[[0.7263],
         [0.9036],
         [1.0137]],

        [[0.8909],
         [0.8934],
         [0.7573]],

        [[0.8753],
         [0.9162],
         [1.0462]],

        [[1.0586],
         [0.9005],
         [0.8297]],

        [[1.1091],
         [1.0116],
         [1.0187]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.2097, -1.5702,  0.1657,  0.1481,  0.6466, -1.0497, -0.4731,
           1.9228],
         [ 1.2452, -1.0787, -1.5477, -0.1385,  1.5804, -0.4306,  0.4295,
          -0.0597],
         [-1.2768, -0.5815,  0.9342, -0.4293, -0.9706,  1.4751, -0.4406,
           1.2893]],

    