- basics
    - dnn 的标准组件，稳定和加速训练过程
    - `batch_size, seq_len, hidden_size = 2, 3, 4`
    - Batch Norm: reduce cross batch size
        - mini-batch dimension
        - 一般用于图像，不涉及到padding的问题；
    - Layer Norm: reduce cross hidden dim
        - 一般用于序列，一个 batch size 内存在 padding；
        - across the feature dimension.
        - RMSNorm: 对 LN 的一种变体，llama
    - residual connection + norm
        - https://spaces.ac.cn/archives/9009
        - Pre LN: `llama`
        - Post LN: `attention is all you need`
- 理解高维 tensor 的 shape （axis/dim）处理
    - 最终的输出都不改变 shape
- 理解这三个 norm 的计算过程

In [1]:
import torch
from torch import nn
import transformers
from transformers import AutoModelForCausalLM

from IPython.display import Image
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

torch.manual_seed(42)

<torch._C.Generator at 0x736c404e39f0>

In [2]:
# 3d tensor
batch_size, seq_len, hidden_size = 2, 3, 4
x = torch.randn(batch_size, seq_len, hidden_size)

In [3]:
x

tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055],
         [ 0.6784, -1.2345, -0.0431, -1.6047],
         [ 0.3559, -0.6866, -0.4934,  0.2415]],

        [[-1.1109,  0.0915, -2.3169, -0.2168],
         [-0.3097, -0.3957,  0.8034, -0.6216],
         [-0.5920, -0.0631, -0.8286,  0.3309]]])

### BN

x: (a, b, c)
- mean/var 都是 reduce 的过程；
- mean(dim=a) => 返回的 shape 是 (b, c)
    - mean(dim=a, keepdim=True) => (1, b, c): 方便 broadcast
    - einsum: `ijk->jk`
- mean(dim=(a, b)) => 返回的 shape 是 (c)
    - mean(dim=(a, b), keepdim=True) => (1, 1, c): 方便 broadcast
    - einsum: `ijk->k`

In [4]:
bn = nn.BatchNorm1d(hidden_size)

In [5]:
bn(x)

RuntimeError: running_mean should contain 3 elements not 4

In [5]:
bn(x.transpose(1, 2)).transpose(1, 2)

tensor([[[ 1.7943,  1.9214,  1.1306, -1.5846],
         [ 0.5278, -1.3052,  0.2633, -1.0345],
         [ 0.2006, -0.6557, -0.1505,  0.9930]],

        [[-1.2873,  0.2668, -1.8263,  0.4897],
         [-0.4746, -0.3108,  1.0412,  0.0451],
         [-0.7609,  0.0835, -0.4585,  1.0912]]], grad_fn=<TransposeBackward0>)

In [6]:
# 2*3*4 => 2*4*3
x.transpose(1, 2)

tensor([[[ 1.9269,  0.6784,  0.3559],
         [ 1.4873, -1.2345, -0.6866],
         [ 0.9007, -0.0431, -0.4934],
         [-2.1055, -1.6047,  0.2415]],

        [[-1.1109, -0.3097, -0.5920],
         [ 0.0915, -0.3957, -0.0631],
         [-2.3169,  0.8034, -0.8286],
         [-0.2168, -0.6216,  0.3309]]])

In [7]:
mean = x.transpose(1, 2).mean(dim=(0, 2), keepdim=True)
mean

tensor([[[ 0.1581],
         [-0.1335],
         [-0.3296],
         [-0.6627]]])

In [37]:
# 序列的每个位置算均值
torch.einsum('ijk->k', x) / (x.shape[0] * x.shape[1])

tensor([ 0.1581, -0.1335, -0.3296, -0.6627])

In [8]:
# x.transpose(1, 2).mean(dim=0).mean(dim=1)
x.transpose(1, 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True)

tensor([[[ 0.1581],
         [-0.1335],
         [-0.3296],
         [-0.6627]]])

In [32]:
# 2*4*3 => 4*6
x.transpose(1, 2).transpose(0, 1).reshape(4, -1)

tensor([[ 1.9269,  0.6784,  0.3559, -1.1109, -0.3097, -0.5920],
        [ 1.4873, -1.2345, -0.6866,  0.0915, -0.3957, -0.0631],
        [ 0.9007, -0.0431, -0.4934, -2.3169,  0.8034, -0.8286],
        [-2.1055, -1.6047,  0.2415, -0.2168, -0.6216,  0.3309]])

In [33]:
x.transpose(1, 2).transpose(0, 1).reshape(4, -1).mean(dim=-1)

tensor([ 0.1581, -0.1335, -0.3296, -0.6627])

In [8]:
var = x.transpose(1, 2).var(dim=(0, 2), keepdim=True, unbiased=False)
var

tensor([[[0.9717],
         [0.7116],
         [1.1841],
         [0.8291]]])

In [9]:
((x.transpose(1, 2) - mean) / torch.sqrt(var + bn.eps)).transpose(1, 2)

tensor([[[ 1.7943,  1.9214,  1.1306, -1.5846],
         [ 0.5278, -1.3052,  0.2633, -1.0345],
         [ 0.2006, -0.6557, -0.1505,  0.9930]],

        [[-1.2873,  0.2668, -1.8263,  0.4897],
         [-0.4746, -0.3108,  1.0412,  0.0451],
         [-0.7609,  0.0835, -0.4585,  1.0912]]])

