diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 723b41f..0e1a15e 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -180,9 +180,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.seqlen_k_rounded, _1{}) ); Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset(params.zoh_batch_stride, params.zoh_row_stride, params.zoh_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.zoh_head_stride, params.zoh_row_stride, _1{}) + make_stride(params.zoh_head_stride, params.zoh_row_stride, params.zoh_col_stride) ); Tensor gZOH = local_tile( mZOH(bidh / params.h_h_k_ratio, _, _), @@ -190,9 +190,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_coord(m_block, _) ); // (kBlockM, kBlockN, nblocksN) Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, bidb)), + make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset(params.active_mask_batch_stride, params.active_mask_row_stride, params.active_mask_col_stride, bidb)), make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k), - make_stride(params.active_mask_head_stride, params.active_mask_row_stride, _1{}) + make_stride(params.active_mask_head_stride, params.active_mask_row_stride, params.active_mask_col_stride) ); Tensor gActiveMask = local_tile( mActiveMask(bidh / params.h_h_k_ratio, _, _),