From 5ba150249b943ab6e71fd1e8423e835b15f9492b Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 14 May 2025 10:41:59 +0800 Subject: [PATCH 1/3] Add kernel_traits.h to csrc/src --- csrc/src/kernel_traits.h | 384 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 csrc/src/kernel_traits.h diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h new file mode 100644 index 0000000..9f6ebcc --- /dev/null +++ b/csrc/src/kernel_traits.h @@ -0,0 +1,384 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile, _16, _16>>; + + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom, Element>; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load + + // Zero hold + using SmemLayoutAtomZeroHold = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + + using SmemLayoutZeroHold = decltype(tile_to_shape( + SmemLayoutAtomZeroHold{}, + Shape, Int<1>>{})); // One zero-hold per query + + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + + // The overall shared memory size needs to consider the zero-hold for dynamic mask + static constexpr int kSmemSizeWithZeroHold = kSmemSize + kSmemZeroHoldSize; + + using GmemLayoutAtomZeroHold = GmemLayoutAtom; + using GmemTiledCopyZeroHold = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomZeroHold{}, + Layout>{})); // Val layout, 8 vals per read +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + static constexpr bool Is_dynamic_mask = Is_dynamic_mask_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom, elem_type>; + + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom, elem_type>; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom, elem_type>; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + using SmemLayoutAtomZeroHold = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + + using SmemLayoutZeroHold = decltype(tile_to_shape( + SmemLayoutAtomZeroHold{}, + Shape, Int<1>>{})); // One zero-hold per query + + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + + // The overall shared memory size needs to consider the zero-hold + static constexpr int kSmemSizeWithZeroHold = kSmemSize + + (Is_dynamic_mask ? kSmemZeroHoldSize : 0); + + using GmemLayoutAtomZeroHold = GmemLayoutAtom; + using GmemTiledCopyZeroHold = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomZeroHold{}, + Layout>{})); // Val layout, 8 vals per read +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// From 0594716aad2a21b7056d68b81e50524a1c5ac3c4 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 14 May 2025 11:57:50 +0800 Subject: [PATCH 2/3] Update mask.h --- csrc/src/mask.h | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 4508f25..d3ef098 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -125,13 +125,12 @@ __forceinline__ __device__ void apply_dynamic_mask_1rowblock( const Element* zero_hold_states, // Pre-calculated zero_hold states [key_len] const Element* causal_mask_ptr, // Causal mask values [key_len] const int key_len, // Sequence length for keys - const int keep_window_size // Maximum window size to keep + const int keep_window_size, // Maximum window size to keep + float* row_vals, // Shared memory buffer for mask values [key_len] + float* sort_keys, // Shared memory buffer for sorting keys [key_len] + int* sort_indices // Shared memory buffer for sorting indices [key_len] ) { static_assert(Layout::rank == 1, "Tensor must be 1D"); - extern __shared__ float shared_mem[]; - float* row_vals = shared_mem; // [key_len] - float* sort_keys = row_vals + key_len; // [key_len] - int* sort_indices = reinterpret_cast(sort_keys + key_len); // [key_len] int tid = threadIdx.x; // Load zero_hold and initialize row values @@ -169,10 +168,14 @@ struct DynamicMask { Tensor &tensor, const Element* zero_hold_states, const Element* causal_mask_ptr, - const int key_len + const int key_len, + float* row_vals, + float* sort_keys, + int* sort_indices ) { apply_dynamic_mask_1rowblock( - tensor, zero_hold_states, causal_mask_ptr, key_len, keep_window_size + tensor, zero_hold_states, causal_mask_ptr, key_len, keep_window_size, + row_vals, sort_keys, sort_indices ); } }; From 0aa1e4a2e078b1cacbbaf7671f8e725faefc0a56 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Wed, 14 May 2025 11:59:23 +0800 Subject: [PATCH 3/3] Update kernel_traits.h --- csrc/src/kernel_traits.h | 153 ++++++++++++++++++++++++++------------- 1 file changed, 102 insertions(+), 51 deletions(-) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index 9f6ebcc..7fd8272 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -12,6 +12,10 @@ using namespace cute; +/** + * Base traits class for Flash Attention kernels + * Contains common type definitions and architecture-specific settings + */ template struct Flash_kernel_traits { @@ -45,7 +49,11 @@ struct Flash_kernel_traits { #endif }; -// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +/** + * Forward pass kernel traits + * Specializes the base traits for forward propagation with additional settings + * Supports dynamic mask attention + */ template > struct Flash_fwd_kernel_traits : public Base { @@ -56,26 +64,33 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - // The number of threads. + // Thread configuration static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; + // Block dimensions static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + + // Memory layout constants static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + // Define the MMA (matrix-multiply-accumulate) tiled structure using TiledMma = TiledMMA< typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Layout,_1,_1>>, // Thread group layout Tile, _16, _16>>; + // Shared memory layout for Q matrix using SmemLayoutAtomQ = decltype( composition(Swizzle{}, // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 @@ -85,15 +100,17 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); + // Shared memory layout for K and V matrices using SmemLayoutKV = decltype(tile_to_shape( SmemLayoutAtomQ{}, Shape, Int>{})); - // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + // Transposed layouts for V matrix using SmemLayoutVtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + // Shared memory layout for output using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout, Int>, @@ -101,41 +118,72 @@ struct Flash_fwd_kernel_traits : public Base { using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); + + // Copy atoms for output using SmemCopyAtomO = Copy_Atom, Element>; using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + // Dynamic mask related definitions + using SmemLayoutAtomZeroHold = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + + // Zero-hold states layout [kBlockM, kBlockN] + using SmemLayoutZeroHold = decltype(tile_to_shape( + SmemLayoutAtomZeroHold{}, + Shape, Int>{})); + + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + + // Dynamic mask memory allocation constants + static constexpr int kMaxKeysPerBlock = kBlockN; + static constexpr int kDynamicMaskBufferPerQuery = kMaxKeysPerBlock * (2 * sizeof(float) + sizeof(int)); + static constexpr int kTotalDynamicMaskBuffer = kBlockM * kDynamicMaskBufferPerQuery; + + // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + + // Base shared memory size without dynamic mask buffer static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + // Total shared memory size including dynamic mask buffer + static constexpr int kSmemSizeWithMask = kSmemSize + kTotalDynamicMaskBuffer; + // Global memory access configuration static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. + + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, Stride, _1>>; - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. + // Global memory copy structures + // Using CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, AutoVectorizingCopyWithAssumedAlignment<128> >; + + // Tiled copy for QKV matrices using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + + // Tiled copy for output using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + // Accumulator layout for output using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row @@ -143,34 +191,26 @@ struct Flash_fwd_kernel_traits : public Base { Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; + + // Tiled copy for output accumulator using GmemTiledCopyOaccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>{})); // Val layout, 4 vals per store + + // Rotary embedding related definitions using GmemLayoutAtomRotcossin = GmemLayoutAtom; using GmemTiledCopyRotcossin = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, - Layout>{})); // Val layout, 4 vals per load using GmemTiledCopyRotcossinCont = decltype( + Layout>{})); // Val layout, 4 vals per load + + using GmemTiledCopyRotcossinCont = decltype( make_tiled_copy(Copy_Atom, Element>{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load - // Zero hold - using SmemLayoutAtomZeroHold = decltype( - composition(Swizzle{}, - Layout>, - Stride, _1>>{})); - - using SmemLayoutZeroHold = decltype(tile_to_shape( - SmemLayoutAtomZeroHold{}, - Shape, Int<1>>{})); // One zero-hold per query - - static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); - - // The overall shared memory size needs to consider the zero-hold for dynamic mask - static constexpr int kSmemSizeWithZeroHold = kSmemSize + kSmemZeroHoldSize; - + // Zero hold global memory operations using GmemLayoutAtomZeroHold = GmemLayoutAtom; using GmemTiledCopyZeroHold = decltype( make_tiled_copy(Copy_Atom{}, @@ -178,11 +218,14 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{})); // Val layout, 8 vals per read }; -// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. -// No_double_buffer is another option to reduce smem usage, but will slow things down. +/** + * Backward pass kernel traits + * Specializes the base traits for backward propagation + * Contains additional memory layout definitions for gradients + */ template > struct Flash_bwd_kernel_traits : public Base { using Element = typename Base::Element; @@ -192,27 +235,33 @@ struct Flash_bwd_kernel_traits : public Base { using SmemCopyAtom = typename Base::SmemCopyAtom; using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + // Memory optimization flags static constexpr bool Is_V_in_regs = Is_V_in_regs_; static constexpr bool No_double_buffer = No_double_buffer_; - static constexpr bool Is_dynamic_mask = Is_dynamic_mask_; - // The number of threads. + // Thread configuration static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; + // Block dimensions static constexpr int kBlockM = kBlockM_; static constexpr int kBlockN = kBlockN_; static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + + // Memory layout constants static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + // Atom layout configuration static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; static_assert(kNWarps % AtomLayoutMSdP == 0); static_assert(kNWarps % AtomLayoutNdKV == 0); static_assert(kNWarps % AtomLayoutMdQ == 0); + // Define the MMA tiled structures for different computations using TiledMmaSdP = TiledMMA< typename Base::MMA_Atom_Arch, Layout, Int, _1>>, @@ -225,9 +274,10 @@ struct Flash_bwd_kernel_traits : public Base { using TiledMmadQ = TiledMMA< typename Base::MMA_Atom_Arch, - Layout, Int, _1>>, // 2x4x1 or 4x2x1 thread group + Layout, Int, _1>>, // Thread group layout Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + // Shared memory layouts using SmemLayoutAtomQdO = decltype( composition(Swizzle{}, Layout>, @@ -241,7 +291,6 @@ struct Flash_bwd_kernel_traits : public Base { Layout, Int>, Stride, _1>>{})); using SmemLayoutKV = decltype(tile_to_shape( - // SmemLayoutAtomQdO{}, SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); @@ -249,17 +298,12 @@ struct Flash_bwd_kernel_traits : public Base { composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); - // TODO: generalize to other values of kBlockN - // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 - // static constexpr int kPBlockN = kBlockN; - // Temporarily disabling this for hdim 256 on sm86 and sm89 - // static_assert(kBlockN >= 64); + // PdS layout settings static_assert(kBlockN >= 32); - // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); - // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( composition(Swizzle{}, Layout, Int>, @@ -295,12 +339,13 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom, elem_type>; - // Double buffer for sQ + // Shared memory size calculations static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) @@ -310,38 +355,43 @@ struct Flash_bwd_kernel_traits : public Base { ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + // Global memory access configuration static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem - // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, Stride, _1>>; - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. + // Global memory copy structures using Gmem_copy_struct = std::conditional_t< Has_cp_async, SM80_CP_ASYNC_CACHEGLOBAL, AutoVectorizingCopyWithAssumedAlignment<128> >; + using GmemTiledCopyQKV = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( make_tiled_copy(Copy_Atom, elem_type>{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< kBlockKSmem == 32, Layout, // Thread layout, 8 threads per row @@ -349,6 +399,7 @@ struct Flash_bwd_kernel_traits : public Base { Layout, // Thread layout, 16 threads per row Stride< _16, _1>> >; + using GmemTiledCopydQaccum = decltype( make_tiled_copy(Copy_Atom, ElementAccum>{}, GmemLayoutAtomdQaccum{}, @@ -359,6 +410,9 @@ struct Flash_bwd_kernel_traits : public Base { Layout, // Thread layout, 8 threads per row Stride<_32, _1>>{}, Layout>{})); // Val layout, 1 val per store + + // Dynamic mask related definitions for backward pass + // Note: These are primarily handled in forward pass but kept for consistency using SmemLayoutAtomZeroHold = decltype( composition(Swizzle{}, Layout>, @@ -366,14 +420,11 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutZeroHold = decltype(tile_to_shape( SmemLayoutAtomZeroHold{}, - Shape, Int<1>>{})); // One zero-hold per query + Shape, Int>{})); static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); - // The overall shared memory size needs to consider the zero-hold - static constexpr int kSmemSizeWithZeroHold = kSmemSize + - (Is_dynamic_mask ? kSmemZeroHoldSize : 0); - + // Zero hold global memory operations using GmemLayoutAtomZeroHold = GmemLayoutAtom; using GmemTiledCopyZeroHold = decltype( make_tiled_copy(Copy_Atom{},