In [2]:
import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm_x = torch.norm(x, dim=-1, keepdim=True)
        rms_x = norm_x * x.size(-1) ** (-0.5)
        x_normed = x / (rms_x + self.eps)
        return self.weight * x_normed


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        x_normed = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_normed + self.bias

rms_norm = RMSNorm(4)
layer_norm = LayerNorm(4)
x = torch.randn(1,4)
print(x)
print(rms_norm(x))
print(layer_norm(x))

tensor([[ 0.1806, -1.3203,  1.3125, -0.6420]])
tensor([[ 0.1826, -1.3353,  1.3274, -0.6493]], grad_fn=<MulBackward0>)
tensor([[ 0.3034, -1.2253,  1.4563, -0.5344]], grad_fn=<AddBackward0>)
