From 74918cfaf9a21d069d0d8adaa24b87a8aa233918 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 1 Jul 2025 21:32:16 +0800 Subject: [PATCH] Renames dzero_hold to dzoh and adds column stride Shortens variable names for better readability and consistency. Adds missing column stride parameter to support proper memory layout handling. --- csrc/src/flash.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 0923185..277d5a3 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -165,13 +165,13 @@ struct Flash_bwd_params : public Flash_fwd_params { void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; - void *__restrict__ dzero_hold_ptr; + void *__restrict__ dzoh_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; - void *__restrict__ dzero_hold_accum_ptr; + void *__restrict__ dzoh_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ @@ -192,9 +192,10 @@ struct Flash_bwd_params : public Flash_fwd_params { index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; - index_t dzero_hold_batch_stride; - index_t dzero_hold_head_stride; - index_t dzero_hold_row_stride; + index_t dzoh_batch_stride; + index_t dzoh_head_stride; + index_t dzoh_row_stride; + index_t dzoh_col_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum;