@@ -169,40 +169,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
169169 );
170170
171171 Tensor mZeroHold = make_tensor (
172- make_gmem_ptr (reinterpret_cast <Element*>(params.zero_hold_ptr ) + bidb * params.zero_hold_batch_stride ),
172+ make_gmem_ptr (reinterpret_cast <Element*>(params.zero_hold_ptr ) + binfo. q_offset ( params.zero_hold_batch_stride , params. zero_hold_row_stride , bidb) ),
173173 make_shape (params.h_k , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
174- make_stride (params.zero_hold_head_stride , params.zero_hold_query_stride , _1{})
174+ make_stride (params.zero_hold_head_stride , params.zero_hold_row_stride , _1{})
175175 );
176176 Tensor gZeroHold = local_tile (
177177 mZeroHold (bidh / params.h_h_k_ratio , _, _),
178178 Shape<Int<kBlockM >, Int<kBlockN >>{},
179- make_coord (m_block, 0 )
180- );
181-
182- auto mCausalMask = has_causal_mask ?
183- make_tensor (
184- make_gmem_ptr (reinterpret_cast <Element*>(params.causal_mask_ptr ) + bidb * params.causal_mask_batch_stride ),
185- make_shape (1 , binfo.actual_seqlen_q , binfo.actual_seqlen_k ),
186- make_stride (params.causal_mask_head_stride , params.causal_mask_query_len_stride , _1{})
187- ) :
188- make_tensor (
189- make_gmem_ptr (static_cast <Element*>(nullptr )),
190- make_shape (1 , 1 , 1 ),
191- make_stride (static_cast <flash::index_t >(0 ), static_cast <flash::index_t >(0 ), _1{})
192- );
193-
194- auto gCausalMask = has_causal_mask ?
195- local_tile (
196- mCausalMask (0 , _, _),
197- Shape<Int<kBlockM >, Int<kBlockN >>{},
198- make_coord (m_block, 0 )
199- ) :
200- make_tensor (
201- make_gmem_ptr (static_cast <Element*>(nullptr )),
202- make_layout (
203- Shape<Int<kBlockM >, Int<kBlockN >>{},
204- make_stride (static_cast <flash::index_t >(0 ), _1{}))
205- );
179+ make_coord (m_block, n_block_max - 1 )
180+ ); // (kBlockM, kBlockN)
206181
207182 // Shared memory layout configuration
208183 Tensor sQ = make_tensor (
@@ -230,22 +205,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
230205 // Dynamic mask related shared memory. Use a running char* pointer for robust allocation.
231206 char * dynamic_smem_current_ptr = reinterpret_cast <char *>(sV .data ().get () + size (sV ) * sizeof (Element));
232207 Tensor sZeroHold = make_tensor (
233- make_smem_ptr (reinterpret_cast <Element*>(dynamic_smem_current_ptr)),
208+ make_smem_ptr (reinterpret_cast <Element*>(dynamic_smem_current_ptr)), // Element type
234209 typename Kernel_traits::SmemLayoutZeroHold{}
235210 );
236211
237212 dynamic_smem_current_ptr += Kernel_traits::kSmemZeroHoldSize ;
238- auto causal_mask_smem_ptr = has_causal_mask ?
239- make_smem_ptr (reinterpret_cast <Element*>(dynamic_smem_current_ptr)) :
240- make_smem_ptr (static_cast <Element*>(nullptr ));
241- Tensor sCausalMask = make_tensor (
242- causal_mask_smem_ptr,
243- typename Kernel_traits::SmemLayoutCausalMask{}
244- );
245-
246- if (has_causal_mask) {
247- dynamic_smem_current_ptr += Kernel_traits::kSmemCausalMaskSize ;
248- }
249213 Tensor sDynamicMaskValues = make_tensor (
250214 make_smem_ptr (reinterpret_cast <float *>(dynamic_smem_current_ptr)), // float type
251215 typename Kernel_traits::SmemLayoutDynamicMaskValues{}
@@ -280,8 +244,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
280244 auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice (tidx);
281245 typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_ZeroHold;
282246 auto gmem_thr_copy_ZeroHold = gmem_tiled_copy_ZeroHold.get_thread_slice (tidx);
283- typename Kernel_traits::GmemTiledCopyZeroHold gmem_tiled_copy_CausalMask;
284- auto gmem_thr_copy_CausalMask = gmem_tiled_copy_CausalMask.get_thread_slice (tidx);
285247
286248 Tensor tQgQ = gmem_thr_copy_QKV.partition_S (gQ );
287249 Tensor tQsQ = gmem_thr_copy_QKV.partition_D (sQ );
@@ -291,12 +253,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
291253 Tensor tVsV = gmem_thr_copy_QKV.partition_D (sV );
292254 Tensor tZeroHoldgZeroHold = gmem_thr_copy_ZeroHold.partition_S (gZeroHold );
293255 Tensor tZeroHoldsZeroHold = gmem_thr_copy_ZeroHold.partition_D (sZeroHold );
294- decltype (gmem_thr_copy_CausalMask.partition_S (gCausalMask )) tCausalMaskgCausalMask;
295- decltype (gmem_thr_copy_CausalMask.partition_D (sCausalMask )) tCausalMasksCausalMask;
296- if (has_causal_mask) {
297- tCausalMaskgCausalMask = gmem_thr_copy_CausalMask.partition_S (gCausalMask );
298- tCausalMasksCausalMask = gmem_thr_copy_CausalMask.partition_D (sCausalMask );
299- }
300256
301257 // Matrix Multiply Accumulate
302258 typename Kernel_traits::TiledMma tiled_mma;
@@ -336,23 +292,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
336292 Tensor cKV = make_identity_tensor (make_shape (size<0 >(sK ), size<1 >(sK ))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
337293 // Identity tensor for gZeroHold -> sZeroHold copy
338294 Tensor cZeroHold = make_identity_tensor (make_shape (size<0 >(sZeroHold ), size<1 >(sZeroHold )));
339- // Identity tensor for gCausalMask -> sCausalMask copy, use dummy 1×1 when no mask
340- Tensor cCausalMask = make_identity_tensor (make_shape (
341- has_causal_mask ? size<0 >(sCausalMask ) : Int<1 >{},
342- has_causal_mask ? size<1 >(sCausalMask ) : Int<1 >{}
343- ));
344295 // Repeat the partitioning with identity layouts
345296 Tensor tQcQ = gmem_thr_copy_QKV.partition_S (cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
346297 Tensor tKVcKV = gmem_thr_copy_QKV.partition_S (cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
347298 // Predicate for ZeroHold GMEM copy
348299 Tensor tZeroHoldcZeroHold = gmem_thr_copy_ZeroHold.partition_S (cZeroHold);
349- // Predicate for CausalMask GMEM copy
350- Tensor tCausalMaskcCausalMask = gmem_thr_copy_CausalMask.partition_S (cCausalMask);
351300 // Allocate predicate tensors for k
352301 Tensor tQpQ = make_tensor<bool >(make_shape (size<2 >(tQsQ)));
353302 Tensor tKVpKV = make_tensor<bool >(make_shape (size<2 >(tKsK)));
354303 Tensor tZeroHoldpZeroHold = make_tensor<bool >(make_shape (size<2 >(tZeroHoldsZeroHold))); // N-dim predicate for ZeroHold
355- Tensor tCausalMaskpCausalMask = make_tensor<bool >(make_shape (size<2 >(tCausalMasksCausalMask))); // N-dim predicate for CausalMask (always allocate; only used when has_causal_mask)
356304 // Set predicates for k bounds
357305 if (!Is_even_K) {
358306 #pragma unroll
@@ -363,71 +311,61 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
363311 for (int k = 0 ; k < size (tKVpKV); ++k) {
364312 tKVpKV (k) = get<1 >(tKVcKV (0 , 0 , k)) < params.d ;
365313 }
314+ #pragma unroll
315+ for (int k = 0 ; k < size (tZeroHoldpZeroHold); ++k) {
316+ tZeroHoldpZeroHold (k) = true ; // All elements are valid for the moment
317+ }
366318 }
367319
368- // 初始化动态掩码处理器
320+ // Prologue
321+ // Init dynamic mask processor
369322 DynamicMask<Is_causal> dynamic_mask (params.keep_window_size );
370-
371- // 加载Q到共享内存
323+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
372324 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
373325 gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
374326 binfo.actual_seqlen_q - m_block * kBlockM
375327 );
376-
377328 if (Kernel_traits::Is_Q_in_regs) {
378329 cute::cp_async_fence ();
379330 }
380-
381- // 如果共享Q和K的内存,需要等待并同步
331+ // If share Q and K smem, wait and sync
382332 if (Kernel_traits::Share_Q_K_smem) {
383333 FLASH_NAMESPACE::cp_async_wait<0 >();
384334 __syncthreads ();
385335 Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D (tSrQ);
386- CUTE_STATIC_ASSERT_V (size<1 >(tSsQ) == size<1 >(tSrQ_copy_view));
336+ CUTE_STATIC_ASSERT_V (size<1 >(tSsQ) == size<1 >(tSrQ_copy_view)); // M
387337 cute::copy (smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
388338 __syncthreads ();
389339 }
390-
391- // 反向迭代N块
340+ // Reverse iteration over N blocks
392341 int n_block = n_block_max - 1 ;
393-
394- // 加载第一个K块到共享内存
342+ // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
395343 FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
396- gmem_tiled_copy_QKV, tKgK (_, _, _, n_block), tKsK, tKVcKV, tKVpKV,
344+ gmem_tiled_copy_QKV,
345+ tKgK (_, _, _, n_block),
346+ tKsK, tKVcKV, tKVpKV,
397347 binfo.actual_seqlen_k - n_block * kBlockN
398348 );
399349 cute::cp_async_fence ();
400-
401- // 加载第一个ZeroHold块到共享内存
402- FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
403- gmem_tiled_copy_ZeroHold, tZeroHoldgZeroHold (_, _, _, n_block), tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
404- binfo.actual_seqlen_k - n_block * kBlockN
405- );
406- cute::cp_async_fence ();
407-
408- // 加载第一个CausalMask块到共享内存(如果有)
409- if (params.causal_mask_ptr != nullptr ) {
410- FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
411- gmem_tiled_copy_CausalMask, tCausalMaskgCausalMask (_, _, _, n_block), tCausalMasksCausalMask, tCausalMaskcCausalMask, tCausalMaskpCausalMask,
412- binfo.actual_seqlen_k - n_block * kBlockN
413- );
414- cute::cp_async_fence ();
415- }
416-
417- // 将Q从共享内存加载到寄存器(如果需要)
418350 if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
419351 FLASH_NAMESPACE::cp_async_wait<1 >();
420352 __syncthreads ();
421353 Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D (tSrQ);
422- CUTE_STATIC_ASSERT_V (size<1 >(tSsQ) == size<1 >(tSrQ_copy_view));
354+ CUTE_STATIC_ASSERT_V (size<1 >(tSsQ) == size<1 >(tSrQ_copy_view)); // M
423355 cute::copy (smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
424356 }
425-
426- // 初始化输出累加器
427- Tensor acc_o = partition_fragment_C (tiled_mma, Shape<Int<kBlockM >, Int<kHeadDim >>{});
357+ // For ZeroHold, Is_even_K in copy refers to the kBlockN dimension alignment for vectorization,
358+ // which is generally true. The boundary is handled by the length argument.
359+ FLASH_NAMESPACE::copy<Is_even_MN, Is_even_K>(
360+ gmem_tiled_copy_ZeroHold,
361+ tZeroHoldgZeroHold,
362+ tZeroHoldsZeroHold, tZeroHoldcZeroHold, tZeroHoldpZeroHold,
363+ binfo.actual_seqlen_k - n_block * kBlockN
364+ );
365+ cute::cp_async_fence ();
366+
428367 clear (acc_o);
429-
430- // 创建softmax计算器
368+
431369 FLASH_NAMESPACE::Softmax<2 * size<1 >(acc_o)> softmax;
432370
433371 // 处理需要掩码的块(通常是最后几个块)
0 commit comments