
https://github.com/Dao-AILab/flash-attention

- install
    - pip
    ```
    pip install flash-attn --no-build-isolation
    pip install flash_attn -U --force-reinstall
    ```
    - source code compile

    ```
    python setup.py install
    ```

In [2]:
import numpy as np

# 输入矩阵
X = np.array([[1, 2, 3, 4],
              [5, 6, 7, 8],
              [9, 10, 11, 12],
              [13, 14, 15, 16]])

# 权重矩阵
W_Q = W_K = W_V = np.eye(4)

# 经典自注意力机制
Q = np.dot(X, W_Q)
K = np.dot(X, W_K)
V = np.dot(X, W_V)
attention_scores = np.dot(Q, K.T) / np.sqrt(4)
attention_weights = np.exp(attention_scores) / np.sum(np.exp(attention_scores), axis=1, keepdims=True)
output_classic = np.dot(attention_weights, V)
print("经典自注意力机制输出：")
print(output_classic)

# Flash Attention
b = 2
m = 2
output_flash = np.zeros((4, 4))
for i in range(b):
    X_block = X[i * m: (i + 1) * m]
    Q_block = np.dot(X_block, W_Q)
    K_block = np.dot(X_block, W_K)
    V_block = np.dot(X_block, W_V)
    
    # 计算块间的注意力得分
    attention_scores_block = np.dot(Q_block, K.T) / np.sqrt(4)
    attention_weights_block = np.exp(attention_scores_block) / np.sum(np.exp(attention_scores_block), axis=1, keepdims=True)
    
    # 累加到输出
    output_flash[i * m: (i + 1) * m] = np.dot(attention_weights_block, V)

print("Flash Attention 输出：")
print(output_flash)

# 比较结果
print("两者结果是否相同：", np.allclose(output_classic, output_flash))


经典自注意力机制输出：
[[12.99999999 13.99999999 14.99999999 15.99999999]
 [13.         14.         15.         16.        ]
 [13.         14.         15.         16.        ]
 [13.         14.         15.         16.        ]]
Flash Attention 输出：
[[12.99999999 13.99999999 14.99999999 15.99999999]
 [13.         14.         15.         16.        ]
 [13.         14.         15.         16.        ]
 [13.         14.         15.         16.        ]]
两者结果是否相同： True
