Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions csrc/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,48 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
Tensor dS = make_tensor(acc_dp.data(), scores.layout());

// Compute dBias before modifying dS
// Following Triton logic: dbias = p * (dp - Di[:, None])
Tensor dBias = make_tensor_like(scores);
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dBias(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi));
}
}

// Copy dBias to shared memory and then to global memory
Tensor dBias_reshaped = make_tensor(dBias.data(), acc_dp.layout());
// Convert dBias from fp32 to fp16/bf16
Tensor rdBias = FLASH_NAMESPACE::convert_type<Element>(dBias_reshaped);
Tensor tadBiasdBias = smem_thr_copy_PdS.retile_S(rdBias);
cute::copy(smem_tiled_copy_PdS, tadBiasdBias, smem_thr_copy_PdS.partition_D(sdBias));

__syncthreads(); // Synchronize before copying to global memory

// Copy dBias from shared memory to global memory
typename Kernel_traits::GmemTiledCopyBias gmem_tiled_copy_bias;
auto gmem_thr_copy_bias = gmem_tiled_copy_bias.get_thread_slice(tidx);
Tensor tBiasgdBias = gmem_thr_copy_bias.partition_D(gdBias);
Tensor tBiassdBias = gmem_thr_copy_bias.partition_S(sdBias);
Tensor tBiasrdBias = make_tensor<Element>(shape(tBiasgdBias));
cute::copy(gmem_tiled_copy_bias, tBiassdBias, tBiasrdBias);

// Write to global memory with bounds checking
Tensor cBias = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tBiascBias = gmem_thr_copy_bias.partition_D(cBias);
Tensor tBiaspBias = make_tensor<bool>(make_shape(size<1>(tBiasgdBias)));
#pragma unroll
for (int n = 0; n < size(tBiaspBias); ++n) {
tBiaspBias(n) = get<1>(tBiascBias(0, 0, n)) < binfo.actual_seqlen_k - n_block * kBlockN;
}
// Copy with bounds checking similar to dK/dV pattern
FLASH_NAMESPACE::copy<Is_even_MN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_bias, tBiasrdBias, tBiasgdBias, tBiascBias, tBiaspBias, binfo.actual_seqlen_q - m_block * kBlockM
);

auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
Expand Down Expand Up @@ -739,6 +781,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
// Advance gdBias pointer for next iteration
gdBias.data() = gdBias.data() + (-int(kBlockM * params.dbias_row_stride));
}

if (!Is_last) {
Expand Down