From f0aad8614f96f76a2f3d5d5e7f068c2078e27817 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 17:45:08 +0800 Subject: [PATCH 1/5] Removes local attention and ALiBi support from backward kernel Simplifies the backward pass implementation by removing support for local attention patterns and Additive Linear Bias (ALiBi) features. Updates copyright to include new contributor and removes unnecessary template parameters and conditional logic blocks that handled these specialized attention mechanisms. Streamlines the kernel to focus on core functionality without the complexity of windowed attention and positional bias computation. --- csrc/src/flash_bwd_kernel.h | 53 ++++++++++--------------------------- 1 file changed, 14 insertions(+), 39 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 50af5f6..d251047 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2025, Jingze Shi and Tri Dao. ******************************************************************************/ #pragma once @@ -18,8 +18,6 @@ #include "mask.h" #include "dropout.h" -#include "alibi.h" - namespace FLASH_NAMESPACE { using namespace cute; @@ -77,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -101,9 +99,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (n_block * kBlockN >= binfo.actual_seqlen_k) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); - if (Is_local) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); - } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; @@ -313,9 +308,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = (!Is_causal && !Is_local) + int m_block_min = (!Is_causal) ? 0 - : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); + : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); // If not local, we're guaranteed that m_block_min <= m_block: // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. @@ -328,7 +323,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Otherwise we get wrong result for the case where we don't enter the for loop. // And we might read OOB elements from gQ and gdO. // This also covers the case where actual_seqlen_q == 0 - if ((Is_local || !Is_even_MN) && m_block < m_block_min) { + if ((!Is_even_MN) && m_block < m_block_min) { const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) @@ -410,8 +405,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply - // with V (which would be zero), we're fine. However, with ALiBi, we might modify these - // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + // with V (which would be zero), we're fine. // Tensor tKrK = make_fragment_like(tKsK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); @@ -449,9 +443,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); - for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) clear(acc_s); @@ -486,12 +477,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - // Alibi - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); - } - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // actual_seqlen_k, because acc_s would be some finite value for those indices. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, @@ -499,7 +484,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // However, it's possible that the values in acc_s are so large that they overflow // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. // So we need to mask out the elements beyond actual_seqlen_k. - if (!Is_causal && !Is_local) { + if (!Is_causal) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); @@ -517,16 +502,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } - } else if (Is_local) { - if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right - || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left - || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, AtomLayoutMS * 16, - params.window_size_left, params.window_size_right); - } - } // if (cute::thread(32, 0)) { print(scores); } @@ -794,7 +769,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -808,20 +783,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -831,7 +806,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } From b2d12ac8da1177e8b7c3147ad2a036bea4293b04 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 19:42:55 +0800 Subject: [PATCH 2/5] Adds mask and bias support to backward kernel traits Introduces shared memory layouts and copy operations for mask and bias tensors in the backward kernel configuration. Updates memory size calculations to account for the additional mask and bias storage requirements. Adds specialized copy atoms with 64-byte alignment for efficient mask and bias data transfers. --- csrc/src/kernel_traits.h | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 38b885d..281a2dd 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -261,6 +261,22 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + using SmemLayoutAtomMask = decltype( + composition(Swizzle{}, + Layout, + Stride<_8, _1>>{})); + using SmemLayoutMask = decltype(tile_to_shape( + SmemLayoutAtomMask{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomBias = decltype( + composition(Swizzle{}, + Layout, + Stride<_8, _1>>{})); + using SmemLayoutBias = decltype(tile_to_shape( + SmemLayoutAtomBias{}, + make_shape(Int{}, Int{}))); + // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; @@ -306,6 +322,7 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; + using SmemCopyAtomBias = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -313,11 +330,13 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); - static constexpr int kSmemSize = kSmemQdOSize + static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize + static constexpr int kSmemSize1colblock = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); @@ -354,6 +373,14 @@ struct Flash_bwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopyMask = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 4 vals per read + using GmemTiledCopyBias = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 4 vals per read using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row From b08c51394b8b7e21fc1164e273d2409bef3b9375 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 19:43:36 +0800 Subject: [PATCH 3/5] Renames dzoh variables to dbias for clarity Updates variable and parameter names from "dzoh" (dZeroHold) to "dbias" throughout the Flash backward parameters structure to better reflect their actual purpose as bias gradients. Improves code readability and maintains consistency with naming conventions. --- csrc/src/flash.h | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 64302af..1ccfd2a 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -167,18 +167,18 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par struct Flash_bwd_params : public Flash_fwd_params { - // The dO and dQKV and dZeroHold matrices. + // The dO and dQKV and dBias matrices. void *__restrict__ do_ptr; void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; - void *__restrict__ dzoh_ptr; + void *__restrict__ dbias_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; - void *__restrict__ dzoh_accum_ptr; + void *__restrict__ dbias_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ @@ -199,10 +199,10 @@ struct Flash_bwd_params : public Flash_fwd_params { index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; - index_t dzoh_batch_stride; - index_t dzoh_head_stride; - index_t dzoh_row_stride; - index_t dzoh_col_stride; + index_t dbias_batch_stride; + index_t dbias_head_stride; + index_t dbias_row_stride; + index_t dbias_col_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; From 124dff1da52f6be4960adb0ba53398626f6dab17 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 19:44:03 +0800 Subject: [PATCH 4/5] Adds attention mask and bias tensor support Introduces tensor definitions for attention mask, bias, and bias gradient to enable masked attention and bias computations in the backward kernel. Calculates proper memory offsets for mask and bias tensors based on batch, head, and block dimensions to ensure correct data access patterns during gradient computation. --- csrc/src/flash_bwd_kernel.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index d251047..9dee833 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -106,6 +106,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_mask = binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.attn_mask_head_stride + (m_block_max - 1) * kBlockM * params.attn_mask_row_stride + + n_block * kBlockN * params.attn_mask_col_stride; + const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride + + n_block * kBlockN * params.attn_bias_col_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) @@ -129,6 +135,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); + Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), + Shape, Int>{}, + make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride)); + Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride)); + Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.dbias_row_stride, params.dbias_col_stride)); Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, make_stride(params.do_row_stride, _1{})); From 543843c010c0660be8110e5e5c5523982c5ae9c7 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 2 Jul 2025 19:44:15 +0800 Subject: [PATCH 5/5] Fixes tensor reference in attention mask creation Corrects the tensor reference used for creating the attention mask when not using dynamic masking, ensuring consistent device and shape alignment with the attention bias tensor. --- benchmarks/benchmark_forward_equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 33a38e4..00dde63 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -73,7 +73,7 @@ def prepare_dynamic_mask( attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: - attn_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) + attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) return attn_bias, attn_mask