Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
typename Kernel_traits::SmemLayoutBias{}
);

// Golobal to Shared Memory operation
// Global to Shared Memory operation
Copy link

Copilot AI Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed spelling error from 'Golobal' to 'Global' in the comment.

Copilot uses AI. Check for mistakes.
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyMask gmem_tiled_copy_Mask;
Expand Down Expand Up @@ -273,10 +273,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma);
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma);
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);

Expand Down Expand Up @@ -351,15 +351,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::copy_Mask<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
gmem_tiled_copy_Mask,
tMaskgMask(_, _, _, n_block),
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::copy_Mask<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
gmem_tiled_copy_Bias,
tBiasgBias(_, _, _, n_block),
tBiassBias,
Expand Down Expand Up @@ -449,15 +449,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask(_, _, _, n_block - 1),
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias(_, _, _, n_block - 1),
tBiassBias,
Expand Down Expand Up @@ -549,15 +549,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads();
if (n_block > n_block_min) {
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask(_, _, _, n_block - 1),
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias(_, _, _, n_block - 1),
tBiassBias,
Expand Down Expand Up @@ -883,10 +883,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomMask{}, tiled_mma);
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomBias{}, tiled_mma);
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);

Expand Down Expand Up @@ -932,15 +932,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::copy_Mask<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
gmem_tiled_copy_Mask,
tMaskgMask,
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - n_block * kBlockN
);
FLASH_NAMESPACE::copy_Mask<Is_even_MN>(
FLASH_NAMESPACE::copy_MN<Is_even_MN>(
gmem_tiled_copy_Bias,
tBiasgBias,
tBiassBias,
Expand Down Expand Up @@ -1051,15 +1051,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask,
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias,
tBiassBias,
Expand Down Expand Up @@ -1159,15 +1159,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
tBiasgBias.data() = tBiasgBias.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.attn_bias_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.attn_bias_col_stride;
}
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Mask,
tMaskgMask,
tMasksMask,
tMaskcMask,
binfo.actual_seqlen_q - m_block * kBlockM,
binfo.actual_seqlen_k - (n_block - 1) * kBlockN
);
FLASH_NAMESPACE::copy_Mask</*Is_even_MN=*/true>(
FLASH_NAMESPACE::copy_MN</*Is_even_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias,
tBiassBias,
Expand Down
5 changes: 4 additions & 1 deletion csrc/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutMask = decltype(tile_to_shape(
SmemLayoutAtomMask{},
Shape<Int<kBlockM>, Int<kBlockN>>{}));
using SmemCopyAtomMask = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
using SmemLayoutBias = decltype(tile_to_shape(
SmemLayoutAtomBias{},
Shape<Int<kBlockM>, Int<kBlockN>>{}));
using SmemCopyAtomBias = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;

// Shared memory layout for output
using SmemLayoutAtomO = decltype(
Expand Down Expand Up @@ -268,6 +270,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutMask = decltype(tile_to_shape(
SmemLayoutAtomMask{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemCopyAtomMask = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;

using SmemLayoutAtomBias = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Expand All @@ -276,6 +279,7 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutBias = decltype(tile_to_shape(
SmemLayoutAtomBias{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemCopyAtomBias = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;

// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
Expand Down Expand Up @@ -322,7 +326,6 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemCopyAtomBias = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, elem_type>;

// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
Expand Down
2 changes: 1 addition & 1 deletion csrc/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ __forceinline__ __device__ void copy(
template <bool Is_even_MN=true, bool Clear_OOB_MN=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2>
__forceinline__ __device__ void copy_Mask(
__forceinline__ __device__ void copy_MN(
TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
const int max_M=0, const int max_N=0
Expand Down