Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,55 +207,55 @@ This is exactly what happens in the widespread case of [FlashAttention version 2
The FlashAttention v2 Forward pass algorithm in pseudo-code is:

```python
# Inputs : Q, K and V are 2D Matrices in Global Memory
def FlashAttention2_forward(Q, K, V):
O = torch.zeros_like(Q, requires_grad=True)
L = torch.zeros(Q.shape[:-1])[...,None]

Q_BLOCKS = torch.split(Q, BLOCK_SHAPE)
K_BLOCKS = torch.split(K, BLOCK_SHAPE)
V_BLOCKS = torch.split(V, BLOCK_SHAPE)

Tr = len(Q_BLOCKS)
Tc = len(K_BLOCKS)

for i in range(Tr):
Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM
Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
mi = NEG_INF # No load required, Initialized on chip

for j in range(Tc):
Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM
Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM

KTj = Kj.transpose()
S_ij = matmul(Qi, KTj)

P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)

P_ij_Vj = matmul(P_ij, Vj)
Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj

# update li and mi
li = li_new
mi = mi_new

Oi = Oij / diag(li)
O.store(Oi, i) # Store data to Global Memory as the i-th block of O
L.store(li, i) # Store data to Global Memory as the i-th block of L

return O, L
1 # Inputs : Q, K and V are 2D Matrices in Global Memory
2 def FlashAttention2_forward(Q, K, V):
3 O = torch.zeros_like(Q, requires_grad=True)
4 L = torch.zeros(Q.shape[:-1])[...,None]
5
6 Q_BLOCKS = torch.split(Q, BLOCK_SHAPE)
7 K_BLOCKS = torch.split(K, BLOCK_SHAPE)
8 V_BLOCKS = torch.split(V, BLOCK_SHAPE)
9
10 Tr = len(Q_BLOCKS)
11 Tc = len(K_BLOCKS)
12
13 for i in range(Tr):
14 Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM
15 Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
16 li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip
17 mi = NEG_INF # No load required, Initialized on chip
18
19 for j in range(Tc):
20 Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM
21 Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM
22
23 KTj = Kj.transpose()
24 S_ij = matmul(Qi, KTj)
25
26 P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li)
27
28 P_ij_Vj = matmul(P_ij, Vj)
29 Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj
30
31 # update li and mi
32 li = li_new
33 mi = mi_new
34
35 Oi = Oij / diag(li)
36 O.store(Oi, i) # Store data to Global Memory as the i-th block of O
37 L.store(li, i) # Store data to Global Memory as the i-th block of L
38
39 return O, L
```

In the second version of the implementation of the FlashAttention model, the loop order has been reversed to promote
data locality.
As long as there is enough local memory (or registers) to contain all the needed data, this algorithm works fine and
provides significant performance improvements compared to FlashAttention v1 (in the paper, the authors mention 2x faster
for the Cutlass implementation and 1.3-1.5× faster in Triton on an Nvidia Ampere GPU A100).
Deployed on a GPU target, line 4-10 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
Deployed on a GPU target, line 13-37 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group (
i.e. a SM/XeCore).
But as you can see, variable Q is loaded before the loop (line 4) and remains *live* across the loop.
As you can see, variable Q is loaded before the loop (line 14) and remains *live* across the loop.

The long lifespan of variable Q is even more problematic in the causal variation of the FlashAttention implementation.
The causal variation is defined in the paper as :
Expand All @@ -264,7 +264,7 @@ The causal variation is defined in the paper as :

The Triton implementation of FlashAttention v2 with causal mask is as follow:

```python {.line-numbers}
```python
@triton.jit
def _attn_fwd(Q_block_ptr, K_block_ptr, V_block_ptr, sm_scale, M, N_CTX: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
Expand Down
Loading