From 33a38fe5b1cda5ed3b42d2725bcd096882c29f86 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 18:54:47 +0800 Subject: [PATCH] Adds column stride support to tensor memory layouts Updates tensor creation for ZOH and ActiveMask to include column stride parameters in offset calculations and stride configurations. Enables more flexible memory layout patterns by allowing non-unit column strides in addition to existing batch and row strides. --- csrc/src/flash_fwd_kernel.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 723b41f..0e1a15e 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -180,9 +180,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.seqlen_k_rounded, _1{}) ); Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{}) + make_stride(params.zoh_head_stride, params.zoh_row_stride, params.zoh_col_stride) ); Tensor gZOH = local_tile( mZOH(bidh / params.h_h_k_ratio, _, _), @@ -190,9 +190,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{}) + make_stride(params.active_mask_head_stride, params.active_mask_row_stride, params.active_mask_col_stride) ); Tensor gActiveMask = local_tile( mActiveMask(bidh / params.h_h_k_ratio, _, _),