In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x75a9189179b0>

In [3]:
d = 4  # 模型维度
B = 2
T = 3
h = 2    # 多头注意力中的头数
ff = 8  # 前馈网络的维度

In [4]:
X = torch.randn(T, B, d)  # [seq_len, batch_size, d_model]
X.shape

torch.Size([3, 2, 4])

In [5]:
encoder = nn.TransformerEncoderLayer(d, h, ff, dropout=0.0)

In [6]:
W_in, b_in = encoder.self_attn.in_proj_weight, encoder.self_attn.in_proj_bias
# (3d, d), (3d, )
W_in.shape, b_in.shape

(torch.Size([12, 4]), torch.Size([12]))

In [7]:
encoder(X).shape

torch.Size([3, 2, 4])

### 张量内存 （tensor memory layout）

[[pytorch] Tensor shape 变化 view 与 reshape（contiguous 的理解）)](https://www.bilibili.com/video/BV1Zw411y7Ks/)

https://stackoverflow.com/questions/26998223/what-is-the-difference-between-contiguous-and-non-contiguous-arrays

- 张量内存 layout
    - 在大多数深度学习框架（如 PyTorch）中，（高维）张量的数据是以**一维数组**的形式在内存中**连续存储**的。对于多维张量，其高维结构是通过一维内存数组和步幅（strides）来实现的。
    - pytorch 存储顺序（Storage Order）是 Row-major，最后一个维度变化最快。
- Strides（步幅）
    - 对于一个形状为 $(D_0,D_1,D_2)$（`2*3*4`） 的 3D 张量，其步幅计算如下：
        - $\text{stride[2]} = 1$
        - $\text{stride[1]} = D_2\times \text{stride[2]}=D_2$ （3）
        - $\text{stride[0]} = D_1\times \text{stride[1]}=D_1\times D_2$（3*4=12）

In [45]:
A = torch.randint(0, 5, (2, 3, 4))
A

tensor([[[0, 0, 3, 0],
         [3, 3, 1, 1],
         [0, 3, 1, 4]],

        [[1, 1, 0, 2],
         [4, 1, 1, 0],
         [4, 1, 0, 3]]])

In [46]:
A[1]

tensor([[1, 1, 0, 2],
        [4, 1, 1, 0],
        [4, 1, 0, 3]])

- 当张量在内存中的数据排列不再符合其形状和步幅之间的默认关系时，张量就是非连续的（is not contiguous）。
    - 特征：`.is_contiguous()` 方法返回 False。
    - 影响：某些操作在非连续张量上可能性能较差，或者需要额外的内存拷贝。
    - 解决方法：使用 .contiguous() 方法，将张量拷贝为内存中连续的版本。
- 什么样的操作会导致内存的不连续
    - permute, transpose, view;
        - transpose 是 permute 的特例，transpose 只允许交换两个维度。
- `reshape` 不改变内存中的数据顺序


In [11]:
A = torch.randn(3, 4)
A.shape, A.stride(), A.is_contiguous()

(torch.Size([3, 4]), (4, 1), True)

In [12]:
A = A.transpose(0, 1)
A.shape, A.stride(), A.is_contiguous()

(torch.Size([4, 3]), (1, 4), False)

In [13]:
A = A.contiguous()
A.shape, A.stride(), A.is_contiguous()

(torch.Size([4, 3]), (3, 1), True)

### view vs. reshape

- view（类比sql中的概念）
    - 不会复制数据：view 创建的是原始张量的一个新的视图，内存数据保持不变，只是重新解释内存中的数据。因为它依赖于张量的内存布局，所以无法对非连续的张量使用。
    - 不连续内存，view 时有可能报错；
- reshape: 它会自动处理非连续张量，尽可能返回视图，如果无法返回视图，则会拷贝
    - 不要求内存连续：reshape 可以用于非连续的张量。如果张量不连续，reshape 会自动尝试创建一个新的连续张量并复制数据，以确保能够完成形状转换。
    - 可能复制数据：当张量是内存不连续的，reshape 可能会进行数据复制，生成一个新的内存布局的张量。否则，它和 view 的行为是一样的，不复制数据。
    - 不改变数据在内存中的顺序，只改变张量的形状解释。

In [14]:
A = torch.randn(2, 3, 4)
A.view(-1, 4).shape, A.view(-1, 4).stride(), A.view(-1, 4).is_contiguous(), 

(torch.Size([6, 4]), (4, 1), True)

In [49]:
A = torch.randn(2, 3, 4)
A_t = A.permute(1, 2, 0)
A.stride(), A_t.shape, A_t.stride(), A_t.is_contiguous()

((12, 4, 1), torch.Size([3, 4, 2]), (4, 1, 12), False)

In [16]:
A_t

tensor([[[-1.3793, -0.5572],
         [ 0.6258, -0.9683],
         [-2.5850,  0.8713],
         [-0.0240, -0.0956]],

        [[-0.1222,  0.3463],
         [-0.7470, -0.5402],
         [ 1.7093,  0.8569],
         [ 0.0579, -0.6721]],

        [[ 0.5230,  1.0682],
         [ 0.9717, -0.2527],
         [-0.2779, -0.1882],
         [-0.6116, -0.7712]]])

In [17]:
A_t.view(-1, 2)

tensor([[-1.3793, -0.5572],
        [ 0.6258, -0.9683],
        [-2.5850,  0.8713],
        [-0.0240, -0.0956],
        [-0.1222,  0.3463],
        [-0.7470, -0.5402],
        [ 1.7093,  0.8569],
        [ 0.0579, -0.6721],
        [ 0.5230,  1.0682],
        [ 0.9717, -0.2527],
        [-0.2779, -0.1882],
        [-0.6116, -0.7712]])

In [18]:
A_t.view(-1, 2)

tensor([[-1.3793, -0.5572],
        [ 0.6258, -0.9683],
        [-2.5850,  0.8713],
        [-0.0240, -0.0956],
        [-0.1222,  0.3463],
        [-0.7470, -0.5402],
        [ 1.7093,  0.8569],
        [ 0.0579, -0.6721],
        [ 0.5230,  1.0682],
        [ 0.9717, -0.2527],
        [-0.2779, -0.1882],
        [-0.6116, -0.7712]])

In [19]:
A_t.reshape(-1, 4), A_t.reshape(-1, 4).is_contiguous()

(tensor([[-1.3793, -0.5572,  0.6258, -0.9683],
         [-2.5850,  0.8713, -0.0240, -0.0956],
         [-0.1222,  0.3463, -0.7470, -0.5402],
         [ 1.7093,  0.8569,  0.0579, -0.6721],
         [ 0.5230,  1.0682,  0.9717, -0.2527],
         [-0.2779, -0.1882, -0.6116, -0.7712]]),
 True)

### encoder recap

https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html

- input: $\mathbf{X} \in \mathbb{R}^{T \times B \times d_{\text{model}}}$ （`batch_first=False`）
- multihead selfattn
    - 线性变换（linear projection, 矩阵乘法）生成 Q、K、V矩阵
        - $X_{\text{flat}}=\mathbf X.\text{reshape}(T\times B,d_{model})$ (3d -> 2d)
        - $\mathbf{QKV}=\mathbf X\mathbf W_{in}^T+\mathbf b_{in}$（`encoder_layer.self_attn.in_proj_weight`, `encoder_layer.self_attn.in_proj_bias`）
            - $\mathbf{W}_{in} \in \mathbb{R}^{3d_{\text{model}} \times d_{\text{model}}}$，$\mathbf{b}_{in} \in \mathbb{R}^{3d_{\text{model}}}$
            - $\mathbf{QKV}\in \mathbb R^{T\times B,3d_{model}}$
    - 拆分 $\mathbf Q, \mathbf K,\mathbf V$
        - $\mathbf Q, \mathbf K,\mathbf V=\text{split}(\mathbf{QKV},d_{model})$（按列进行拆分）
        - $\mathbf Q, \mathbf K,\mathbf V\in \mathbb R^{T \times B, d_{\text{model}}}$
    - 调整形状以适应多头注意力
        - $d_k = \frac{d_{\text{model}}}h$ （4/2 = 2）
        - `reshape_for_heads`
        $$
        \begin{align*}
            \mathbf{Q}_{\text{heads}} &= \mathbf{Q}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
            \mathbf{K}_{\text{heads}} &= \mathbf{K}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
            \mathbf{V}_{\text{heads}} &= \mathbf{V}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k)
        \end{align*}
        $$
    - 计算注意力分数：$\text{Scores} = \frac{\mathbf{Q}_{\text{heads}} \mathbf{K}_{\text{heads}}^\top}{\sqrt{d_k}}$
        - $\mathbf{Q}_{\text{heads}} \in \mathbb{R}^{(B \times h) \times T \times d_k}$，$\mathbf{K}_{\text{heads}}^\top \in \mathbb{R}^{(B \times h) \times d_k \times T}$，因此 $\text{Scores} \in \mathbb{R}^{(B \times h) \times T \times T}$。
    - 计算注意力权重：$\text{AttentionWeights}=\text{softmax}(\text{Scores})$
    - 计算注意力输出：$\text{AttentionOutput}=\text{AttentionWeights}\times{\mathbf V_\text{heads}}$
        - $\mathbf{V}_{\text{heads}} \in \mathbb{R}^{(B \times h) \times T \times d_k}$，因此 $\text{AttentionOutput} \in \mathbb{R}^{(B \times h) \times T \times d_k}$。
    - 合并多头输出：$\text{AttentionOutput} = \text{AttentionOutput}.\text{reshape}(B, h, T, d_k).\text{permute}(2, 0, 1, 3).\text{reshape}(T, B, d_{\text{model}})$

