From 8f196ebe19d56d2685dd8c7bd87b35f76e540147 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 16:44:46 +0800 Subject: [PATCH 1/3] Adds bias gradient computation to backward kernel Implements bias gradient calculation in the backward pass by adding a new column-block kernel that computes DBias alongside existing DQ, DK, and DV gradients. Updates initialization function to support multiple tensor names and modifies autotuning configurations to initialize both DQ and DBias tensors. Includes extensive masking logic and memory access patterns to handle various sequence length and head dimension configurations while maintaining numerical stability. --- flash_dmattn/flash_dmattn_triton.py | 333 +++++++++++++++++++++++++++- 1 file changed, 325 insertions(+), 8 deletions(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 41bfd28..2f7fc3a 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -336,8 +336,325 @@ def _bwd_store_dk_dv( tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) -def init_to_zero(name): - return lambda nargs: nargs[name].zero_() +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Mask, + Bias, + DO, + DQ, + DK, + DV, + DBias, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_mm, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + stride_dbm, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + m_ptrs = Mask + (offs_qm[:, None] * stride_mm + offs_n[None, :]) + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + db_ptrs = DBias + (offs_qm[:, None] * stride_dbm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255). In this case the for loop may have zero step, + # and pipelining with the bias matrix could screw it up. So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load( + k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + v = tl.load( + v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(acc_s, dim=-1).T + acc_s = tl.dot(q, tl.trans(k)) + + tl.debug_barrier() + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs) + else: + mask = tl.load( + m_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ) + + # Trying to combine the two masks seem to make the result wrong + # Apply sequence length mask + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s = tl.where(offs_n[None, :] < seqlen_k, acc_s, float("-inf")) + # Apply causal mask + if IS_CAUSAL: + acc_s = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), acc_s, float("-inf")) + # Apply dynamic mask + acc_s = tl.where(mask > 0.0, acc_s, float("-inf")) + + tl.debug_barrier() # Race condition otherwise + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load( + b_ptrs, + mask=(mask > 0.0), + other=0.0, + ).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) & (mask > 0.0), + other=0.0, + ).to(tl.float32) + acc_s = acc_s * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + p = tl.exp(acc_s - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(tl.trans(p.to(do.dtype)), do) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, tl.trans(v)) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute dbias = p * (dp - delta[:, None]) and ds = dbias * softmax_scale + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + dbias = (p * (dp - Di[:, None])) + ds = (dbias * softmax_scale).to(q.dtype) + # dbias = tl.where(mask > 0.0, dbias, 0.0) + # ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + if not (EVEN_M & EVEN_N): + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_N: + tl.store( + db_ptrs, + dbias + ) + else: + tl.store( + db_ptrs, + dbias, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) + ) + else: + if EVEN_M & EVEN_N: + tl.atomic_add( + db_ptrs, + dbias + ) + else: + tl.atomic_add( + db_ptrs, + dbias, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k) + ) + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds), q) + # compute dq + if not ( + EVEN_M & EVEN_HEADDIM + ): # Otherewise there's a race condition + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + do_ptrs += BLOCK_M * stride_dom + dq_ptrs += BLOCK_M * stride_dqm + db_ptrs += BLOCK_M * stride_dbm + q_ptrs += BLOCK_M * stride_qm + m_ptrs += BLOCK_M * stride_mm + b_ptrs += BLOCK_M * stride_bm + + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(names): + if isinstance(names, str): + names = [names] + def init_func(nargs): + for name in names: + nargs[name].zero_() + return init_func @triton.autotune( @@ -346,20 +663,20 @@ def init_to_zero(name): {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, - pre_hook=init_to_zero("DQ"), + pre_hook=init_to_zero(["DQ", "DBias"]), ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, - pre_hook=init_to_zero("DQ"), + pre_hook=init_to_zero(["DQ", "DBias"]), ), # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), ], key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "IS_CAUSAL", "BLOCK_HEADDIM"], ) From 6ca3bae6c0e831ca33848708f1b41af17e7d6bcc Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 16:49:15 +0800 Subject: [PATCH 2/3] Renames function alias for clarity Updates the function alias to better reflect its Triton-based implementation, improving code readability and making the backend technology more explicit for developers. --- flash_dmattn/flash_dmattn_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py index 2f7fc3a..7d68c64 100644 --- a/flash_dmattn/flash_dmattn_triton.py +++ b/flash_dmattn/flash_dmattn_triton.py @@ -1110,4 +1110,4 @@ def backward(ctx, do): return dq, dk, dv, None, dbias, None, None -flash_dmattn_func = FlashDMAttnFunc.apply +triton_dmattn_func = FlashDMAttnFunc.apply From 6af00930eef977c52777ee41e73f0ce45af61848 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 16:49:32 +0800 Subject: [PATCH 3/3] Adds alias for flex attention forward function Creates a more convenient function name that follows the module's naming convention and improves code readability --- flash_dmattn/flash_dmattn_flex.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py index c5ce05b..a7b7e2a 100644 --- a/flash_dmattn/flash_dmattn_flex.py +++ b/flash_dmattn/flash_dmattn_flex.py @@ -61,4 +61,6 @@ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): attention_weights = attention_weights.to(value.dtype) attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attention_weights \ No newline at end of file + return attn_output, attention_weights + +flex_dmattn_func = flex_attention_forward \ No newline at end of file