In [10]:
(x - x.mean(dim=(0, 1), keepdim=True)) / torch.sqrt(x.var(dim=(0, 1), unbiased=False, keepdim=True) + bn.eps)

tensor([[[ 1.7943,  1.9214,  1.1306, -1.5846],
         [ 0.5278, -1.3052,  0.2633, -1.0345],
         [ 0.2006, -0.6557, -0.1505,  0.9930]],

        [[-1.2873,  0.2668, -1.8263,  0.4897],
         [-0.4746, -0.3108,  1.0412,  0.0451],
         [-0.7609,  0.0835, -0.4585,  1.0912]]])

### LN

- element-wise operation

    $$
    y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
    $$

- 关于 $\mathrm{Var}[x]$
  
    $$
    \mathrm{Var}[x] = \mathbb{E}[(x - \mathbb{E}[x])^2]
    $$

In [11]:
ln = nn.LayerNorm(hidden_size)
ln_out = ln(x)
ln_out

tensor([[[ 0.8716,  0.5928,  0.2209, -1.6853],
         [ 1.3440, -0.7473,  0.5552, -1.1519],
         [ 1.1111, -1.1985, -0.7703,  0.8577]],

        [[-0.2380,  1.0472, -1.5270,  0.7177],
         [-0.3243, -0.4803,  1.6947, -0.8900],
         [-0.6717,  0.4977, -1.1947,  1.3688]]],
       grad_fn=<NativeLayerNormBackward0>)

In [12]:
ln

LayerNorm((4,), eps=1e-05, elementwise_affine=True)

In [13]:
x[0, 0, :]

tensor([ 1.9269,  1.4873,  0.9007, -2.1055])

In [14]:
(x[0, 0, :] - torch.mean(x[0, 0, :])) / torch.sqrt(torch.var(x[0, 0, :], unbiased=False) + ln.eps)

tensor([ 0.8716,  0.5928,  0.2209, -1.6853])

In [15]:
(x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, unbiased=False, keepdim=True) + ln.eps)

tensor([[[ 0.8716,  0.5928,  0.2209, -1.6853],
         [ 1.3440, -0.7473,  0.5552, -1.1519],
         [ 1.1111, -1.1985, -0.7703,  0.8577]],

        [[-0.2380,  1.0472, -1.5270,  0.7177],
         [-0.3243, -0.4803,  1.6947, -0.8900],
         [-0.6717,  0.4977, -1.1947,  1.3688]]])

In [38]:
(x - x.mean(dim=-1, keepdim=True))

tensor([[[ 1.3746,  0.9349,  0.3484, -2.6579],
         [ 1.2294, -0.6836,  0.5079, -1.0537],
         [ 0.5015, -0.5410, -0.3477,  0.3871]],

        [[-0.2226,  0.9798, -1.4287,  0.6715],
         [-0.1788, -0.2648,  0.9343, -0.4907],
         [-0.3038,  0.2251, -0.5404,  0.6191]]])

In [42]:
x - torch.einsum('ijk->ij', x).unsqueeze(2) / x.shape[-1]

tensor([[[ 1.3746,  0.9349,  0.3484, -2.6579],
         [ 1.2294, -0.6836,  0.5079, -1.0537],
         [ 0.5015, -0.5410, -0.3477,  0.3871]],

        [[-0.2226,  0.9798, -1.4287,  0.6715],
         [-0.1788, -0.2648,  0.9343, -0.4907],
         [-0.3038,  0.2251, -0.5404,  0.6191]]])

### RMSNorm


$$
y = \frac{x}{\sqrt{\frac1n\sum_ix_i^2+\epsilon}} * \gamma
$$

- https://github.com/meta-llama/llama3/blob/main/llama/model.py#L35C7-L46
- 存在 learnable parameter

In [16]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

In [17]:
rms = RMSNorm(dim=4)
rms(x)

tensor([[[ 1.1531,  0.8900,  0.5390, -1.2600],
         [ 0.6353, -1.1561, -0.0403, -1.5027],
         [ 0.7503, -1.4477, -1.0402,  0.5092]],

        [[-0.8611,  0.0710, -1.7959, -0.1681],
         [-0.5466, -0.6983,  1.4178, -1.0970],
         [-1.1039, -0.1176, -1.5450,  0.6170]]], grad_fn=<MulBackward0>)

In [18]:
x / torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + 1e-6)

tensor([[[ 1.1531,  0.8900,  0.5390, -1.2600],
         [ 0.6353, -1.1561, -0.0403, -1.5027],
         [ 0.7503, -1.4477, -1.0402,  0.5092]],

        [[-0.8611,  0.0710, -1.7959, -0.1681],
         [-0.5466, -0.6983,  1.4178, -1.0970],
         [-1.1039, -0.1176, -1.5450,  0.6170]]])

#### llama3 rmsnorm

In [19]:
llama3_id = "meta-llama/Meta-Llama-3-8B"
llama3 = AutoModelForCausalLM.from_pretrained(llama3_id, torch_dtype=torch.bfloat16, device_map='auto')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.


In [20]:
llama3

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [21]:
llama3.config

LlamaConfig {
  "_name_or_path": "meta-llama/Meta-Llama-3-8B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.45.0.dev0",
  "use_cache": true,
  "vocab_size": 128256
}

```
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out
```

In [22]:
llama3.model.layers[0].input_layernorm.weight

Parameter containing:
tensor([0.0537, 0.2090, 0.4492,  ..., 0.0859, 0.0437, 0.0292], device='cuda:0',
       dtype=torch.bfloat16, requires_grad=True)