### qkv, mhsa

- $X_{\text{flat}}=\mathbf X.\text{reshape}(T\times B,d_{model})$

In [20]:
X.shape, X

(torch.Size([3, 2, 4]),
 tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047]],
 
         [[ 0.3559, -0.6866, -0.4934,  0.2415],
          [-1.1109,  0.0915, -2.3169, -0.2168]],
 
         [[-0.3097, -0.3957,  0.8034, -0.6216],
          [-0.5920, -0.0631, -0.8286,  0.3309]]]))

In [21]:
X_flat = X.reshape(-1, d)
# (T*B, d)
X_flat.shape

torch.Size([6, 4])

- $\mathbf{QKV}=\mathbf X_{\text{flat}}\mathbf W_{in}^T+\mathbf b_{in}$

In [22]:
QKV = F.linear(X_flat, W_in, b_in)
QKV.shape

torch.Size([6, 12])

In [23]:
QKV

tensor([[ 0.0668,  1.3935, -1.1805,  0.8455, -1.2987, -0.4164,  1.8174,  0.1275,
          1.6109,  1.8252,  0.4149, -0.3524],
        [ 0.8724,  0.8454, -0.1481,  1.8186, -1.4116,  0.6840,  1.1799, -0.1952,
          0.2774,  1.5957,  0.2900,  1.0896],
        [ 0.1566, -0.1838, -0.1016,  0.7407, -0.3022, -0.0460, -0.0557, -0.3397,
         -0.0158,  0.3054, -0.2973,  0.4672],
        [-0.4729, -1.3182, -0.9481,  0.8103, -0.1380, -0.3860, -1.1347,  0.0943,
          0.7324, -0.0104,  0.6872,  1.2226],
        [ 0.4362,  0.6260,  0.5733, -0.1727, -0.1107,  0.6780,  0.5196,  0.2229,
         -0.4313,  0.1020,  0.2595, -0.0467],
        [-0.2227, -0.6440, -0.1496,  0.0278,  0.2149, -0.1376, -0.6585,  0.0122,
          0.0174, -0.3351,  0.1178,  0.3368]], grad_fn=<AddmmBackward0>)

In [24]:
Q, K, V = QKV.split(d, dim=1)
Q.shape, K.shape, V.shape

(torch.Size([6, 4]), torch.Size([6, 4]), torch.Size([6, 4]))

In [25]:
K

tensor([[-1.2987, -0.4164,  1.8174,  0.1275],
        [-1.4116,  0.6840,  1.1799, -0.1952],
        [-0.3022, -0.0460, -0.0557, -0.3397],
        [-0.1380, -0.3860, -1.1347,  0.0943],
        [-0.1107,  0.6780,  0.5196,  0.2229],
        [ 0.2149, -0.1376, -0.6585,  0.0122]], grad_fn=<SplitBackward0>)

In [26]:
# 调整Q、K、V的形状以适应多头注意力
d_k = d // h  # 每个头的维度
def reshape_for_heads(x):
    # x.shape: (T*B, h*d_k)
    # 最末尾的维度上展开，d => h * d_k
    # (T*B, h, d_k) => (T, B, h, d_k)
    # permute(1, 2, 0, 3) => (B, h, T, d_k)
    print(x.shape, x.is_contiguous())
    y = x.contiguous().view(T, B, h, d_k).permute(1, 2, 0, 3).reshape(B * h, T, d_k)
    print(y.shape)
    return y

In [27]:
Q

