In [9]:
import torch
import torch.nn as nn

In [2]:
input_embeddings = torch.tensor([
    [0.43, 0.15, 0.89], # Your    -> x_0
    [0.55, 0.87, 0.66], # journey -> x_1
    [0.57, 0.85, 0.64], # starts  -> x_2
    [0.22, 0.58, 0.33], # with    -> x_3
    [0.77, 0.25, 0.10], # one     -> x_4
    [0.05, 0.80, 0.55], # step    -> x_5
])

In [4]:
# means and variances might be anything
# this leads to the vanishing and exploding gradient problem, thus making training difficult
mean = input_embeddings.mean(dim=-1, keepdim=True)
var = input_embeddings.var(dim=-1, keepdim=True)
display(mean)
display(var)

tensor([[0.4900],
        [0.6933],
        [0.6867],
        [0.3767],
        [0.3733],
        [0.4667]])

tensor([[0.1396],
        [0.0264],
        [0.0212],
        [0.0340],
        [0.1236],
        [0.1458]])

In [5]:
normalized_embeddings = (input_embeddings - mean) / torch.sqrt(var)

In [7]:
# one solution is to normalize layers - transform values so that mean is 0 and variance is 1
mean = normalized_embeddings.mean(dim=-1, keepdim=True)
var = normalized_embeddings.var(dim=-1, keepdim=True)
display(mean)
display(var)

tensor([[ 0.0000e+00],
        [-2.5332e-07],
        [-9.9341e-09],
        [ 9.9341e-09],
        [-3.9736e-08],
        [-6.4572e-08]])

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]])

In [10]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [12]:
ln = LayerNorm(emb_dim=input_embeddings.shape[-1])
out_ln = ln(input_embeddings)
mean = out_ln.mean(dim=-1, keepdim=True)
var = out_ln.var(dim=-1, unbiased=False, keepdim=True)
display(mean)
display(var)

tensor([[-3.9736e-08],
        [-3.1789e-07],
        [ 0.0000e+00],
        [ 9.9341e-09],
        [-1.9868e-08],
        [-3.9736e-08]], grad_fn=<MeanBackward1>)

tensor([[0.9999],
        [0.9994],
        [0.9993],
        [0.9996],
        [0.9999],
        [0.9999]], grad_fn=<VarBackward0>)