diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h new file mode 100644 index 0000000..7fd8272 --- /dev/null +++ b/csrc/src/kernel_traits.h @@ -0,0 +1,435 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include + +using namespace cute; + +/** + * Base traits class for Flash Attention kernels + * Contains common type definitions and architecture-specific settings + */ +template +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +/** + * Forward pass kernel traits + * Specializes the base traits for forward propagation with additional settings + * Supports dynamic mask attention + */ +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + // If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // Thread configuration + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + // Block dimensions + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + + static_assert(kHeadDim % 32 == 0); + + // Memory layout constants + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + // Define the MMA (matrix-multiply-accumulate) tiled structure + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout,_1,_1>>, // Thread group layout + Tile, _16, _16>>; + + // Shared memory layout for Q matrix + using SmemLayoutAtomQ = decltype( + composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // Shared memory layout for K and V matrices + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // Transposed layouts for V matrix + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + // Shared memory layout for output + using SmemLayoutAtomO = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + + // Copy atoms for output + using SmemCopyAtomO = Copy_Atom, Element>; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + // Dynamic mask related definitions + using SmemLayoutAtomZeroHold = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + + // Zero-hold states layout [kBlockM, kBlockN] + using SmemLayoutZeroHold = decltype(tile_to_shape( + SmemLayoutAtomZeroHold{}, + Shape, Int>{})); + + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + + // Dynamic mask memory allocation constants + static constexpr int kMaxKeysPerBlock = kBlockN; + static constexpr int kDynamicMaskBufferPerQuery = kMaxKeysPerBlock * (2 * sizeof(float) + sizeof(int)); + static constexpr int kTotalDynamicMaskBuffer = kBlockM * kDynamicMaskBufferPerQuery; + + // Shared memory size calculations + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + + // Base shared memory size without dynamic mask buffer + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + // Total shared memory size including dynamic mask buffer + static constexpr int kSmemSizeWithMask = kSmemSize + kTotalDynamicMaskBuffer; + + // Global memory access configuration + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // Global memory copy structures + // Using CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + + // Tiled copy for QKV matrices + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + // Tiled copy for output + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + // Accumulator layout for output + using GmemLayoutAtomOaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + + // Tiled copy for output accumulator + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store + + // Rotary embedding related definitions + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 4 vals per load + + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout>{})); // Val layout, 8 vals per load + + // Zero hold global memory operations + using GmemLayoutAtomZeroHold = GmemLayoutAtom; + using GmemTiledCopyZeroHold = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomZeroHold{}, + Layout>{})); // Val layout, 8 vals per read +}; + +/** + * Backward pass kernel traits + * Specializes the base traits for backward propagation + * Contains additional memory layout definitions for gradients + */ +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + // Memory optimization flags + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // Thread configuration + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + // Block dimensions + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + + static_assert(kHeadDim % 32 == 0); + + // Memory layout constants + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + // Atom layout configuration + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + // Define the MMA tiled structures for different computations + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, + Tile, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, Int, _1>>, // Thread group layout + Tile, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + // Shared memory layouts + using SmemLayoutAtomQdO = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomKV{}, + make_shape(Int{}, Int{}))); + + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // PdS layout settings + static_assert(kBlockN >= 32); + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + static constexpr int kSwizzlePdS = 3; + + using SmemLayoutAtomPdS = decltype( + composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom, elem_type>; + + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom, elem_type>; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom, elem_type>; + + // Shared memory size calculations + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + // Global memory access configuration + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + + // Global memory copy structures + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + AutoVectorizingCopyWithAssumedAlignment<128> + >; + + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom, elem_type>{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom, ElementAccum>{}, + Layout, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout>{})); // Val layout, 1 val per store + + // Dynamic mask related definitions for backward pass + // Note: These are primarily handled in forward pass but kept for consistency + using SmemLayoutAtomZeroHold = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + + using SmemLayoutZeroHold = decltype(tile_to_shape( + SmemLayoutAtomZeroHold{}, + Shape, Int>{})); + + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + + // Zero hold global memory operations + using GmemLayoutAtomZeroHold = GmemLayoutAtom; + using GmemTiledCopyZeroHold = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomZeroHold{}, + Layout>{})); // Val layout, 8 vals per read +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/src/mask.h b/csrc/src/mask.h index 4508f25..d3ef098 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -125,13 +125,12 @@ __forceinline__ __device__ void apply_dynamic_mask_1rowblock( const Element* zero_hold_states, // Pre-calculated zero_hold states [key_len] const Element* causal_mask_ptr, // Causal mask values [key_len] const int key_len, // Sequence length for keys - const int keep_window_size // Maximum window size to keep + const int keep_window_size, // Maximum window size to keep + float* row_vals, // Shared memory buffer for mask values [key_len] + float* sort_keys, // Shared memory buffer for sorting keys [key_len] + int* sort_indices // Shared memory buffer for sorting indices [key_len] ) { static_assert(Layout::rank == 1, "Tensor must be 1D"); - extern __shared__ float shared_mem[]; - float* row_vals = shared_mem; // [key_len] - float* sort_keys = row_vals + key_len; // [key_len] - int* sort_indices = reinterpret_cast(sort_keys + key_len); // [key_len] int tid = threadIdx.x; // Load zero_hold and initialize row values @@ -169,10 +168,14 @@ struct DynamicMask { Tensor &tensor, const Element* zero_hold_states, const Element* causal_mask_ptr, - const int key_len + const int key_len, + float* row_vals, + float* sort_keys, + int* sort_indices ) { apply_dynamic_mask_1rowblock( - tensor, zero_hold_states, causal_mask_ptr, key_len, keep_window_size + tensor, zero_hold_states, causal_mask_ptr, key_len, keep_window_size, + row_vals, sort_keys, sort_indices ); } };