In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])

B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [4]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [5]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [6]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [7]:
dims

[-1, -2]

In [8]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [9]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [10]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [11]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [12]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [17]:
class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i+1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [15]:
batch_size = 3
sentence_length = 5
embedding_dim = 8

inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"Input \n ({inputs.size()}) = \n {inputs}")

Input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.6663, -0.6031, -0.1954, -0.4783, -0.1913, -2.1234,  1.4013,
          -1.8660],
         [-2.6459, -0.9321, -1.2232,  0.3879, -0.8330, -0.1635, -0.8674,
           0.6379],
         [-0.0296, -1.5482, -1.4638,  1.8920,  0.8234,  1.9997,  0.9654,
          -0.6988]],

        [[ 0.3853,  1.2258, -1.1657, -0.7474, -0.4420,  0.3952,  0.2892,
           1.7438],
         [ 0.4357,  0.2910, -0.0870, -0.8109, -0.4364, -0.5306, -1.1887,
           0.5579],
         [-0.1653,  1.5606, -0.2312, -0.3046, -0.9664, -0.9021, -0.1178,
           0.5464]],

        [[-0.5126, -0.6931, -1.3968, -0.2932,  0.4205, -1.1616, -0.7999,
           0.6077],
         [ 0.9497,  2.3231, -0.1199, -0.6458, -0.0829,  1.1362,  2.1374,
           0.1906],
         [ 0.1895,  0.4245,  0.7102, -0.8329, -1.7412, -0.5330, -0.4730,
           0.3305]],

        [[ 0.3882,  0.6321,  0.2389, -2.3705,  0.3904, -0.0395, -0.6019,
          -0.1565],
         [ 1.9020,  0.7406, 

In [18]:
layer_norm = LayerNormalization(inputs.size()[-1:])

out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.4237],
         [-0.7049],
         [ 0.2425]],

        [[ 0.2105],
         [-0.2211],
         [-0.0726]],

        [[-0.4786],
         [ 0.7360],
         [-0.2407]],

        [[-0.1898],
         [ 0.3555],
         [ 0.0687]],

        [[ 0.3476],
         [ 0.4420],
         [ 0.0394]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.0971],
         [0.9631],
         [1.3109]],

        [[0.9145],
         [0.5857],
         [0.7605]],

        [[0.6598],
         [1.0185],
         [0.7583]],

        [[0.9000],
         [1.1142],
         [0.7581]],

        [[1.1249],
         [0.7564],
         [0.8698]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.9936, -0.1635,  0.2081, -0.0497,  0.2119, -1.5492,  1.6635,
          -1.3146],
         [-2.0153, -0.2358, -0.5381,  1.1346, -0.1330,  0.5622, -0.1688,
           1.3942],
         [-0.2076, -1.3660, -1.3017,  1.2583,  0.4431,  1.3405,  0.5515,
          -0.7181]],



In [19]:
out[0].mean(), out[0].std()

(tensor(1.4901e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))