From bf26a6801c8cdfb68459f3839aa2c70acb0066e3 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 27 Oct 2025 17:55:32 +0800 Subject: [PATCH 01/17] Adds GQA forward and boolean mask/bias Adds forward support for GQA/MQA (different Q vs KV heads) with optional boolean mask and bias, including broadcasting across batch/head/seq dims and per-head routing. Switches to compile-time mask/bias flags, removes the scratchpad workaround, simplifies scaling, and indexes LSE/Out by Q heads. Skips masked-out tiles, tightens mask semantics (True = keep), and fixes backward mask handling. Introduces a contiguity helper, bumps pipeline stages, and errors out on Triton backward for GQA/MQA until implemented. --- flash_dmattn/flash_dmattn_triton.py | 396 +++++++++++++++------------- 1 file changed, 219 insertions(+), 177 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index c94500b..b4b56f0 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1,3 +1,4 @@ +from typing import Optional import math import torch @@ -30,7 +31,6 @@ def _fwd_kernel( Bias, Out, Lse, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug softmax_scale, stride_qb, stride_qh, @@ -50,14 +50,20 @@ def _fwd_kernel( stride_ob, stride_oh, stride_om, - nheads, + nheads_q, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, - CACHE_KEY_SEQLEN_Q, - CACHE_KEY_SEQLEN_K, + CACHE_KEY_SEQLEN_Q: tl.constexpr, + CACHE_KEY_SEQLEN_K: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, @@ -67,44 +73,56 @@ def _fwd_kernel( ): start_m = tl.program_id(0) off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads + off_b = off_hb // nheads_q + off_hq = off_hb % nheads_q + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq # off_b = tl.program_id(1) # off_h = tl.program_id(2) # off_hb = off_b * nheads + off_h - # initialize offsets + + # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V, Mask, Bias - # Adding parenthesis around indexing might use int32 math instead of int64 math? - # https://github.com/openai/triton/issues/741 - # I'm seeing a tiny bit of difference (5-7us) q_ptrs = ( - Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + Q + off_b * stride_qb + off_hq * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( - K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + K + off_b * stride_kb + off_hk * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( - V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + V + off_b * stride_vb + off_hk * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) m_ptrs = ( - Mask + off_b * stride_mb + off_h * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) - ) + Mask + off_b * stride_mb + off_hmask * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) + ) if HAS_MASK else None b_ptrs = ( - Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) - ) - - # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m - lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + Bias + off_b * stride_bb + off_hbbias * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) if HAS_BIAS else None + + # Initialize pointer to m and l + lse_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) - # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! - if EVEN_M & EVEN_N: + + # Load q: it will stay in SRAM throughout + if EVEN_M: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: @@ -116,133 +134,134 @@ def _fwd_kernel( q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) - # loop over k, v and update accumulator + + # Loop over k, v and update accumulator end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - # Load k - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) + if HAS_MASK: + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs + start_n) else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, + mask = tl.load( + m_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=False ) - # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - acc_s += tl.dot(q, tl.trans(k)) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs + start_n) + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) else: - mask = tl.load( - m_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0 - ) - - # Check if any element in mask is non-zero - # BUG: Triton needs to determine the control flow at compile time. - # Dynamic conditions at runtime can undermine this optimization. - # any_active = tl.sum(mask) != 0 - # Apply dynamic mask - acc_s += tl.where(mask > 0.0, 0.0, float("-inf")) + any_active = True - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load(b_ptrs + start_n).to(tl.float32) - else: - bias = tl.load( - b_ptrs + start_n, - mask=(offs_m[:, None] < seqlen_q) - & ((start_n + offs_n)[None, :] < seqlen_k), - other=0.0, - ).to(tl.float32) - - # Apply scaling and bias - # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler - # can then fuse the mult and add into an fma instruction. But if we have bias we need to - # to multiply with softmax_scale here. - acc_s = acc_s * softmax_scale + bias - # acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) - m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) - p = tl.exp(acc_s - m_ij[:, None]) - l_ij = tl.sum(p, 1) + # Skip this iteration if no active elements + if any_active: - # scale acc_o - acc_o_scale = tl.exp(m_i - m_ij) + # Load k + if EVEN_N: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) - # update output accumulator - # BUG: have to store and immediately load - tl.store(t_ptrs, acc_o_scale) - acc_o_scale = tl.load(t_ptrs) - acc_o = acc_o * acc_o_scale[:, None] + # Compute acc_s + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s += tl.dot(q, tl.trans(k)) - # load v - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0.0, float("-inf")) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + + # Apply scaling and bias + acc_s = acc_s * softmax_scale + bias else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) + # Apply scaling + acc_s = acc_s * softmax_scale + + m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) + p = tl.exp(acc_s - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # Update output accumulator + acc_o = acc_o * acc_o_scale[:, None] + + # Load v + if EVEN_N: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_o += tl.dot(p.to(v.dtype), v) + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute acc_o + acc_o += tl.dot(p.to(v.dtype), v) - # update statistics - m_i = m_ij - l_i_new = tl.exp(lse_i - m_ij) + l_ij - lse_i = m_ij + tl.log(l_i_new) + # Update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) - # BUG: have to store and immediately load - tl.store(t_ptrs, o_scale) - o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] - # rematerialize offsets to save registers + # Rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - # write back l and m + # Write back l and m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) - # initialize pointers to output + # Initialize pointers to output offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Out + off_b * stride_ob - + off_h * stride_oh + + off_hq * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: @@ -281,10 +300,10 @@ def _bwd_preprocess_do_o_dot( off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads - # initialize offsets + # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) - # load + # Load o o = tl.load( Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), @@ -300,7 +319,7 @@ def _bwd_preprocess_do_o_dot( other=0.0, ).to(tl.float32) delta = tl.sum(o * do, axis=1) - # write-back + # Write back tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @@ -463,7 +482,7 @@ def _bwd_kernel_one_col_block( mask = tl.load( m_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=0.0, + other=False, ) # Trying to combine the two masks seem to make the result wrong @@ -473,21 +492,21 @@ def _bwd_kernel_one_col_block( # Apply causal mask if IS_CAUSAL: acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf")) - # Apply dynamic mask - acc_s = tl.where(mask > 0.0, acc_s, float("-inf")) + # Apply dynamic mask (boolean mask: True = keep, False = mask-out) + acc_s = tl.where(mask, acc_s, float("-inf")) tl.debug_barrier() # Race condition otherwise # Load bias if EVEN_M & EVEN_N: bias = tl.load( b_ptrs, - mask=(mask > 0.0), + mask=mask, other=0.0, ).to(tl.float32) else: bias = tl.load( b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & (mask > 0.0), + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & mask, other=0.0, ).to(tl.float32) acc_s = acc_s * softmax_scale + bias @@ -846,51 +865,66 @@ def _bwd_kernel( ) -def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): +def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_causal=False): # shape constraints - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - assert k.shape == (batch, seqlen_k, nheads, d) - assert v.shape == (batch, seqlen_k, nheads, d) - assert d <= 128, "FlashAttention only support head dimensions up to 128" + batch, seqlen_q, nheads_q, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + + assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" - assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" - assert mask.is_cuda, "mask must be on CUDA" - if mask.stride(-1) != 1: - mask = mask.contiguous() + HAS_MASK = mask is not None + if HAS_MASK: + assert mask.dtype == torch.bool, "Only support bool mask" + assert mask.is_cuda + assert mask.dim() == 4, "mask must be 4D" + mb, hm, mq, nk = mask.shape + assert mb in (1, batch), "mask batch dim must be 1 or batch" + assert hm in (1, nheads_k, nheads_q), "mask head dim must be 1, nheads_k, or nheads_q" + assert mq in (1, seqlen_q), "mask query dim must be 1 or seqlen_q" + assert nk in (1, seqlen_k), "mask key dim must be 1 or seqlen_k" + mask = mask.expand(batch, hm, seqlen_q, seqlen_k) + nheads_mask = hm + else: + nheads_mask = 1 - assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" - assert bias.is_cuda, "bias must be on CUDA" - assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" - assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" - if bias.stride(-1) != 1: - bias = bias.contiguous() + HAS_BIAS = bias is not None + if HAS_BIAS: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4, "bias must be 4D" + bb, hb, bq, bk_ = bias.shape + assert bb in (1, batch), "bias batch dim must be 1 or batch" + assert hb in (1, nheads_k, nheads_q), "bias head dim must be 1, nheads_k, or nheads_q" + assert bq in (1, seqlen_q), "bias query dim must be 1 or seqlen_q" + assert bk_ in (1, seqlen_k), "bias key dim must be 1 or seqlen_k" + bias = bias.expand(batch, hb, seqlen_q, seqlen_k) + nheads_bias = hb + else: + nheads_bias = 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(d) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) - tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + lse = torch.empty((batch, nheads_q, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) BLOCK_M = 128 BLOCK_N = 64 num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads_q) _fwd_kernel[grid]( q, k, v, - mask, - bias, + mask if HAS_MASK else torch.empty(0, device=q.device, dtype=torch.bool), + bias if HAS_BIAS else torch.empty(0, device=q.device, dtype=q.dtype), o, lse, - tmp, softmax_scale, q.stride(0), q.stride(2), @@ -901,16 +935,20 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False v.stride(0), v.stride(2), v.stride(1), - mask.stride(0), - mask.stride(1), - mask.stride(2), - bias.stride(0), - bias.stride(1), - bias.stride(2), + mask.stride(0) if HAS_MASK else 0, + mask.stride(1) if HAS_MASK else 0, + mask.stride(2) if HAS_MASK else 0, + bias.stride(0) if HAS_BIAS else 0, + bias.stride(1) if HAS_BIAS else 0, + bias.stride(2) if HAS_BIAS else 0, o.stride(0), o.stride(2), o.stride(1), - nheads, + nheads_q, + nheads_k, + nheads_mask, + nheads_bias, + nheads_q // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -920,11 +958,13 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, is_causal, + HAS_MASK, + HAS_BIAS, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, - num_stages=1, + num_stages=2, ) return o, lse, softmax_scale # softmax_scale could have been updated @@ -942,7 +982,8 @@ def _flash_attn_backward( seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) - assert mask.dtype in [q.dtype, torch.float] + # Use boolean mask consistently (True = keep). Ensure GPU + layout. + assert mask.dtype == torch.bool, "Only support bool mask" assert mask.is_cuda assert mask.dim() == 4 assert mask.stride(-1) == 1 @@ -1052,6 +1093,10 @@ def _flash_attn_backward( return dq, dk, dv, dbias +def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): @@ -1064,18 +1109,9 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa is_causal: bool, whether to apply causal masking softmax_scale: float, scaling factor for attention scores """ - batch, seqlen_q, nheads, _ = query.shape - _, seqlen_k, _, _ = key.shape - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask = torch.where(attn_mask, 1.0, 0.0) - else: - attn_mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) - if attn_bias is None: - attn_bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=query.device, dtype=query.dtype) # Make sure that the last dimension is contiguous - query, key, value, attn_mask, attn_bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [query, key, value, attn_mask, attn_bias]] + query, key, value, attn_mask, attn_bias = [maybe_contiguous(x) for x in [query, key, value, attn_mask, attn_bias]] o, lse, ctx.softmax_scale = _flash_attn_forward( query, key, @@ -1093,6 +1129,12 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa def backward(ctx, do): query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" + # Backward for GQA/MQA (nheads_q != nheads_k) is not implemented for Triton + if query.shape[2] != key.shape[2]: + raise RuntimeError( + "Triton backward for GQA/MQA (nheads_q != nheads_k) is not implemented yet. " + "Use the CUDA backend for training or disable grad for Triton path." + ) dq, dk, dv, dbias = _flash_attn_backward( do, query, From be550d006a4f9ca3b0d8d3bcdce9eba8fac0dce4 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 12:14:40 +0800 Subject: [PATCH 02/17] Simplifies mask/bias logic; aligns kernel flags Removes rigid 4D asserts and explicit expands for mask/bias to support broadcast-friendly inputs and avoid unnecessary memory overhead. Renames flags to lower-case and passes them through consistently to the kernel, updating conditional strides and placeholder tensors accordingly. Improves readability and aligns runtime flags with kernel expectations. --- flash_dmattn/flash_dmattn_triton.py | 48 ++++++++++------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index b4b56f0..bb0b8e1 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -876,33 +876,19 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - HAS_MASK = mask is not None - if HAS_MASK: + has_mask = mask is not None + if has_mask: assert mask.dtype == torch.bool, "Only support bool mask" assert mask.is_cuda - assert mask.dim() == 4, "mask must be 4D" - mb, hm, mq, nk = mask.shape - assert mb in (1, batch), "mask batch dim must be 1 or batch" - assert hm in (1, nheads_k, nheads_q), "mask head dim must be 1, nheads_k, or nheads_q" - assert mq in (1, seqlen_q), "mask query dim must be 1 or seqlen_q" - assert nk in (1, seqlen_k), "mask key dim must be 1 or seqlen_k" - mask = mask.expand(batch, hm, seqlen_q, seqlen_k) - nheads_mask = hm + nheads_mask = mask.shape[1] else: nheads_mask = 1 - HAS_BIAS = bias is not None - if HAS_BIAS: + has_bias = bias is not None + if has_bias: assert bias.dtype in [q.dtype, torch.float] assert bias.is_cuda - assert bias.dim() == 4, "bias must be 4D" - bb, hb, bq, bk_ = bias.shape - assert bb in (1, batch), "bias batch dim must be 1 or batch" - assert hb in (1, nheads_k, nheads_q), "bias head dim must be 1, nheads_k, or nheads_q" - assert bq in (1, seqlen_q), "bias query dim must be 1 or seqlen_q" - assert bk_ in (1, seqlen_k), "bias key dim must be 1 or seqlen_k" - bias = bias.expand(batch, hb, seqlen_q, seqlen_k) - nheads_bias = hb + nheads_bias = bias.shape[1] else: nheads_bias = 1 @@ -921,8 +907,8 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca q, k, v, - mask if HAS_MASK else torch.empty(0, device=q.device, dtype=torch.bool), - bias if HAS_BIAS else torch.empty(0, device=q.device, dtype=q.dtype), + mask if has_mask else torch.empty(0, device=q.device, dtype=torch.bool), + bias if has_bias else torch.empty(0, device=q.device, dtype=q.dtype), o, lse, softmax_scale, @@ -935,12 +921,12 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca v.stride(0), v.stride(2), v.stride(1), - mask.stride(0) if HAS_MASK else 0, - mask.stride(1) if HAS_MASK else 0, - mask.stride(2) if HAS_MASK else 0, - bias.stride(0) if HAS_BIAS else 0, - bias.stride(1) if HAS_BIAS else 0, - bias.stride(2) if HAS_BIAS else 0, + mask.stride(0) if has_mask else 0, + mask.stride(1) if has_mask else 0, + mask.stride(2) if has_mask else 0, + bias.stride(0) if has_bias else 0, + bias.stride(1) if has_bias else 0, + bias.stride(2) if has_bias else 0, o.stride(0), o.stride(2), o.stride(1), @@ -956,10 +942,10 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=d, is_causal, - HAS_MASK, - HAS_BIAS, + has_mask, + has_bias, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, From 46f3a97f55974cbdcd2fb1265932dfd3349797e1 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 12:17:30 +0800 Subject: [PATCH 03/17] Enables mask grad and GQA in Triton backward Drops backward-path guards that blocked attention mask gradients and GQA/MQA head configs with the Triton backend. Expands training support without forcing CUDA fallback; relies on the underlying kernel for validation. --- flash_dmattn/flash_dmattn_triton.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index bb0b8e1..d8d14e3 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1114,13 +1114,6 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa @staticmethod def backward(ctx, do): query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" - # Backward for GQA/MQA (nheads_q != nheads_k) is not implemented for Triton - if query.shape[2] != key.shape[2]: - raise RuntimeError( - "Triton backward for GQA/MQA (nheads_q != nheads_k) is not implemented yet. " - "Use the CUDA backend for training or disable grad for Triton path." - ) dq, dk, dv, dbias = _flash_attn_backward( do, query, From 05a49898f53ffd6cd65e1030e548e6e39ecf0109 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 13:18:11 +0800 Subject: [PATCH 04/17] Enable Triton autotune; default mask/bias tensors Enables kernel autotuning with multiple tile configs and warp counts, keyed by sequence lengths, causality, mask/bias presence, and head dim for better performance and correct cache separation. Defaults missing mask/bias to empty tensors up front to simplify the call path and stabilize the kernel signature. Removes hardcoded launch params to defer selection to the autotuner; standardizes num_stages in configs. --- flash_dmattn/flash_dmattn_triton.py | 64 +++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index d8d14e3..f3f0726 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -6,15 +6,41 @@ import triton.language as tl -# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), -# # This config has a race condition when EVEN_M == False, disabling it for now. -# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), -# ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] -# ) +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=4, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_warps=8, + num_stages=1, + ), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'HAS_MASK', 'HAS_BIAS', 'BLOCK_HEADDIM'] +) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, @@ -883,6 +909,7 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca nheads_mask = mask.shape[1] else: nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) has_bias = bias is not None if has_bias: @@ -891,6 +918,7 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca nheads_bias = bias.shape[1] else: nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) softmax_scale = softmax_scale or 1.0 / math.sqrt(d) @@ -899,16 +927,16 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - BLOCK_M = 128 - BLOCK_N = 64 - num_warps = 4 if d <= 64 else 8 + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads_q) _fwd_kernel[grid]( q, k, v, - mask if has_mask else torch.empty(0, device=q.device, dtype=torch.bool), - bias if has_bias else torch.empty(0, device=q.device, dtype=q.dtype), + mask, + bias, o, lse, softmax_scale, @@ -947,10 +975,10 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_ca has_mask, has_bias, BLOCK_HEADDIM, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=2, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, ) return o, lse, softmax_scale # softmax_scale could have been updated From 7475bdedfa0dd8d929de532f6f6852db1081d69f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 13:27:20 +0800 Subject: [PATCH 05/17] Adds GQA/MQA and optional mask/bias handling Supports mixed Q/KV head counts (GQA/MQA) by validating divisibility, updating grid/shape logic to query heads, and passing head-count ratios to the kernel. Makes mask and bias truly optional: supplies empty tensors with zero strides and passes has_mask/has_bias flags to the kernel, removing strict stride/layout assumptions. Improves robustness with clearer assertions and compiles keys (is_causal/has_mask/has_bias/head dim), and standardizes forward to require explicit mask/bias args. --- flash_dmattn/flash_dmattn_triton.py | 77 +++++++++++++++++------------ 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index f3f0726..47578c8 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -891,7 +891,7 @@ def _bwd_kernel( ) -def _flash_attn_forward(q, k, v, mask=None, bias=None, softmax_scale=None, is_causal=False): +def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints batch, seqlen_q, nheads_q, d = q.shape _, seqlen_k, nheads_k, _ = k.shape @@ -989,23 +989,29 @@ def _flash_attn_backward( # Make sure that the last dimension is contiguous if do.stride(-1) != 1: do = do.contiguous() - batch, seqlen_q, nheads, d = q.shape - _, seqlen_k, _, _ = k.shape - # assert d in {16, 32, 64, 128} - assert d <= 128 - seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - assert lse.shape == (batch, nheads, seqlen_q_rounded) + batch, seqlen_q, nheads_q, d = q.shape + _, seqlen_k, nheads_k, dk = k.shape - # Use boolean mask consistently (True = keep). Ensure GPU + layout. - assert mask.dtype == torch.bool, "Only support bool mask" - assert mask.is_cuda - assert mask.dim() == 4 - assert mask.stride(-1) == 1 + assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads_q, seqlen_q_rounded) + + has_mask = mask is not None + if has_mask: + assert mask.dtype == torch.bool, "Only support bool mask" + nheads_mask = mask.shape[1] + else: + nheads_mask = 1 + mask = torch.empty(0, device=q.device, dtype=torch.bool) - assert bias.dtype in [q.dtype, torch.float] - assert bias.is_cuda - assert bias.dim() == 4 - assert bias.stride(-1) == 1 + has_bias = bias is not None + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + nheads_bias = bias.shape[1] + else: + nheads_bias = 1 + bias = torch.empty(0, device=q.device, dtype=q.dtype) softmax_scale = softmax_scale or 1.0 / math.sqrt(d) # dq_accum = torch.zeros_like(q, dtype=torch.float32) @@ -1014,10 +1020,10 @@ def _flash_attn_backward( # delta = torch.zeros_like(lse) dk = torch.empty_like(k) dv = torch.empty_like(v) - dbias = torch.empty_like(bias) + dbias = torch.empty_like(bias) if has_bias else torch.empty(0, device=q.device, dtype=q.dtype) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads_q) _bwd_preprocess_do_o_dot[grid]( o, do, @@ -1028,7 +1034,7 @@ def _flash_attn_backward( do.stride(0), do.stride(2), do.stride(1), - nheads, + nheads_q, seqlen_q, seqlen_q_rounded, d, @@ -1041,7 +1047,7 @@ def _flash_attn_backward( # num_warps = 4 grid = lambda META: ( triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads, + batch * nheads_q, ) _bwd_kernel[grid]( q, @@ -1066,12 +1072,12 @@ def _flash_attn_backward( v.stride(0), v.stride(2), v.stride(1), - mask.stride(0), - mask.stride(1), - mask.stride(2), - bias.stride(0), - bias.stride(1), - bias.stride(2), + mask.stride(0) if has_mask else 0, + mask.stride(1) if has_mask else 0, + mask.stride(2) if has_mask else 0, + bias.stride(0) if has_bias else 0, + bias.stride(1) if has_bias else 0, + bias.stride(2) if has_bias else 0, do.stride(0), do.stride(2), do.stride(1), @@ -1084,10 +1090,14 @@ def _flash_attn_backward( dv.stride(0), dv.stride(2), dv.stride(1), - dbias.stride(0), - dbias.stride(1), - dbias.stride(2), - nheads, + dbias.stride(0) if has_bias else 0, + dbias.stride(1) if has_bias else 0, + dbias.stride(2) if has_bias else 0, + nheads_q, + nheads_k, + nheads_mask, + nheads_bias, + nheads_q // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -1095,11 +1105,14 @@ def _flash_attn_backward( seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs - # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + # IS_CAUSAL=is_causal, HAS_MASK=has_mask, HAS_BIAS=has_bias, BLOCK_HEADDIM=BLOCK_HEADDIM, is_causal, + has_mask, + has_bias, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, - # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # BLOCK_M=BLOCK_M, + # BLOCK_N=BLOCK_N, # num_warps=num_warps, # num_stages=1, ) From 2f6ef9c54ae0d8ebd5baa451c4cc0489f0deeb60 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 15:33:54 +0800 Subject: [PATCH 06/17] Support GQA heads and bias/mask broadcast in bwd Enables backward pass with differing Q/K head counts (GQA/MQA) and broadcasting of mask/bias across 1, KV, or Q heads by remapping head offsets and conditioning pointer advances on feature presence. Introduces a compile-time switch for atomic accumulation that triggers under sequence parallelism or head-count mismatches to prevent write races, while avoiding atomics otherwise for performance. Improves correctness and flexibility across broader attention configurations. --- flash_dmattn/flash_dmattn_triton.py | 66 +++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 47578c8..9bfa3ef 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -730,6 +730,7 @@ def init_func(nargs): "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + "ATOMIC_ADD": lambda args: args["SEQUENCE_PARALLEL"] or (args["nheads_q"] != args["nheads_k"] and args["nheads_q"] != args["nheads_bias"]), } ) @triton.jit @@ -777,7 +778,11 @@ def _bwd_kernel( stride_dbb, stride_dbh, stride_dbm, - nheads, + nheads_q, + nheads_k, + nheads_mask, + nheads_bias, + h_h_k_ratio, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -785,31 +790,54 @@ def _bwd_kernel( CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads - # offset pointers for batch/head - Q += off_b * stride_qb + off_h * stride_qh - K += off_b * stride_kb + off_h * stride_kh - V += off_b * stride_vb + off_h * stride_vh - Mask += off_b * stride_mb + off_h * stride_mh - Bias += off_b * stride_bb + off_h * stride_bh - DO += off_b * stride_dob + off_h * stride_doh - DQ += off_b * stride_dqb + off_h * stride_dqh - DK += off_b * stride_dkb + off_h * stride_dkh - DV += off_b * stride_dvb + off_h * stride_dvh - DBias += off_b * stride_dbb + off_h * stride_dbh - # pointer to row-wise quantities in value-like data + off_b = off_hb // nheads_q + off_hq = off_hb % nheads_q + off_hk = off_hq // h_h_k_ratio + if HAS_MASK: + if nheads_mask == 1: + off_hmask = 0 + elif nheads_mask == nheads_k: + off_hmask = off_hk + else: + off_hmask = off_hq + if HAS_BIAS: + if nheads_bias == 1: + off_hbbias = 0 + elif nheads_bias == nheads_k: + off_hbbias = off_hk + else: + off_hbbias = off_hq + + # Advance offset pointers for batch and head + Q += off_b * stride_qb + off_hq * stride_qh + K += off_b * stride_kb + off_hk * stride_kh + V += off_b * stride_vb + off_hk * stride_vh + if HAS_MASK: + Mask += off_b * stride_mb + off_hmask * stride_mh + if HAS_BIAS: + Bias += off_b * stride_bb + off_hbbias * stride_bh + DO += off_b * stride_dob + off_hq * stride_doh + DQ += off_b * stride_dqb + off_hq * stride_dqh + DK += off_b * stride_dkb + off_hk * stride_dkh + DV += off_b * stride_dvb + off_hk * stride_dvh + if HAS_BIAS: + DBias += off_b * stride_dbb + off_hbbias * stride_dbh + # Advance pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): @@ -841,12 +869,14 @@ def _bwd_kernel( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=False, IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=ATOMIC_ADD, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) @@ -880,12 +910,14 @@ def _bwd_kernel( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD=True, IS_CAUSAL=IS_CAUSAL, + HAS_MASK=HAS_MASK, + HAS_BIAS=HAS_BIAS, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + ATOMIC_ADD=ATOMIC_ADD, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) From 4fe748eaa72709adf26adbd0478f84d84935e564 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 15:39:49 +0800 Subject: [PATCH 07/17] Adds atomic grad updates to fix bwd races Introduces a flag to perform atomic accumulation of gradients when tiles may contend, eliminating race conditions in the backward path. Retains fast masked stores for safe even/fully covered cases; applies atomic adds with appropriate masks otherwise. Improves numerical correctness for uneven M/N and variable head dims while preserving performance where possible. --- flash_dmattn/flash_dmattn_triton.py | 39 +++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 9bfa3ef..c1aa020 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -362,23 +362,40 @@ def _bwd_store_dk_dv( EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, ): # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) + if not ATOMIC_ADD: + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.atomic_add(dv_ptrs, dv) + tl.atomic_add(dk_ptrs, dk) + else: + tl.atomic_add(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.atomic_add(dk_ptrs, dk, mask=offs_d[None, :] < headdim) else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + if EVEN_HEADDIM: + tl.atomic_add(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.atomic_add(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.atomic_add(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.atomic_add(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) @triton.jit From d734df9740f2546d2fe9d6176200f292b2d413c9 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 16:32:07 +0800 Subject: [PATCH 08/17] Fixes head indexing and mask dtype handling Uses query head count for backward offsets to correct head/batch mapping, enabling GQA/MQA configurations without misindexing. Aligns masked score accumulation with the tensor dtype by using a zero literal that avoids unintended type promotion, improving numerical stability and performance across fp16/bf16. --- flash_dmattn/flash_dmattn_triton.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index c1aa020..28a2afa 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -216,7 +216,7 @@ def _fwd_kernel( if IS_CAUSAL: acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) if HAS_MASK: - acc_s += tl.where(mask, 0.0, float("-inf")) + acc_s += tl.where(mask, 0, float("-inf")) if HAS_BIAS: # Load bias @@ -315,7 +315,7 @@ def _bwd_preprocess_do_o_dot( stride_dob, stride_doh, stride_dom, - nheads, + nheads_q, seqlen_q, seqlen_q_rounded, headdim, @@ -324,8 +324,8 @@ def _bwd_preprocess_do_o_dot( ): start_m = tl.program_id(0) off_hb = tl.program_id(1) - off_b = off_hb // nheads - off_h = off_hb % nheads + off_b = off_hb // nheads_q + off_h = off_hb % nheads_q # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) From f464664aaa44e2038f9be8504ce638e7e5b2044c Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 19:13:16 +0800 Subject: [PATCH 09/17] Add Triton backward tests; drop contiguous copies Adds a comprehensive backward equivalence test suite validating Triton gradients against the Python prototype across many shapes and head dims, with accuracy and speed reporting. Enables the Triton test path via the test_type flag. Removes redundant .contiguous() calls before the Triton attention invocation to avoid extra copies and rely on stride-aware kernels, improving memory use and potential performance. Skips when Triton is unavailable and performs GPU memory cleanup between runs. --- benchmarks/backward_equivalence.py | 247 ++++++++++++++++++++++++++++- 1 file changed, 239 insertions(+), 8 deletions(-) diff --git a/benchmarks/backward_equivalence.py b/benchmarks/backward_equivalence.py index 2149a8e..7c34ba2 100644 --- a/benchmarks/backward_equivalence.py +++ b/benchmarks/backward_equivalence.py @@ -283,11 +283,9 @@ def dynamic_mask_attention_triton( attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv) # Ensure correct data types and memory layout for Triton function - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] - attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] - attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim] # Call the Triton implementation attn_outputs = triton_dmattn_func( @@ -729,6 +727,239 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95): return all_passed +def test_triton_backward_equivalence(accuracy_threshold=0.95): + """Test backward pass equivalence between Python prototype and Triton implementation.""" + print("\n" + "πŸš€" + "=" * 76 + "πŸš€") + print("πŸ”¬ Testing backward Pass Equivalence: Python Prototype vs Triton Implementation") + print("πŸš€" + "=" * 76 + "πŸš€") + + # Check if Triton implementation is available + if triton_dmattn_func is None: + print("❌ Triton implementation not available, skipping test.") + return False + + # Set random seed for reproducibility + torch.manual_seed(0) + + # Test different parameter configurations + # If you encounter NAN issues when running multiple configurations, try running a single configuration + # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) + test_configs = [ + # Head dim 32 + (1, 2, 1, 128, 128, 32, False), + (1, 2, 1, 128, 128, 32, True), + (1, 2, 1, 256, 256, 32, False), + (1, 2, 1, 256, 256, 32, True), + (1, 2, 1, 512, 512, 32, False), + (1, 2, 1, 512, 512, 32, True), + (1, 2, 1, 1024, 1024, 32, False), + (1, 2, 1, 1024, 1024, 32, True), + (1, 2, 1, 2048, 2048, 32, False), + (1, 2, 1, 2048, 2048, 32, True), + (1, 2, 1, 4096, 4096, 32, False), + (1, 2, 1, 4096, 4096, 32, True), + + # Head dim 64 + (1, 2, 1, 128, 128, 64, False), + (1, 2, 1, 128, 128, 64, True), + (1, 2, 1, 256, 256, 64, False), + (1, 2, 1, 256, 256, 64, True), + (1, 2, 1, 512, 512, 64, False), + (1, 2, 1, 512, 512, 64, True), + (1, 2, 1, 1024, 1024, 64, False), + (1, 2, 1, 1024, 1024, 64, True), + (1, 2, 1, 2048, 2048, 64, False), + (1, 2, 1, 2048, 2048, 64, True), + (1, 2, 1, 4096, 4096, 64, False), + (1, 2, 1, 4096, 4096, 64, True), + + # Head dim 96 + (1, 2, 1, 128, 128, 96, False), + (1, 2, 1, 128, 128, 96, True), + (1, 2, 1, 256, 256, 96, False), + (1, 2, 1, 256, 256, 96, True), + (1, 2, 1, 512, 512, 96, False), + (1, 2, 1, 512, 512, 96, True), + (1, 2, 1, 1024, 1024, 96, False), + (1, 2, 1, 1024, 1024, 96, True), + (1, 2, 1, 2048, 2048, 96, False), + (1, 2, 1, 2048, 2048, 96, True), + (1, 2, 1, 4096, 4096, 96, False), + (1, 2, 1, 4096, 4096, 96, True), + + # Head dim 128 + (1, 2, 1, 128, 128, 128, False), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, False), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 512, 512, 128, False), + (1, 2, 1, 512, 512, 128, True), + (1, 2, 1, 1024, 1024, 128, False), + (1, 2, 1, 1024, 1024, 128, True), + (1, 2, 1, 2048, 2048, 128, False), + (1, 2, 1, 2048, 2048, 128, True), + (1, 2, 1, 4096, 4096, 128, False), + (1, 2, 1, 4096, 4096, 128, True), + + # triton currently supports up to head dim 128 + ] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.bfloat16 + device_icon = "πŸ”₯" if device.type == "cuda" else "πŸ’»" + print(f"{device_icon} Using device: {device}") + + all_passed = True + + for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config + + # Progress indicator + progress_filled = "β–ˆ" * (i + 1) + progress_empty = "β–‘" * (len(test_configs) - i - 1) + progress_bar = f"[{progress_filled}{progress_empty}]" + + print(f"\nπŸ§ͺ Test configuration {i+1}/{len(test_configs)} {progress_bar}") + print(f" πŸ“Š batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") + print(f" πŸ“ query_len={query_len}, key_len={key_len}, head_dim={head_dim}") + print(f" πŸ”’ is_causal={is_causal}") + print(f" 🎯 Accuracy threshold: {accuracy_threshold*100:.1f}%") + + # Create random input data + query_states = torch.randn( + batch_size, num_heads, query_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + key_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + value_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=dtype, requires_grad=True + ) + attn_bias = torch.randn( + batch_size, num_kv_heads, query_len, key_len, + device=device, dtype=torch.bfloat16 + ) + cache_position = torch.arange(key_len - query_len, key_len, device=device) + causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + # Set scaling factor and keep window size + scaling = head_dim ** -0.5 + window_size = 10240 + + # Clone inputs for Python implementation + query_python = query_states.clone().detach().requires_grad_(True) + key_python = key_states.clone().detach().requires_grad_(True) + value_python = value_states.clone().detach().requires_grad_(True) + attn_bias_python = attn_bias.clone().detach().requires_grad_(True) + causal_mask_python = causal_mask.clone().detach() + + # Run Python implementation + start_time = time.time() + attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python( + query_python, key_python, value_python, + attn_bias_python, causal_mask_python, + scaling, window_size, is_causal + ) + torch.cuda.synchronize() + py_time = time.time() - start_time + + # Clone inputs for Triton implementation + query_triton = query_states.clone().detach().requires_grad_(True) + key_triton = key_states.clone().detach().requires_grad_(True) + value_triton = value_states.clone().detach().requires_grad_(True) + attn_bias_triton = attn_bias.clone().detach().requires_grad_(True) + causal_mask_triton = causal_mask.clone().detach() + + # Run Triton implementation + start_time = time.time() + attn_outputs_triton, dq_triton, dk_triton, dv_triton, dbias_triton = dynamic_mask_attention_triton( + query_triton, key_triton, value_triton, + attn_bias_triton, causal_mask_triton, + scaling, window_size, is_causal + ) + torch.cuda.synchronize() + triton_time = time.time() - start_time + + # Analyze outputs + print(f"\nπŸ” Analyzing differences between Python and Triton outputs:") + is_attn_output_close, max_attn_output_diff, mean_attn_output_diff = analyze_differences( + attn_outputs_python, attn_outputs_triton, accuracy_threshold + ) + + # Analyze dQ gradients + print(f"\nπŸ” Analyzing dQ gradients:") + is_dq_close, max_dq_diff, mean_dq_diff = analyze_differences( + dq_python, dq_triton, accuracy_threshold + ) + + # Analyze dK gradients + print(f"\nπŸ” Analyzing dK gradients:") + is_dk_close, max_dk_diff, mean_dk_diff = analyze_differences( + dk_python, dk_triton, accuracy_threshold + ) + + # Analyze dV gradients + print(f"\nπŸ” Analyzing dV gradients:") + is_dv_close, max_dv_diff, mean_dv_diff = analyze_differences( + dv_python, dv_triton, accuracy_threshold + ) + + # Analyze dBias gradients + print(f"\nπŸ” Analyzing dBias gradients:") + is_dbias_close, max_dbias_diff, mean_dbias_diff = analyze_differences( + dbias_python, dbias_triton, accuracy_threshold + ) + + # Report performance difference + speedup = py_time / triton_time if triton_time > 0 else float('inf') + print(f"\n⚑ Performance comparison:") + print(f" 🐍 Python implementation: {py_time*1000:.2f} ms") + print(f" πŸš€ Triton implementation: {triton_time*1000:.2f} ms") + print(f" πŸ“ˆ Speedup: {speedup:.2f}x") + + # Check if all gradients pass + is_close = (is_attn_output_close and is_dq_close and is_dk_close and is_dv_close and is_dbias_close) + test_result = "Passed" if is_close else "Failed" + result_icon = "βœ…" if is_close else "❌" + all_passed = all_passed and is_close + print(f"\n{result_icon} Test result: {test_result}") + + # If test fails with large difference, can exit early + if not is_close and max_attn_output_diff > 1e-2: + print(" ⚠️ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dq_diff > 1e-2: + print(" ⚠️ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dk_diff > 1e-2: + print(" ⚠️ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dv_diff > 1e-2: + print(" ⚠️ Difference too large, stopping subsequent tests.") + break + if not is_close and max_dbias_diff > 1e-2: + print(" ⚠️ Difference too large, stopping subsequent tests.") + break + del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_triton, dk_triton, dv_triton, dbias_triton + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + print("\n" + "🏁" + "=" * 76 + "🏁") + summary_icon = "πŸŽ‰" if all_passed else "😞" + print(f"{summary_icon} Backward Equivalence Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}") + print("🏁" + "=" * 76 + "🏁") + + return all_passed + def main(): """ Test backward pass equivalence between Python prototype and various implementations @@ -782,9 +1013,9 @@ def main(): print("\n" + "πŸ“" + " Starting Python vs CUDA Backward Tests " + "πŸ“") test_results['cuda'] = test_cuda_backward_equivalence(args.accuracy_threshold) - # if args.test_type in ['all', 'triton']: - # print("\n" + "πŸ”₯" + " Starting Python vs Triton Backward Tests " + "πŸ”₯") - # test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold) + if args.test_type in ['all', 'triton']: + print("\n" + "πŸ”₯" + " Starting Python vs Triton Backward Tests " + "πŸ”₯") + test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold) # if args.test_type in ['all', 'flex']: # print("\n" + "🌟" + " Starting Python vs Flex Attention Backward Tests " + "🌟") From caa433527caf17fe39ad84032e704c11b706bc1c Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 29 Oct 2025 23:57:39 +0800 Subject: [PATCH 10/17] Fixes GQA/broadcast bwd; adds bias grad + padding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improves correctness and stability of the backward path with GQA/MQA and broadcasted mask/bias: - Unifies head indexing and fixes DK/DV/DBias head offsets; expands and reduces grads when head counts differ - Adds broadcast-aware strides and rounded K length; pads head dim to 8 for aligned 16‑bit storage and crops after - Reworks bwd column kernel to handle optional mask/bias, accumulate bias grad to avoid atomics, and inline safe stores - Enables sequence-parallel config and introduces dbias accumulation heuristic for better performance Also refines masking/bias application and removes race-prone barriers, addressing correctness for uneven shapes and broadcasts. --- flash_dmattn/flash_dmattn_triton.py | 687 ++++++++++++++-------------- 1 file changed, 354 insertions(+), 333 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 28a2afa..adcf9ba 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -76,7 +76,7 @@ def _fwd_kernel( stride_ob, stride_oh, stride_om, - nheads_q, + nheads, nheads_k, nheads_mask, nheads_bias, @@ -99,8 +99,8 @@ def _fwd_kernel( ): start_m = tl.program_id(0) off_hb = tl.program_id(1) - off_b = off_hb // nheads_q - off_hq = off_hb % nheads_q + off_b = off_hb // nheads + off_hq = off_hb % nheads off_hk = off_hq // h_h_k_ratio if HAS_MASK: if nheads_mask == 1: @@ -315,7 +315,7 @@ def _bwd_preprocess_do_o_dot( stride_dob, stride_doh, stride_dom, - nheads_q, + nheads, seqlen_q, seqlen_q_rounded, headdim, @@ -324,8 +324,8 @@ def _bwd_preprocess_do_o_dot( ): start_m = tl.program_id(0) off_hb = tl.program_id(1) - off_b = off_hb // nheads_q - off_h = off_hb % nheads_q + off_b = off_hb // nheads + off_h = off_hb % nheads # Initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) @@ -349,55 +349,6 @@ def _bwd_preprocess_do_o_dot( tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) -@triton.jit -def _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M: tl.constexpr, - EVEN_N: tl.constexpr, - EVEN_HEADDIM: tl.constexpr, - ATOMIC_ADD: tl.constexpr, -): - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if not ATOMIC_ADD: - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - else: - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.atomic_add(dv_ptrs, dv) - tl.atomic_add(dk_ptrs, dk) - else: - tl.atomic_add(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.atomic_add(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.atomic_add(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.atomic_add(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.atomic_add(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.atomic_add(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - - @triton.jit def _bwd_kernel_one_col_block( start_n, @@ -427,32 +378,37 @@ def _bwd_kernel_one_col_block( seqlen_q, seqlen_k, headdim, - ATOMIC_ADD: tl.constexpr, IS_CAUSAL: tl.constexpr, + HAS_MASK: tl.constexpr, + HAS_BIAS: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M - # initialize row/col offsets + # Initialize row/col offsets offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) - # initialize pointers to value-like data + # Initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) - m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) - b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + if HAS_MASK: + m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) + if HAS_BIAS: + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) - # initialize dv and dk + # Initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # There seems to be some problem with Triton pipelining that makes results wrong for @@ -461,24 +417,25 @@ def _bwd_kernel_one_col_block( if begin_m >= seqlen_q: dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) return - # k and v stay in SRAM throughout - # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.load(k_ptrs), we get the wrong output! - if EVEN_N & EVEN_M: + + # Load k and v, them will stay in SRAM throughout + if EVEN_N: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) @@ -496,218 +453,211 @@ def _bwd_kernel_one_col_block( v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) - # loop over rows + acc_dbias = tl.zeros([BLOCK_N], dtype=tl.float32) if (HAS_BIAS and ACCUM_DBIAS) else None + + # Loop over q and update accumulators num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m - # load q, k, v, do on-chip - # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) - if EVEN_M & EVEN_HEADDIM: - q = tl.load(q_ptrs) - else: - if EVEN_HEADDIM: - q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - q = tl.load( - q_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # recompute p = softmax(acc_s, dim=-1).T - acc_s = tl.dot(q, tl.trans(k)) - - tl.debug_barrier() - # Load mask - if EVEN_M & EVEN_N: - mask = tl.load(m_ptrs) - else: - mask = tl.load( - m_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), - other=False, - ) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s = tl.where(offs_n[None, :] < seqlen_k, acc_s, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf")) - # Apply dynamic mask (boolean mask: True = keep, False = mask-out) - acc_s = tl.where(mask, acc_s, float("-inf")) - - tl.debug_barrier() # Race condition otherwise - # Load bias - if EVEN_M & EVEN_N: - bias = tl.load( - b_ptrs, - mask=mask, - other=0.0, - ).to(tl.float32) - else: - bias = tl.load( - b_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & mask, - other=0.0, - ).to(tl.float32) - acc_s = acc_s * softmax_scale + bias - # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - # Also wrong for headdim=64. - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - lse_i = tl.load(LSE + offs_m_curr) - p = tl.exp(acc_s - lse_i[:, None]) - # compute dv - # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs - # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, - # the output is correct. - if EVEN_M & EVEN_HEADDIM: - do = tl.load(do_ptrs) - else: - # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. - do = tl.load( - do_ptrs, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - other=0.0, - ) - # if EVEN_M: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs) - # else: - # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - # else: - # if EVEN_HEADDIM: - # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - # else: - # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - # & (offs_d[None, :] < headdim), other=0.0) - dv += tl.dot(tl.trans(p.to(do.dtype)), do) - # compute dp = dot(v, do) - # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not (EVEN_M & EVEN_HEADDIM): - tl.debug_barrier() - dp = tl.dot(do, tl.trans(v)) - # There's a race condition for headdim=48 - if not EVEN_HEADDIM: - tl.debug_barrier() - # compute dbias = p * (dp - delta[:, None]) and ds = dbias * softmax_scale - # Putting the subtraction after the dp matmul (instead of before) is slightly faster - Di = tl.load(D + offs_m_curr) - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - dbias = (p * (dp - Di[:, None])) - ds = (dbias * softmax_scale).to(q.dtype) - # dbias = tl.where(mask > 0.0, dbias, 0.0) - # ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) - if not (EVEN_M & EVEN_N): - tl.debug_barrier() - if not ATOMIC_ADD: + + if HAS_MASK: + # Load mask if EVEN_M & EVEN_N: - tl.store( - db_ptrs, - dbias - ) + mask = tl.load(m_ptrs) else: - tl.store( - db_ptrs, - dbias, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) + mask = tl.load( + m_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=False, ) + + # Check if any element in mask is non-zero + any_active = tl.reduce_or(mask, axis=None) else: - if EVEN_M & EVEN_N: - tl.atomic_add( - db_ptrs, - dbias - ) - else: - tl.atomic_add( - db_ptrs, - dbias, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) - ) - # compute dk = dot(ds.T, q) - dk += tl.dot(tl.trans(ds), q) - # compute dq - if not ( - EVEN_M & EVEN_HEADDIM - ): # Otherewise there's a race condition - tl.debug_barrier() - if not ATOMIC_ADD: - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") + any_active = True + + # Skip this iteration if no active elements + if any_active: + # Load q + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) else: if EVEN_HEADDIM: - dq = tl.load( - dq_ptrs, - mask=offs_m_curr[:, None] < seqlen_q, - other=0.0, - eviction_policy="evict_last", - ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=offs_m_curr[:, None] < seqlen_q, - eviction_policy="evict_last", - ) + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: - dq = tl.load( - dq_ptrs, + q = tl.load( + q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, - eviction_policy="evict_last", ) - dq += tl.dot(ds, k) - tl.store( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - eviction_policy="evict_last", - ) - else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) - if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M - tl.atomic_add(dq_ptrs, dq) + + # Compute acc_s + acc_s = tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + + if HAS_BIAS: + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + + # Apply scaling and bias + acc_s = acc_s * softmax_scale + bias else: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + # Apply scaling + acc_s = acc_s * softmax_scale + + lse_i = tl.load(LSE + offs_m_curr) + # p = tl.exp(acc_s - lse_i[:, None]) + p = tl.exp(acc_s - tl.where(lse_i > float("-inf"), lse_i, 0.0)[:, None]) + + # Load do + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # There's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute dv + dv += tl.dot(tl.trans(p.to(do.dtype)), do) + + # Compute dp + dp = tl.dot(do, tl.trans(v)) + + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + + # Compute dbias + dbias = (p * (dp - Di[:, None])).to(q.dtype) + + # Write back + if not (EVEN_M & EVEN_N): + tl.debug_barrier() + if HAS_BIAS: + if ACCUM_DBIAS: + acc_dbias += tl.sum(dbias, axis=0) else: - tl.atomic_add( - dq_ptrs, - dq, - mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), - ) - # increment pointers - do_ptrs += BLOCK_M * stride_dom - dq_ptrs += BLOCK_M * stride_dqm - db_ptrs += BLOCK_M * stride_dbm - q_ptrs += BLOCK_M * stride_qm - m_ptrs += BLOCK_M * stride_mm - b_ptrs += BLOCK_M * stride_bm + if EVEN_M & EVEN_N: + tl.store( + db_ptrs, + dbias, + ) + else: + tl.store( + db_ptrs, + dbias, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + ) + + # Compute ds + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (dbias * softmax_scale).to(q.dtype) + + # Compute dk + dk += tl.dot(tl.trans(ds), q) + + # Compute dq + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + + # Increment pointers + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + if HAS_BIAS: + db_ptrs += BLOCK_M * stride_dbm + q_ptrs += BLOCK_M * stride_qm + if HAS_MASK: + m_ptrs += BLOCK_M * stride_mm + if HAS_BIAS: + b_ptrs += BLOCK_M * stride_bm - # write-back + # Write back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - _bwd_store_dk_dv( - dk_ptrs, - dv_ptrs, - dk, - dv, - offs_n, - offs_d, - seqlen_k, - headdim, - EVEN_M=EVEN_M, - EVEN_N=EVEN_N, - EVEN_HEADDIM=EVEN_HEADDIM, - ) + if HAS_BIAS and ACCUM_DBIAS: + if EVEN_N: + tl.store(DBias + offs_n, acc_dbias) + else: + tl.store(DBias + offs_n, acc_dbias, mask=(offs_n < seqlen_k)) + + if EVEN_N: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) def init_to_zero(names): @@ -727,18 +677,12 @@ def init_func(nargs): num_stages=1, pre_hook=init_to_zero(["DQ", "DBias"]), ), - # triton.Config( - # {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, - # num_warps=8, - # num_stages=1, - # pre_hook=init_to_zero(["DQ", "DBias"]), - # ), - # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now - # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero(["DQ", "DBias"]), + ), ], key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], ) @@ -747,7 +691,7 @@ def init_func(nargs): "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], - "ATOMIC_ADD": lambda args: args["SEQUENCE_PARALLEL"] or (args["nheads_q"] != args["nheads_k"] and args["nheads_q"] != args["nheads_bias"]), + "ACCUM_DBIAS": lambda args: args["HAS_BIAS"] and (args["stride_dbm"] == 0) and (args["seqlen_q"] > 1), } ) @triton.jit @@ -795,7 +739,7 @@ def _bwd_kernel( stride_dbb, stride_dbh, stride_dbm, - nheads_q, + nheads, nheads_k, nheads_mask, nheads_bias, @@ -814,13 +758,13 @@ def _bwd_kernel( EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, - ATOMIC_ADD: tl.constexpr, + ACCUM_DBIAS: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) - off_b = off_hb // nheads_q - off_hq = off_hb % nheads_q + off_b = off_hb // nheads + off_hq = off_hb % nheads off_hk = off_hq // h_h_k_ratio if HAS_MASK: if nheads_mask == 1: @@ -847,10 +791,10 @@ def _bwd_kernel( Bias += off_b * stride_bb + off_hbbias * stride_bh DO += off_b * stride_dob + off_hq * stride_doh DQ += off_b * stride_dqb + off_hq * stride_dqh - DK += off_b * stride_dkb + off_hk * stride_dkh - DV += off_b * stride_dvb + off_hk * stride_dvh + DK += off_b * stride_dkb + off_hq * stride_dkh + DV += off_b * stride_dvb + off_hq * stride_dvh if HAS_BIAS: - DBias += off_b * stride_dbb + off_hbbias * stride_dbh + DBias += off_b * stride_dbb + off_hq * stride_dbh # Advance pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded @@ -893,7 +837,8 @@ def _bwd_kernel( EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=ATOMIC_ADD, + ATOMIC_ADD=False, + ACCUM_DBIAS=ACCUM_DBIAS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) @@ -934,7 +879,8 @@ def _bwd_kernel( EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, - ATOMIC_ADD=ATOMIC_ADD, + ATOMIC_ADD=True, + ACCUM_DBIAS=ACCUM_DBIAS, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, ) @@ -942,10 +888,10 @@ def _bwd_kernel( def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints - batch, seqlen_q, nheads_q, d = q.shape + batch, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape - assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" @@ -972,14 +918,14 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False softmax_scale = softmax_scale or 1.0 / math.sqrt(d) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - lse = torch.empty((batch, nheads_q, seqlen_q_rounded), device=q.device, dtype=torch.float32) + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) # BLOCK_M = 128 # BLOCK_N = 64 # num_warps = 4 if d <= 64 else 8 - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads_q) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_kernel[grid]( q, k, @@ -998,20 +944,20 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False v.stride(0), v.stride(2), v.stride(1), - mask.stride(0) if has_mask else 0, - mask.stride(1) if has_mask else 0, - mask.stride(2) if has_mask else 0, - bias.stride(0) if has_bias else 0, - bias.stride(1) if has_bias else 0, - bias.stride(2) if has_bias else 0, + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), o.stride(0), o.stride(2), o.stride(1), - nheads_q, + nheads, nheads_k, nheads_mask, nheads_bias, - nheads_q // nheads_k, + nheads // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -1038,14 +984,15 @@ def _flash_attn_backward( # Make sure that the last dimension is contiguous if do.stride(-1) != 1: do = do.contiguous() - batch, seqlen_q, nheads_q, d = q.shape + batch, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, dk = k.shape - assert nheads_q % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" + assert nheads % nheads_k == 0, "Number of Q heads must be divisible by KV heads for GQA/MQA" assert d <= 128, "FlashDynamicMaskAttention only support head dimensions up to 128" seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 - assert lse.shape == (batch, nheads_q, seqlen_q_rounded) - + seqlen_k_rounded = math.ceil(seqlen_k / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + has_mask = mask is not None if has_mask: assert mask.dtype == torch.bool, "Only support bool mask" @@ -1071,8 +1018,25 @@ def _flash_attn_backward( dv = torch.empty_like(v) dbias = torch.empty_like(bias) if has_bias else torch.empty(0, device=q.device, dtype=q.dtype) + dk_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dk + dv_expanded = torch.empty(batch, seqlen_k, nheads, d, device=q.device, dtype=q.dtype) if nheads != nheads_k else dv + if has_bias: + if ( + nheads_bias != nheads + or ((bias.shape[0] == 1) and (batch > 1)) + or ((bias.shape[-2] == 1) and (seqlen_q > 1)) + ): + if bias.shape[-2] == 1: + dbias_expanded = torch.zeros(batch, nheads, 1, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = torch.zeros(batch, nheads, seqlen_q, seqlen_k_rounded, device=q.device, dtype=dbias.dtype) + else: + dbias_expanded = dbias + else: + dbias_expanded = dbias + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) - grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads_q) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _bwd_preprocess_do_o_dot[grid]( o, do, @@ -1083,7 +1047,7 @@ def _flash_attn_backward( do.stride(0), do.stride(2), do.stride(1), - nheads_q, + nheads, seqlen_q, seqlen_q_rounded, d, @@ -1096,7 +1060,7 @@ def _flash_attn_backward( # num_warps = 4 grid = lambda META: ( triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, - batch * nheads_q, + batch * nheads, ) _bwd_kernel[grid]( q, @@ -1106,9 +1070,9 @@ def _flash_attn_backward( bias, do, dq_accum, - dk, - dv, - dbias, + dk_expanded, + dv_expanded, + dbias_expanded, lse, delta, softmax_scale, @@ -1121,32 +1085,32 @@ def _flash_attn_backward( v.stride(0), v.stride(2), v.stride(1), - mask.stride(0) if has_mask else 0, - mask.stride(1) if has_mask else 0, - mask.stride(2) if has_mask else 0, - bias.stride(0) if has_bias else 0, - bias.stride(1) if has_bias else 0, - bias.stride(2) if has_bias else 0, + ((0 if (has_mask and mask.shape[0] == 1) else (mask.stride(0) if has_mask else 0))), + ((0 if (has_mask and mask.shape[1] == 1) else (mask.stride(1) if has_mask else 0))), + ((0 if (has_mask and mask.shape[2] == 1) else (mask.stride(2) if has_mask else 0))), + ((0 if (has_bias and bias.shape[0] == 1) else (bias.stride(0) if has_bias else 0))), + ((0 if (has_bias and bias.shape[1] == 1) else (bias.stride(1) if has_bias else 0))), + ((0 if (has_bias and bias.shape[2] == 1) else (bias.stride(2) if has_bias else 0))), do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), - dk.stride(0), - dk.stride(2), - dk.stride(1), - dv.stride(0), - dv.stride(2), - dv.stride(1), - dbias.stride(0) if has_bias else 0, - dbias.stride(1) if has_bias else 0, - dbias.stride(2) if has_bias else 0, - nheads_q, + dk_expanded.stride(0), + dk_expanded.stride(2), + dk_expanded.stride(1), + dv_expanded.stride(0), + dv_expanded.stride(2), + dv_expanded.stride(1), + (dbias_expanded.stride(0) if has_bias else 0), + (dbias_expanded.stride(1) if has_bias else 0), + ((0 if (has_bias and bias.shape[-2] == 1) else (dbias_expanded.stride(2) if has_bias else 0))), + nheads, nheads_k, nheads_mask, nheads_bias, - nheads_q // nheads_k, + nheads // nheads_k, seqlen_q, seqlen_k, seqlen_q_rounded, @@ -1166,6 +1130,24 @@ def _flash_attn_backward( # num_stages=1, ) dq = dq_accum.to(q.dtype) + if nheads != nheads_k: + dk = dk_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + dv = dv_expanded.view(batch, seqlen_k, nheads_k, nheads // nheads_k, d).sum(dim=3) + if has_bias: + if ( + nheads_bias != nheads + and bias.shape[0] == batch + and bias.shape[-2] == seqlen_q + ): + dbias = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + else: + if bias.shape[-2] == 1: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, 1, seqlen_k_rounded).sum(dim=2) + else: + dbias_expanded = dbias_expanded.view(batch, nheads_bias, nheads // nheads_bias, seqlen_q, seqlen_k_rounded).sum(dim=2) + if bias.shape[0] == 1: + dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) + dbias.copy_(dbias_expanded) return dq, dk, dv, dbias @@ -1173,6 +1155,10 @@ def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def round_multiple(x, m): + return (x + m - 1) // m * m + + class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=False, softmax_scale=None): @@ -1188,6 +1174,25 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa # Make sure that the last dimension is contiguous query, key, value, attn_mask, attn_bias = [maybe_contiguous(x) for x in [query, key, value, attn_mask, attn_bias]] + + # Padding to multiple of 8 for 16-bit memory allocations + head_size_og = query.size(3) + if head_size_og % 8 != 0: + query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8]) + key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8]) + value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8]) + seqlen_k_rounded = round_multiple(key.shape[1], 128) + if attn_mask is not None and attn_mask.shape[-1] != seqlen_k_rounded: + if attn_mask.shape[-1] == 1: + attn_mask = attn_mask.expand(*attn_mask.shape[:-1], seqlen_k_rounded) + else: + attn_mask = torch.nn.functional.pad(attn_mask, [0, seqlen_k_rounded - attn_mask.shape[-1]]) + if attn_bias is not None and attn_bias.shape[-1] != seqlen_k_rounded: + if attn_bias.shape[-1] == 1: + attn_bias = attn_bias.expand(*attn_bias.shape[:-1], seqlen_k_rounded) + else: + attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) + o, lse, ctx.softmax_scale = _flash_attn_forward( query, key, @@ -1199,13 +1204,20 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa ) ctx.save_for_backward(query, key, value, o, lse, attn_mask, attn_bias) ctx.is_causal = is_causal + ctx.seqlen_k_bias_og = attn_bias.shape[-1] if attn_bias is not None else 0 return o @staticmethod def backward(ctx, do): query, key, value, o, lse, attn_mask, attn_bias = ctx.saved_tensors + + head_size_og = do.size(3) + do_padded = do + if head_size_og % 8 != 0: + do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) + dq, dk, dv, dbias = _flash_attn_backward( - do, + do_padded, query, key, value, @@ -1216,6 +1228,15 @@ def backward(ctx, do): softmax_scale=ctx.softmax_scale, is_causal=ctx.is_causal, ) + + # We could have padded the head dimension + dq = dq[..., : do.shape[-1]] + dk = dk[..., : do.shape[-1]] + dv = dv[..., : do.shape[-1]] + + if dbias is not None: + dbias = dbias[..., :key.shape[1]].sum(dim=-1, keepdim=True) if ctx.seqlen_k_bias_og == 1 else dbias[..., : key.shape[1]] + return dq, dk, dv, None, dbias, None, None From 35d5edba34a47a751626e7c5720a1b71bbabb9d9 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 31 Oct 2025 19:08:24 +0800 Subject: [PATCH 11/17] Pre-scales queries; refactors logits accumulation Moves softmax scaling from logits to queries to avoid per-block scaling and keep bias unscaled. Initializes logits with bias, then adds the dot product and applies masks afterward. Improves numerical stability and likely reduces instruction count in the forward path. --- flash_dmattn/flash_dmattn_triton.py | 34 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index adcf9ba..139d070 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -161,6 +161,9 @@ def _fwd_kernel( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) + # Scale q + q = (q * softmax_scale).to(q.dtype) + # Loop over k, v and update accumulator end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): @@ -205,19 +208,6 @@ def _fwd_kernel( other=0.0, ) - # Compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - acc_s += tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the three masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) - if HAS_MASK: - acc_s += tl.where(mask, 0, float("-inf")) - if HAS_BIAS: # Load bias if EVEN_M & EVEN_N: @@ -230,12 +220,20 @@ def _fwd_kernel( other=0.0, ).to(tl.float32) - # Apply scaling and bias - acc_s = acc_s * softmax_scale + bias - else: - # Apply scaling - acc_s = acc_s * softmax_scale + # Compute acc_s + acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) + # Compute p m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) p = tl.exp(acc_s - m_ij[:, None]) l_ij = tl.sum(p, 1) From 4291d01a07bce3c8e6cc7ff8460344f3170e6ee5 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 31 Oct 2025 19:47:13 +0800 Subject: [PATCH 12/17] Fuses softmax scale into K; fixes dQ Moves softmax scaling from logits to keys to reduce redundant multiplies and improve numerical stability/perf. Reworks score accumulation to start from bias, then matmul, then masking, preserving masking semantics while simplifying flow. Introduces the reciprocal scale and uses it when accumulating the query gradients so gradients are computed with unscaled keys. --- flash_dmattn/flash_dmattn_triton.py | 44 +++++++++++++++-------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 139d070..e12fdc3 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -409,6 +409,8 @@ def _bwd_kernel_one_col_block( # Initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # Initialize softmax unscale factor + softmax_unscale = 1.0 / softmax_scale # There seems to be some problem with Triton pipelining that makes results wrong for # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, # and pipelining with the bias matrix could screw it up. So we just exit early. @@ -451,6 +453,11 @@ def _bwd_kernel_one_col_block( v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) + + # Scale k + k = (k * softmax_scale).to(k.dtype) + + # Initialize accumulator for dbias if needed acc_dbias = tl.zeros([BLOCK_N], dtype=tl.float32) if (HAS_BIAS and ACCUM_DBIAS) else None # Loop over q and update accumulators @@ -490,18 +497,6 @@ def _bwd_kernel_one_col_block( other=0.0, ) - # Compute acc_s - acc_s = tl.dot(q, tl.trans(k)) - - # Apply masks - # Trying to combine the three masks seem to make the result wrong - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) - if IS_CAUSAL: - acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) - if HAS_MASK: - acc_s += tl.where(mask, 0, float("-inf")) - if HAS_BIAS: # Load bias if EVEN_M & EVEN_N: @@ -513,11 +508,18 @@ def _bwd_kernel_one_col_block( other=0.0, ).to(tl.float32) - # Apply scaling and bias - acc_s = acc_s * softmax_scale + bias - else: - # Apply scaling - acc_s = acc_s * softmax_scale + # Compute acc_s + acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s += tl.dot(q, tl.trans(k)) + + # Apply masks + # Trying to combine the three masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where(offs_n[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + acc_s += tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), 0, float("-inf")) + if HAS_MASK: + acc_s += tl.where(mask, 0, float("-inf")) lse_i = tl.load(LSE + offs_m_curr) # p = tl.exp(acc_s - lse_i[:, None]) @@ -577,7 +579,7 @@ def _bwd_kernel_one_col_block( if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) + dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) tl.store(dq_ptrs, dq, eviction_policy="evict_last") else: if EVEN_HEADDIM: @@ -587,7 +589,7 @@ def _bwd_kernel_one_col_block( other=0.0, eviction_policy="evict_last", ) - dq += tl.dot(ds, k) + dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) tl.store( dq_ptrs, dq, @@ -601,7 +603,7 @@ def _bwd_kernel_one_col_block( other=0.0, eviction_policy="evict_last", ) - dq += tl.dot(ds, k) + dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) tl.store( dq_ptrs, dq, @@ -609,7 +611,7 @@ def _bwd_kernel_one_col_block( eviction_policy="evict_last", ) else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, k) + dq = tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M tl.atomic_add(dq_ptrs, dq) else: From 730f40eb16f55d5227c11c2320a6c1b614386a45 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Fri, 31 Oct 2025 20:00:50 +0800 Subject: [PATCH 13/17] Fix accumulator init with optional bias Initializes the score accumulator within the bias branch and zero-initializes otherwise. Prevents referencing an undefined bias when disabled and improves compilation stability in both forward and backward kernels. --- flash_dmattn/flash_dmattn_triton.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index e12fdc3..9430563 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -219,9 +219,11 @@ def _fwd_kernel( & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0, ).to(tl.float32) + acc_s = bias + else: + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # Compute acc_s - acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) acc_s += tl.dot(q, tl.trans(k)) # Apply masks @@ -507,9 +509,11 @@ def _bwd_kernel_one_col_block( mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), other=0.0, ).to(tl.float32) + acc_s = bias + else: + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # Compute acc_s - acc_s = bias if HAS_BIAS else tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) acc_s += tl.dot(q, tl.trans(k)) # Apply masks From cd60a833fbb325b915bc14334d8ece5f4b6ab84b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 3 Nov 2025 15:20:24 +0800 Subject: [PATCH 14/17] Include HAS_MASK/HAS_BIAS/HAS_INDICE in autotune key; tighten mask/bias dtype checks Add HAS_MASK, HAS_BIAS and HAS_INDICE to the autotune key to ensure different kernel configs are cached per mask/bias/indice usage. Also enforce bias dtype to match query dtype (only fp16/bf16) and standardize the mask dtype assert message. --- flash_dmattn/flash_dmattn_triton.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 9430563..dbb74d1 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -688,7 +688,7 @@ def init_func(nargs): pre_hook=init_to_zero(["DQ", "DBias"]), ), ], - key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "HAS_MASK", "HAS_BIAS", "HAS_INDICE", "BLOCK_HEADDIM"], ) @triton.heuristics( { @@ -903,7 +903,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False has_mask = mask is not None if has_mask: - assert mask.dtype == torch.bool, "Only support bool mask" + assert mask.dtype == torch.bool, "Only support bool" assert mask.is_cuda nheads_mask = mask.shape[1] else: @@ -912,7 +912,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False has_bias = bias is not None if has_bias: - assert bias.dtype in [q.dtype, torch.float] + assert bias.dtype == q.dtype, "Only support fp16 and bf16" assert bias.is_cuda nheads_bias = bias.shape[1] else: @@ -999,7 +999,7 @@ def _flash_attn_backward( has_mask = mask is not None if has_mask: - assert mask.dtype == torch.bool, "Only support bool mask" + assert mask.dtype == torch.bool, "Only support bool" nheads_mask = mask.shape[1] else: nheads_mask = 1 @@ -1007,7 +1007,7 @@ def _flash_attn_backward( has_bias = bias is not None if has_bias: - assert bias.dtype in [q.dtype, torch.float] + assert bias.dtype == q.dtype, "Only support fp16 and bf16" nheads_bias = bias.shape[1] else: nheads_bias = 1 From 53c34fa92541f50cace5f08d5c12d85fbcc9e531 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 13:39:26 +0800 Subject: [PATCH 15/17] Rebalances backward scaling Moves softmax scaling to the final dk update to cut register pressure and simplify accumulation. Aligns dq accumulation with unscaled k for more stable gradients. --- flash_dmattn/flash_dmattn_triton.py | 34 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index dbb74d1..5f83d63 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -411,8 +411,6 @@ def _bwd_kernel_one_col_block( # Initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) - # Initialize softmax unscale factor - softmax_unscale = 1.0 / softmax_scale # There seems to be some problem with Triton pipelining that makes results wrong for # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, # and pipelining with the bias matrix could screw it up. So we just exit early. @@ -549,33 +547,30 @@ def _bwd_kernel_one_col_block( # Putting the subtraction after the dp matmul (instead of before) is slightly faster Di = tl.load(D + offs_m_curr) - # Compute dbias - dbias = (p * (dp - Di[:, None])).to(q.dtype) - + # Compute ds + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None])).to(q.dtype) + # Write back if not (EVEN_M & EVEN_N): tl.debug_barrier() if HAS_BIAS: if ACCUM_DBIAS: - acc_dbias += tl.sum(dbias, axis=0) + acc_dbias += tl.sum(ds, axis=0) else: if EVEN_M & EVEN_N: tl.store( db_ptrs, - dbias, + ds, ) else: tl.store( db_ptrs, - dbias, + ds, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), ) - # Compute ds - # Converting ds to q.dtype here reduces register pressure and makes it much faster - # for BLOCK_HEADDIM=128 - ds = (dbias * softmax_scale).to(q.dtype) - # Compute dk dk += tl.dot(tl.trans(ds), q) @@ -583,7 +578,7 @@ def _bwd_kernel_one_col_block( if not ATOMIC_ADD: if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) + dq += tl.dot(ds, k).to(ds.dtype) tl.store(dq_ptrs, dq, eviction_policy="evict_last") else: if EVEN_HEADDIM: @@ -593,7 +588,7 @@ def _bwd_kernel_one_col_block( other=0.0, eviction_policy="evict_last", ) - dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) + dq += tl.dot(ds, k).to(ds.dtype) tl.store( dq_ptrs, dq, @@ -607,7 +602,7 @@ def _bwd_kernel_one_col_block( other=0.0, eviction_policy="evict_last", ) - dq += tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) + dq += tl.dot(ds, k).to(ds.dtype) tl.store( dq_ptrs, dq, @@ -615,7 +610,7 @@ def _bwd_kernel_one_col_block( eviction_policy="evict_last", ) else: # If we're parallelizing across the seqlen_k dimension - dq = tl.dot(ds, (k * softmax_unscale).to(ds.dtype)) + dq = tl.dot(ds, k).to(ds.dtype) if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M tl.atomic_add(dq_ptrs, dq) else: @@ -638,7 +633,10 @@ def _bwd_kernel_one_col_block( m_ptrs += BLOCK_M * stride_mm if HAS_BIAS: b_ptrs += BLOCK_M * stride_bm - + + # Scale dk + dk = (dk * softmax_scale).to(dk.dtype) + # Write back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) From b9e3eaa5d044010d445e47065117c5df76e5c12f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 5 Nov 2025 11:46:30 +0800 Subject: [PATCH 16/17] Rename _flash_attn_forward/_flash_attn_backward to _flash_dmattn_forward/_flash_dmattn_backward and update call sites --- flash_dmattn/flash_dmattn_triton.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 5f83d63..d89c125 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -888,7 +888,7 @@ def _bwd_kernel( ) -def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): +def _flash_dmattn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, nheads_k, _ = k.shape @@ -980,7 +980,7 @@ def _flash_attn_forward(q, k, v, mask, bias, softmax_scale=None, is_causal=False return o, lse, softmax_scale # softmax_scale could have been updated -def _flash_attn_backward( +def _flash_dmattn_backward( do, q, k, v, mask, bias, o, lse, softmax_scale=None, is_causal=False ): # Make sure that the last dimension is contiguous @@ -1195,7 +1195,7 @@ def forward(ctx, query, key, value, attn_mask=None, attn_bias=None, is_causal=Fa else: attn_bias = torch.nn.functional.pad(attn_bias, [0, seqlen_k_rounded - attn_bias.shape[-1]]) - o, lse, ctx.softmax_scale = _flash_attn_forward( + o, lse, ctx.softmax_scale = _flash_dmattn_forward( query, key, value, @@ -1218,7 +1218,7 @@ def backward(ctx, do): if head_size_og % 8 != 0: do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_og % 8]) - dq, dk, dv, dbias = _flash_attn_backward( + dq, dk, dv, dbias = _flash_dmattn_backward( do_padded, query, key, From b7deeba48b7b3850140fe63af557f7706176f74a Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Thu, 6 Nov 2025 15:11:28 +0800 Subject: [PATCH 17/17] Guards bias grad return Ensures backward only returns a bias gradient when bias exists, keeping the signature consistent for biasless calls. --- flash_dmattn/flash_dmattn_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index d89c125..66141cb 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1150,7 +1150,7 @@ def _flash_dmattn_backward( if bias.shape[0] == 1: dbias_expanded = dbias_expanded.sum(dim=0, keepdim=True) dbias.copy_(dbias_expanded) - return dq, dk, dv, dbias + return dq, dk, dv, dbias if has_bias else None def maybe_contiguous(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]: