In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x75ba13f13a10>

In [3]:
d = 4  # 模型维度
B = 2
T = 3
h = 2    # 多头注意力中的头数
ff = 8  # 前馈网络的维度

In [4]:
X = torch.randn(T, B, d)  # [seq_len, batch_size, d_model]

In [17]:
encoder = nn.TransformerEncoderLayer(d, h, ff, dropout=0.0)

In [21]:
W_in, b_in = encoder.self_attn.in_proj_weight, encoder.self_attn.in_proj_bias
# (3d, d), (3d, )
W_in.shape, b_in.shape

(torch.Size([12, 4]), torch.Size([12]))

### encoder recap

- input: $\mathbf{X} \in \mathbb{R}^{T \times B \times d_{\text{model}}}$
- 1. multihead selfattn
    - 线性变换（linear projection, 矩阵乘法）生成 Q、K、V矩阵
    - $X_{\text{flat}}=\mathbf X.\text{reshape}(T\times B,d_{model})$
    - $\mathbf{QKV}=\mathbf X\mathbf W_{in}^T+\mathbf b_{in}$（`encoder_layer.self_attn.in_proj_weight`, `encoder_layer.self_attn.in_proj_bias`）
        - $\mathbf{W}_{in} \in \mathbb{R}^{3d_{\text{model}} \times d_{\text{model}}}$，$\mathbf{b}_{in} \in \mathbb{R}^{3d_{\text{model}}}$
        - $\mathbf{QKV}\in \mathbb R^{T\times B,3d_{model}}$
    - 拆分 $\mathbf Q, \mathbf K,\mathbf V$
        - $\mathbf Q, \mathbf K,\mathbf V=\text{split}(\mathbf{QKV},d_{model})$（按列进行拆分）
        - $\mathbf Q, \mathbf K,\mathbf V\in \mathbb R^{T \times B, d_{\text{model}}}$
    - 调整形状以适应多头注意力
        - $d_k = \frac{d_{\text{model}}}h$
        - `reshape_for_heads`
        $$
        \begin{align*}
            \mathbf{Q}_{\text{heads}} &= \mathbf{Q}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
            \mathbf{K}_{\text{heads}} &= \mathbf{K}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k) \\
            \mathbf{V}_{\text{heads}} &= \mathbf{V}.\text{reshape}(T, B, h, d_k).\text{permute}(1, 2, 0, 3).\text{reshape}(B \times h, T, d_k)
        \end{align*}
        $$
    - 计算注意力分数：$\text{Scores} = \frac{\mathbf{Q}_{\text{heads}} \mathbf{K}_{\text{heads}}^\top}{\sqrt{d_k}}$
        - $\mathbf{Q}_{\text{heads}} \in \mathbb{R}^{(B \times h) \times T \times d_k}$，$\mathbf{K}_{\text{heads}}^\top \in \mathbb{R}^{(B \times h) \times d_k \times T}$，因此 $\text{Scores} \in \mathbb{R}^{(B \times h) \times T \times T}$。
    - 计算注意力权重：$\text{AttentionWeights}=\text{softmax}(\text{Scores})$
    - 计算注意力输出：$\text{AttentionOutput}=\text{AttentionWeights}\times{\mathbf V_\text{heads}}$
        - $\mathbf{V}_{\text{heads}} \in \mathbb{R}^{(B \times h) \times T \times d_k}$，因此 $\text{AttentionOutput} \in \mathbb{R}^{(B \times h) \times T \times d_k}$。
    - 合并多头输出：$\text{AttentionOutput} = \text{AttentionOutput}.\text{reshape}(B, h, T, d_k).\text{permute}(2, 0, 1, 3).\text{reshape}(T, B, d_{\text{model}})$

### 张量内存 layout

- 张量内存 layout
    - 在大多数深度学习框架（如 PyTorch）中，张量的数据是以**一维数组**的形式在内存中**连续存储**的。对于多维张量，其高维结构是通过一维内存数组和步幅（strides）来实现的。
    - pytorch 存储顺序（Storage Order）是 Row-major
