From c557885f7d9a7f624756ea9f5c73bf2ec5c0db43 Mon Sep 17 00:00:00 2001 From: Maxime France-Pillois Date: Wed, 17 Sep 2025 11:49:23 +0100 Subject: [PATCH] Fix Line Number issue. --- ...flashattention-performance-on-intel-gpu.md | 84 +++++++++---------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md b/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md index 882f1f6..33df973 100644 --- a/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md +++ b/_collections/_portal_posts/2025-09-02-improving-triton-flashattention-performance-on-intel-gpu.md @@ -207,45 +207,45 @@ This is exactly what happens in the widespread case of [FlashAttention version 2 The FlashAttention v2 Forward pass algorithm in pseudo-code is: ```python -# Inputs : Q, K and V are 2D Matrices in Global Memory -def FlashAttention2_forward(Q, K, V): - O = torch.zeros_like(Q, requires_grad=True) - L = torch.zeros(Q.shape[:-1])[...,None] - - Q_BLOCKS = torch.split(Q, BLOCK_SHAPE) - K_BLOCKS = torch.split(K, BLOCK_SHAPE) - V_BLOCKS = torch.split(V, BLOCK_SHAPE) - - Tr = len(Q_BLOCKS) - Tc = len(K_BLOCKS) - - for i in range(Tr): - Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM - Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip - li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip - mi = NEG_INF # No load required, Initialized on chip - - for j in range(Tc): - Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM - Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM - - KTj = Kj.transpose() - S_ij = matmul(Qi, KTj) - - P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li) - - P_ij_Vj = matmul(P_ij, Vj) - Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj - - # update li and mi - li = li_new - mi = mi_new - - Oi = Oij / diag(li) - O.store(Oi, i) # Store data to Global Memory as the i-th block of O - L.store(li, i) # Store data to Global Memory as the i-th block of L - - return O, L +1 # Inputs : Q, K and V are 2D Matrices in Global Memory +2 def FlashAttention2_forward(Q, K, V): +3 O = torch.zeros_like(Q, requires_grad=True) +4 L = torch.zeros(Q.shape[:-1])[...,None] +5 +6 Q_BLOCKS = torch.split(Q, BLOCK_SHAPE) +7 K_BLOCKS = torch.split(K, BLOCK_SHAPE) +8 V_BLOCKS = torch.split(V, BLOCK_SHAPE) +9 +10 Tr = len(Q_BLOCKS) +11 Tc = len(K_BLOCKS) +12 +13 for i in range(Tr): +14 Qi = load(Q_BLOCKS[i]) # Load data from Global Memory to SRAM +15 Oi = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip +16 li = torch.zeros(BLOCK_SHAPE) # No load required, Initialized on chip +17 mi = NEG_INF # No load required, Initialized on chip +18 +19 for j in range(Tc): +20 Kj = load(K_BLOCKS[j]) # Load data from Global Memory to SRAM +21 Vj = load(V_BLOCKS[j]) # Load data from Global Memory to SRAM +22 +23 KTj = Kj.transpose() +24 S_ij = matmul(Qi, KTj) +25 +26 P_ij, m_block_ij, mi_new, li_new = online_softmax(S_ij, mi, li) +27 +28 P_ij_Vj = matmul(P_ij, Vj) +29 Oij = (li/li_new) * torch.exp(mi - mi_new) * Oi + (torch.exp(m_block_ij - mi_new) / li_new) * P_ij_Vj +30 +31 # update li and mi +32 li = li_new +33 mi = mi_new +34 +35 Oi = Oij / diag(li) +36 O.store(Oi, i) # Store data to Global Memory as the i-th block of O +37 L.store(li, i) # Store data to Global Memory as the i-th block of L +38 +39 return O, L ``` In the second version of the implementation of the FlashAttention model, the loop order has been reversed to promote @@ -253,9 +253,9 @@ data locality. As long as there is enough local memory (or registers) to contain all the needed data, this algorithm works fine and provides significant performance improvements compared to FlashAttention v1 (in the paper, the authors mention 2x faster for the Cutlass implementation and 1.3-1.5× faster in Triton on an Nvidia Ampere GPU A100). -Deployed on a GPU target, line 4-10 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group ( +Deployed on a GPU target, line 13-37 constitutes the computing kernel that is dispatched to a Thread Block/Work-Group ( i.e. a SM/XeCore). -But as you can see, variable Q is loaded before the loop (line 4) and remains *live* across the loop. +As you can see, variable Q is loaded before the loop (line 14) and remains *live* across the loop. The long lifespan of variable Q is even more problematic in the causal variation of the FlashAttention implementation. The causal variation is defined in the paper as : @@ -264,7 +264,7 @@ The causal variation is defined in the paper as : The Triton implementation of FlashAttention v2 with causal mask is as follow: -```python {.line-numbers} +```python @triton.jit def _attn_fwd(Q_block_ptr, K_block_ptr, V_block_ptr, sm_scale, M, N_CTX: tl.constexpr, # BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #