From f21e404585722ab2a669236feebde51055eed7d1 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:05:20 +0800 Subject: [PATCH 01/13] Adds default mask and bias initialization in forward pass Ensures mask and bias tensors are always properly initialized when not provided by the caller. Converts boolean masks to float tensors and creates default all-ones mask when none is specified. Initializes zero bias tensor when bias parameter is None. Updates contiguity check to include mask and bias tensors for consistent memory layout. --- flash_dmattn/flash_dmattn_triton.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index f273451..55cb377 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -353,16 +353,27 @@ class FlashAttnFunc(torch.autograd.Function): def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): """ q: (batch_size, seqlen_q, nheads, headdim) - k, v: (batch_size, seqlen_k, nheads, headdim) - mask: optional, shape (batch, nheads, seqlen_q, seqlen_k), dynamic attention mask - bias: optional, shape must be exactly (batch, nheads, seqlen_q, seqlen_k), attention bias matrix + k: (batch_size, seqlen_k, nheads, headdim) + v: (batch_size, seqlen_k, nheads, headdim) + mask: optional, (batch, nheads, seqlen_q, seqlen_k) + bias: optional, (batch, nheads, seqlen_q, seqlen_k) causal: bool, whether to apply causal masking softmax_scale: float, scaling factor for attention scores """ + batch, seqlen_q, nheads, _ = q.shape + _, seqlen_k, _, _ = k.shape + if mask is not None: + if mask.dtype == torch.bool: + mask = torch.where(mask, 1.0, 0.0) + else: + mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device) + if bias is None: + bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device) + # Make sure that the last dimension is contiguous - q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]] o, lse, ctx.softmax_scale = _flash_attn_forward( - q, k, v, mask=mask, bias=bias, causal=causal, softmax_scale=softmax_scale + q, k, v, mask, bias, causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, k, v, o, lse, mask, bias) ctx.causal = causal From 99239b2ea7364a3cd315c203ed3dadd6b42f7379 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:09:55 +0800 Subject: [PATCH 02/13] Enables backward pass for FlashDynamicMaskAttention Uncomments and updates the backward method to support gradient computation. Adds bias gradient computation with dbias tensor allocation and updates error message to reflect FlashDynamicMaskAttention functionality. Reorders function parameters to include mask and bias in the backward call. --- flash_dmattn/flash_dmattn_triton.py | 54 +++++++++++++++-------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 55cb377..b538b51 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -379,32 +379,34 @@ def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None ctx.causal = causal return o - # @staticmethod - # def backward(ctx, do): - # q, k, v, o, lse, mask, bias = ctx.saved_tensors - # assert not ctx.needs_input_grad[3], "FlashAttention does not support mask gradient yet" - # assert not ctx.needs_input_grad[4], "FlashAttention does not support bias gradient yet" - # # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd - # # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. - # with torch.inference_mode(): - # dq = torch.empty_like(q) - # dk = torch.empty_like(k) - # dv = torch.empty_like(v) - # _flash_attn_backward( - # do, - # q, - # k, - # v, - # o, - # lse, - # dq, - # dk, - # dv, - # bias=bias, - # causal=ctx.causal, - # softmax_scale=ctx.softmax_scale, - # ) - # return dq, dk, dv, None, None, None, None + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, mask, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbias = torch.empty_like(bias) + _flash_attn_backward( + do, + q, + k, + v, + mask, + bias, + o, + lse, + dq, + dk, + dv, + dbias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dk, dv, None, None, None, None flash_dmattn_func = FlashAttnFunc.apply From 224bd6450c17f6ff90b6c519189507aafe70a490 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:11:30 +0800 Subject: [PATCH 03/13] Adds backward pass implementation for flash attention Implements the backward propagation function for the flash attention mechanism, enabling gradient computation through the attention layers. The backward pass handles gradient computation for queries, keys, values, bias, and mask tensors with proper memory layout validation and stride checking. Includes preprocessing step for output gradients and uses Triton kernels for efficient backward computation with support for causal masking and custom softmax scaling. --- flash_dmattn/flash_dmattn_triton.py | 121 ++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index b538b51..d1533c1 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -348,6 +348,127 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_sca return o, lse, softmax_scale # softmax_scale could have been updated +def _flash_attn_backward( + do, q, k, v, mask, bias, o, lse, dq, dk, dv, dbias, causal=False, softmax_scale=None +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + + assert mask.dtype in [q.dtype, torch.float] + assert mask.is_cuda + assert mask.dim() == 4 + assert mask.stride(-1) == 1 + + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=64, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + mask, + bias, + do, + dq_accum, + dk, + dv, + dbias, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + mask.stride(0), + mask.stride(1), + mask.stride(2), + bias.stride(0), + bias.stride(1), + bias.stride(2), + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + dbias.stride(0), + dbias.stride(1), + dbias.stride(2), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): From 80ee701c178f965eee00b04bc9f452948c7bb209 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:12:56 +0800 Subject: [PATCH 04/13] Makes mask and bias parameters required in flash attention Removes optional default values for mask and bias parameters to enforce explicit passing of these tensors. Eliminates automatic creation of default all-ones mask and zero bias tensors, requiring callers to provide these inputs explicitly. Simplifies parameter validation logic by removing conditional null checks. --- flash_dmattn/flash_dmattn_triton.py | 32 +++++++++++------------------ 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index d1533c1..00d0470 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -258,7 +258,7 @@ def _fwd_kernel( ) -def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): +def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape @@ -269,26 +269,18 @@ def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_sca assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.is_cuda and k.is_cuda and v.is_cuda - if mask is not None: - assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" - assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" - assert mask.is_cuda, "mask must be on CUDA" - if mask.stride(-1) != 1: - mask = mask.contiguous() - else: - # Create a default mask of all ones - mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" + assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" + assert mask.is_cuda, "mask must be on CUDA" + if mask.stride(-1) != 1: + mask = mask.contiguous() - if bias is not None: - assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" - assert bias.is_cuda, "bias must be on CUDA" - assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" - assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" - if bias.stride(-1) != 1: - bias = bias.contiguous() - else: - # Create zero bias if none provided - bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" + assert bias.is_cuda, "bias must be on CUDA" + assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" + assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" + if bias.stride(-1) != 1: + bias = bias.contiguous() softmax_scale = softmax_scale or 1.0 / math.sqrt(d) From 11d0a18b8eb4521606d2308eb460c65a6892aa3d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:20:37 +0800 Subject: [PATCH 05/13] Adds backward pass kernel for flash attention Implements the backward pass computation with Triton autotuning support. Includes configuration options for block sizes and sequence parallelism with optimized settings for different scenarios. The kernel supports both sequential and parallel execution modes based on the SEQUENCE_PARALLEL flag. Adds proper memory stride handling and atomic operations for gradient accumulation in parallel mode. --- flash_dmattn/flash_dmattn_triton.py | 189 ++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 00d0470..ae56882 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -258,6 +258,195 @@ def _fwd_kernel( ) +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_mb, + stride_mh, + stride_mm, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dbiasb, + stride_dbiash, + stride_dbiasm, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + Mask += off_b * stride_mb + off_h * stride_mh + Bias += off_b * stride_bb + off_h * stride_bh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + DBias += off_b * stride_dbiasb + off_h * stride_dbiash + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbiasm, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbiasm, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + def _flash_attn_forward(q, k, v, mask, bias, causal=False, softmax_scale=None): # shape constraints batch, seqlen_q, nheads, d = q.shape From 97b155658164f63d3e5e17faf559c076966ce372 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:21:55 +0800 Subject: [PATCH 06/13] Adds utility function for tensor initialization Introduces a helper function that returns a lambda for zeroing tensors by name. This utility simplifies tensor initialization patterns in the codebase by providing a reusable function that can zero out named tensor arguments. --- flash_dmattn/flash_dmattn_triton.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index ae56882..7e773f5 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -258,6 +258,10 @@ def _fwd_kernel( ) +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + @triton.autotune( configs=[ triton.Config( From 31ca7802002b76c4775022291bd070a4f179f530 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:22:34 +0800 Subject: [PATCH 07/13] Adds backward preprocessing kernel for dot product computation Implements a new Triton kernel for the backward pass preprocessing step that computes the dot product between output and output gradients. This kernel calculates the delta values needed for the backward pass by performing element-wise multiplication and reduction along the head dimension, which is a common operation in attention mechanism gradients. --- flash_dmattn/flash_dmattn_triton.py | 45 +++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 7e773f5..49891ce 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -258,6 +258,51 @@ def _fwd_kernel( ) +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + def init_to_zero(name): return lambda nargs: nargs[name].zero_() From 3e9b96684b0ce4848562a93f35760613578b693c Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 11:23:09 +0800 Subject: [PATCH 08/13] Adds helper function for storing gradients safely Introduces a dedicated function to handle storing dk and dv gradients with proper masking logic to prevent race conditions. The function handles different combinations of EVEN_M, EVEN_N, and EVEN_HEADDIM flags to apply appropriate masks during tensor stores, addressing a known race condition bug when certain dimension conditions are met. --- flash_dmattn/flash_dmattn_triton.py | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 49891ce..9e80674 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -303,6 +303,38 @@ def _bwd_preprocess_do_o_dot( tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + def init_to_zero(name): return lambda nargs: nargs[name].zero_() From 5be7ee8c190e17c30f57142eb2d2661064c25d01 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 21:00:41 +0800 Subject: [PATCH 09/13] Fixes dtype consistency for mask and bias tensors Ensures mask and bias tensors match the dtype of query tensor when created as defaults. Prevents potential type mismatches that could cause runtime errors or unexpected behavior in attention computations. --- flash_dmattn/flash_dmattn_triton.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 9e80674..7fa32b3 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -749,9 +749,9 @@ def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None if mask.dtype == torch.bool: mask = torch.where(mask, 1.0, 0.0) else: - mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device) + mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) if bias is None: - bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device) + bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) # Make sure that the last dimension is contiguous q, k, v, mask, bias = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, mask, bias]] From 1cb291ab7a85819cb160c785ed649bc37f655c08 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 21:03:31 +0800 Subject: [PATCH 10/13] Renames bias stride parameters for consistency Standardizes parameter naming convention by shortening bias-related stride variable names from `stride_dbias*` to `stride_db*` format. Improves code readability and maintains consistency with other stride parameter naming patterns throughout the kernel function. --- flash_dmattn/flash_dmattn_triton.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 7fa32b3..90d3d6b 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -411,9 +411,9 @@ def _bwd_kernel( stride_dvb, stride_dvh, stride_dvn, - stride_dbiasb, - stride_dbiash, - stride_dbiasm, + stride_dbb, + stride_dbh, + stride_dbm, nheads, seqlen_q, seqlen_k, @@ -443,7 +443,7 @@ def _bwd_kernel( DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh - DBias += off_b * stride_dbiasb + off_h * stride_dbiash + DBias += off_b * stride_dbb + off_h * stride_dbh # pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded @@ -474,7 +474,7 @@ def _bwd_kernel( stride_dqm, stride_dkn, stride_dvn, - stride_dbiasm, + stride_dbm, seqlen_q, seqlen_k, headdim, @@ -513,7 +513,7 @@ def _bwd_kernel( stride_dqm, stride_dkn, stride_dvn, - stride_dbiasm, + stride_dbm, seqlen_q, seqlen_k, headdim, From 65210fa19b79a8f5bbbc48dbfc3b9bbb375a57e0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 21:04:11 +0800 Subject: [PATCH 11/13] Removes conditional computation to fix Triton compilation Eliminates runtime conditional checks that prevented Triton from properly optimizing control flow at compile time. Moves key and value loading outside of conditional blocks to ensure consistent execution paths, which allows the compiler to make better optimization decisions. Removes the any_active check that was causing dynamic branching issues and simplifies the masking logic to use compile-time determinable conditions. --- flash_dmattn/flash_dmattn_triton.py | 113 ++++++++++++++-------------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 90d3d6b..02fb401 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -12,7 +12,7 @@ # # This config has a race condition when EVEN_M == False, disabling it for now. # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), # ], -# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] # ) @triton.heuristics( { @@ -120,6 +120,38 @@ def _fwd_kernel( end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + + # Load k + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # compute acc_s + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + acc_s += tl.dot(q, tl.trans(k)) + + # Trying to combine the two masks seem to make the result wrong + # Apply sequence length mask + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + # Apply causal mask + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) # Load mask if EVEN_M & EVEN_N: @@ -132,41 +164,11 @@ def _fwd_kernel( ) # Check if any element in mask is non-zero - any_active = tl.sum(mask) > 0 - - # compute acc_s - acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if any_active: - # Load k - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - k = tl.load(k_ptrs + start_n * stride_kn) - else: - k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_s += tl.dot(q, tl.trans(k)) - - # Trying to combine the two masks seem to make the result wrong - # Apply sequence length mask - if not EVEN_N: # Need to mask out otherwise the softmax is wrong - acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) - # Apply causal mask - if IS_CAUSAL: - acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + # BUG: Triton needs to determine the control flow at compile time. + # Dynamic conditions at runtime can undermine this optimization. + # any_active = tl.sum(mask) != 0 # Apply dynamic mask - acc_s += tl.where(mask > 0, 0, float("-inf")) + acc_s += tl.where(mask > 0.0, 0.0, float("-inf")) # Load bias if EVEN_M & EVEN_N: @@ -184,7 +186,7 @@ def _fwd_kernel( # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. acc_s = acc_s * softmax_scale + bias - acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) + # acc_s = tl.where(acc_s != float("-inf"), acc_s * softmax_scale + bias, acc_s) m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) p = tl.exp(acc_s - m_ij[:, None]) l_ij = tl.sum(p, 1) @@ -198,27 +200,26 @@ def _fwd_kernel( acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] - if any_active: - # load v - if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition - if EVEN_HEADDIM: - v = tl.load(v_ptrs + start_n * stride_vn) - else: - v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + # load v + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) else: - if EVEN_HEADDIM: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=(start_n + offs_n)[:, None] < seqlen_k, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vn, - mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), - other=0.0, - ) - acc_o += tl.dot(p.to(v.dtype), v) + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + acc_o += tl.dot(p.to(v.dtype), v) # update statistics m_i = m_ij From 231efeed2cc8708d91cb07b16628388f8ccf7922 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 21:10:39 +0800 Subject: [PATCH 12/13] Returns dbias gradient in backward pass Fixes the backward function to properly return the bias gradient instead of None, ensuring gradient computation flows correctly through the bias parameter during backpropagation. --- flash_dmattn/flash_dmattn_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 02fb401..254b242 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -790,7 +790,7 @@ def backward(ctx, do): causal=ctx.causal, softmax_scale=ctx.softmax_scale, ) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, dbias, None, None flash_dmattn_func = FlashAttnFunc.apply From 143e97647376e6a07c52561bfb4b31d4798a1adf Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 5 Jul 2025 21:15:02 +0800 Subject: [PATCH 13/13] Renames class to match function naming convention Updates class name from FlashAttnFunc to FlashDMAttnFunc for consistency with the flash_dmattn_func function name and module purpose. Also updates corresponding error message text to use the shortened "FlashDMAttn" naming convention. --- flash_dmattn/flash_dmattn_triton.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 254b242..41bfd28 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -732,7 +732,7 @@ def _flash_attn_backward( dq.copy_(dq_accum) -class FlashAttnFunc(torch.autograd.Function): +class FlashDMAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): """ @@ -766,7 +766,7 @@ def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None @staticmethod def backward(ctx, do): q, k, v, o, lse, mask, bias = ctx.saved_tensors - assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet" + assert not ctx.needs_input_grad[3], "FlashDMAttn does not support mask gradient yet" # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): @@ -793,4 +793,4 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -flash_dmattn_func = FlashAttnFunc.apply +flash_dmattn_func = FlashDMAttnFunc.apply