Skip to content

Commit 134fe9a

Browse files
authored
Fix some nits for layout
2 parents 5c3c4f7 + 71e70af commit 134fe9a

File tree

1 file changed

+31
-93
lines changed

1 file changed

+31
-93
lines changed

csrc/src/flash_attention_fwd_kernel.h

Lines changed: 31 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -169,40 +169,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
169169
);
170170

171171
Tensor mZeroHold = make_tensor(
172-
make_gmem_ptr(reinterpret_cast<Element*>(params.zero_hold_ptr) + bidb * params.zero_hold_batch_stride),
172+
make_gmem_ptr(reinterpret_cast<Element*>(params.zero_hold_ptr) + binfo.q_offset(params.zero_hold_batch_stride, params.zero_hold_row_stride, bidb)),
173173
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
174-
make_stride(params.zero_hold_head_stride, params.zero_hold_query_stride, _1{})
174+
make_stride(params.zero_hold_head_stride, params.zero_hold_row_stride, _1{})
175175
);
176176
Tensor gZeroHold = local_tile(
177177
mZeroHold(bidh / params.h_h_k_ratio, _, _),
178178
Shape<Int<kBlockM>, Int<kBlockN>>{},
179-
make_coord(m_block, 0)
180-
);
181-
182-
auto mCausalMask = has_causal_mask ?
183-
make_tensor(
184-
make_gmem_ptr(reinterpret_cast<Element*>(params.causal_mask_ptr) + bidb * params.causal_mask_batch_stride),
185-
make_shape(1, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
186-
make_stride(params.causal_mask_head_stride, params.causal_mask_query_len_stride, _1{})
187-
) :
188-
make_tensor(
189-
make_gmem_ptr(static_cast<Element*>(nullptr)),
190-
make_shape(1, 1, 1),
191-
make_stride(static_cast<flash::index_t>(0), static_cast<flash::index_t>(0), _1{})
192-
);
193-
194-
auto gCausalMask = has_causal_mask ?
195-
local_tile(
196-
mCausalMask(0, _, _),
197-
Shape<Int<kBlockM>, Int<kBlockN>>{},
198-
make_coord(m_block, 0)
199-
) :
200-
make_tensor(
201-
make_gmem_ptr(static_cast<Element*>(nullptr)),
202-
make_layout(
203-
Shape<Int<kBlockM>, Int<kBlockN>>{},
204-
make_stride(static_cast<flash::index_t>(0), _1{}))
205-
);
179+
make_coord(m_block, n_block_max - 1)
180+
); // (kBlockM, kBlockN)
206181

207182
// Shared memory layout configuration
208183
Tensor sQ = make_tensor(
@@ -230,22 +205,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
230205
// Dynamic mask related shared memory. Use a running char* pointer for robust allocation.
231206
char* dynamic_smem_current_ptr = reinterpret_cast<char*>(sV.data().get() + size(sV) * sizeof(Element));
232207
Tensor sZeroHold = make_tensor(
233-
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)),
208+
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)), // Element type
234209
typename Kernel_traits::SmemLayoutZeroHold{}
235210
);
236211

