Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 145 additions & 55 deletions csrc/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride
+ n_block * kBlockN * params.attn_bias_col_stride;
const index_t row_offset_dbias = binfo.attn_bias_offset(params.dbias_batch_stride, params.dbias_row_stride, params.dbias_col_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride
+ n_block * kBlockN * params.dbias_col_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
Expand All @@ -126,63 +129,150 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d
const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM;

Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride));
Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_bias_ptr) + row_offset_bias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride));
Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dbias_ptr) + row_offset_bias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, params.dbias_col_stride));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{});
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
// Global memory tensor configuration
Tensor gQ = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{})
);
Tensor gK = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{})
);
Tensor gV = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{})
);
Tensor gMask = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride)
);
Tensor gBias = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.attn_bias_ptr) + row_offset_bias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride)
);
Tensor gdO = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{})
);
Tensor gO = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{})
);
Tensor gdQ = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{})
);
Tensor gdQaccum = make_tensor(
make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{})
);
Tensor gdBias = make_tensor(
make_gmem_ptr(reinterpret_cast<Element *>(params.dbias_ptr) + row_offset_dbias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, params.dbias_col_stride)
);
Tensor gLSE = make_tensor(
make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}
);
Tensor gdPsum = make_tensor(
make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{}
);

// Shared memory layout configuration
Tensor sQ = make_tensor(
make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{}
);
Tensor sQt = make_tensor(
sQ.data(),
typename Kernel_traits::SmemLayoutQdOtransposed{}
);
Tensor sQtNoSwizzle = make_tensor(
sQ.data(),
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}
);
// Double buffer for sQ
Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});
Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});
Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
typename Kernel_traits::SmemLayoutPdS{});
Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});
Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
Tensor sdO = make_tensor(
sQ.data() + (Double_buffer ? 2 : 1) * size(sQ),
typename Kernel_traits::SmemLayoutQdO{}
);
Tensor sdOt = make_tensor(
sdO.data(),
typename Kernel_traits::SmemLayoutQdOtransposed{}
);
Tensor sdOtransposedNoSwizzle = make_tensor(
sdO.data(),
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}
);
Tensor sK = make_tensor(
sdO.data() + size(sdO),
typename Kernel_traits::SmemLayoutKV{}
);
Tensor sKt = make_tensor(
sK.data(),
typename Kernel_traits::SmemLayoutKtransposed{}
);
Tensor sKtNoSwizzle = make_tensor(
sK.data(),
typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}
);
Tensor sV = make_tensor(
sK.data() + size(sK),
typename Kernel_traits::SmemLayoutKV{}
);
Tensor sMask = make_tensor(
!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
typename Kernel_traits::SmemLayoutMask{}
);
Tensor sBias = make_tensor(
sMask.data() + size(sMask),
typename Kernel_traits::SmemLayoutBias{}
);
Tensor sdBias = make_tensor(
sBias.data(),
typename Kernel_traits::SmemLayoutBias{}
);
Tensor sdS = make_tensor(
sBias.data() + size(sBias),
typename Kernel_traits::SmemLayoutPdS{}
);
Tensor sdSt = make_tensor(
sdS.data(),
typename Kernel_traits::SmemLayoutPdStransposed{}
);
Tensor sdStNoSwizzle = make_tensor(
sdS.data(),
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
);
Tensor sP = make_tensor(
sdS.data() + size(sdS),
typename Kernel_traits::SmemLayoutPdS{}
);
Tensor sPt = make_tensor(
sP.data(),
typename Kernel_traits::SmemLayoutPdStransposed{}
);
Tensor sPtNoSwizzle = make_tensor(
sP.data(),
typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}
);
// sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
Tensor sdQ = make_tensor(
sP.data(),
typename Kernel_traits::SmemLayoutdQ{}
);


typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Expand Down
2 changes: 1 addition & 1 deletion csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;

// Golobal memory tensor configuration
// Global memory tensor configuration
Tensor mQ = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
Expand Down