diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 5d89419..3f4a86b 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -232,20 +232,20 @@ def dynamic_mask_attention_cuda( # Call the CUDA implementation using the mha_fwd function signature out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + result = flash_dma_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + active_mask, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 1d948ea..2db3e99 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -40,8 +40,8 @@ void set_params_fprop( const at::Tensor q, const at::Tensor k, const at::Tensor v, - const at::Tensor zoh, - const at::Tensor active_mask, + const at::Tensor attn_mask, + const at::Tensor attn_bias, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, @@ -65,32 +65,32 @@ void set_params_fprop( params.q_ptr = q.data_ptr(); params.k_ptr = k.data_ptr(); params.v_ptr = v.data_ptr(); - params.zoh_ptr = zoh.data_ptr(); - params.active_mask_ptr = active_mask.data_ptr(); + params.attn_mask_ptr = attn_mask.data_ptr(); + params.attn_bias_ptr = attn_bias.data_ptr(); params.o_ptr = out.data_ptr(); // All stride are in elements, not bytes. params.q_row_stride = q.stride(-3); params.k_row_stride = k.stride(-3); params.v_row_stride = v.stride(-3); - params.zoh_row_stride = zoh.stride(-2); - params.active_mask_row_stride = active_mask.stride(-2); + params.attn_mask_row_stride = attn_mask.stride(-2); + params.attn_bias_row_stride = attn_bias.stride(-2); params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); - params.zoh_head_stride = zoh.stride(-3); - params.active_mask_head_stride = active_mask.stride(-3); + params.attn_mask_head_stride = attn_mask.stride(-3); + params.attn_bias_head_stride = attn_bias.stride(-3); params.o_head_stride = out.stride(-2); - params.zoh_col_stride = zoh.stride(-1); - params.active_mask_col_stride = active_mask.stride(-1); + params.attn_mask_col_stride = attn_mask.stride(-1); + params.attn_bias_col_stride = attn_bias.stride(-1); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0); params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); - params.zoh_batch_stride = zoh.stride(0); - params.active_mask_batch_stride = active_mask.stride(0); + params.attn_mask_batch_stride = attn_mask.stride(0); + params.attn_bias_batch_stride = attn_bias.stride(0); params.o_batch_stride = out.stride(0); if (seqlenq_ngroups_swapped) { params.q_batch_stride *= seqlen_q; @@ -271,8 +271,8 @@ mha_fwd( at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) - const at::Tensor &zoh, // batch_size x num_heads_k x seqlen_q x seqlen_k - const at::Tensor &active_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &attn_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const float p_dropout, const float softmax_scale, @@ -295,10 +295,10 @@ mha_fwd( "FlashAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); - TORCH_CHECK(zoh.dtype() == q_dtype, "zoh must have the same dtype as inputs"); - TORCH_CHECK(active_mask.dtype() == q_dtype, "active_mask must have the same dtype as inputs"); + TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); + TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs"); - CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -335,8 +335,8 @@ mha_fwd( CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(attn_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(attn_bias, batch_size, num_heads_k, seqlen_q, seqlen_k); at::Tensor out; if (out_.has_value()) { @@ -379,7 +379,7 @@ mha_fwd( num_heads, num_heads_k, head_size, head_size_rounded, keep_window_size, - q, k, v, zoh, active_mask, out, + q, k, v, attn_mask, attn_bias, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, diff --git a/csrc/src/block_info.h b/csrc/src/block_info.h index b0ce7c5..d0758eb 100644 --- a/csrc/src/block_info.h +++ b/csrc/src/block_info.h @@ -36,15 +36,15 @@ struct BlockInfo { } template - __forceinline__ __device__ index_t zoh_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb - ) const { + __forceinline__ __device__ index_t attn_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; } template - __forceinline__ __device__ index_t active_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const { + __forceinline__ __device__ index_t attn_bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb + ) const { index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride; return offset; diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 277d5a3..64302af 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -45,27 +45,34 @@ struct QKV_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -struct ZOH_params { - void *__restrict__ zoh_ptr; // ZOH states tensor [batch_size, num_kv_heads, query_len, key_len] - void * __restrict__ active_mask_ptr; // Active mask tensor [batch_size, num_kv_heads, query_len, key_len] - - // The stride of the zero-hold states and active mask tensors. - index_t zoh_batch_stride; // Stride between batches of ZOH states - index_t active_mask_batch_stride; // Stride between batches of active mask - index_t zoh_head_stride; // Stride between heads of ZOH states - index_t active_mask_head_stride; // Stride between heads of active mask - index_t zoh_row_stride; // Stride between rows of ZOH states - index_t active_mask_row_stride; // Stride between rows of active mask - index_t zoh_col_stride; // Stride between columns of ZOH states - index_t active_mask_col_stride; // Stride between columns of active mask +struct Mask_params { + void * __restrict__ attn_mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention mask tensors. + index_t attn_mask_batch_stride; // Stride between batches of attention mask + index_t attn_mask_head_stride; // Stride between heads of attention mask + index_t attn_mask_row_stride; // Stride between rows of attention mask + index_t attn_mask_col_stride; // Stride between columns of attention mask // The keep window size. - int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k) + int keep_window_size; // Number of tokens to keep in top-k +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Bias_params { + void *__restrict__ attn_bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention bias tensor. + index_t attn_bias_batch_stride; // Stride between batches of attention bias + index_t attn_bias_head_stride; // Stride between heads of attention bias + index_t attn_bias_row_stride; // Stride between rows of attention bias + index_t attn_bias_col_stride; // Stride between columns of attention bias }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Flash_fwd_params : public QKV_params, public ZOH_params { +struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { // The O matrix (output). void * __restrict__ o_ptr; diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 42b78e0..f1823a1 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -179,23 +179,23 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{}) ); - Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb)), + Tensor mMask = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_mask_ptr) + binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zoh_head_stride, params.zoh_row_stride, params.zoh_col_stride) + make_stride(params.attn_mask_head_stride, params.attn_mask_row_stride, params.attn_mask_col_stride) ); - Tensor gZOH = local_tile( - mZOH(bidh / params.h_h_k_ratio, _, _), + Tensor gMask = local_tile( + mMask(bidh / params.h_h_k_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) - Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb)), + Tensor mBias = make_tensor( + make_gmem_ptr(reinterpret_cast(params.attn_bias_ptr) + binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.active_mask_head_stride, params.active_mask_row_stride, params.active_mask_col_stride) + make_stride(params.attn_bias_head_stride, params.attn_bias_row_stride, params.attn_bias_col_stride) ); - Tensor gActiveMask = local_tile( - mActiveMask(bidh / params.h_h_k_ratio, _, _), + Tensor gBias = local_tile( + mBias(bidh / params.h_h_k_ratio, _, _), Shape, Int>{}, make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) @@ -222,22 +222,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{} ); - Tensor sZOH = make_tensor( + Tensor sMask = make_tensor( sV.data() + size(sV), - typename Kernel_traits::SmemLayoutZOH{} + typename Kernel_traits::SmemLayoutMask{} ); - Tensor sActiveMask = make_tensor( - sZOH.data() + size(sZOH), - typename Kernel_traits::SmemLayoutActiveMask{} + Tensor sBias = make_tensor( + sMask.data() + size(sMask), + typename Kernel_traits::SmemLayoutBias{} ); // Golobal to Shared Memory operation typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyZOH gmem_tiled_copy_ZOH; - auto gmem_thr_copy_ZOH = gmem_tiled_copy_ZOH.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyActiveMask gmem_tiled_copy_AM; - auto gmem_thr_copy_AM = gmem_tiled_copy_AM.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask; + auto gmem_thr_copy_Mask = gmem_tiled_copy_Mask.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_Bias; + auto gmem_thr_copy_Bias = gmem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -245,19 +245,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tZOHgZOH = gmem_thr_copy_ZOH.partition_S(gZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N, nblocksN) - Tensor tZOHsZOH = gmem_thr_copy_ZOH.partition_D(sZOH); - Tensor tAMgAM = gmem_thr_copy_AM.partition_S(gActiveMask); // (AMCPY, AMCPY_M, AMCPY_N, nblocksN) - Tensor tAMsAM = gmem_thr_copy_AM.partition_D(sActiveMask); + Tensor tMaskgMask = gmem_thr_copy_Mask.partition_S(gMask); // (MaskCPY, MaskCPY_M, MaskCPY_N, nblocksN) + Tensor tMasksMask = gmem_thr_copy_Mask.partition_D(sMask); + Tensor tBiasgBias = gmem_thr_copy_Bias.partition_S(gBias); // (BiasCPY, BiasCPY_M, BiasCPY_N, nblocksN) + Tensor tBiassBias = gmem_thr_copy_Bias.partition_D(sBias); // Matrix Multiply Accumulate typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) - Tensor tSrZOH = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tSrAM = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) + Tensor tSrMask = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) + Tensor tSrBias = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSgS = thr_mma.partition_C(gP); Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) @@ -270,26 +270,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); - auto smem_tiled_copy_ZOH = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_ZOH = smem_tiled_copy_ZOH.get_thread_slice(tidx); - Tensor tSsZOH = smem_thr_copy_ZOH.partition_S(sZOH); - auto smem_tiled_copy_AM = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_AM = smem_tiled_copy_AM.get_thread_slice(tidx); - Tensor tSsAM = smem_thr_copy_AM.partition_S(sActiveMask); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); + Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); + auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); + Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); // PREDICATES // // Allocate predicate tensors for m and n // Tensor tQpQ = make_tensor(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); // Tensor tKVpKV = make_tensor(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor cZOH = make_identity_tensor(make_shape(size<0>(sZOH), size<1>(sZOH))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cAM = make_identity_tensor(make_shape(size<0>(sActiveMask), size<1>(sActiveMask))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M, BLK_K) -> (blk_m, blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N, BLK_K) -> (blk_n, blk_k) + Tensor cMask = make_identity_tensor(make_shape(size<0>(sMask), size<1>(sMask))); // (BLK_M, BLK_N) -> (blk_m, blk_n) + Tensor cBias = make_identity_tensor(make_shape(size<0>(sBias), size<1>(sBias))); // (BLK_M, BLK_N) -> (blk_m, blk_n) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA, MMA_M, MMA_K) // if (cute::thread0()) { // print(tScQ.layout()); printf("\n"); // for (int i = 0; i < size(tScQ); ++i) { @@ -302,10 +302,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // printf("\n"); // } // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - Tensor tZOHcZOH = gmem_thr_copy_ZOH.partition_S(cZOH); // (ZOHCPY, ZOHCPY_M, ZOHCPY_N) -> (blk_m, blk_n) - Tensor tAMcAM = gmem_thr_copy_AM.partition_S(cAM); // (AMCPY, AMCPY_M, AMCPY_N) -> (blk_m, blk_n) + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY, BCPY_N, BCPY_K) -> (blk_n, blk_k) + Tensor tMaskcMask = gmem_thr_copy_Mask.partition_S(cMask); // (MaskCPY, MaskCPY_M, MaskCPY_N) -> (blk_m, blk_n) + Tensor tBiascBias = gmem_thr_copy_Bias.partition_S(cBias); // (BiasCPY, BiasCPY_M, BiasCPY_N) -> (blk_m, blk_n) // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); @@ -351,19 +351,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN ); @@ -384,7 +384,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; // Init dynamic mask processor - FLASH_NAMESPACE::DynamicMask dynamic_mask( + FLASH_NAMESPACE::Mask mask( binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.keep_window_size ); @@ -407,11 +407,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - // Copy ZOH and ActiveMask from smem to registers - Tensor tSrZOH_copy_view = smem_thr_copy_ZOH.retile_D(tSrZOH); - cute::copy(smem_tiled_copy_ZOH, tSsZOH, tSrZOH_copy_view); - Tensor tSrAM_copy_view = smem_thr_copy_AM.retile_D(tSrAM); - cute::copy(smem_tiled_copy_AM, tSsAM, tSrAM_copy_view); + // Copy Mask and Bias from smem to registers + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); // Advance gV if (masking_step > 0) { @@ -428,7 +428,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -437,9 +437,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - // Scale attention scores and apply dynamic mask - dynamic_mask.template apply_mask( - acc_s, tSrZOH, tSrAM, params.scale_softmax, + // Scale attention scores and apply mask/bias + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); @@ -449,19 +449,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); @@ -500,7 +500,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); // if (cute::thread0()) { print(scores); } @@ -519,11 +519,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - // Copy ZOH and ActiveMask from smem to registers - Tensor tSrZOH_copy_view = smem_thr_copy_ZOH.retile_D(tSrZOH); - cute::copy(smem_tiled_copy_ZOH, tSsZOH, tSrZOH_copy_view); - Tensor tSrAM_copy_view = smem_thr_copy_AM.retile_D(tSrAM); - cute::copy(smem_tiled_copy_AM, tSsAM, tSrAM_copy_view); + // Copy Mask and Bias from smem to registers + Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); + cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrMask, // Active key mask for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -540,8 +540,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Scale attention scores and apply dynamic mask - dynamic_mask.template apply_mask( - acc_s, tSrZOH, tSrAM, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, tSrBias, params.scale_softmax, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); @@ -549,19 +549,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi __syncthreads(); if (n_block > n_block_min) { FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_ZOH, - tZOHgZOH(_, _, _, n_block - 1), - tZOHsZOH, - tZOHcZOH, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Mask, + tMaskgMask(_, _, _, n_block - 1), + tMasksMask, + tMaskcMask, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); - FLASH_NAMESPACE::copy_ZOH( - gmem_tiled_copy_AM, - tAMgAM(_, _, _, n_block - 1), - tAMsAM, - tAMcAM, + FLASH_NAMESPACE::copy_Mask( + gmem_tiled_copy_Bias, + tBiasgBias(_, _, _, n_block - 1), + tBiassBias, + tBiascBias, binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - (n_block - 1) * kBlockN ); @@ -597,7 +597,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrMask, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); } @@ -608,12 +608,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); - Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M, SMEM_N) // Partition sO to match the accumulator partitioning auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom, AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom, AtomNum), PIPE_M, PIPE_N) // sO has the same size as sQ, so we don't need to sync here. if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } @@ -634,7 +634,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom, AtomNum), ATOM_M, ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); __syncthreads(); @@ -642,12 +642,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M, BLK_K) -> (blk_m, blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA, MMA_M, MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { @@ -657,9 +657,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } // Construct identity layout for sO - Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M, BLK_K) -> (blk_m, blk_k) // Repeat the partitioning with identity layouts - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY, ACPY_M, ACPY_K) -> (blk_m, blk_k) Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); if (!Is_even_K) { #pragma unroll diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 19dc319..38b885d 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -82,11 +82,11 @@ struct Flash_fwd_kernel_traits : public Base { // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 Layout>, Stride, _1>>{})); - using SmemLayoutAtomZOH = decltype( + using SmemLayoutAtomMask = decltype( composition(Swizzle{}, Layout, Stride<_8, _1>>{})); - using SmemLayoutAtomActiveMask = decltype( + using SmemLayoutAtomBias = decltype( composition(Swizzle{}, Layout, Stride<_8, _1>>{})); @@ -104,11 +104,11 @@ struct Flash_fwd_kernel_traits : public Base { composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); - using SmemLayoutZOH = decltype(tile_to_shape( - SmemLayoutAtomZOH{}, + using SmemLayoutMask = decltype(tile_to_shape( + SmemLayoutAtomMask{}, Shape, Int>{})); - using SmemLayoutActiveMask = decltype(tile_to_shape( - SmemLayoutAtomActiveMask{}, + using SmemLayoutBias = decltype(tile_to_shape( + SmemLayoutAtomBias{}, Shape, Int>{})); // Shared memory layout for output @@ -125,10 +125,11 @@ struct Flash_fwd_kernel_traits : public Base { // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); - static constexpr int kSmemMaskSize = size(SmemLayoutZOH{}) * sizeof(Element) + size(SmemLayoutActiveMask{}) * sizeof(Element); + static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element); + static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element); - // Shared memory size with QKV matrices - static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize; + // Shared memory size with QKV matrices and mask/bias matrices + static constexpr int kSmemSize = (Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize) + kSmemMaskSize + kSmemBiasSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); @@ -153,11 +154,11 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyZOH = decltype( + using GmemTiledCopyMask = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 4 vals per read - using GmemTiledCopyActiveMask = decltype( + using GmemTiledCopyBias = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 4 vals per read diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 849d6e1..8a59444 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -21,11 +21,11 @@ namespace FLASH_NAMESPACE { using namespace cute; template -struct DynamicMask { +struct Mask { const int max_seqlen_k, max_seqlen_q; const int keep_window_size; - __forceinline__ __device__ DynamicMask( + __forceinline__ __device__ Mask( const int max_seqlen_k, const int max_seqlen_q, const int keep_window_size @@ -35,25 +35,25 @@ struct DynamicMask { , keep_window_size(keep_window_size) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) - ZOHType &tSrZOH, // ZOH states (MMA=4, MMA_M, MMA_N) - ActiveMaskType &tSrAM, // Active Mask (MMA=4, MMA_M, MMA_N) + MaskType &Mask, // Attention Mask (MMA=4, MMA_M, MMA_N) + BiasType &Bias, // Attention Bias (MMA=4, MMA_M, MMA_N) const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride ) { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); - static_assert(ZOHType::rank == 3, "tZOH must be 3D Tensor"); - static_assert(ActiveMaskType::rank == 3, "tActiveMask must be 3D Tensor"); + static_assert(MaskType::rank == 3, "Mask must be 3D Tensor"); + static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); const bool Need_masking = Causal_mask || !Is_even_MN || (keep_window_size < max_seqlen_k); // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); - Tensor zoh = make_tensor(tSrZOH.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrZOH.layout())); - Tensor active_mask = make_tensor(tSrAM.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrAM.layout())); + Tensor mask = make_tensor(Mask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Mask.layout())); + Tensor bias = make_tensor(Bias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(Bias.layout())); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -72,19 +72,19 @@ struct DynamicMask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - bool inactive = (col_idx >= col_idx_limit) || (active_mask(coord) <= 0.0f); + bool inactive = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f); if (inactive) { tensor(coord) = -INFINITY; } else { - // Apply scaling and zoh - tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); + // Apply scaling and bias + tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } } } } else { - // If no masking is needed, just scale the tensor and add zoh + // If no masking is needed, just scale the tensor and add bias #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { // const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -98,7 +98,7 @@ struct DynamicMask { for (int j = 0; j < size<1, 0>(tensor); ++j) { // const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - tensor(coord) = tensor(coord) * scale_softmax + zoh(coord); + tensor(coord) = tensor(coord) * scale_softmax + bias(coord); } } } diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 2e2df1b..3203b23 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -500,7 +500,7 @@ __forceinline__ __device__ void copy( template -__forceinline__ __device__ void copy_ZOH( +__forceinline__ __device__ void copy_Mask( TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, const int max_M=0, const int max_N=0