237212
dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize;
238-
auto causal_mask_smem_ptr = has_causal_mask ?
239-
make_smem_ptr(reinterpret_cast<Element*>(dynamic_smem_current_ptr)) :
240-
make_smem_ptr(static_cast<Element*>(nullptr));
241-
Tensor sCausalMask = make_tensor(
242-
causal_mask_smem_ptr,
243-
typename Kernel_traits::SmemLayoutCausalMask{}
244-
);
245-
246-
if (has_causal_mask) {
247-
dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize;
248-
}
249213
Tensor sDynamicMaskValues = make_tensor(
250214
make_smem_ptr(reinterpret_cast<float*>(dynamic_smem_current_ptr)), // float type
251215
typename Kernel_traits::SmemLayoutDynamicMaskValues{}
@@ -280,8 +244,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
280244
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
281245
typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold;
282246
auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice(tidx);
283-
typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_CausalMask;
284-
auto gmem_thr_copy_CausalMask = gmem_tiled_copy_CausalMask.get_thread_slice(tidx);
285247

286248
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
287249
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@@ -291,12 +253,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
291253
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
292254
Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S(gZeroHold);
293255
Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D(sZeroHold);
294-
decltype(gmem_thr_copy_CausalMask.partition_S(gCausalMask)) tCausalMaskgCausalMask;
295-
decltype(gmem_thr_copy_CausalMask.partition_D(sCausalMask)) tCausalMasksCausalMask;
296-
if (has_causal_mask) {
297-
tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S(gCausalMask);
298-
tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D(sCausalMask);
299-
}
300256

301257
// Matrix Multiply Accumulate
302258
typename Kernel_traits::TiledMma tiled_mma;
@@ -336,23 +292,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
336292
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
337293
// Identity tensor for gZeroHold -> sZeroHold copy
338294
Tensor cZeroHold = make_identity_tensor(make_shape(size<0>(sZeroHold), size<1>(sZeroHold)));
339-
// Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask
340-
Tensor cCausalMask = make_identity_tensor(make_shape(
341-
has_causal_mask ? size<0>(sCausalMask) : Int<1>{},
342-
has_causal_mask ? size<1>(sCausalMask) : Int<1>{}
343-
));
344295
// Repeat the partitioning with identity layouts
345296
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
346297
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
347298
// Predicate for ZeroHold GMEM copy
348299
Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S(cZeroHold);
349-
// Predicate for CausalMask GMEM copy
350-
Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S(cCausalMask);
351300
// Allocate predicate tensors for k
352301
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
353302
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
354303
Tensor tZeroHoldpZeroHold = make_tensor<bool>(make_shape(size<2>(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold
355-
Tensor tCausalMaskpCausalMask = make_tensor<bool>(make_shape(size<2>(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask)
356304
// Set predicates for k bounds
357305
if (!Is_even_K) {
358306
#pragma unroll
@@ -363,71 +311,61 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
363311
for (int k = 0; k < size(tKVpKV); ++k) {
364312
tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
365313
}
314+
#pragma unroll
315+
for (int k = 0; k < size(tZeroHoldpZeroHold); ++k) {
316+
tZeroHoldpZeroHold(k) = true; // All elements are valid for the moment
317+
}
366318
}
367319

368-
// 初始化动态掩码处理器
320+
// Prologue
321+
// Init dynamic mask processor
369322
DynamicMask<Is_causal> dynamic_mask(params.keep_window_size);
370-
371-
// 加载Q到共享内存
323+
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
372324
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
373325
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
374326
binfo.actual_seqlen_q - m_block * kBlockM
375327
);
376-
377328
if (Kernel_traits::Is_Q_in_regs) {
378329
cute::cp_async_fence();
379330
}
380-
381-
// 如果共享Q和K的内存,需要等待并同步
331+
// If share Q and K smem, wait and sync
382332
if (Kernel_traits::Share_Q_K_smem) {
383333
FLASH_NAMESPACE::cp_async_wait<0>();
384334
__syncthreads();
385335
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
386-
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));
336+
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
387337
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
388338
__syncthreads();
389339
}
390-
391-
// 反向迭代N块
340+
// Reverse iteration over N blocks
392341
int n_block = n_block_max - 1;
393-
394-
// 加载第一个K块到共享内存
342+
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
395343
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
396-
gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
344+
gmem_tiled_copy_QKV,
345+
tKgK(_, _, _, n_block),
346+
tKsK, tKVcKV, tKVpKV,
397347
binfo.actual_seqlen_k - n_block * kBlockN
398348
);
399349
cute::cp_async_fence();
400-
401-
// 加载第一个ZeroHold块到共享内存
402-
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
403-
gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold(_, _, _, n_block), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
404-
binfo.actual_seqlen_k - n_block * kBlockN
405-
);
406-
cute::cp_async_fence();
407-
408-
// 加载第一个CausalMask块到共享内存(如果有)
409-
if (params.causal_mask_ptr != nullptr) {
410-
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
411-
gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask(_, _, _, n_block), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask,
412-
binfo.actual_seqlen_k - n_block * kBlockN
413-
);
414-
cute::cp_async_fence();
415-
}
416-
417-
// 将Q从共享内存加载到寄存器(如果需要)
418350
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
419351
FLASH_NAMESPACE::cp_async_wait<1>();
420352
__syncthreads();
421353
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
422-
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view));
354+
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
423355
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
424356
}
425-
426-
// 初始化输出累加器
427-
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
357+
// For ZeroHold, Is_even_K in copy refers to the kBlockN dimension alignment for vectorization,
358+
// which is generally true. The boundary is handled by the length argument.
359+
FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
360+
gmem_tiled_copy_ZeroHold,
361+
tZeroHoldgZeroHold,
362+
tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
363+
binfo.actual_seqlen_k - n_block * kBlockN
364+
);
365+
cute::cp_async_fence();
366+
428367
clear(acc_o);
429-
430-
// 创建softmax计算器
368+
431369
FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;
432370

433371
// 处理需要掩码的块(通常是最后几个块)

0 commit comments

Comments
 (0)