Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions csrc/src/flash_attention_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);

dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize;
auto causal_mask_smem_ptr = has_causal_mask
? make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr))
: make_smem_ptr(static_cast<Element*>(nullptr));
auto causal_mask_smem_ptr = has_causal_mask ?
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)) :
make_smem_ptr(static_cast<Element*>(nullptr));
Tensor sCausalMask = make_tensor(
causal_mask_smem_ptr,
typename Kernel_traits::SmemLayoutCausalMask{}
Expand All @@ -247,31 +247,31 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize;
}
Tensor sDynamicMaskValues = make_tensor(
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
typename Kernel_traits::SmemLayoutDynamicMaskValues{}
);

dynamic_smem_current_ptr += Kernel_traits::kSmemMaskValuesSize;
Tensor sDynamicMaskSortKeys = make_tensor(
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
typename Kernel_traits::SmemLayoutDynamicMaskSortKeys{}
);

dynamic_smem_current_ptr += Kernel_traits::kSmemSortKeysSize;
Tensor sDynamicMaskSortIndices = make_tensor(
make_smem_ptr(reinterpret_cast<int*>(dynamic_smem_current_ptr)), // int type
make_smem_ptr(reinterpret_cast<int*>(dynamic_smem_current_ptr)), // int type
typename Kernel_traits::SmemLayoutDynamicMaskSortIndices{}
);

dynamic_smem_current_ptr += Kernel_traits::kSmemSortIndicesSize;
Tensor sNonZeroIndices = make_tensor(
make_smem_ptr(reinterpret_cast<int*>(dynamic_smem_current_ptr)), // int type
make_smem_ptr(reinterpret_cast<int*>(dynamic_smem_current_ptr)), // int type
typename Kernel_traits::SmemLayoutNonZeroIndices{}
);

dynamic_smem_current_ptr += Kernel_traits::kSmemNonZeroIndicesSize;
Tensor sPredicate = make_tensor(
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)), // Element type
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)), // Element type
typename Kernel_traits::SmemLayoutPredicate{}
);

Expand Down Expand Up @@ -331,29 +331,29 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// auto smem_thr_copy_CausalMask_smem = smem_tiled_copy_CausalMask_smem.get_thread_slice(tidx);
// Tensor tSsCausalMask = has_causal_mask ? smem_thr_copy_CausalMask_smem.partition_S(sCausalMask) : empty_smem_tensor_for_copy_D;

// 设置谓词
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ)));
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));
Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold)));
Tensor cCausalMask = params.causal_mask_ptr != nullptr
? make_identity_tensor(make_shape(size<0>(sCausalMask), size<1>(sCausalMask)))
: Tensor();

Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);
Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold);
Tensor tCausalMaskcCausalMask = params.causal_mask_ptr != nullptr
? gmem_thr_copy_CausalMask.partition_S(cCausalMask)
: Tensor();

// PREDICATES
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
// Identity tensor for gZeroHold -> sZeroHold copy
Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold)));
// Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask
Tensor cCausalMask = make_identity_tensor(make_shape(
has_causal_mask ? size<0>(sCausalMask) : Int<1>{},
has_causal_mask ? size<1>(sCausalMask) : Int<1>{}
));
// Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Predicate for ZeroHold GMEM copy
Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold);
// Predicate for CausalMask GMEM copy
Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S(cCausalMask);
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
Tensor tZeroHoldpZeroHold = make_tensor<bool>(make_shape(size<2>(tZeroHoldsZeroHold)));
Tensor tCausalMaskpCausalMask = params.causal_mask_ptr != nullptr
? make_tensor<bool>(make_shape(size<2>(tCausalMasksCausalMask)))
: Tensor();

// 设置K维度的谓词
Tensor tZeroHoldpZeroHold = make_tensor<bool>(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold
Tensor tCausalMaskpCausalMask = make_tensor<bool>(make_shape(size<2>(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask)
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) {
Expand Down