tensor([[ 0.0668,  1.3935, -1.1805,  0.8455],
        [ 0.8724,  0.8454, -0.1481,  1.8186],
        [ 0.1566, -0.1838, -0.1016,  0.7407],
        [-0.4729, -1.3182, -0.9481,  0.8103],
        [ 0.4362,  0.6260,  0.5733, -0.1727],
        [-0.2227, -0.6440, -0.1496,  0.0278]], grad_fn=<SplitBackward0>)

In [28]:
Q.contiguous().view(T, B, h, d_k)

tensor([[[[ 0.0668,  1.3935],
          [-1.1805,  0.8455]],

         [[ 0.8724,  0.8454],
          [-0.1481,  1.8186]]],


        [[[ 0.1566, -0.1838],
          [-0.1016,  0.7407]],

         [[-0.4729, -1.3182],
          [-0.9481,  0.8103]]],


        [[[ 0.4362,  0.6260],
          [ 0.5733, -0.1727]],

         [[-0.2227, -0.6440],
          [-0.1496,  0.0278]]]], grad_fn=<ViewBackward0>)

In [29]:
# T, B, h, d_k => (B, h, T, d_k)
Q.contiguous().view(T, B, h, d_k).permute(1, 2, 0, 3)

tensor([[[[ 0.0668,  1.3935],
          [ 0.1566, -0.1838],
          [ 0.4362,  0.6260]],

         [[-1.1805,  0.8455],
          [-0.1016,  0.7407],
          [ 0.5733, -0.1727]]],


        [[[ 0.8724,  0.8454],
          [-0.4729, -1.3182],
          [-0.2227, -0.6440]],

         [[-0.1481,  1.8186],
          [-0.9481,  0.8103],
          [-0.1496,  0.0278]]]], grad_fn=<PermuteBackward0>)

In [30]:
Q.contiguous().view(T, B, h, d_k).permute(1, 2, 0, 3).is_contiguous()

False

In [31]:
Q = reshape_for_heads(Q)
K = reshape_for_heads(K)
V = reshape_for_heads(V)

torch.Size([6, 4]) False
torch.Size([4, 3, 2])
torch.Size([6, 4]) False
torch.Size([4, 3, 2])
torch.Size([6, 4]) False
torch.Size([4, 3, 2])


In [32]:
Q

tensor([[[ 0.0668,  1.3935],
         [ 0.1566, -0.1838],
         [ 0.4362,  0.6260]],

        [[-1.1805,  0.8455],
         [-0.1016,  0.7407],
         [ 0.5733, -0.1727]],

        [[ 0.8724,  0.8454],
         [-0.4729, -1.3182],
         [-0.2227, -0.6440]],

        [[-0.1481,  1.8186],
         [-0.9481,  0.8103],
         [-0.1496,  0.0278]]], grad_fn=<ReshapeAliasBackward0>)

- 6, 4
```
[ 
 [ 0,  1,  2,  3],
 [ 4,  5,  6,  7],
 [ 8,  9, 10, 11],
 [12, 13, 14, 15],
 [16, 17, 18, 19],
 [20, 21, 22, 23]
]
```

- `(3*2, 2*2) => (3, 2, 2, 2)`

    ```
    [
      [  # 时间步 t=0
        [  # 批次 b=0
          [0, 1],    # 头 h=0
          [2, 3]     # 头 h=1
        ],
        [  # 批次 b=1
          [4, 5],
          [6, 7]
        ]
      ],
      [  # 时间步 t=1
        [
          [8, 9],
          [10,11]
        ],
        [
          [12,13],
          [14,15]
        ]
      ],
      [  # 时间步 t=2
        [
          [16,17],
          [18,19]
        ],
        [
          [20,21],
          [22,23]
        ]
      ]
    ]
    
    ```

