In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### LayerNorm

In [None]:

class MyLayerNorm(nn.Module):
    def __init__(self, features, eps=1e-5, elementwise_affine=True):
        super(MyLayerNorm, self).__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.a_2 = nn.Parameter(torch.ones(features))
            self.b_2 = nn.Parameter(torch.zeros(features))
        else:
            self.a_2 = torch.ones(features)
            self.b_2 = torch.zeros(features)
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        if self.elementwise_affine:
            nn.init.ones_(self.a_2)
            nn.init.zeros_(self.b_2)
    
    
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1,keepdim=True)
        return  self.a_2 * (x - mean) / (std + self.eps) + self.b_2


# test
N, S, H = 2, 2, 10
input = torch.randn(N, S, H)

layer_norm_op = nn.LayerNorm([N, S, H], elementwise_affine=True) 
ln_y = layer_norm_op(input)

my_layer_norm = MyLayerNorm(input.size(), elementwise_affine=True)
verify_ln_y = my_layer_norm(input)

print(ln_y)
print(verify_ln_y)