In [1]:
import torch
from torch import nn

In [5]:
inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
print(inputs.size())
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([1, 2, 3])


torch.Size([2, 1, 3])

In [8]:
parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta =  nn.Parameter(torch.zeros(parameter_shape))

In [9]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [10]:
dims = [-(i + 1) for i in range(len(parameter_shape))]

In [11]:
dims

[-1, -2]

In [12]:
mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [13]:
mean

tensor([[[0.2000]],

        [[0.2333]]])

In [14]:
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
epsilon = 1e-5
std = (var + epsilon).sqrt()
std

tensor([[[0.0817]],

        [[0.1886]]])

In [15]:
y = (inputs - mean) / std
y

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])

In [16]:
out = gamma * y + beta

In [17]:
out

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)

## Class

In [18]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [19]:
batch_size = 3
sentence_length = 5
embedding_dim = 8 
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.4023,  1.5926, -0.1526, -0.3596,  0.5150, -1.1845,  1.1122,
           0.4230],
         [-0.3525, -0.6960, -0.1742, -0.4869,  0.1198, -0.0162, -0.7351,
           0.1737],
         [-0.1398,  2.0746, -1.2026, -0.6559, -1.9888,  0.6903,  0.4831,
           2.1596]],

        [[ 0.2647,  0.8568,  0.6843,  0.8339, -1.4749,  1.3304,  0.0936,
          -0.8010],
         [-0.3198, -0.2293, -1.0938, -0.5439, -0.6727, -1.2386,  2.2249,
          -0.0895],
         [ 0.0045,  1.3078, -1.8217, -1.3068, -0.3689,  0.6626,  0.3573,
           0.1493]],

        [[ 0.5129,  1.2459, -1.3325, -1.2644, -0.7750, -2.0900, -0.2529,
           1.7253],
         [ 1.9686, -0.5160,  0.7849, -0.0605, -0.1311, -2.0453, -0.1969,
          -0.5052],
         [-0.0689, -0.1974, -1.7485,  0.2388, -0.5273, -0.2787,  0.0732,
          -1.2348]],

        [[-0.3295, -2.3304,  0.8456,  0.4713, -0.6211, -0.9053,  1.0010,
           0.0509],
         [-0.4747, -0.9568, 

In [20]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [21]:
out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[ 0.1930],
         [-0.2709],
         [ 0.1775]],

        [[ 0.2235],
         [-0.2454],
         [-0.1270]],

        [[-0.2788],
         [-0.0877],
         [-0.4680]],

        [[-0.2272],
         [ 0.1399],
         [ 0.2817]],

        [[ 0.1381],
         [ 0.1228],
         [ 0.0478]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[0.8394],
         [0.3306],
         [1.3819]],

        [[0.8778],
         [1.0069],
         [0.9574]],

        [[1.2539],
         [1.0720],
         [0.6415]],

        [[1.0170],
         [0.6299],
         [0.5658]],

        [[1.0579],
         [0.6418],
         [0.5734]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[-0.7092,  1.6674, -0.4117, -0.6582,  0.3836, -1.6410,  1.0951,
           0.2740],
         [-0.2468, -1.2861,  0.2925, -0.6534,  1.1822,  0.7706, -1.4041,
           1.3450],
         [-0.2296,  1.3728, -0.9988, -0.6031, -1.5677,  0.3710,  0.2211,
           1.4343]],



In [22]:
out[0].mean(), out[0].std()

(tensor(1.9868e-08, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))