diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 040eba1..0782b90 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -190,7 +190,6 @@ void set_params_dgrad( const float softcap, bool has_mask, bool has_bias, - bool accum_dbias, bool deterministic, const bool unpadded_lse ) { @@ -246,8 +245,6 @@ void set_params_dgrad( // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; - params.accum_dbias = accum_dbias; - params.deterministic = deterministic; } @@ -982,11 +979,10 @@ mha_bwd( dbias_expanded = has_bias ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q ? (seqlen_q_bias == 1) - ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat)) + ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts) : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias : torch::empty({0}, opts); - bool accum_dbias = has_bias && (seqlen_q_bias == 1 && seqlen_q != 1); Flash_bwd_params params; @@ -1013,7 +1009,6 @@ mha_bwd( softcap, has_mask, has_bias, - accum_dbias, deterministic, /*unpadded_lse*/false ); @@ -1041,7 +1036,7 @@ mha_bwd( if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } else { - if (accum_dbias) { + if (seqlen_q_bias == 1) { dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2}); } else { dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); @@ -1244,7 +1239,6 @@ mha_varlen_bwd( softcap, has_mask, has_bias, - /*accum_dbias*/false, deterministic, /*unpadded_lse*/true ); diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index 29c342f..a1c9bf1 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -195,8 +195,6 @@ struct Flash_bwd_params : public Flash_fwd_params { bool deterministic; index_t dq_accum_split_stride; - - bool accum_dbias; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 7c4056e..9a8b534 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -101,6 +101,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in if (n_block * kBlockN >= binfo.actual_seqlen_k) return; int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM); + bool accum_dbias = Has_bias && (params.dbias_row_stride == 0) && (binfo.actual_seqlen_q > 1); const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride; @@ -159,10 +160,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Shape, Int>{}, make_stride(params.dbias_row_stride, _1{}) ); - [[maybe_unused]] ElementAccum *gdBias_accum_ptr = nullptr; - if constexpr (Has_bias) { - gdBias_accum_ptr = reinterpret_cast(params.dbias_ptr) + row_offset_dbias; - } Tensor gdO = make_tensor( make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, @@ -287,8 +284,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in GmemTiledCopydO gmem_tiled_copy_dO; auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; - auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); using GmemLayoutAtomdQaccum = std::conditional_t< !Seq_parallel, @@ -297,6 +292,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in >; GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydBias gmem_tiled_copy_dBias; + auto gmem_thr_copy_dBias = gmem_tiled_copy_dBias.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -346,6 +343,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // (MMA, MMA_N, MMA_K) Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // (MMA, MMA_N, MMA_K) + [[maybe_unused]] auto acc_dbias = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); + [[maybe_unused]] auto acc_dbias_rowcol = make_tensor(acc_dbias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(acc_dbias.layout())); // Copy Atom retiling auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); @@ -641,8 +640,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view); } - clear(acc_dv); clear(acc_dk); + clear(acc_dv); + if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } } for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -806,6 +806,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); } dS(mi, ni) = scaled_ds; + if constexpr (Has_bias) { + if (accum_dbias) { + acc_dbias_rowcol(mi, ni) += scaled_ds; + } + } } } // if (cute::thread0()) { print(dS); } @@ -852,36 +857,13 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in __syncthreads(); if constexpr (Has_bias) { // Write dS to dBias - if (!params.accum_dbias) { + if (!accum_dbias) { FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_dBias, tBiassBias, tdBiasgdBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM ); - } else { - #pragma unroll - for (int m = 0; m < size<1>(tBiassBias); ++m) { - if (Is_even_MN || get<0>(tBiascBias(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) { - #pragma unroll - for (int n = 0; n < size<2>(tBiassBias); ++n) { - if (Is_even_MN || tBiaspBias(n)) { - #pragma unroll - for (int i = 0; i < size<0>(tBiassBias); ++i) { - const auto coord = tBiascBias(i, m, n); - const int row = get<0>(coord); - const int col = get<1>(coord); - if (Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM) { - atomicAdd( - gdBias_accum_ptr + row * params.dbias_row_stride + col, - static_cast(tBiassBias(i, m, n)) - ); - } - } - } - } - } - } } } @@ -1023,9 +1005,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Advance gBias and gdBias tBiasgBias.data() = tBiasgBias.data() + (-int(kBlockM * params.bias_row_stride)); tdBiasgdBias.data() = tdBiasgdBias.data() + (-int(kBlockM * params.dbias_row_stride)); - if (params.accum_dbias) { - gdBias_accum_ptr -= int(kBlockM * params.dbias_row_stride); - } if (any_active_next) { FLASH_NAMESPACE::copy_MN( gmem_tiled_copy_Bias, @@ -1069,10 +1048,53 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Epilogue + if constexpr (Has_bias) { + if (accum_dbias) { + const int actual_block_n = Is_even_MN ? kBlockN : std::max(0, std::min(kBlockN, binfo.actual_seqlen_k - n_block * kBlockN)); + + // Convert acc_dbias from fp32 to fp16 + Tensor tdBiasrdBias = FLASH_NAMESPACE::convert_type(acc_dbias); + + // Partition sBias to match the accumulator partitioning + Tensor tdBiasadBias = smem_thr_copy_Bias.retile_S(tdBiasrdBias); // ((Atom, AtomNum), MMA_M, MMA_N) + + // We need syncthreads here since we're writing to the same location as sBias. + // Without syncthreads, some thread might modify the location of sBias while another thread + // is reading it for dQ gemm, leading to a race condition. + // If Is_last, there's already a __syncthreads() at the end of the loop. + if (!Is_last) { __syncthreads(); } + + cute::copy(smem_tiled_copy_PdS, tdBiasadBias, tdSsdS); + + __syncthreads(); + for (int col = threadIdx.x; col < kBlockN; col += blockDim.x) { + if (col < actual_block_n) { + ElementAccum rowsum = 0.f; + #pragma unroll + for (int row = 0; row < kBlockM; ++row) { + rowsum += static_cast(sdS(row, col)); + } + sdS(0, col) = static_cast(rowsum); + } + } + __syncthreads(); + + #pragma unroll + for (int ni = 0; ni < size(tBiaspBias); ++ni) { tBiaspBias(ni) = ni < actual_block_n; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + FLASH_NAMESPACE::copy_MN( + gmem_tiled_copy_dBias, + tBiassBias, tdBiasgdBias, + tBiascBias, tBiaspBias, + /*max_M=*/1 + ); + } + } + #pragma unroll for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax; } - // Convert acc_dv from fp32 to fp16 + // Convert acc_dk, acc_dv from fp32 to fp16 Tensor rdK = FLASH_NAMESPACE::convert_type(acc_dk); Tensor rdV = FLASH_NAMESPACE::convert_type(acc_dv);