<a href="https://colab.research.google.com/github/fatemeh-ict/NLp/blob/main/layer_normalization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn

In [2]:
inputs=torch.Tensor([[[.2,.1,.3],[.5,.1,.1]]])
b,s,e=inputs.size()
print(b,s,e)

1 2 3


In [3]:
inputs=inputs.reshape(s,b,e)
print(inputs.size())

torch.Size([2, 1, 3])


In [4]:
parameter_shape=inputs.size()[-2:]
gamma=nn.Parameter(torch.ones(parameter_shape))
beta=nn.Parameter(torch.zeros(parameter_shape))
print(gamma.size())
print(beta.size())

torch.Size([1, 3])
torch.Size([1, 3])


In [5]:
dims=[-(i+1) for i in range(len(parameter_shape))]
print(dims)

[-1, -2]


In [6]:
means=inputs.mean(dim=dims,keepdim=True)
print(means)
print(means.size())

tensor([[[0.2000]],

        [[0.2333]]])
torch.Size([2, 1, 1])


In [8]:
var=((inputs-means)**2).mean(dim=dims,keepdim=True)
epsilon=1e-5
std=(var+epsilon).sqrt()
print(std)

tensor([[[0.0817]],

        [[0.1886]]])


In [10]:
y=(inputs-means)/std
print(y)

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]])


In [11]:
out=gamma*y+beta
print(out)

tensor([[[ 0.0000, -1.2238,  1.2238]],

        [[ 1.4140, -0.7070, -0.7070]]], grad_fn=<AddBackward0>)


In [12]:
import torch
from torch import nn

class LayerNormalization():
    def __init__(self, parameters_shape, eps=1e-5):
        self.parameters_shape=parameters_shape
        self.eps=eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta =  nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, input):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        means = inputs.mean(dim=dims, keepdim=True)
        print('means:' ,means)
        var = ((inputs - means) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        print('std:' , std)
        y = (inputs - means) / std
        print('y:',y)
        out = self.gamma * y  + self.beta
        print('out:', out)
        return out

In [14]:
batch_size = 3
sentence_length = 5
embedding_dim = 8
inputs = torch.randn(sentence_length, batch_size, embedding_dim)

print('inputs: ',inputs)
print(inputs.size())

inputs:  tensor([[[ 0.0874,  1.1189,  0.2708, -1.0413, -0.3140, -0.0485,  0.0634,
           0.0640],
         [-2.8542,  1.3534,  1.8117, -1.7698,  1.0013, -1.1117, -0.7677,
           1.2499],
         [-0.4510,  1.1840, -0.6403,  0.2035, -0.9923, -0.3421,  0.8990,
          -1.3165]],

        [[ 0.3243,  0.0428, -0.8750, -0.3396,  0.4660,  0.4735, -1.4888,
          -1.6862],
         [-0.7080,  1.6126, -0.6237,  0.4729, -0.4081,  0.9999, -2.7812,
          -0.5007],
         [ 2.0901, -0.9178,  1.3806,  0.5554, -1.8190,  0.1140, -0.5563,
           1.6962]],

        [[ 0.2977,  0.3517, -0.9572,  1.4232,  0.4644,  1.1520,  0.5046,
           1.8658],
         [-1.1295,  1.5063,  1.1001,  0.0392, -0.9724, -0.9391, -0.8830,
          -0.2965],
         [ 0.7376, -1.0246,  0.3616, -1.1200, -0.0225,  1.4643, -1.4382,
          -2.1789]],

        [[ 0.3788,  0.3097,  0.9465, -0.5448,  0.5773,  0.2146,  0.7069,
          -0.8370],
         [ 0.0252,  0.2673,  0.3239,  0.3168, -0.7451, 

In [15]:
layer_norm = LayerNormalization(inputs.size()[-1:])

In [16]:
out = layer_norm.forward(inputs)

means: tensor([[[ 0.0251],
         [-0.1359],
         [-0.1820]],

        [[-0.3854],
         [-0.2420],
         [ 0.3179]],

        [[ 0.6378],
         [-0.1968],
         [-0.4026]],

        [[ 0.2190],
         [ 0.0170],
         [ 0.2964]],

        [[ 0.2423],
         [ 0.4175],
         [ 0.4860]]])
std: tensor([[[0.5614],
         [1.6060],
         [0.8242]],

        [[0.8140],
         [1.2449],
         [1.2809]],

        [[0.8046],
         [0.9448],
         [1.1534]],

        [[0.5731],
         [0.4971],
         [0.9659]],

        [[0.9619],
         [0.6967],
         [1.2140]]])
y: tensor([[[ 0.1110,  1.9484,  0.4377, -1.8995, -0.6041, -0.1311,  0.0681,
           0.0693],
         [-1.6925,  0.9273,  1.2127, -1.0173,  0.7081, -0.6076, -0.3934,
           0.8629],
         [-0.3264,  1.6572, -0.5561,  0.4677, -0.9831, -0.1943,  1.3115,
          -1.3765]],

        [[ 0.8718,  0.5260, -0.6015,  0.0562,  1.0459,  1.0551, -1.3555,
          -1.5980],
      