From 97b6cdb2da1425795e1d26111a7a0a45eb8cf74e Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 18:34:04 +0800 Subject: [PATCH] Reorders stride parameter assignments Reorganizes the order of stride parameter assignments to group related parameters together for better code readability and maintainability. Adds missing column stride assignments for zoh and active_mask tensors. --- csrc/flash_api.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index bfeb420..2e04840 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -75,13 +75,15 @@ void set_params_fprop( params.v_row_stride = v.stride(-3); params.zoh_row_stride = zoh.stride(-2); params.active_mask_row_stride = active_mask.stride(-2); + params.o_row_stride = out.stride(-3); params.q_head_stride = q.stride(-2); params.k_head_stride = k.stride(-2); params.v_head_stride = v.stride(-2); params.zoh_head_stride = zoh.stride(-3); params.active_mask_head_stride = active_mask.stride(-3); - params.o_row_stride = out.stride(-3); params.o_head_stride = out.stride(-2); + params.zoh_col_stride = zoh.stride(-1); + params.active_mask_col_stride = active_mask.stride(-1); if (cu_seqlens_q_d == nullptr) { params.q_batch_stride = q.stride(0);