Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ struct Flash_bwd_params : public Flash_fwd_params {
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
void *__restrict__ dzero_hold_ptr;
void *__restrict__ dzoh_ptr;

// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
void *__restrict__ dzero_hold_accum_ptr;
void *__restrict__ dzoh_accum_ptr;

// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
Expand All @@ -192,9 +192,10 @@ struct Flash_bwd_params : public Flash_fwd_params {
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
index_t dzero_hold_batch_stride;
index_t dzero_hold_head_stride;
index_t dzero_hold_row_stride;
index_t dzoh_batch_stride;
index_t dzoh_head_stride;
index_t dzoh_row_stride;
index_t dzoh_col_stride;

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
Expand Down