In [29]:
import torch
from torch import nn


In [30]:

inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, W, E = inputs.size()  #batch size, no. of words, embedding for a batch
inputs = inputs.reshape(W, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [31]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [32]:
dims = [-(i + 1) for i in range(len(parameter_shape))] #Dimensions for which to compute normalization: batch and embedding

In [33]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [34]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [35]:
output = gamma * y + beta
output

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [36]:
import torch
from torch import nn

class Normalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        output = self.gamma * y  + self.beta
        print(f"out \n ({output.size()}) = \n {output}")
        return output

In [37]:

batch_size = 3
sentence_length = 4
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([4, 3, 8])) = 
 tensor([[[ 0.8732,  0.8488,  0.1720, -0.1793,  0.7333, -0.9600, -1.0085,
           1.3496],
         [-0.1007, -0.6165,  0.8324, -2.2993,  0.6018,  0.4077, -0.2467,
          -0.5008],
         [ 0.0466,  0.2815, -0.2698, -0.6876, -0.4209, -0.1700, -0.8521,
           0.9155]],

        [[ 0.0042,  1.0113,  0.1568,  0.4589, -0.7188, -1.2711,  0.5005,
           1.5250],
         [-0.0371,  0.3279,  2.3972,  0.5790,  0.5150,  0.6903,  0.8016,
          -1.5472],
         [-1.3012,  1.4258, -0.7413, -1.5327, -0.2613, -0.6357, -1.8437,
           1.6475]],

        [[ 0.3957, -0.3621,  0.1139,  0.4707,  0.5018, -1.0613,  0.3563,
          -0.1848],
         [ 0.1086,  0.7377, -0.0471,  0.2584,  0.6483, -2.6424,  0.0730,
           0.7366],
         [ 0.5479,  0.9551, -1.5742, -0.1763,  1.3702, -0.4154,  0.0106,
          -0.8223]],

        [[ 0.8055,  1.7578,  0.7961, -0.8395,  0.5807,  1.2985, -1.5671,
          -1.0510],
         [ 0.0862, -0.8279, 

In [38]:

normlzn = Normalization(inputs.size()[-1:]) ##normalization on layer dim: embedding

In [39]:
output = normlzn.forward(inputs)

y 
 (torch.Size([4, 3, 8])) = 
 tensor([[[ 0.7828,  0.7531, -0.0688, -0.4954,  0.6129, -1.4435, -1.5024,
           1.3614],
         [ 0.1520, -0.4097,  1.1679, -2.2419,  0.9169,  0.7055, -0.0070,
          -0.2837],
         [ 0.3622,  0.8070, -0.2372, -1.0285, -0.5232, -0.0481, -1.3401,
           2.0079]],

        [[-0.2436,  0.9580, -0.0615,  0.2989, -1.1063, -1.7652,  0.3486,
           1.5710],
         [-0.4958, -0.1360,  1.9038,  0.1116,  0.0485,  0.2213,  0.3310,
          -1.9844],
         [-0.7347,  1.5016, -0.2755, -0.9245,  0.1181, -0.1889, -1.1795,
           1.6834]],

        [[ 0.7240, -0.7712,  0.1680,  0.8719,  0.9334, -2.1508,  0.6461,
          -0.4214],
         [ 0.1203,  0.7279, -0.0301,  0.2649,  0.6416, -2.5372,  0.0858,
           0.7268],
         [ 0.6249,  1.0785, -1.7391, -0.1819,  1.5409, -0.4482,  0.0264,
          -0.9015]],

        [[ 0.5146,  1.3553,  0.5062, -0.9377,  0.3161,  0.9499, -1.5800,
          -1.1244],
         [ 0.3489, -0.9556, -0.1