- Strides（步幅）
    - 对于一个形状为 $(D_0,D_1,D_2)$ 的 3D 张量，其步幅计算如下：
        - $\text{stride[2]} = 1$
        - $\text{stride[1]} = D_2\times \text{stride[2]}=D_2$
        - $\text{stride[0]} = D_1\times \text{stride[1]}=D_1\times D_2$

- `reshape` 不改变内存中的数据顺序
- 什么样的操作会导致内存的不连续
    - permute, transpose, view;
        - transpose 是 permute 的特例，transpose 只允许交换两个维度。
- 当张量在内存中的数据排列不再符合其形状和步幅之间的默认关系时，张量就是非连续的。
    - 特征：.is_contiguous() 方法返回 False。
    - 影响：某些操作在非连续张量上可能性能较差，或者需要额外的内存拷贝。
    - 解决方法：使用 .contiguous() 方法，将张量拷贝为内存中连续的版本。

In [5]:
A = torch.randn(3, 4)
A.shape, A.stride(), A.is_contiguous()

(torch.Size([3, 4]), (4, 1), True)

In [6]:
A = A.transpose(0, 1)
A.shape, A.stride(), A.is_contiguous()

(torch.Size([4, 3]), (1, 4), False)

In [7]:
A = A.contiguous()
A.shape, A.stride(), A.is_contiguous()

(torch.Size([4, 3]), (3, 1), True)

### view vs. reshape

- view
    - 不会复制数据：view 创建的是原始张量的一个新的视图，内存数据保持不变，只是重新解释内存中的数据。因为它依赖于张量的内存布局，所以无法对非连续的张量使用。
    - 不连续内存，view 时有可能报错；
- reshape: 它会自动处理非连续张量，尽可能返回视图，如果无法返回视图
    - 不要求内存连续：reshape 可以用于非连续的张量。如果张量不连续，reshape 会自动尝试创建一个新的连续张量并复制数据，以确保能够完成形状转换。
    - 可能复制数据：当张量是内存不连续的，reshape 可能会进行数据复制，生成一个新的内存布局的张量。否则，它和 view 的行为是一样的，不复制数据。
    - 不改变数据在内存中的顺序，只改变张量的形状解释。

In [8]:
A = torch.randn(2, 3, 4)
A.view(-1, 4).is_contiguous()

True

In [9]:
A = torch.randn(2, 3, 4)
A_t = A.permute(1, 2, 0)
A_t.shape, A_t.stride(), A_t.is_contiguous()

(torch.Size([3, 4, 2]), (4, 1, 12), False)

In [10]:
A_t

tensor([[[-0.6788, -0.1360],
         [ 0.5743,  1.6354],
         [ 0.1877,  0.6547],
         [-0.3576,  0.5760]],

        [[-0.3165,  1.1415],
         [ 0.5886,  0.0186],
         [-0.8905, -1.8058],
         [ 0.4098,  0.9254]],

        [[ 1.9312, -0.3753],
         [ 1.0119,  1.0331],
         [-1.4364, -0.6867],
         [-1.1299,  0.6368]]])

In [11]:
A_t.view(-1, 2)

tensor([[-0.6788, -0.1360],
        [ 0.5743,  1.6354],
        [ 0.1877,  0.6547],
        [-0.3576,  0.5760],
        [-0.3165,  1.1415],
        [ 0.5886,  0.0186],
        [-0.8905, -1.8058],
        [ 0.4098,  0.9254],
        [ 1.9312, -0.3753],
        [ 1.0119,  1.0331],
        [-1.4364, -0.6867],
        [-1.1299,  0.6368]])

In [12]:
A_t.view(-1, 2)

tensor([[-0.6788, -0.1360],
        [ 0.5743,  1.6354],
        [ 0.1877,  0.6547],
        [-0.3576,  0.5760],
        [-0.3165,  1.1415],
        [ 0.5886,  0.0186],
        [-0.8905, -1.8058],
        [ 0.4098,  0.9254],
        [ 1.9312, -0.3753],
        [ 1.0119,  1.0331],
        [-1.4364, -0.6867],
        [-1.1299,  0.6368]])

In [13]:
A_t.reshape(-1, 4), A_t.reshape(-1, 4).is_contiguous()

