In [20]:
import torch
from torch import nn
import matplotlib.pyplot as plt


In [2]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [3]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [4]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [5]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [6]:
dims

[-1, -2]

In [7]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [8]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [9]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [10]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [11]:
out = gamma * y + beta

In [12]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

## Class

In [13]:
class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.2165, -1.2031,  1.1671, -0.6439,  0.0029,  1.7320,  1.8665,
          -0.9005],
         [-2.0346, -1.1776,  0.9851,  0.8818,  0.0410, -1.3186, -2.0149,
          -0.2498],
         [ 1.5747, -0.1300, -0.4144,  1.1730,  0.0966, -0.8541, -1.0799,
           0.1837]],

        [[ 0.7905, -1.9945,  0.3453, -0.1782,  0.9445, -0.1497,  1.1303,
           0.0070],
         [ 0.2251, -0.5068,  0.4586, -0.6954,  2.4347, -0.1075, -0.5950,
          -1.6831],
         [ 0.7575,  0.2112,  1.6075, -0.2902,  0.3255,  0.4440, -0.7570,
          -2.0835]],

        [[-0.9890,  1.5902, -1.3114, -0.4992, -0.8564, -0.4182,  0.5670,
          -0.2791],
         [-1.3975,  0.1773, -0.5528, -0.3466,  0.8521, -1.2662,  0.5178,
           0.7615],
         [-0.9814, -0.0427, -0.4114, -0.5481,  0.1414, -0.8279,  0.7115,
           0.4007]],

        [[-0.1078, -0.0194,  0.2184, -0.6093,  0.2187,  0.4831,  0.7040,
          -0.6380],
         [-0.0390, -0.6313, 

In [15]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [16]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.1006],
         [-0.6110],
         [ 0.0687]],

        [[ 0.1119],
         [-0.0587],
         [ 0.0269]],

        [[-0.2745],
         [-0.1568],
         [-0.1947]],

        [[ 0.0312],
         [-0.3674],
         [-0.0653]],

        [[ 0.4449],
         [ 0.3310],
         [-0.0167]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.2205],
         [1.1264],
         [0.8621]],

        [[0.9251],
         [1.1236],
         [1.0306]],

        [[0.8790],
         [0.8209],
         [0.5594]],

        [[0.4482],
         [0.3909],
         [0.6832]],

        [[0.5851],
         [0.7331],
         [0.9507]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-1.0791, -1.0681,  0.8738, -0.6099, -0.0800,  1.3367,  1.4468,
          -0.8202],
         [-1.2638, -0.5030,  1.4169,  1.3252,  0.5787, -0.6282, -1.2464,
           0.3206],
         [ 1.7470, -0.2305, -0.5604,  1.2810,  0.0323, -1.0704, -1.3324,
           0.1334]],



In [17]:
out[0].mean(), out[0].std()

(tensor(-2.2352e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))