### Layer Normalization



In [1]:
import torch
from torch import nn

In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))
print(gamma.size(), beta.size())

torch.Size([1, 3]) torch.Size([1, 3])


In [4]:
dims = [-(i+1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [5]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [6]:
var = ((inputs - mean) ** 2).mean(dim = dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [7]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [8]:
out = gamma * y + beta
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

In [9]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [10]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.4238, -0.8494, -0.2418, -0.1314,  1.6148, -1.1439, -0.4703,
          -0.0906],
         [ 0.5904, -0.9541,  0.7285,  0.9892,  0.9337,  1.2070, -0.8387,
           0.8825],
         [-0.7253, -0.2724, -0.8787, -0.0711, -1.0140, -1.2031, -1.6363,
           0.9726]],

        [[-0.3037, -0.1229,  0.1466, -0.1652,  0.7835, -0.7846, -0.2860,
          -1.7230],
         [-0.4247,  0.9897, -1.3756, -1.1338, -0.8413,  0.9969,  0.5889,
           1.6285],
         [ 0.7884, -0.5712,  0.1942,  0.3323,  0.3311, -1.0986,  0.7959,
          -0.6507]],

        [[-0.5853, -0.2893, -0.1234,  0.1845,  1.0803, -2.1561,  0.8917,
          -0.5371],
         [ 0.5074, -1.8032, -0.0088, -0.6137, -1.4536, -0.5079, -0.0224,
           0.3738],
         [ 0.3341,  1.1888, -0.0369,  1.8104,  0.7141,  0.2760, -0.4513,
          -0.1409]],

        [[-0.6605, -2.6758,  0.2186, -2.3132,  0.6068,  0.8918, -1.1461,
           0.6550],
         [ 1.1131, -0.9203, 

In [11]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [12]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-0.1111],
         [ 0.4423],
         [-0.6035]],

        [[-0.3069],
         [ 0.0536],
         [ 0.0152]],

        [[-0.1918],
         [-0.4410],
         [ 0.4618]],

        [[-0.5529],
         [-0.1447],
         [-0.0875]],

        [[-0.0570],
         [-0.2846],
         [ 0.3390]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.7924],
         [0.7917],
         [0.7552]],

        [[0.6785],
         [1.0616],
         [0.6584]],

        [[0.9408],
         [0.7779],
         [0.6989]],

        [[1.2969],
         [0.8618],
         [0.6656]],

        [[0.9550],
         [0.7716],
         [0.9658]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 0.6751, -0.9316, -0.1650, -0.0257,  2.1780, -1.3033, -0.4533,
           0.0259],
         [ 0.1870, -1.7639,  0.3614,  0.6908,  0.6207,  0.9659, -1.6181,
           0.5560],
         [-0.1612,  0.4385, -0.3644,  0.7051, -0.5435, -0.7939, -1.3675,
           2.0871]],

