In [2]:
# The mean and standard-deviation are calculated over the last D dimensions, 
# where D is the dimension of normalized_shape. For example, if normalized_shape is (3, 5) 
# (a 2-dimensional shape), the mean and standard-deviation are computed over the last 2 dimensions of the input 
# (i.e. input.mean((-2, -1)))
# Note main difference from batchnorm -> batchnorm is applied across batch dim ie dim 0
# Refer https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
import torch
class LayerNorm1d:
  
  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)
  
  def __call__(self, x):
    # calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
    self.out = self.gamma * xhat + self.beta
    return self.out
  
  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors
x = module(x)
x.shape

torch.Size([32, 100])

In [5]:
x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs, should not be 0 mean, unit var

(tensor(0.1469), tensor(0.8803))

In [6]:
x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features

(tensor(-9.5367e-09), tensor(1.0000))