(tensor([[-0.6788, -0.1360,  0.5743,  1.6354],
         [ 0.1877,  0.6547, -0.3576,  0.5760],
         [-0.3165,  1.1415,  0.5886,  0.0186],
         [-0.8905, -1.8058,  0.4098,  0.9254],
         [ 1.9312, -0.3753,  1.0119,  1.0331],
         [-1.4364, -0.6867, -1.1299,  0.6368]]),
 True)

### qkv, mhsa

- $X_{\text{flat}}=\mathbf X.\text{reshape}(T\times B,d_{model})$

In [22]:
X.shape, X

(torch.Size([3, 2, 4]),
 tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047]],
 
         [[ 0.3559, -0.6866, -0.4934,  0.2415],
          [-1.1109,  0.0915, -2.3169, -0.2168]],
 
         [[-0.3097, -0.3957,  0.8034, -0.6216],
          [-0.5920, -0.0631, -0.8286,  0.3309]]]))

In [26]:
X_flat = X.reshape(-1, d)
# (T*B, d)
X_flat.shape

torch.Size([6, 4])

- $\mathbf{QKV}=\mathbf X_{\text{flat}}\mathbf W_{in}^T+\mathbf b_{in}$

In [24]:
QKV = F.linear(X_flat, W_in, b_in)
QKV.shape

torch.Size([6, 12])

In [27]:
QKV

tensor([[-1.0015, -0.4584,  0.5529,  1.0841,  1.3037,  1.5273,  0.0718,  1.7711,
         -0.2659, -1.4456,  0.2097, -0.3207],
        [-0.5442,  0.5712,  0.7767,  0.3247,  0.0511, -0.0176,  0.1066, -0.1615,
         -0.2305, -0.5033,  1.3315, -0.8612],
        [ 0.4200,  0.4271,  0.3509, -0.6858, -0.2740, -0.2462, -0.0934, -0.2253,
         -0.0029, -0.2396,  0.2950, -0.3374],
        [ 0.2420,  0.5418,  0.4608, -0.4450, -0.3390, -0.8328,  0.4763, -0.2165,
         -0.3916, -0.6494,  0.9870, -0.7253],
        [-0.5956, -0.1391, -0.1943,  0.7945,  0.0555,  0.0245,  0.0508, -0.2930,
          0.0233,  0.5277,  0.1955,  0.1179],
        [ 0.2601,  0.1877,  0.0276, -0.3264, -0.2683, -0.4348,  0.1211, -0.2601,
         -0.0682, -0.0048,  0.1702, -0.1248]], grad_fn=<AddmmBackward0>)

In [28]:
Q, K, V = QKV.split(d, dim=1)
Q.shape, K.shape, V.shape

(torch.Size([6, 4]), torch.Size([6, 4]), torch.Size([6, 4]))

In [29]:
K

tensor([[ 1.3037,  1.5273,  0.0718,  1.7711],
        [ 0.0511, -0.0176,  0.1066, -0.1615],
        [-0.2740, -0.2462, -0.0934, -0.2253],
        [-0.3390, -0.8328,  0.4763, -0.2165],
        [ 0.0555,  0.0245,  0.0508, -0.2930],
        [-0.2683, -0.4348,  0.1211, -0.2601]], grad_fn=<SplitBackward0>)

In [33]:
# 调整Q、K、V的形状以适应多头注意力
d_k = d // h  # 每个头的维度
def reshape_for_heads(x):
    # x.shape: (T*B, d)
    # 最末尾的维度上展开，d => h * d_k
    # (T*B, h, d_k) => (T, B, h, d_k)
    # permute(1, 2, 0, 3) => (B, h, T, d_k)
    print(x.shape, x.is_contiguous())
    y = x.contiguous().view(T, B, h, d_k).permute(1, 2, 0, 3).reshape(B * h, T, d_k)
    print(y.shape)
    return y

In [34]:
Q = reshape_for_heads(Q)
K = reshape_for_heads(K)
V = reshape_for_heads(V)

torch.Size([4, 3, 2]) False
torch.Size([4, 3, 2])
torch.Size([4, 3, 2]) False
torch.Size([4, 3, 2])
torch.Size([4, 3, 2]) False
torch.Size([4, 3, 2])


### einsum