-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Robust dBias accumulation for seqlen_q_bias == 1 #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -101,6 +101,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| if (n_block * kBlockN >= binfo.actual_seqlen_k) return; | ||||||||||
|
|
||||||||||
| int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); | ||||||||||
| bool accum_dbias = Has_bias && (params.dbias_row_stride == 0) && (binfo.actual_seqlen_q > 1); | ||||||||||
|
|
||||||||||
| const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) | ||||||||||
| + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; | ||||||||||
|
|
@@ -159,10 +160,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| Shape<Int<kBlockM>, Int<kBlockN>>{}, | ||||||||||
| make_stride(params.dbias_row_stride, _1{}) | ||||||||||
| ); | ||||||||||
| [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr; | ||||||||||
| if constexpr (Has_bias) { | ||||||||||
| gdBias_accum_ptr = reinterpret_cast<ElementAccum *>(params.dbias_ptr) + row_offset_dbias; | ||||||||||
| } | ||||||||||
| Tensor gdO = make_tensor( | ||||||||||
| make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do), | ||||||||||
| Shape<Int<kBlockM>, Int<kHeadDim>>{}, | ||||||||||
|
|
@@ -287,8 +284,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| GmemTiledCopydO gmem_tiled_copy_dO; | ||||||||||
| auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); | ||||||||||
| typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; | ||||||||||
| typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; | ||||||||||
| auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); | ||||||||||
| auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); | ||||||||||
| using GmemLayoutAtomdQaccum = std::conditional_t< | ||||||||||
| !Seq_parallel, | ||||||||||
|
|
@@ -297,6 +292,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| >; | ||||||||||
| GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; | ||||||||||
| auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); | ||||||||||
| typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; | ||||||||||
| auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); | ||||||||||
|
|
||||||||||
| Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); | ||||||||||
| Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); | ||||||||||
|
|
@@ -346,6 +343,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
|
|
||||||||||
| Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K) | ||||||||||
| Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (MMA, MMA_N, MMA_K) | ||||||||||
| [[maybe_unused]] auto acc_dbias = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); | ||||||||||
| [[maybe_unused]] auto acc_dbias_rowcol = make_tensor(acc_dbias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dbias.layout())); | ||||||||||
|
|
||||||||||
| // Copy Atom retiling | ||||||||||
| auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); | ||||||||||
|
|
@@ -641,8 +640,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| clear(acc_dv); | ||||||||||
| clear(acc_dk); | ||||||||||
| clear(acc_dv); | ||||||||||
| if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } } | ||||||||||
|
|
||||||||||
| for (; m_block >= m_block_min; --m_block) { | ||||||||||
| Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) | ||||||||||
|
|
@@ -806,6 +806,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); | ||||||||||
| if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } | ||||||||||
| dS(mi, ni) = scaled_ds; | ||||||||||
| if constexpr (Has_bias) { | ||||||||||
| if (accum_dbias) { | ||||||||||
| acc_dbias_rowcol(mi, ni) += scaled_ds; | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| // if (cute::thread0()) { print(dS); } | ||||||||||
|
|
@@ -852,36 +857,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| __syncthreads(); | ||||||||||
| if constexpr (Has_bias) { | ||||||||||
| // Write dS to dBias | ||||||||||
| if (!params.accum_dbias) { | ||||||||||
| if (!accum_dbias) { | ||||||||||
| FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/false>( | ||||||||||
| gmem_tiled_copy_dBias, | ||||||||||
| tBiassBias, tdBiasgdBias, | ||||||||||
| tBiascBias, tBiaspBias, | ||||||||||
| binfo.actual_seqlen_q - m_block * kBlockM | ||||||||||
| ); | ||||||||||
| } else { | ||||||||||
| #pragma unroll | ||||||||||
| for (int m = 0; m < size<1>(tBiassBias); ++m) { | ||||||||||
| if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { | ||||||||||
| #pragma unroll | ||||||||||
| for (int n = 0; n < size<2>(tBiassBias); ++n) { | ||||||||||
| if (Is_even_MN || tBiaspBias(n)) { | ||||||||||
| #pragma unroll | ||||||||||
| for (int i = 0; i < size<0>(tBiassBias); ++i) { | ||||||||||
| const auto coord = tBiascBias(i, m, n); | ||||||||||
| const int row = get<0>(coord); | ||||||||||
| const int col = get<1>(coord); | ||||||||||
| if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { | ||||||||||
| atomicAdd( | ||||||||||
| gdBias_accum_ptr + row * params.dbias_row_stride + col, | ||||||||||
| static_cast<ElementAccum>(tBiassBias(i, m, n)) | ||||||||||
| ); | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
@@ -1023,9 +1005,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
| // Advance gBias and gdBias | ||||||||||
| tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); | ||||||||||
| tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); | ||||||||||
| if (params.accum_dbias) { | ||||||||||
| gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride); | ||||||||||
| } | ||||||||||
| if (any_active_next) { | ||||||||||
| FLASH_NAMESPACE::copy_MN<Is_even_MN, /*Clear_OOB_MN=*/true>( | ||||||||||
| gmem_tiled_copy_Bias, | ||||||||||
|
|
@@ -1069,10 +1048,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in | |||||||||
|
|
||||||||||
| // Epilogue | ||||||||||
|
|
||||||||||
| if constexpr (Has_bias) { | ||||||||||
| if (accum_dbias) { | ||||||||||
| const int actual_block_n = Is_even_MN ? kBlockN : std::max(0, std::min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN)); | ||||||||||
|
|
||||||||||
| // Convert acc_dbias from fp32 to fp16 | ||||||||||
| Tensor tdBiasrdBias = FLASH_NAMESPACE::convert_type<Element>(acc_dbias); | ||||||||||
|
|
||||||||||
| // Partition sBias to match the accumulator partitioning | ||||||||||
| Tensor tdBiasadBias = smem_thr_copy_Bias.retile_S(tdBiasrdBias); // ((Atom, AtomNum), MMA_M, MMA_N) | ||||||||||
|
|
||||||||||
| // We need syncthreads here since we're writing to the same location as sBias. | ||||||||||
| // Without syncthreads, some thread might modify the location of sBias while another thread | ||||||||||
| // is reading it for dQ gemm, leading to a race condition. | ||||||||||
| // If Is_last, there's already a __syncthreads() at the end of the loop. | ||||||||||
| if (!Is_last) { __syncthreads(); } | ||||||||||
|
|
||||||||||
| cute::copy(smem_tiled_copy_PdS, tdBiasadBias, tdSsdS); | ||||||||||
|
|
||||||||||
| __syncthreads(); | ||||||||||
|
||||||||||
| __syncthreads(); | |
| __syncthreads(); | |
| // Perform a column-wise sum across all M rows into row 0 of shared memory. | |
| // This is the critical step for correct dBias accumulation when dbias_row_stride == 0. |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_M=1 parameter is critical to the correctness of this implementation as it ensures only the reduced row is written. This should have a comment explaining that we're writing only row 0 which contains the sum across all M rows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype for dbias_expanded allocation has changed from at::kFloat to the default opts dtype. This may cause precision loss if opts is not Float32, since the accumulation in the kernel uses ElementAccum (fp32). Consider explicitly preserving opts.dtype(at::kFloat) to match the kernel's accumulation precision.