In [1]:
import torch
import torch.nn as nn

In [2]:
outputs = torch.randn(3,6)
outputs

tensor([[-0.1869,  1.3040, -0.0892,  1.4863,  1.0535,  0.1018],
        [ 0.6618, -0.5140,  0.3077,  0.4659,  2.7684, -1.6090],
        [-0.6748,  1.8086,  1.4294,  1.0300, -0.1191, -0.4336]])

In [3]:
mean = outputs.mean( dim=-1, keepdim=True )
mean

tensor([[0.6116],
        [0.3468],
        [0.5067]])

In [4]:
sd = outputs.std( dim=-1, keepdim=True)
sd

tensor([[0.7521],
        [1.4522],
        [1.0480]])

In [5]:
normalized_outputs = (outputs - mean)/sd
normalized_outputs

tensor([[-1.0617,  0.9206, -0.9318,  1.1630,  0.5876, -0.6778],
        [ 0.2169, -0.5927, -0.0269,  0.0820,  1.6675, -1.3467],
        [-1.1274,  1.2422,  0.8804,  0.4993, -0.5972, -0.8972]])

In [6]:
normalized_outputs.mean( dim=-1, keepdim=True)

tensor([[1.9868e-08],
        [1.8626e-08],
        [3.4769e-08]])

In [7]:
normalized_outputs.std( dim=-1, keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000]])

In [8]:
torch.set_printoptions( sci_mode=False )

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift
