In [1]:
import torch
from torch import nn

In [4]:
"""
inputs has shape [1, 2, 3] → (Batch=1, Seq=2, Embedding=3)
You reshape it to [2, 1, 3] because in transformer convention:
S → sequence length
B → batch size
E → embedding size
So now each token (S=2) has an embedding vector of 3 numbers.
"""

inputs = torch.Tensor([[[0.2, 0.1, 0.3], [0.5, 0.1, 0.1]]])
B, S, E = inputs.size()
inputs = inputs.reshape(S, B, E)
inputs.size()

torch.Size([2, 1, 3])

In [7]:
"""
You create two learnable parameters:
gamma (scale) → starts with ones
beta (shift) → starts with zeros
Both have the same shape as the embedding: [1, 3]
These are the learnable weights used to rescale and shift normalized outputs.
"""

parameter_shape = inputs.size()[-2:]
gamma = nn.Parameter(torch.ones(parameter_shape))
beta = nn.Parameter(torch.zeros(parameter_shape))

In [6]:
gamma.size(), beta.size()

(torch.Size([1, 3]), torch.Size([1, 3]))

In [10]:
"""
This computes the dimensions along which you’ll normalize.
Since you want to normalize over the embedding dimensions, you take the last 2 axes.
"""

dims = [-(i + 1) for i in range(len(parameter_shape))]
dims

[-1, -2]

In [13]:
"""
Compute the mean for each token embedding.
For every token (S=2), you average its 3 embedding values.
Example: [0.2, 0.1, 0.3] → mean = 0.2.
"""

mean = inputs.mean(dim=dims, keepdim=True)
mean.size()

torch.Size([2, 1, 1])

In [14]:
"""
Compute variance → average of squared differences from the mean.
Then take square root → standard deviation (std).
Add small epsilon to avoid division by zero.
"""

epsilon = 1e-5
var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
std = (var + epsilon).sqrt()

In [18]:
"""
This normalizes the inputs:
𝑦 = 𝑥 − mean / std
Now each embedding vector has:
Mean ≈ 0
Std ≈ 1
"""

y = (inputs - mean) / std

In [19]:
"""
Multiply by gamma (scale)
Add beta (shift)
These make normalization learnable, so model can later adjust the normalized data.
"""

out = gamma * y + beta

In [20]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = inputs.mean(dim=dims, keepdim=True)
        print(f"Mean \n ({mean.size()}): \n {mean}")
        var = ((inputs - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print(f"Standard Deviation \n ({std.size()}): \n {std}")
        y = (inputs - mean) / std
        print(f"y \n ({y.size()}) = \n {y}")
        out = self.gamma * y  + self.beta
        print(f"out \n ({out.size()}) = \n {out}")
        return out

In [21]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print(f"input \n ({inputs.size()}) = \n {inputs}")

input 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.5531e+00, -1.3203e+00, -8.9852e-01, -3.8968e-01,  1.1432e+00,
          -1.4133e-01, -2.1831e+00, -8.0777e-01],
         [ 1.0404e-01,  3.7712e-01, -5.0668e-01, -1.5099e+00, -6.1259e-01,
          -3.5395e-01,  2.1064e-01, -3.5194e-01],
         [-1.0980e-01,  1.6165e+00,  1.1620e+00, -1.8649e+00, -2.9045e-01,
          -1.5426e+00,  1.3990e+00,  1.2064e+00]],

        [[-4.8764e-01,  1.1285e+00,  1.6624e-02,  9.1445e-01, -8.3786e-01,
           1.0709e+00,  7.7118e-01, -2.6322e+00],
         [-1.4271e+00,  2.1251e-01,  8.6352e-01, -5.5403e-01, -1.0171e+00,
           7.4149e-01, -6.4174e-01, -1.2790e+00],
         [ 1.3430e+00,  2.0475e+00,  2.3823e-01,  2.7290e-01,  8.6858e-01,
           2.1184e-02, -1.1507e+00,  1.7641e-01]],

        [[-9.3440e-01, -2.8756e-01, -3.2899e-01,  1.3814e+00,  1.0341e-03,
          -7.5865e-01,  6.4666e-02,  8.6306e-01],
         [ 4.6873e-01,  1.9462e+00, -4.6005e-01, -2.1770e-01, -9.4085e-01,
          

In [23]:
layer_norm = LayerNormalization(inputs.size()[-1:])

out = layer_norm.forward(inputs)

Mean 
 (torch.Size([5, 3, 1])): 
 tensor([[[-3.8054e-01],
         [-3.3040e-01],
         [ 1.9702e-01]],

        [[-6.9945e-03],
         [-3.8768e-01],
         [ 4.7713e-01]],

        [[ 6.4559e-05],
         [ 2.0969e-02],
         [-4.3356e-02]],

        [[-3.4355e-01],
         [-8.6600e-02],
         [-2.2516e-01]],

        [[ 1.6503e-01],
         [ 4.4278e-01],
         [-4.4705e-01]]])
Standard Deviation 
 (torch.Size([5, 3, 1])): 
 tensor([[[1.1566],
         [0.5573],
         [1.2757]],

        [[1.2103],
         [0.8340],
         [0.8955]],

        [[0.7329],
         [1.3986],
         [0.7239]],

        [[0.7276],
         [1.0137],
         [0.6666]],

        [[1.0722],
         [0.8589],
         [1.1389]]])
y 
 (torch.Size([5, 3, 8])) = 
 tensor([[[ 1.6718e+00, -8.1251e-01, -4.4786e-01, -7.8967e-03,  1.3175e+00,
           2.0683e-01, -1.5585e+00, -3.6939e-01],
         [ 7.7950e-01,  1.2695e+00, -3.1628e-01, -2.1162e+00, -5.0632e-01,
          -4.2255e-02

In [24]:
out[0].mean(), out[0].std()

(tensor(6.2088e-09, grad_fn=<MeanBackward0>),
 tensor(1.0215, grad_fn=<StdBackward0>))