diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 1facc13..a91689f 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -15,11 +15,11 @@ namespace FLASH_NAMESPACE { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; +typedef int64_t index_t; //////////////////////////////////////////////////////////////////////////////////////////////////// struct QKV_params { - using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim] void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim] @@ -46,20 +46,12 @@ struct QKV_params { //////////////////////////////////////////////////////////////////////////////////////////////////// struct ZeroHold_params { - using index_t = int64_t; - - void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len] + void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len] // The stride of the zero-hold states tensor. - index_t zero_hold_batch_stride; // Stride between batches of zero-hold states - index_t zero_hold_head_stride; // Stride between heads of zero-hold states - index_t zero_hold_query_stride; // Stride for the third dimension (query_len) of zero-hold states - // Assuming last dim (key_len) has stride 1 for the zero_hold_states_ptr - - index_t causal_mask_batch_stride; // Stride between batches of causal_mask - index_t causal_mask_head_stride; // Stride for the second dimension (size 1) of causal_mask - index_t causal_mask_query_len_stride; // Stride for the third dimension (query_len) of causal_mask - // Assuming last dim (key_len) has stride 1 for the causal_mask_ptr + index_t zero_hold_batch_stride; // Stride between batches of zero-hold states + index_t zero_hold_head_stride; // Stride between heads of zero-hold states + index_t zero_hold_row_stride; // Stride for the third dimension (key_len) of zero-hold states // The keep window size. int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k) @@ -73,7 +65,6 @@ struct Flash_fwd_params : public QKV_params, public ZeroHold_params { void *k_ptr = nullptr; void *v_ptr = nullptr; void *zero_hold_ptr = nullptr; - void *causal_mask_ptr = nullptr; // Input tensor for the bias void *b_ptr = nullptr; @@ -207,7 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params { index_t dv_head_stride; index_t dzero_hold_batch_stride; index_t dzero_hold_head_stride; - index_t dzero_hold_query_stride; + index_t dzero_hold_row_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; diff --git a/csrc/src/flash_attention_fwd_kernel.h b/csrc/src/flash_attention_fwd_kernel.h index a668b3c..0a8b9ae 100644 --- a/csrc/src/flash_attention_fwd_kernel.h +++ b/csrc/src/flash_attention_fwd_kernel.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2024, Tri Dao. + * Copyright (c) 2025, Jingze Shi and Tri Dao. ******************************************************************************/ #pragma once @@ -203,7 +203,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // Dynamic mask related shared memory. Use a running char* pointer for robust allocation. - char* dynamic_smem_current_ptr = reinterpret_cast(sV.data().get() + size(sV) * sizeof(Element)); + char* dynamic_smem_current_ptr = reinterpret_cast(sV.data().get()) + size(sV) * sizeof(Element); Tensor sZeroHold = make_tensor( make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type typename Kernel_traits::SmemLayoutZeroHold{} @@ -235,7 +235,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi dynamic_smem_current_ptr += Kernel_traits::kSmemNonZeroIndicesSize; Tensor sPredicate = make_tensor( - make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // Element type + make_smem_ptr(reinterpret_cast(dynamic_smem_current_ptr)), // bool type typename Kernel_traits::SmemLayoutPredicate{} ); @@ -266,7 +266,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Copy Atom retiling auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); @@ -281,12 +283,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // auto smem_thr_copy_ZeroHold = smem_tiled_copy_ZeroHold.get_thread_slice(tidx); // Tensor tSsZeroHold = smem_thr_copy_ZeroHold.partition_S(sZeroHold); - // For sCausalMask -> registers (if needed) - // using CausalMaskSmemCopyAtom = typename Kernel_traits::SmemCopyAtom; // Assuming Element type - // auto smem_tiled_copy_CausalMask_smem = make_tiled_copy_B(CausalMaskSmemCopyAtom{}, tiled_mma); - // auto smem_thr_copy_CausalMask_smem = smem_tiled_copy_CausalMask_smem.get_thread_slice(tidx); - // Tensor tSsCausalMask = has_causal_mask ? smem_thr_copy_CausalMask_smem.partition_S(sCausalMask) : empty_smem_tensor_for_copy_D; - // PREDICATES Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) @@ -319,7 +315,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue // Init dynamic mask processor - DynamicMask dynamic_mask(params.keep_window_size); + DynamicMask dynamic_mask(params.keep_window_size); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs FLASH_NAMESPACE::copy( gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, @@ -368,442 +364,299 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax; - // 处理需要掩码的块(通常是最后几个块) + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. constexpr int n_masking_steps = (!Is_causal) ? 1 : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { - // 等待K数据 + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - // 加载V块到共享内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN - ); + // Advance gV + if (masking_step > 0) { + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + FLASH_NAMESPACE::copy( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } cute::cp_async_fence(); - - // 计算块中实际键的数量 + + // Calculating the actual number of keys in the block const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); - - // 为当前块内的每个查询行处理动态掩码 - const int queries_in_block = min(kBlockM, binfo.actual_seqlen_q - m_block * kBlockM); - for (int m_idx = 0; m_idx < queries_in_block; ++m_idx) { - // 获取当前查询的全局索引 + + // Process dynamic mask for each query row in the current block + for (int m_idx = 0; m_idx < kBlockM; ++m_idx) { + // Get the global index of the current query const int query_idx = m_block * kBlockM + m_idx; + if (query_idx >= binfo.actual_seqlen_q) { + continue; + } - // 获取当前查询行的动态掩码内存 - Tensor mask_values = sDynamicMaskValues(m_idx, _); - Tensor sort_keys = sDynamicMaskSortKeys(m_idx, _); - Tensor sort_indices = sDynamicMaskSortIndices(m_idx, _); - Tensor nonzero_indices = sNonZeroIndices(m_idx, _); - Tensor predicate_k = sPredicate(m_idx, _); - - // 获取当前查询行的zero_hold和causal_mask - const Element* zero_hold_row = &sZeroHold[m_idx][0]; - const Element* causal_mask_row = params.causal_mask_ptr != nullptr ? - &sCausalMask[m_idx][0] : nullptr; - - // 使用DynamicMask结构体来应用掩码 - dynamic_mask.apply_mask_1rowblock( - mask_values, + // Apply the dynamic mask to the current query row + auto mask_values_row = sDynamicMaskValues(m_idx, _); // float + auto zero_hold_row = sZeroHold(m_idx, _); // half/bfloat16 + auto sort_keys_row = sDynamicMaskSortKeys(m_idx, _); // float + auto sort_indices_row = sDynamicMaskSortIndices(m_idx, _); // int + dynamic_mask.template apply_mask_1rowblock< + typename decltype(mask_values_row)::engine_type, typename decltype(mask_values_row)::layout_type, + typename decltype(zero_hold_row)::engine_type, typename decltype(zero_hold_row)::layout_type, + typename decltype(sort_keys_row)::engine_type, typename decltype(sort_keys_row)::layout_type, + typename decltype(sort_indices_row)::engine_type, typename decltype(sort_indices_row)::layout_type, + Element, Is_causal + >( + mask_values_row, zero_hold_row, - causal_mask_row, + query_idx, block_key_len, - mask_values.data().get(), - sort_keys.data().get(), - reinterpret_cast(sort_indices.data().get()), + mask_values_row, + sort_keys_row, + sort_indices_row ); __syncthreads(); - - // 初始化键的活性状态谓词 - if (tidx == 0) { - // 只需一个线程来初始化整个谓词数组 - #pragma unroll - for (int k_idx = 0; k_idx < kBlockN; ++k_idx) { - predicate_k(k_idx) = false; - } - } - __syncthreads(); - - // 找出非零位置 - int nonzero_count = 0; - // 每个线程负责处理部分键位置 + // Find the non-zero positions + auto predicate_k_row = sPredicate(m_idx, _); // bool for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { - if (mask_values(k_idx) != 0.0f) { - // 使用原子操作安全地增加计数并获取索引位置 - int idx = atomicAdd(&nonzero_count, 1); - if (idx < Kernel_traits::kMaxKeysPerBlock) { - nonzero_indices(idx) = k_idx; - // 标记该键为活跃状态 - predicate_k(k_idx) = true; - } - } + predicate_k_row(k_idx) = (mask_values_row(k_idx) != 0.0f); } __syncthreads(); + } - // 如果没有非零键,跳过当前查询行 - if (nonzero_count == 0) { - continue; - } - - // 处理多查询头情况 (MQA/GQA) - const int num_queries_per_kv = params.h_h_k_ratio; - - // 对于每个查询组内的查询头 - for (int q_group_idx = 0; q_group_idx < num_queries_per_kv; q_group_idx++) { - // 创建累加器用于注意力分数 - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); - clear(acc_s); - - // 执行稀疏矩阵乘法 - FLASH_NAMESPACE::sparse_gemm_rs( - acc_s(_, m_idx, _), // 当前查询行的累加器 - tSrQ(_, m_idx, _), // 当前查询 - tSrK, // 键值 - tSsK, // 共享内存中的键值 - tiled_mma, - smem_tiled_copy_K, - smem_thr_copy_K, - predicate_k // 活跃键的谓词 - ); + // 执行稀疏矩阵乘法 + FLASH_NAMESPACE::sparse_gemm( + acc_s, + tSrQ, + tSrK, tSsQ, tSsK, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K, + sPredicate // 活跃键的谓词 + ); - // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) - for (int s_idx = 0; s_idx < size(acc_s); ++s_idx) { - const int k_idx = get<2>(thr_mma.get_slice_idx(s_idx, acc_s)); - if (k_idx < block_key_len && predicate_k(k_idx)) { - acc_s(s_idx) += static_cast(mask_values[k_idx]); + // 应用掩码添加(zero_hold状态既是掩码也是要添加到注意力分数的值) + for (int mma = 0; mma < size<0>(acc_s); ++mma) { + for (int mi = 0; mi < size<1>(acc_s); ++mi) { + for (int ki = 0; ki < size<2>(acc_s); ++ki) { + int m_idx = mi; // 或者根据你的tile映射 + int k_idx = ki; + if (m_idx < kBlockM && k_idx < block_key_len) { + auto mask_values_row = sDynamicMaskValues(m_idx, _); + auto predicate_k_row = sPredicate(m_idx, _); + if (predicate_k_row(k_idx)) { + acc_s(mma, mi, ki) += static_cast(mask_values_row(k_idx)); + } } } - - // 执行softmax并更新输出累加器 - if (q_group_idx == 0 && n_block == n_block_max - 1) { - softmax.template softmax_rescale_o( - acc_s, acc_o, params.scale_softmax_log2); - } else { - softmax.template softmax_rescale_o( - acc_s, acc_o, params.scale_softmax_log2); - } - - // 将浮点分数转换为Element类型进行输出计算 - Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - Tensor tOrP = make_tensor( - rP.data(), - FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout()) - ); - - // 计算该查询头的输出 - FLASH_NAMESPACE::sparse_gemm_rs( - acc_o, // 输出累加器 - tOrP, // 注意力权重 - tOrVt, // 值向量 - tOsVt, // 共享内存中的值向量 - tiled_mma, - smem_tiled_copy_V, - smem_thr_copy_V, - predicate_k // 应用相同的谓词来进行稀疏V矩阵乘法 - ); } - __syncthreads(); } - - // 等待V数据 + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - - // 准备加载下一个K块(如果有) if (n_block > n_block_min) { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, tKgK(_, _, _, n_block-1), tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); - cute::cp_async_fence(); - - // 加载下一个ZeroHold块到共享内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block-1), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. cute::cp_async_fence(); - - // 加载下一个CausalMask块到共享内存(如果有) - if (params.causal_mask_ptr != nullptr) { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block-1), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); - cute::cp_async_fence(); - } } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // 提前退出检查 + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, + tOrP, tOrVt, tOsVt, + tiled_mma, smem_tiled_copy_V, smem_thr_copy_V, + sPredicate // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { break; } } - // 处理不需要掩码的块 + // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { - // 等待K数据 + Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - - // 加载V块到共享内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV - ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - - // 计算块中实际键的数量 + + // calculate the actual number of keys in the block const int block_key_len = min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN); - const int queries_in_block = min(kBlockM, binfo.actual_seqlen_q - m_block * kBlockM); - - // 为当前块内的每个查询行处理动态掩码 - for (int m_idx = 0; m_idx < queries_in_block; ++m_idx) { - // 获取当前查询的零状态行 - Tensor mask_values = sDynamicMaskValues(m_idx, _); - Tensor sort_keys = sDynamicMaskSortKeys(m_idx, _); - Tensor sort_indices = sDynamicMaskSortIndices(m_idx, _); - Tensor nonzero_indices = sNonZeroIndices(m_idx, _); - Tensor predicate_k = sPredicate(m_idx, _); - - // 获取当前查询行的zero_hold - const Element* zero_hold_row = &sZeroHold[m_idx][0]; - - // 使用DynamicMask结构体来应用掩码,没有因果掩码 - dynamic_mask.apply_mask_1rowblock( - mask_values, + + // Process dynamic mask for each query row in the current block + for (int m_idx = 0; m_idx < kBlockM; ++m_idx) { + // Get the global index of the current query + const int query_idx = m_block * kBlockM + m_idx; + if (query_idx >= binfo.actual_seqlen_q) { + continue; + } + + // Apply the dynamic mask to the current query row + auto mask_values_row = sDynamicMaskValues(m_idx, _); // float + auto zero_hold_row = sZeroHold(m_idx, _); // half/bfloat16 + auto sort_keys_row = sDynamicMaskSortKeys(m_idx, _); // float + auto sort_indices_row = sDynamicMaskSortIndices(m_idx, _); // int + dynamic_mask.template apply_mask_1rowblock< + typename decltype(mask_values_row)::engine_type, typename decltype(mask_values_row)::layout_type, + typename decltype(zero_hold_row)::engine_type, typename decltype(zero_hold_row)::layout_type, + typename decltype(sort_keys_row)::engine_type, typename decltype(sort_keys_row)::layout_type, + typename decltype(sort_indices_row)::engine_type, typename decltype(sort_indices_row)::layout_type, + Element, /*Is_causal=*/false + >( + mask_values_row, zero_hold_row, - nullptr, // 无因果掩码 + query_idx, block_key_len, - mask_values.data().get(), - sort_keys.data().get(), - reinterpret_cast(sort_indices.data().get()) + mask_values_row, + sort_keys_row, + sort_indices_row ); __syncthreads(); - - // 初始化键的活性状态谓词 - if (tidx == 0) { - // 只需一个线程来初始化整个谓词数组 - #pragma unroll - for (int k_idx = 0; k_idx < kBlockN; ++k_idx) { - predicate_k(k_idx) = false; - } - } - __syncthreads(); - - // 找出非零位置 - int nonzero_count = 0; - // 每个线程负责处理部分键位置 + // Find the non-zero positions + auto predicate_k_row = sPredicate(m_idx, _); // bool for (int k_idx = tidx; k_idx < block_key_len; k_idx += blockDim.x) { - if (mask_values(k_idx) != 0.0f) { - // 使用原子操作安全地增加计数并获取索引位置 - int idx = atomicAdd(&nonzero_count, 1); - if (idx < Kernel_traits::kMaxKeysPerBlock) { - nonzero_indices(idx) = k_idx; - // 标记该键为活跃状态 - predicate_k(k_idx) = true; - } - } - } - __syncthreads(); - - // 如果没有非零键,跳过当前查询行 - if (nonzero_count == 0) { - continue; - } - - // 处理多查询头情况 (MQA/GQA) - const int num_queries_per_kv = params.h_h_k_ratio; - - // 对于每个查询组内的查询头 - for (int q_group_idx = 0; q_group_idx < num_queries_per_kv; q_group_idx++) { - // 创建累加器用于注意力分数 - Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); - clear(acc_s); - - // 执行稀疏矩阵乘法 - FLASH_NAMESPACE::sparse_gemm_rs( - acc_s(_, m_idx, _), // 当前查询行的累加器 - tSrQ(_, m_idx, _), // 当前查询 - tSrK, // 键值 - tSsK, // 共享内存中的键值 - tiled_mma, - smem_tiled_copy_K, - smem_thr_copy_K, - predicate_k // 活跃键的谓词 - ); - - // 应用掩码添加 - for (int s_idx = 0; s_idx < size(acc_s); ++s_idx) { - const int k_idx = get<2>(thr_mma.get_slice_idx(s_idx, acc_s)); - if (k_idx < block_key_len && predicate_k(k_idx)) { - acc_s(s_idx) += static_cast(mask_values[k_idx]); - } - } - - // 执行softmax并更新输出累加器 - softmax.template softmax_rescale_o( - acc_s, acc_o, params.scale_softmax_log2); - - // 将浮点分数转换为Element类型进行输出计算 - Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - Tensor tOrP = make_tensor( - rP.data(), - FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout()) - ); - - // 计算该查询头的输出 - FLASH_NAMESPACE::sparse_gemm_rs( - acc_o, // 输出累加器 - tOrP, // 注意力权重 - tOrVt, // 值向量 - tOsVt, // 共享内存中的值向量 - tiled_mma, - smem_tiled_copy_V, - smem_thr_copy_V, - predicate_k // 应用相同的谓词来进行稀疏V矩阵乘法 - ); + predicate_k_row(k_idx) = (mask_values_row(k_idx) != 0.0f); } __syncthreads(); } - - // 等待V数据 + + FLASH_NAMESPACE::sparse_gemm( + acc_s, + tSrQ, + tSrK, tSsQ, tSsK, + tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K, + sPredicate // 活跃键的谓词 + ); + FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); - if (n_block > n_block_min) { - // 准备加载下一个K块(如果有) - FLASH_NAMESPACE::copy( - gmem_tiled_copy_QKV, tKgK(_, _, _, n_block-1), tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); - cute::cp_async_fence(); - - // 加载下一个ZeroHold块到共享内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block-1), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); + FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. cute::cp_async_fence(); - - // 加载下一个CausalMask块到共享内存(如果有) - if (params.causal_mask_ptr != nullptr) { - FLASH_NAMESPACE::copy( - gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block-1), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask, - binfo.actual_seqlen_k - (n_block-1) * kBlockN - ); - cute::cp_async_fence(); - } } + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); + if (Return_softmax) { + tSgS.data() = tSgS.data() + (-kBlockN); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); + + FLASH_NAMESPACE::sparse_gemm_rs( + acc_o, + tOrP, tOrVt, tOsVt, + tiled_mma, smem_tiled_copy_V, smem_thr_copy_V, + sPredicate // 应用相同的谓词来进行稀疏V矩阵乘法 + ); + } + + // Epilogue // 后处理和输出归一化 - Tensor lse = softmax.template normalize_softmax_lse( + Tensor lse = softmax.template normalize_softmax_lse( acc_o, params.scale_softmax, 1.0f ); - - // 转换acc_o到Element类型 + + // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); - - // 准备共享内存用于输出 - Tensor sO = make_tensor( - sQ.data(), - typename Kernel_traits::SmemLayoutO{} - ); - - // 设置从累加器到共享内存的拷贝 - auto smem_tiled_copy_O = make_tiled_copy_C( - typename Kernel_traits::SmemCopyAtomO{}, - tiled_mma - ); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - - Tensor taccOrO = smem_thr_copy_O.retile_S(rO); - Tensor taccOsO = smem_thr_copy_O.partition_D(sO); - - // 确保共享内存区域可以安全使用 - if (Kernel_traits::Share_Q_K_smem) { - __syncthreads(); - } - - // 拷贝输出到共享内存 + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - - // 设置全局内存输出张量 + Tensor mO = make_tensor( - make_gmem_ptr(reinterpret_cast(params.o_ptr) - + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.o_ptr) + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), make_shape(binfo.actual_seqlen_q, params.h, params.d), make_stride(params.o_row_stride, params.o_head_stride, _1{}) ); - Tensor gO = local_tile( - mO(_, bidh, _), + mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0) - ); - - Tensor gLSE = get_lse_tile( - params, bidb, bidh, m_block, binfo - ); - - // 设置从共享内存到全局内存的拷贝 + ); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - - Tensor tOsO = gmem_thr_copy_O.partition_S(sO); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); - + __syncthreads(); - - // 从共享内存拷贝到寄存器,准备写入全局内存 + Tensor tOrO = make_tensor(shape(tOgO)); cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - // 设置输出的谓词 - Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); - Tensor tOcO = gmem_thr_copy_O.partition_D(cO); - Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); - - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { - tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; - } - } - - // 写入输出到全局内存 - FLASH_NAMESPACE::copy( - gmem_tiled_copy_O, - tOrO, - tOgO, - tOcO, - tOpO, - binfo.actual_seqlen_q - m_block * kBlockM - ); - - // 写入LSE值到全局内存 - Tensor caccO = make_identity_tensor(Shape, Int>{}); - Tensor taccOcO = thr_mma.partition_C(caccO); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) static_assert(decltype(size<0>(taccOcO))::value == 4); - - // 将张量转换为(2,2)形式,然后只获取行索引 + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); - CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); - - // 只有第一个线程写入LSE值 + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M if (get<1>(taccOcO_row(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { - if (m_block * kBlockM + get<0>(taccOcO_row(mi)) < binfo.actual_seqlen_q) { - gLSE(get<0>(taccOcO_row(mi))) = lse(mi); - } + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } } } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); } template diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index 3b5b215..133a839 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -38,19 +38,6 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, b #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { - #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); - #else - FLASH_UNSUPPORTED_ARCH - #endif -} - -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { - static_assert(Log_max_splits >= 1); - FLASH_NAMESPACE::combine_attn_seqk_parallel(params); -} - template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSizeWithMask; @@ -84,71 +71,6 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template -void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); - static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); - - const size_t smem_size = Kernel_traits::kSmemSizeWithMask; - - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); - const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - if (params.num_splits > 1) { - // We want kBlockM to be as small as possible for more parallelism. - // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. - // If headdim is divisible by 64, then we set kBlockM = 8, etc. - constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); - dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - if (params.num_splits <= 2) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 4) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 8) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 16) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 32) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 64) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } else if (params.num_splits <= 128) { - flash_fwd_splitkv_combine_kernel<<>>(params); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - } -} - -template -void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { - constexpr static int kBlockM = 64; // Fixed for all head dimensions - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd, Is_causal>(params, stream); -} - template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; diff --git a/csrc/src/generate_kernels.py b/csrc/src/generate_kernels.py new file mode 100644 index 0000000..54160ec --- /dev/null +++ b/csrc/src/generate_kernels.py @@ -0,0 +1,111 @@ +import argparse +import itertools +import os +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +DTYPE_MAP = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = [80] # Sm80 kernels support up to +HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256] +IS_CAUSAL = ["false", "true"] +NAMESPACE_INCLUDE = '#include "namespace_config.h"\n' + +def get_fwd_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE""" + +def get_fwd_split_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); + +}} // namespace FLASH_NAMESPACE""" + +def get_bwd_template() -> str: + return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h" + +namespace FLASH_NAMESPACE {{ + +template<> +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} + +}} // namespace FLASH_NAMESPACE""" + +@dataclass +class Kernel: + sm: int + dtype: str + head_dim: int + is_causal: str + direction: str + + @property + def template(self) -> str: + template_funcs = { + "fwd": get_fwd_template, + # "bwd": get_bwd_template, + # "fwd_split": get_fwd_split_template + } + template_func = template_funcs[self.direction] + return template_func().format( + DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, + IS_CAUSAL=self.is_causal + ) + + @property + def filename(self) -> str: + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + +def get_all_kernels() -> List[Kernel]: + for direction in ["fwd"]: #, "fwd_split", "bwd"]: + for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + +def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: + prelude = """// Copyright (c) 2025, Jingze Shi and Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"\n""" + content = prelude + kernel.template + (autogen_dir / kernel.filename).write_text(content) + +def main(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) + + for kernel in get_all_kernels(): + write_kernel(kernel, output_dir) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate_kernels", + description="Generate the flash_attention kernels template instantiations", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="Where to generate the kernels " + " will default to the current directory ", + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/csrc/src/kernel_traits.h b/csrc/src/kernel_traits.h index c2e28e3..774fc39 100644 --- a/csrc/src/kernel_traits.h +++ b/csrc/src/kernel_traits.h @@ -106,6 +106,7 @@ struct Flash_fwd_kernel_traits : public Base { Shape, Int>{})); // Transposed layouts for V matrix + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 using SmemLayoutVtransposed = decltype( composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); @@ -124,73 +125,56 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; // Dynamic mask related definitions - using SmemLayoutAtomZeroHold = decltype( + using SmemLayoutAtomMask = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - // Zero-hold states layout [kBlockM, kBlockN] - using SmemLayoutZeroHold = decltype(tile_to_shape( - SmemLayoutAtomZeroHold{}, + // layout [kBlockM, kBlockN] + using SmemLayoutZeroHold = decltype(tile_to_shape( // Used for sZeroHold (Element type) + SmemLayoutAtomMask{}, + Shape, Int>{})); + + using SmemLayoutDynamicMaskValues = decltype(tile_to_shape( // Used for sDynamicMaskValues (float type) + SmemLayoutAtomMask{}, // Layout is fine, type is float + Shape, Int>{})); + + using SmemLayoutDynamicMaskSortKeys = decltype(tile_to_shape( // Used for sDynamicMaskSortKeys (float type) + SmemLayoutAtomMask{}, // Layout is fine, type is float Shape, Int>{})); - - static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); - // Dynamic mask memory allocation constants - static constexpr int kMaxKeysPerBlock = kBlockN; - static constexpr int kMaskValuesSize = kMaxKeysPerBlock * sizeof(float); - static constexpr int kNonZeroIndicesSize = kMaxKeysPerBlock * sizeof(int); - static constexpr int kSortKeysSize = kMaxKeysPerBlock * sizeof(float); - static constexpr int kSortIndicesSize = kMaxKeysPerBlock * sizeof(int); - static constexpr int kDynamicMaskBufferPerQuery = kMaskValuesSize + kNonZeroIndicesSize + kSortKeysSize + kSortIndicesSize; - static constexpr int kTotalDynamicMaskBuffer = kBlockM * kDynamicMaskBufferPerQuery; - - // Dynamic mask shared memory layouts - using SmemLayoutDynamicMaskValues = decltype( - tile_to_shape( - composition(Swizzle{}, - Layout>, - Stride, _1>>{}), - Shape, Int>{} - ) - ); - - using SmemLayoutDynamicMaskSortKeys = decltype( - tile_to_shape( - composition(Swizzle{}, - Layout>, - Stride, _1>>{}), - Shape, Int>{} - ) - ); - - using SmemLayoutDynamicMaskSortIndices = decltype( - tile_to_shape( - composition(Swizzle{}, - Layout>, - Stride, _1>>{}), - Shape, Int>{} - ) - ); - - using SmemLayoutNonZeroIndices = decltype( - tile_to_shape( - composition(Swizzle{}, - Layout>, - Stride, _1>>{}), - Shape, Int>{} - ) - ); + using SmemLayoutDynamicMaskSortIndices = decltype(tile_to_shape( // Used for sDynamicMaskSortIndices (int type) + SmemLayoutAtomMask{}, // Layout is fine, type is int + Shape, Int>{})); + + using SmemLayoutNonZeroIndices = decltype(tile_to_shape( // Used for sNonZeroIndices (int type) + SmemLayoutAtomMask{}, // Layout is fine, type is int + Shape, Int>{})); + + using SmemLayoutPredicate = decltype(tile_to_shape( + SmemLayoutAtomMask{}, + Shape, Int>{})); // Used for sPredicate (bool type) // Shared memory size calculations static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); + static constexpr int kSmemMaskValuesSize = size(SmemLayoutDynamicMaskValues{}) * sizeof(float); + static constexpr int kSmemSortKeysSize = size(SmemLayoutDynamicMaskSortKeys{}) * sizeof(float); + static constexpr int kSmemSortIndicesSize = size(SmemLayoutDynamicMaskSortIndices{}) * sizeof(int); + static constexpr int kSmemNonZeroIndicesSize = size(SmemLayoutNonZeroIndices{}) * sizeof(int); + static constexpr int kSmemPredicateSize = size(SmemLayoutPredicate{}) * sizeof(bool); - // Base shared memory size without dynamic mask buffer - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - // Total shared memory size including dynamic mask buffer and nonzero indices - static constexpr int kSmemSizeWithMask = kSmemSize + kTotalDynamicMaskBuffer; + // Base shared memory size with Q and K/V matrices + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) + : kSmemQSize // For Q + + kSmemKVSize // For K and V + + kSmemZeroHoldSize // For sZeroHold + + kSmemMaskValuesSize // For sDynamicMaskValues + + kSmemSortKeysSize // For sDynamicMaskSortKeys + + kSmemSortIndicesSize // For sDynamicMaskSortIndices + + kSmemNonZeroIndicesSize // For sNonZeroIndices + + kSmemPredicateSize; // For sPredicate // Global memory access configuration static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -252,9 +236,9 @@ struct Flash_fwd_kernel_traits : public Base { Layout>{})); // Val layout, 8 vals per load // Zero hold global memory operations - using GmemLayoutAtomZeroHold = GmemLayoutAtom; + using GmemLayoutAtomZeroHold = GmemLayoutAtom; // Re-using GmemLayoutAtom for ZeroHold GMEM copies using GmemTiledCopyZeroHold = decltype( - make_tiled_copy(Copy_Atom{}, + make_tiled_copy(Copy_Atom{}, // Assuming Element type for ZeroHold in GMEM GmemLayoutAtomZeroHold{}, Layout>{})); // Val layout, 8 vals per read }; @@ -466,8 +450,8 @@ struct Flash_bwd_kernel_traits : public Base { static constexpr int kSmemZeroHoldSize = size(SmemLayoutZeroHold{}) * sizeof(Element); // Zero hold global memory operations - using GmemLayoutAtomZeroHold = GmemLayoutAtom; - using GmemTiledCopyZeroHold = decltype( + using GmemLayoutAtomZeroHold = GmemLayoutAtom; // Reusing fwd definition + using GmemTiledCopyZeroHold = decltype( // Reusing fwd definition make_tiled_copy(Copy_Atom{}, GmemLayoutAtomZeroHold{}, Layout>{})); // Val layout, 8 vals per read diff --git a/csrc/src/mask.h b/csrc/src/mask.h index d3ef098..7844ff6 100644 --- a/csrc/src/mask.h +++ b/csrc/src/mask.h @@ -26,16 +26,16 @@ using namespace cute; // Apply causal masking for dynamic mask with 1 row block template __forceinline__ __device__ void apply_causal_mask_1rowblock( - float* zero_hold_states, // Zero-hold states for one query row [key_len] - const Element* causal_mask_ptr, // Causal mask values for one query row [key_len] - int key_len // Key length + float* zero_hold_states, // Zero-hold states for one query row [key_len] + int query_idx, // Current query position (row index) + int key_len // Key length (sequence length for keys) ) { if constexpr (Is_causal) { - if (causal_mask_ptr != nullptr) { - #pragma unroll - for (int k_idx = 0; k_idx < key_len; ++k_idx) { - const bool is_masked = causal_mask_ptr[k_idx] != 0; - zero_hold_states[k_idx] = is_masked ? 0.0f : zero_hold_states[k_idx]; + #pragma unroll + for (int k_idx = 0; k_idx < key_len; ++k_idx) { + const bool is_masked = k_idx > query_idx; + if (is_masked) { + zero_hold_states[k_idx] = 0.0f; } } } @@ -119,18 +119,24 @@ __forceinline__ __device__ void apply_topk_window_selection_1rowblock( } // Apply dynamic mask with 1 row block -template +template < + typename EngineDst, typename LayoutDst, // float + typename EngineSrc, typename LayoutSrc, // half/bfloat16 + typename EngineSortKey, typename LayoutSortKey, // float + typename EngineSortIdx, typename LayoutSortIdx, // int + typename Element, bool Is_causal +> __forceinline__ __device__ void apply_dynamic_mask_1rowblock( - Tensor &tensor, // Output 1D tensor [key_len] - const Element* zero_hold_states, // Pre-calculated zero_hold states [key_len] - const Element* causal_mask_ptr, // Causal mask values [key_len] - const int key_len, // Sequence length for keys - const int keep_window_size, // Maximum window size to keep - float* row_vals, // Shared memory buffer for mask values [key_len] - float* sort_keys, // Shared memory buffer for sorting keys [key_len] - int* sort_indices // Shared memory buffer for sorting indices [key_len] + Tensor &tensor, // Output 1D tensor [key_len] + Tensor const &zero_hold_states, // Pre-calculated zero_hold states [key_len] + int query_idx, // Current query position (row index) + const int key_len, // Sequence length for keys + const int keep_window_size, // Maximum window size to keep + Tensor &row_vals, // Shared memory buffer for mask values [key_len] + Tensor &sort_keys, // Shared memory buffer for sorting keys [key_len] + Tensor &sort_indices // Shared memory buffer for sorting indices [key_len] ) { - static_assert(Layout::rank == 1, "Tensor must be 1D"); + static_assert(LayoutDst::rank == 1, "Tensor must be 1D"); int tid = threadIdx.x; // Load zero_hold and initialize row values @@ -141,11 +147,19 @@ __forceinline__ __device__ void apply_dynamic_mask_1rowblock( __syncthreads(); // Apply causal mask across the row - apply_causal_mask_1rowblock(row_vals, causal_mask_ptr, key_len); + apply_causal_mask_1rowblock( + row_vals.data().get(), + query_idx, key_len + ); __syncthreads(); // Top-k window selection - apply_topk_window_selection_1rowblock(row_vals, sort_keys, sort_indices, key_len, keep_window_size); + apply_topk_window_selection_1rowblock( + row_vals.data().get(), + sort_keys.data().get(), + sort_indices.data().get(), + key_len, keep_window_size + ); __syncthreads(); // Write back to tensor @@ -156,25 +170,36 @@ __forceinline__ __device__ void apply_dynamic_mask_1rowblock( } // Struct wrapper for dynamic mask application -template struct DynamicMask { const int keep_window_size; __forceinline__ __device__ DynamicMask(const int keep_window_size = 2048) : keep_window_size(keep_window_size) {} - template + template < + typename EngineDst, typename LayoutDst, + typename EngineSrc, typename LayoutSrc, + typename EngineSortKey, typename LayoutSortKey, + typename EngineSortIdx, typename LayoutSortIdx, + typename Element, bool Is_causal + > __forceinline__ __device__ void apply_mask_1rowblock( - Tensor &tensor, - const Element* zero_hold_states, - const Element* causal_mask_ptr, + Tensor &tensor, // float + Tensor const &zero_hold_states, // half/bfloat16 + int query_idx, const int key_len, - float* row_vals, - float* sort_keys, - int* sort_indices + Tensor &row_vals, + Tensor &sort_keys, + Tensor &sort_indices ) { - apply_dynamic_mask_1rowblock( - tensor, zero_hold_states, causal_mask_ptr, key_len, keep_window_size, + apply_dynamic_mask_1rowblock< + EngineDst, LayoutDst, + EngineSrc, LayoutSrc, + EngineSortKey, LayoutSortKey, + EngineSortIdx, LayoutSortIdx, + Element, Is_causal + >( + tensor, zero_hold_states, query_idx, key_len, keep_window_size, row_vals, sort_keys, sort_indices ); } diff --git a/csrc/src/philox.cuh b/csrc/src/philox.cuh new file mode 100644 index 0000000..5205f45 --- /dev/null +++ b/csrc/src/philox.cuh @@ -0,0 +1,53 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +#include "namespace_config.h" + +namespace FLASH_NAMESPACE { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +__forceinline__ __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace FLASH_NAMESPACE diff --git a/csrc/src/softmax.h b/csrc/src/softmax.h new file mode 100644 index 0000000..01589ad --- /dev/null +++ b/csrc/src/softmax.h @@ -0,0 +1,189 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "namespace_config.h" +#include "philox.cuh" +#include "utils.h" + +namespace FLASH_NAMESPACE { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + FLASH_NAMESPACE::template reduce_max(scores, row_max); + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); + FLASH_NAMESPACE::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + FLASH_NAMESPACE::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + FLASH_NAMESPACE::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + FLASH_NAMESPACE::reduce_sum(scores, row_sum); + } + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace FLASH_NAMESPACE diff --git a/csrc/src/static_switch.h b/csrc/src/static_switch.h new file mode 100644 index 0000000..d912812 --- /dev/null +++ b/csrc/src/static_switch.h @@ -0,0 +1,81 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/src/utils.h b/csrc/src/utils.h new file mode 100644 index 0000000..a52ef7f --- /dev/null +++ b/csrc/src/utils.h @@ -0,0 +1,519 @@ +/****************************************************************************** + * Copyright (c) 2025, Jingze Shi and Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +#include "namespace_config.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace FLASH_NAMESPACE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void sparse_gemm( + Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B, + PredicateTensor const &predicate_K +) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + auto tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i), tCrA_copy_view(_, _, i)); + } + #pragma unroll + for (int n = 0; n < size<1>(tCrB_copy_view); ++n) { + if (!B_in_regs) { + if (!predicate_K(n)) { + cute::clear(tCrB_copy_view(_, n, i)); + } + } + if (!B_in_regs && i < size<2>(tCrA) - 1) { + if (predicate_K(n)) { + cute::copy(smem_tiled_copy_B, tCsB(_, n, i + 1), tCrB_copy_view(_, n, i + 1)); + } else { + cute::clear(tCrB_copy_view(_, n, i + 1)); + } + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void sparse_gemm_rs( + Tensor0 &acc, + Tensor1 &tCrA, + Tensor2 &tCrB, + Tensor3 const& tCsB, + TiledMma tiled_mma, + TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B, + PredicateTensor const &predicate_K +) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + // Retile B for thread-wise copy from shared memory to registers + auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + #pragma unroll + for (int n = 0; n < size<1>(tCrB_copy_view); ++n) { + if (!predicate_K(n)) { + cute::clear(tCrB_copy_view(_, n, i)); + } + if (i < size<2>(tCrA) - 1) { + if (predicate_K(n)) { + cute::copy(smem_tiled_copy_B, tCsB(_, n, i + 1), tCrB_copy_view(_, n, i + 1)); + } else { + cute::clear(tCrB_copy_view(_, n, i + 1)); + } + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void relu_(Tensor &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = FLASH_NAMESPACE::convert_type(tensor); + FLASH_NAMESPACE::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, + const int max_MN=0, const int min_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ int non_zero_mask_indices( + const Element* dynamic_mask_1rowblock, // Dynamic mask one row block [key_len] + int* non_zero_indices, // Non-zero indices [key_len] + int key_len // Key length +) { + int non_zero_count = 0; + #pragma unroll + for (int idx = 0; idx < key_len; ++idx) { + if (dynamic_mask_1rowblock[idx] != static_cast(0)) { + non_zero_indices[non_zero_count++] = idx; + } + } + return non_zero_count; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace FLASH_NAMESPACE