Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 31 additions & 93 deletions csrc/src/flash_attention_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,40 +169,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);

Tensor mZeroHold = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.zero_hold_ptr) + bidb * params.zero_hold_batch_stride),
make_gmem_ptr(reinterpret_cast<Element*>(params.zero_hold_ptr) + binfo.q_offset(params.zero_hold_batch_stride, params.zero_hold_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{})
make_stride(params.zero_hold_head_stride, params.zero_hold_row_stride, _1{})
);
Tensor gZeroHold = local_tile(
mZeroHold(bidh / params.h_h_k_ratio, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, 0)
);

auto mCausalMask = has_causal_mask ?
make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.causal_mask_ptr) + bidb * params.causal_mask_batch_stride),
make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{})
) :
make_tensor(
make_gmem_ptr(static_cast<Element*>(nullptr)),
make_shape(1, 1, 1),
make_stride(static_cast<flash::index_t>(0), static_cast<flash::index_t>(0), _1{})
);

auto gCausalMask = has_causal_mask ?
local_tile(
mCausalMask(0, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, 0)
) :
make_tensor(
make_gmem_ptr(static_cast<Element*>(nullptr)),
make_layout(
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(static_cast<flash::index_t>(0), _1{}))
);
make_coord(m_block, n_block_max - 1)
); // (kBlockM, kBlockN)

// Shared memory layout configuration
Tensor sQ = make_tensor(
Expand Down Expand Up @@ -230,22 +205,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Dynamic mask related shared memory. Use a running char* pointer for robust allocation.
char* dynamic_smem_current_ptr = reinterpret_cast<char*>(sV.data().get() + size(sV) * sizeof(Element));
Tensor sZeroHold = make_tensor(
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)),
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)), // Element type
typename Kernel_traits::SmemLayoutZeroHold{}
);

dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize;
auto causal_mask_smem_ptr = has_causal_mask ?
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)) :
make_smem_ptr(static_cast<Element*>(nullptr));
Tensor sCausalMask = make_tensor(
causal_mask_smem_ptr,
typename Kernel_traits::SmemLayoutCausalMask{}
);

if (has_causal_mask) {
dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize;
}
Tensor sDynamicMaskValues = make_tensor(
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
typename Kernel_traits::SmemLayoutDynamicMaskValues{}
Expand Down Expand Up @@ -280,8 +244,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold;
auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_CausalMask;
auto gmem_thr_copy_CausalMask = gmem_tiled_copy_CausalMask.get_thread_slice(tidx);

Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Expand All @@ -291,12 +253,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold);
Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold);
decltype(gmem_thr_copy_CausalMask.partition_S(gCausalMask)) tCausalMaskgCausalMask;
decltype(gmem_thr_copy_CausalMask.partition_D(sCausalMask)) tCausalMasksCausalMask;
if (has_causal_mask) {
tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S(gCausalMask);
tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D(sCausalMask);
}

// Matrix Multiply Accumulate
typename Kernel_traits::TiledMma tiled_mma;
Expand Down Expand Up @@ -336,23 +292,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Identity tensor for gZeroHold -> sZeroHold copy
Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold)));
// Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask
Tensor cCausalMask = make_identity_tensor(make_shape(
has_causal_mask ? size<0>(sCausalMask) : Int<1>{},
has_causal_mask ? size<1>(sCausalMask) : Int<1>{}
));
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Predicate for ZeroHold GMEM copy
Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold);
// Predicate for CausalMask GMEM copy
Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S(cCausalMask);
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
Tensor tZeroHoldpZeroHold = make_tensor<bool>(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold
Tensor tCausalMaskpCausalMask = make_tensor<bool>(make_shape(size<2>(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask)
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
Expand All @@ -363,71 +311,61 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tKVpKV); ++k) {
tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
}
#pragma unroll
for (int k = 0; k < size(tZeroHoldpZeroHold); ++k) {
tZeroHoldpZeroHold(k) = true; // All elements are valid for the moment
}
}

// 初始化动态掩码处理器
// Prologue
// Init dynamic mask processor
DynamicMask<Is_causal> dynamic_mask(params.keep_window_size);

// 加载Q到共享内存
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM
);

if (Kernel_traits::Is_Q_in_regs) {
cute::cp_async_fence();
}

// 如果共享Q和K的内存,需要等待并同步
// If share Q and K smem, wait and sync
if (Kernel_traits::Share_Q_K_smem) {
FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}

// 反向迭代N块
// Reverse iteration over N blocks
int n_block = n_block_max - 1;

// 加载第一个K块到共享内存
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
gmem_tiled_copy_QKV,
tKgK(_, _, _, n_block),
tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();

// 加载第一个ZeroHold块到共享内存
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();

// 加载第一个CausalMask块到共享内存(如果有)
if (params.causal_mask_ptr != nullptr) {
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask,
binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();
}

// 将Q从共享内存加载到寄存器(如果需要)
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
FLASH_NAMESPACE::cp_async_wait<1>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}

// 初始化输出累加器
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
// For ZeroHold, Is_even_K in copy refers to the kBlockN dimension alignment for vectorization,
// which is generally true. The boundary is handled by the length argument.
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
gmem_tiled_copy_ZeroHold,
tZeroHoldgZeroHold,
tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
binfo.actual_seqlen_k - n_block * kBlockN
);
cute::cp_async_fence();

clear(acc_o);

// 创建softmax计算器

FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;

// 处理需要掩码的块(通常是最后几个块)
Expand Down