# Flash Attention使用示例

这里演示了如何在项目中使用`flash-attn`的官方实现。

首先按照官方文档的说明，安装`flash-attn`。

In [None]:
%pip install torch packaging ninja wheel
%pip install flash-attn --no-build-isolation

## Self Attention的pytorch实现

我们使用pytorch实现了一个self-attention函数

In [4]:
import torch

def torch_attention(q, k, v, mask=None):
    '''
    PyTorch implementation of the scaled dot-product attention mechanism.
    Parameters:
        q: query tensor, [batch_size, n_heads, seq_len, hidden_size]
        k: key tensor, [batch_size, n_heads, seq_len, hidden_size]
        v: value tensor, [batch_size, n_heads, seq_len, hidden_size]
        mask: mask tensor, [batch_size, n_heads, seq_len, seq_len]
    Returns:
        attention output: output tensor, [batch_size, n_heads, seq_len, hidden_size]
    '''
    hidden_size = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(hidden_size, dtype=torch.float16)) # [batch_size, n_heads, seq_len, seq_len]
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    weights = torch.nn.functional.softmax(scores, dim=-1)
    return torch.matmul(weights, v) # [batch_size, n_heads, seq_len, hidden_size]

对于一些随机向量

In [5]:
import torch

batch_size = 4
n_heads = 2
seq_len = [2, 3, 8, 4]
hidden_size = 32

q = torch.randn(batch_size, n_heads, max(seq_len), hidden_size, dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, n_heads, max(seq_len), hidden_size, dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, n_heads, max(seq_len), hidden_size, dtype=torch.float16, device='cuda')

分别使用pytorch实现和`flash_attn`库中所提供的attention实现计算注意力机制的输出。并且比较两个输出的结果

In [6]:
import torch
from flash_attn import flash_attn_func

torch_attention_output = torch_attention(q, k, v)
flash_attn_output = flash_attn_func(
    q.transpose(1, 2), 
    k.transpose(1, 2), 
    v.transpose(1, 2)).transpose(1, 2)

def compare_tensors(t1, t2):
    return torch.allclose(t1, t2, atol=1e-3)

print(compare_tensors(torch_attention_output, flash_attn_output))
print(torch_attention_output[0][0][0])
print(flash_attn_output[0][0][0])

True
tensor([-0.9575,  2.1211, -0.0684, -0.7085, -0.3413, -0.1348, -0.7061, -0.9697,
         0.2690,  0.6123, -0.9048, -0.1287, -0.2888,  0.5713, -0.3391,  1.0098,
         0.8198, -0.0552, -0.4216,  0.5845, -0.6831,  0.3074,  0.5024, -0.5537,
         0.8457, -0.6489, -0.6060, -1.1484, -0.3499, -0.2136,  0.6978,  0.2974],
       device='cuda:0', dtype=torch.float16)
tensor([-0.9580,  2.1211, -0.0685, -0.7090, -0.3416, -0.1349, -0.7061, -0.9707,
         0.2690,  0.6123, -0.9058, -0.1292, -0.2888,  0.5718, -0.3398,  1.0107,
         0.8198, -0.0551, -0.4216,  0.5850, -0.6831,  0.3079,  0.5029, -0.5537,
         0.8457, -0.6494, -0.6060, -1.1494, -0.3499, -0.2140,  0.6982,  0.2974],
       device='cuda:0', dtype=torch.float16)


从上述计算结果可以看出，在半精度浮点数下，使用两种实现得到的attention计算结果是一致的。