In [7]:
import torch
from torch import nn

In [8]:
x = torch.rand(32, 128)

print(f"Mean of Tensor: {torch.mean(x[0], dim=-1, keepdim=True)}")
print(f"Variance of Tensor: {torch.var(x[0], dim=-1, keepdim=True)}")

Mean of Tensor: tensor([0.4423])
Variance of Tensor: tensor([0.0820])


In [9]:
class LayerNorm(nn.Module):
  def __init__(self, emb_size):
    super().__init__()

    self.scale = nn.Parameter(torch.ones(emb_size))
    self.shift = nn.Parameter(torch.zeros(emb_size))

  def forward(self, x, epsilon = 1e-7):
    x_mean = torch.mean(x, dim=-1, keepdim=True)
    x_var = torch.var(x, dim=-1, keepdim=True, unbiased=False)
    x_normalized = (x - x_mean) / torch.sqrt(x_var + epsilon)

    return (x_normalized + self.shift)*self.scale

In [10]:
layernorm  = LayerNorm(128)
x_normalized = layernorm(x)

print(f"Mean of Normalized Tensor: {torch.mean(x_normalized[0], dim=-1, keepdim=True)}")
print(f"Variance of Normalized Tensor: {torch.var(x_normalized[0], dim=-1, keepdim=True)}")

Mean of Normalized Tensor: tensor([-1.1176e-08], grad_fn=<MeanBackward1>)
Variance of Normalized Tensor: tensor([1.0079], grad_fn=<VarBackward0>)


In [17]:
class BatchNorm(nn.Module):
  def __init__(self, emb_size):
    super().__init__()

    self.scale = nn.Parameter(torch.ones(emb_size))
    self.shift = nn.Parameter(torch.zeros(emb_size))

  def forward(self, x, epsilon = 1e-7):
    x_mean = torch.mean(x, dim=0, keepdim=True)
    x_var = torch.var(x, dim=0, keepdim=True, unbiased=False)
    x_normalized = (x - x_mean) / torch.sqrt(x_var + epsilon)

    return (x_normalized + self.shift)*self.scale

In [22]:
batchnorm  = BatchNorm(128)
x_normalized = batchnorm(x)

print(f"Mean of Normalized Tensor: {torch.mean(x_normalized, dim=0, keepdim=True)[0][0]}")
print(f"Variance of Normalized Tensor: {torch.var(x_normalized, dim=0, keepdim=True)[0][0]}")

Mean of Normalized Tensor: 8.568167686462402e-08
Variance of Normalized Tensor: 1.032257080078125