In [33]:
A = torch.arange(24).reshape(6, 4)
A

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]])

In [34]:
A.reshape(3, 2, 2, 2)

tensor([[[[ 0,  1],
          [ 2,  3]],

         [[ 4,  5],
          [ 6,  7]]],


        [[[ 8,  9],
          [10, 11]],

         [[12, 13],
          [14, 15]]],


        [[[16, 17],
          [18, 19]],

         [[20, 21],
          [22, 23]]]])

### einsum => einops

- einsum: 顾名思义，更多是求和约定；不太适合直接做 reshape

In [35]:
from einops import rearrange

$$
\begin{align*}
    \mathbf{Q}_{\text{heads}} &= \mathbf{Q}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
    \mathbf{K}_{\text{heads}} &= \mathbf{K}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
    \mathbf{V}_{\text{heads}} &= \mathbf{V}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k)
\end{align*}
$$

In [36]:
Q, K, V = QKV.split(d, dim=1)
# (T*B, h*d_k)
Q.shape, K.shape, V.shape

(torch.Size([6, 4]), torch.Size([6, 4]), torch.Size([6, 4]))

In [37]:
torch.einsum('t b h k->(b h) t k', Q.contiguous().reshape(T, B, h, d_k))

RuntimeError: einsum(): invalid subscript given at index 9 in the equation string, subscripts must be in [a-zA-Z]

In [38]:
rearrange(Q, '(T B) (h d_k) -> T B h d_k', T=T, B=B, h=h, d_k=d_k)

tensor([[[[ 0.0668,  1.3935],
          [-1.1805,  0.8455]],

         [[ 0.8724,  0.8454],
          [-0.1481,  1.8186]]],


        [[[ 0.1566, -0.1838],
          [-0.1016,  0.7407]],

         [[-0.4729, -1.3182],
          [-0.9481,  0.8103]]],


        [[[ 0.4362,  0.6260],
          [ 0.5733, -0.1727]],

         [[-0.2227, -0.6440],
          [-0.1496,  0.0278]]]], grad_fn=<ReshapeAliasBackward0>)

In [39]:
rearrange(Q, '(T B) (h d_k) -> B h T d_k', T=T, B=B, h=h, d_k=d_k)

tensor([[[[ 0.0668,  1.3935],
          [ 0.1566, -0.1838],
          [ 0.4362,  0.6260]],

         [[-1.1805,  0.8455],
          [-0.1016,  0.7407],
          [ 0.5733, -0.1727]]],


        [[[ 0.8724,  0.8454],
          [-0.4729, -1.3182],
          [-0.2227, -0.6440]],

         [[-0.1481,  1.8186],
          [-0.9481,  0.8103],
          [-0.1496,  0.0278]]]], grad_fn=<PermuteBackward0>)

In [40]:
rearrange(rearrange(Q, '(T B) (h d_k) -> B h T d_k', T=T, B=B, h=h, d_k=d_k), 'B h T d_k -> (B h) T d_k', T=T, B=B, h=h, d_k=d_k)

tensor([[[ 0.0668,  1.3935],
         [ 0.1566, -0.1838],
         [ 0.4362,  0.6260]],

        [[-1.1805,  0.8455],
         [-0.1016,  0.7407],
         [ 0.5733, -0.1727]],

        [[ 0.8724,  0.8454],
         [-0.4729, -1.3182],
         [-0.2227, -0.6440]],

        [[-0.1481,  1.8186],
         [-0.9481,  0.8103],
         [-0.1496,  0.0278]]], grad_fn=<UnsafeViewBackward0>)

In [41]:
rearrange(Q, '(T B) (h d_k) -> (B h) T d_k', T=T, B=B, h=h, d_k=d_k)

tensor([[[ 0.0668,  1.3935],
         [ 0.1566, -0.1838],
         [ 0.4362,  0.6260]],

        [[-1.1805,  0.8455],
         [-0.1016,  0.7407],
         [ 0.5733, -0.1727]],

        [[ 0.8724,  0.8454],
         [-0.4729, -1.3182],
         [-0.2227, -0.6440]],

        [[-0.1481,  1.8186],
         [-0.9481,  0.8103],
         [-0.1496,  0.0278]]], grad_fn=<UnsafeViewBackward0>)