diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 33a38e4..00dde63 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -73,7 +73,7 @@ def prepare_dynamic_mask( attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: - attn_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) + attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) return attn_bias, attn_mask diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 64302af..1ccfd2a 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -167,18 +167,18 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par struct Flash_bwd_params : public Flash_fwd_params { - // The dO and dQKV and dZeroHold matrices. + // The dO and dQKV and dBias matrices. void *__restrict__ do_ptr; void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; - void *__restrict__ dzoh_ptr; + void *__restrict__ dbias_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; - void *__restrict__ dzoh_accum_ptr; + void *__restrict__ dbias_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ @@ -199,10 +199,10 @@ struct Flash_bwd_params : public Flash_fwd_params { index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; - index_t dzoh_batch_stride; - index_t dzoh_head_stride; - index_t dzoh_row_stride; - index_t dzoh_col_stride; + index_t dbias_batch_stride; + index_t dbias_head_stride; + index_t dbias_row_stride; + index_t dbias_col_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index 50af5f6..9dee833 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2025, Jingze Shi and Tri Dao. ******************************************************************************/ #pragma once @@ -18,8 +18,6 @@ #include "mask.h" #include "dropout.h" -#include "alibi.h" - namespace FLASH_NAMESPACE { using namespace cute; @@ -77,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -101,9 +99,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (n_block * kBlockN >= binfo.actual_seqlen_k) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); - if (Is_local) { - m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM)); - } const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; @@ -111,6 +106,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_mask = binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.attn_mask_head_stride + (m_block_max - 1) * kBlockM * params.attn_mask_row_stride + + n_block * kBlockN * params.attn_mask_col_stride; + const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb) + + (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride + + n_block * kBlockN * params.attn_bias_col_stride; const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) @@ -134,6 +135,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); + Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + row_offset_mask), + Shape, Int>{}, + make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride)); + Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride)); + Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.dbias_ptr) + row_offset_bias), + Shape, Int>{}, + make_stride(params.dbias_row_stride, params.dbias_col_stride)); Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, make_stride(params.do_row_stride, _1{})); @@ -313,9 +323,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = (!Is_causal && !Is_local) + int m_block_min = (!Is_causal) ? 0 - : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM); + : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); // If not local, we're guaranteed that m_block_min <= m_block: // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. @@ -328,7 +338,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Otherwise we get wrong result for the case where we don't enter the for loop. // And we might read OOB elements from gQ and gdO. // This also covers the case where actual_seqlen_q == 0 - if ((Is_local || !Is_even_MN) && m_block < m_block_min) { + if ((!Is_even_MN) && m_block < m_block_min) { const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) @@ -410,8 +420,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply - // with V (which would be zero), we're fine. However, with ALiBi, we might modify these - // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + // with V (which would be zero), we're fine. // Tensor tKrK = make_fragment_like(tKsK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); @@ -449,9 +458,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); - const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - FLASH_NAMESPACE::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); - for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) clear(acc_s); @@ -486,12 +492,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - // Alibi - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); - } - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // actual_seqlen_k, because acc_s would be some finite value for those indices. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, @@ -499,7 +499,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // However, it's possible that the values in acc_s are so large that they overflow // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. // So we need to mask out the elements beyond actual_seqlen_k. - if (!Is_causal && !Is_local) { + if (!Is_causal) { if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) { FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16); @@ -517,16 +517,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } - } else if (Is_local) { - if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right - || (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left - || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { - FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, AtomLayoutMS * 16, - params.window_size_left, params.window_size_right); - } - } // if (cute::thread(32, 0)) { print(scores); } @@ -794,7 +784,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -808,20 +798,20 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -831,7 +821,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 38b885d..281a2dd 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -261,6 +261,22 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + using SmemLayoutAtomMask = decltype( + composition(Swizzle{}, + Layout, + Stride<_8, _1>>{})); + using SmemLayoutMask = decltype(tile_to_shape( + SmemLayoutAtomMask{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomBias = decltype( + composition(Swizzle{}, + Layout, + Stride<_8, _1>>{})); + using SmemLayoutBias = decltype(tile_to_shape( + SmemLayoutAtomBias{}, + make_shape(Int{}, Int{}))); + // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; @@ -306,6 +322,7 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomdQ{}, make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; + using SmemCopyAtomBias = Copy_Atom, elem_type>; // Double buffer for sQ static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); @@ -313,11 +330,13 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); - static constexpr int kSmemSize = kSmemQdOSize + static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); - static constexpr int kSmemSize1colblock = kSmemQdOSize + static constexpr int kSmemSize1colblock = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); @@ -354,6 +373,14 @@ struct Flash_bwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopyMask = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 4 vals per read + using GmemTiledCopyBias = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 4 vals per read using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row