- Transformer 是现代人工智能的核心，而attention是 Transformer 最具特色的机制，在 transformer 及 attention 上花再多的时间探索都是值得。
- 从加速计算/存储的角度优化 attention 计算，都是在 attention 的计算机制上做文章：
    - vllm 中的 paged attention；
    - 以及今天要讲的 flash attention；

In [7]:
from IPython.display import Image


https://github.com/Dao-AILab/flash-attention

- install
    
    ```
    # pip
    pip install flash-attn --no-build-isolation
    pip install flash_attn -U --force-reinstall
    
    # source code compile
    python setup.py install
    ```

## basics

> 算法（软件）、硬件协同优化；

- paper
    - [FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness](https://arxiv.org/pdf/2205.14135)
    - [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://tridao.me/publications/flash2/flash2.pdf)
- Memory
    - SRAM > HBM > DRAM
- attention QKV 计算
    - 分块矩阵，然后是 loop（outer loop，inner loop，对应的是 gpu cuda 的 kernel 优化）；

### review of Attention

In [14]:
Image(url='./imgs/attention_steps.png', width=600)

$$
\mathbf{O} = \text{Dropout}(\text{Softmax}(\text{Mask}(\mathbf{QK}^T)))\mathbf{V}
$$

- formula
    - $(\mathbf {Q, K, V})\in\mathbb R^{N\times d}$
    - $\mathbf{QK}^T\in\mathbb R^{N\times N}$
    - $\mathbf A=\text{Dropout}(\text{Softmax}(\text{Mask}(\mathbf{QK}^T)))$
        - mask: m, softmax: sm, dropout: do
    - $\mathbf O=\mathbf {AV}\in \mathbb R^{N\times N}$
- notes
    - attention is bottlenecked by memory reads/writes
    - naive implementation requires repeated R/W from slow GPU HBM

### fused kernel ?

- 将好几个 operations fuse 成一个 operation 进而减轻 memory 存取的 loading

### flash attention

In [10]:
Image(url='./imgs/flash-attn.png', width=700)

### SDPA（pytorch）

## simple demo

In [5]:
import numpy as np

# 输入矩阵
X = np.array([[1, 2, 3, 4],
              [5, 6, 7, 8],
              [9, 10, 11, 12],
              [13, 14, 15, 16]])

# 权重矩阵
W_Q = W_K = W_V = np.eye(4)

# 经典自注意力机制
Q = np.dot(X, W_Q)
K = np.dot(X, W_K)
V = np.dot(X, W_V)
attention_scores = np.dot(Q, K.T) / np.sqrt(4)
attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=1, keepdims=True)
output_classic = np.dot(attention_weights, V)

output_classic

array([[12.99999999, 13.99999999, 14.99999999, 15.99999999],
       [13.        , 14.        , 15.        , 16.        ],
       [13.        , 14.        , 15.        , 16.        ],
       [13.        , 14.        , 15.        , 16.        ]])

In [6]:
# Flash Attention
b = 2
m = 2
output_flash = np.zeros((4, 4))
for i in range(b):
    X_block = X[i * m: (i + 1) * m]
    Q_block = np.dot(X_block, W_Q)
    K_block = np.dot(X_block, W_K)
    V_block = np.dot(X_block, W_V)
    
    # 计算块间的注意力得分
    attention_scores_block = np.dot(Q_block, K.T) / np.sqrt(4)
    attention_weights_block = np.exp(attention_scores_block) / np.sum(np.exp(attention_scores_block), axis=1, keepdims=True)
    
    # 累加到输出
    output_flash[i * m: (i + 1) * m] = np.dot(attention_weights_block, V)

output_flash

array([[12.99999999, 13.99999999, 14.99999999, 15.99999999],
       [13.        , 14.        , 15.        , 16.        ],
       [13.        , 14.        , 15.        , 16.        ],
       [13.        , 14.        , 15.        , 16.        ]])