$$
\begin{split}
x_1 &= x_0 - \text{mean}(x_0)\\
x_2 &= \frac{x_1}{\sqrt{\text{mean}(x_1^2)}}\\
x_3 &= x_2 \cdot w\\
x_4 &= x_3 + b\\
\end{split}
$$

- Unlike BatchNorm, it cannot be turned off at inference time, as it significantly alters the mathematical function implemented by the transformer.
- 定义在 feature 维度，样本甚至是 token 级别，而非 batchnorm 的跨样本；

In [41]:
import torch
from torch import nn
torch.manual_seed(42)

<torch._C.Generator at 0x774eeaaa5690>

In [42]:
batch, sentence_length, embedding_dim = 2, 3, 4

In [43]:
embedding = torch.randint(0, 5, (batch, sentence_length, embedding_dim)).float()
embedding

tensor([[[2., 2., 1., 4.],
         [1., 0., 0., 4.],
         [0., 3., 3., 4.]],

        [[0., 4., 1., 2.],
         [0., 0., 2., 1.],
         [4., 1., 3., 1.]]])

In [44]:
ln = nn.LayerNorm(embedding_dim)

In [45]:
list(ln.parameters())

[Parameter containing:
 tensor([1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0.], requires_grad=True)]

In [50]:
ln.weight, ln.bias

(Parameter containing:
 tensor([1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0.], requires_grad=True))

In [54]:
ln(embedding)

tensor([[[-0.2294, -0.2294, -1.1471,  1.6059],
         [-0.1525, -0.7625, -0.7625,  1.6775],
         [-1.6667,  0.3333,  0.3333,  1.0000]],

        [[-1.1832,  1.5213, -0.5071,  0.1690],
         [-0.9045, -0.9045,  1.5075,  0.3015],
         [ 1.3471, -0.9622,  0.5773, -0.9622]]],
       grad_fn=<NativeLayerNormBackward0>)

In [56]:
x = embedding[0, 0, :]
x1 = x - torch.mean(x)
x2 = x1 / torch.sqrt(torch.var(x1, unbiased=False))
x3 = x2 * ln.weight + ln.bias
x3

tensor([-0.2294, -0.2294, -1.1471,  1.6059], grad_fn=<AddBackward0>)

In [61]:
gamma = ln.weight
beta = ln.bias

In [62]:
normalized = (embedding - embedding.mean(dim=-1, keepdim=True)) / torch.sqrt(embedding.var(dim=-1, unbiased=False, keepdim=True) + ln.eps) 
gamma * normalized + beta

tensor([[[-0.2294, -0.2294, -1.1471,  1.6059],
         [-0.1525, -0.7625, -0.7625,  1.6775],
         [-1.6667,  0.3333,  0.3333,  1.0000]],

        [[-1.1832,  1.5213, -0.5071,  0.1690],
         [-0.9045, -0.9045,  1.5075,  0.3015],
         [ 1.3471, -0.9622,  0.5773, -0.9622]]], grad_fn=<AddBackward0>)