In [5]:
import torch
import torch.nn as nn

In [2]:
# Input Tensor: 4 matrices of 5 rows and 3 columns
X = torch.randint(0, 100, (4, 5, 3)).float()

# Shape to be Normalized: 5 rows, 3 columns
normalized_shape = (5, 3)

# Number of Dimensions in the Shape to be Normalized
D = len(normalized_shape)

# Set the Default Values for Epsilon, Gamma, and Beta
eps = 1e-5
gamma = torch.ones(normalized_shape)
beta = torch.zeros(normalized_shape)

X

tensor([[[51., 89., 14.],
         [ 5., 18., 69.],
         [45.,  9., 75.],
         [26., 21., 29.],
         [74., 44., 74.]],

        [[51., 58., 36.],
         [99., 57., 71.],
         [53., 92., 25.],
         [88., 16.,  0.],
         [61., 73., 14.]],

        [[57., 51., 14.],
         [83., 20., 70.],
         [75., 95., 58.],
         [46., 75., 62.],
         [97., 93., 71.]],

        [[77., 98., 79.],
         [33., 27., 91.],
         [79., 56., 51.],
         [20., 20., 32.],
         [84.,  3., 83.]]])

In [4]:
gamma.shape

torch.Size([5, 3])

In [3]:
# Normalize
for i in range(0,4):               # loop through each matrix
  mean = X[i].mean()               # mean         
  var = X[i].var(unbiased=False)   # variance
  layer_norm = (X[i]-mean)/(torch.sqrt(var+eps))*gamma + beta 

  print(f"μ = {mean:.4f}")            
  print(f"σ^{2} = {var:.4f}") 
  print(layer_norm)
  print("="*50)

μ = 42.8667
σ^2 = 722.7822
tensor([[ 0.3025,  1.7160, -1.0737],
        [-1.4085, -0.9249,  0.9721],
        [ 0.0794, -1.2597,  1.1952],
        [-0.6274, -0.8134, -0.5158],
        [ 1.1580,  0.0422,  1.1580]])
μ = 52.9333
σ^2 = 831.1289
tensor([[-0.0671,  0.1757, -0.5874],
        [ 1.5979,  0.1411,  0.6267],
        [ 0.0023,  1.3551, -0.9689],
        [ 1.2164, -1.2811, -1.8361],
        [ 0.2798,  0.6961, -1.3505]])
μ = 64.4667
σ^2 = 572.9155
tensor([[-0.3119, -0.5626, -2.1084],
        [ 0.7743, -1.8578,  0.2312],
        [ 0.4401,  1.2756, -0.2702],
        [-0.7715,  0.4401, -0.1031],
        [ 1.3592,  1.1921,  0.2730]])
μ = 55.5333
σ^2 = 891.3156
tensor([[ 0.7190,  1.4224,  0.7860],
        [-0.7548, -0.9557,  1.1880],
        [ 0.7860,  0.0156, -0.1518],
        [-1.1902, -1.1902, -0.7883],
        [ 0.9535, -1.7596,  0.9200]])


In [6]:
layer_normalization = nn.LayerNorm(normalized_shape) # nn.LayerNorm((5,3))
layer_normalization(X)

tensor([[[ 0.3025,  1.7160, -1.0737],
         [-1.4085, -0.9249,  0.9721],
         [ 0.0794, -1.2597,  1.1952],
         [-0.6274, -0.8134, -0.5158],
         [ 1.1580,  0.0422,  1.1580]],

        [[-0.0671,  0.1757, -0.5874],
         [ 1.5979,  0.1411,  0.6267],
         [ 0.0023,  1.3551, -0.9689],
         [ 1.2164, -1.2811, -1.8361],
         [ 0.2798,  0.6961, -1.3505]],

        [[-0.3119, -0.5626, -2.1084],
         [ 0.7743, -1.8578,  0.2312],
         [ 0.4401,  1.2756, -0.2702],
         [-0.7715,  0.4401, -0.1031],
         [ 1.3592,  1.1921,  0.2730]],

        [[ 0.7190,  1.4224,  0.7860],
         [-0.7548, -0.9557,  1.1880],
         [ 0.7860,  0.0156, -0.1518],
         [-1.1902, -1.1902, -0.7883],
         [ 0.9535, -1.7596,  0.9200]]], grad_fn=<NativeLayerNormBackward0>)

In [9]:
# Input Tensor: 2 sequences of 3 tokens with 5 dimensional embeddings
X = X = torch.randint(0, 100, (2, 3, 5)).float()

# Shape to be Normalized: 5 dimensional embedding
normalized_shape = (5,)

# Number of Dimensions in the Shape to be Normalized
D = len(normalized_shape) # 1

# Create the LayerNorm 
layer_normalization = nn.LayerNorm(normalized_shape)

# view the beta and gamma and beta
layer_normalization.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.]))])

In [10]:
X

tensor([[[89., 28., 76., 75.,  3.],
         [94.,  3., 64., 42., 73.],
         [52., 24., 42., 86., 70.]],

        [[ 6., 30., 22., 54., 15.],
         [60., 14., 42., 45.,  8.],
         [55., 75., 55., 55., 62.]]])

In [11]:
X.mean(2, keepdims=True) # maintains the dimensions of X

tensor([[[54.2000],
         [55.2000],
         [54.8000]],

        [[25.4000],
         [33.8000],
         [60.4000]]])