diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 0923185..277d5a3 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -165,13 +165,13 @@ struct Flash_bwd_params : public Flash_fwd_params { void *__restrict__ dq_ptr; void *__restrict__ dk_ptr; void *__restrict__ dv_ptr; - void *__restrict__ dzero_hold_ptr; + void *__restrict__ dzoh_ptr; // To accumulate dQ void *__restrict__ dq_accum_ptr; void *__restrict__ dk_accum_ptr; void *__restrict__ dv_accum_ptr; - void *__restrict__ dzero_hold_accum_ptr; + void *__restrict__ dzoh_accum_ptr; // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ @@ -192,9 +192,10 @@ struct Flash_bwd_params : public Flash_fwd_params { index_t dq_head_stride; index_t dk_head_stride; index_t dv_head_stride; - index_t dzero_hold_batch_stride; - index_t dzero_hold_head_stride; - index_t dzero_hold_row_stride; + index_t dzoh_batch_stride; + index_t dzoh_head_stride; + index_t dzoh_row_stride; + index_t dzoh_col_stride; // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum;