From b6d2d5af8c24c1367a68ff66132a96592f1845b0 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 3 Nov 2025 23:59:10 +0800 Subject: [PATCH 01/11] Scale Q by softmax factor before streaming K/V; add sync for smem path --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 576de90..50d11dc 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -422,6 +422,20 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + // Scale Q once before streaming loop KV + if constexpr (Kernel_traits::Is_Q_in_regs) { + #pragma unroll + for (int i = 0; i < size(tSrQ); ++i) { + tSrQ(i) = static_cast(tSrQ(i) * params.scale_softmax); + } + } else { + #pragma unroll + for (int i = 0; i < size(tSsQ); ++i) { + tSsQ(i) = static_cast(tSsQ(i) * params.scale_softmax); + } + __syncthreads(); + } + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. From 7840038deace044bf5d858cbe1d8bb3219a9bc8b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Mon, 3 Nov 2025 23:59:58 +0800 Subject: [PATCH 02/11] Simplifies mask application Removes bias and scaling handling from the mask helper to reduce specialization paths and rely solely on masking behavior. --- csrc/flash_dmattn/src/mask.h | 75 ++++++------------------------------ 1 file changed, 11 insertions(+), 64 deletions(-) diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index 9f81cae..62c02c1 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -143,12 +143,10 @@ struct Mask { , max_seqlen_q(max_seqlen_q) { }; - template + template __forceinline__ __device__ void apply_mask( TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) const MaskType &mask_, // Attention Mask (MMA=4, MMA_M, MMA_N) - const BiasType &bias_, // Attention Bias (MMA=4, MMA_M, MMA_N) - const float scale_softmax, // Scale for softmax const int col_idx_offset_, // Column index offset const int row_idx_offset, // Row index offset const int warp_row_stride // Warp row stride @@ -156,8 +154,6 @@ struct Mask { static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor"); if constexpr (Has_mask) static_assert(MaskType::rank == 3, "mask_ must be 3D Tensor"); - if constexpr (Has_bias) - static_assert(BiasType::rank == 3, "Bias must be 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) @@ -166,9 +162,8 @@ struct Mask { const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Has_mask && Has_bias) { + if constexpr (Has_mask) { Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); - Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -183,63 +178,15 @@ struct Mask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + // Apply masking + if (col_idx >= col_idx_limit || !mask(coord)) { + tensor(coord) = -INFINITY; + } } } } } - } else if constexpr (Has_mask && !Has_bias) { - Tensor mask = make_tensor(mask_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(mask_.layout())); - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax; - } - } - } - } - } else if constexpr (!Has_mask && Has_bias) { - Tensor bias = make_tensor(bias_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(bias_.layout())); - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } else { // !Has_mask && !Has_bias + } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -254,10 +201,10 @@ struct Mask { for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax; + // Apply masking + if (col_idx >= col_idx_limit) { + tensor(coord) = -INFINITY; + } } } } From ce06d5fc0e94bbe3d0e7984b0a24eb54c18aa38c Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 00:37:18 +0800 Subject: [PATCH 03/11] Streamlines attention init Rationalizes accumulator setup so bias kernels reuse shared-memory bias instead of clearing registers, trimming sync overhead. Simplifies mask application templates to drop unused bias handling, tightening specialization footprint. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 103 ++++++++--------------- 1 file changed, 36 insertions(+), 67 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 50d11dc..074dfc9 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -433,7 +433,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int i = 0; i < size(tSsQ); ++i) { tSsQ(i) = static_cast(tSsQ(i) * params.scale_softmax); } - __syncthreads(); } // For performance reason, we separate out two kinds of iterations: @@ -451,10 +450,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (masking_step > 0 && any_active) { FLASH_NAMESPACE::copy( @@ -487,46 +498,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and add bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } @@ -616,10 +600,22 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (any_active) { FLASH_NAMESPACE::copy( @@ -642,46 +638,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and add bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } From 2e87f4ad111c97639512c90b7f7363fa089b1377 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 00:38:15 +0800 Subject: [PATCH 04/11] Improves attn block setup Pre-scales query tiles before streaming to cut redundant softmax multiplications. Initializes accumulators from shared bias when active so mask paths can skip extra clears. Simplifies mask application by dropping per-iteration bias scaling logic. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 108 +++++++++-------------- 1 file changed, 42 insertions(+), 66 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 074dfc9..83e878a 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -1140,6 +1140,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + // Scale Q once before streaming loop KV + #pragma unroll + for (int i = 0; i < size(tSrQ); ++i) { + tSsQ(i) = static_cast(tSsQ(i) * params.scale_softmax); + } + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -1155,10 +1161,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (masking_step > 0) { if (block_table == nullptr) { @@ -1202,46 +1220,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } @@ -1351,10 +1342,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // These are the iterations where we don't need masking on S for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - clear(acc_s); FLASH_NAMESPACE::cp_async_wait<0>(); __syncthreads(); + if (any_active) { + if constexpr (Has_bias) { + // Copy bias from smem to acc_s registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); + cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + // Advance gV if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); @@ -1386,46 +1389,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons FLASH_NAMESPACE::apply_softcap(acc_s, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); - cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and apply mask and bias - mask.template apply_mask( - acc_s, tSrMask, tSrBias, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_Mask.retile_D(tSrMask); cute::copy(smem_tiled_copy_Mask, tSsMask, tSrMask_copy_view); - // Scale attention scores and apply mask - mask.template apply_mask( - acc_s, tSrMask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias); - cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view); - - // Scale attention scores and add bias - mask.template apply_mask( - acc_s, /*mask=*/nullptr, tSrBias, params.scale_softmax, + mask.template apply_mask( + acc_s, tSrMask, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } else { - // Scale attention scores only - mask.template apply_mask( - acc_s, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + mask.template apply_mask( + acc_s, /*mask=*/nullptr, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 ); } From deffed77e3a7dcafb882df56932ea9c19dc0bb1a Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 01:00:32 +0800 Subject: [PATCH 05/11] Adds softmax unscale parameter Provides an inverse scaling factor so kernels can reuse precomputed softmax adjustments instead of recomputing them --- csrc/flash_dmattn/src/flash.h | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index a1c9bf1..1a2740a 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -101,6 +101,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; + float unscale_softmax; float softcap; // array of length b+1 holding starting offset of each sequence. From c65c642f5302e4714227bf994189575d2ab5d291 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 01:00:52 +0800 Subject: [PATCH 06/11] Sets inverse softmax scaling Ensures the reciprocal scale is always populated so downstream kernels can undo the softmax amplification without branching. --- csrc/flash_dmattn/flash_api.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 4a67ec1..fa5a9b1 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -132,11 +132,13 @@ void set_params_fprop( params.softcap = softmax_scale / softcap; params.scale_softmax = softcap; params.scale_softmax_log2 = softcap * M_LOG2E; + params.unscale_softmax = 1.0f / softmax_scale; } else{ // Remove potential NaN params.softcap = 0.0; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; + params.unscale_softmax = 1.0f / softmax_scale; } params.is_causal = is_causal; From df74af5e150a21a2e067b05576c5cfbdd4b287db Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 01:01:28 +0800 Subject: [PATCH 07/11] Simplifies mask application logic Drops bias and scale handling from the masking helper so upstream code owns those adjustments, preventing duplicated math. --- csrc/flash_dmattn/src/mask.h | 74 ++++++------------------------------ 1 file changed, 11 insertions(+), 63 deletions(-) diff --git a/csrc/flash_dmattn/src/mask.h b/csrc/flash_dmattn/src/mask.h index 62c02c1..9b0d81c 100644 --- a/csrc/flash_dmattn/src/mask.h +++ b/csrc/flash_dmattn/src/mask.h @@ -11,12 +11,10 @@ namespace FLASH_NAMESPACE { using namespace cute; -template +template __forceinline__ __device__ void apply_mask( TensorType &tensor, const MaskType &mask, - const BiasType &bias, - const float scale_softmax, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, @@ -27,13 +25,11 @@ __forceinline__ __device__ void apply_mask( static_assert(TensorType::rank == 2, "Only support 2D Tensor"); if constexpr (Has_mask) static_assert(MaskType::rank == 2, "Only support 2D Mask"); - if constexpr (Has_bias) - static_assert(BiasType::rank == 2, "Only support 2D Bias"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Has_mask && Has_bias) { + if constexpr (Has_mask) { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -49,63 +45,15 @@ __forceinline__ __device__ void apply_mask( const int col_idx = col_idx_base + j; // Without the "make_coord" we get wrong results auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); - } - } - } - } - } else if constexpr (Has_mask && !Has_bias) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling or masking - tensor(coord) = (col_idx >= col_idx_limit) || (!mask(coord)) - ? -INFINITY - : tensor(coord) * scale_softmax; - } - } - } - } - } else if constexpr (!Has_mask && Has_bias) { - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - // Without the "make_coord" we get wrong results - auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling and bias - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax + bias(coord); + // Apply masking + if (col_idx >= col_idx_limit || !mask(coord)) { + tensor(coord) = -INFINITY; + } } } } } - } else { // !Has_mask && !Has_bias + } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; @@ -121,10 +69,10 @@ __forceinline__ __device__ void apply_mask( const int col_idx = col_idx_base + j; // Without the "make_coord" we get wrong results auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); - // Apply scaling - tensor(coord) = (col_idx >= col_idx_limit) - ? -INFINITY - : tensor(coord) * scale_softmax; + // Apply masking + if (col_idx >= col_idx_limit) { + tensor(coord) = -INFINITY; + } } } } From fefb7a96114ee33f2362c6319fc2a6390ce57e7e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 14:44:58 +0800 Subject: [PATCH 08/11] Synchronizes async preload before Q scaling Waits for outstanding async loads and syncs threads so Q scaling never races ahead of shared-memory tiles. --- csrc/flash_dmattn/src/flash_fwd_kernel.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/flash_dmattn/src/flash_fwd_kernel.h b/csrc/flash_dmattn/src/flash_fwd_kernel.h index 83e878a..9f7ca6b 100644 --- a/csrc/flash_dmattn/src/flash_fwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_fwd_kernel.h @@ -422,6 +422,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + // Scale Q once before streaming loop KV if constexpr (Kernel_traits::Is_Q_in_regs) { #pragma unroll @@ -1140,6 +1143,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons binfo.actual_seqlen_k, binfo.actual_seqlen_q ); + FLASH_NAMESPACE::cp_async_wait<0>(); + __syncthreads(); + // Scale Q once before streaming loop KV #pragma unroll for (int i = 0; i < size(tSrQ); ++i) { From f3102e198f808635feee212db68c54150ee39942 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 16:08:35 +0800 Subject: [PATCH 09/11] Stops redundant dQ scaling Prevents applying the softmax factor twice in the backward preprocessing so downstream gradients stay correctly scaled. --- csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h index d6710b2..4e550ff 100644 --- a/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h @@ -279,8 +279,8 @@ inline __device__ void convert_dQ( for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; } - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // #pragma unroll + // for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_N, MMA_N) From 77c9e80a271aa652965aa0b225aabbbf2ccc7b5f Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 16:10:00 +0800 Subject: [PATCH 10/11] Refines bias/mask prep in flash bwd Pre-scales the keys right after synchronization so later matmul steps reuse the scaled values and hide latency. Unifies the mask and bias hydration before streaming to keep accumulators coherent and drops the now redundant gradient scaling. --- csrc/flash_dmattn/src/flash_bwd_kernel.h | 87 +++++++++--------------- 1 file changed, 31 insertions(+), 56 deletions(-) diff --git a/csrc/flash_dmattn/src/flash_bwd_kernel.h b/csrc/flash_dmattn/src/flash_bwd_kernel.h index 9a8b534..94f4470 100644 --- a/csrc/flash_dmattn/src/flash_bwd_kernel.h +++ b/csrc/flash_dmattn/src/flash_bwd_kernel.h @@ -644,17 +644,37 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); if constexpr (Has_bias) { if (accum_dbias) { clear(acc_dbias); } } + cute::cp_async_wait<0>(); + __syncthreads(); + + // Scale K once before streaming loop Q + #pragma unroll + for (int k = 0; k < size(tKsK); ++k) { + tKsK(k) = static_cast(tKsK(k) * params.scale_softmax); + } + for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) - cute::cp_async_wait<0>(); - __syncthreads(); - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // (MMA=4, MMA_M, MMA_K) + cute::cp_async_wait<0>(); + __syncthreads(); if (any_active) { - clear(acc_s); + if constexpr (Has_bias) { + // Copy bias from smem to registers + Tensor tSrBias = make_tensor(shape(acc_s)); + Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); + cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); + #pragma unroll + for (int i = 0; i < size(acc_s); ++i) { acc_s(i) = tSrBias(i); } + } else { + clear(acc_s); + } + } + + if (any_active) { Tensor dP_sum = make_fragment_like(lse); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); } @@ -686,35 +706,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap); } - if constexpr (Has_mask && Has_bias) { - // Copy mask and bias from smem to registers - Tensor tSrMask = make_tensor(shape(acc_s)); - Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); - cute::copy(smem_tiled_copy_PdS, tSsMask, tSrMask_copy_view); - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - - // Reshape mask, bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond - // actual_seqlen_k, because acc_s would be some finite value for those indices. - // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, - // so the result would still be correct. - // However, it's possible that the values in acc_s are so large that they overflow - // when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ. - // So we need to mask out the elements beyond actual_seqlen_k. - FLASH_NAMESPACE::apply_mask( - scores, mask, bias, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - } else if constexpr (Has_mask && !Has_bias) { + if constexpr (Has_mask) { // Copy mask from smem to registers Tensor tSrMask = make_tensor(shape(acc_s)); Tensor tSrMask_copy_view = smem_thr_copy_PdS.retile_D(tSrMask); @@ -722,26 +714,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape mask from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); - - FLASH_NAMESPACE::apply_mask( - scores, mask, /*bias=*/nullptr, params.scale_softmax, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16 - ); - } else if constexpr (!Has_mask && Has_bias) { - // Copy bias from smem to registers - Tensor tSrBias = make_tensor(shape(acc_s)); - Tensor tSrBias_copy_view = smem_thr_copy_PdS.retile_D(tSrBias); - cute::copy(smem_tiled_copy_PdS, tSsBias, tSrBias_copy_view); - // Reshape bias from (MMA=4, MMA_M, MMA_N) to (row=(2, MMA_M), col=(2, MMA_N)) - Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); - - FLASH_NAMESPACE::apply_mask( - scores, /*mask=*/nullptr, bias, params.scale_softmax, + FLASH_NAMESPACE::apply_mask( + scores, mask, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), @@ -749,8 +724,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in AtomLayoutMS * 16 ); } else { - FLASH_NAMESPACE::apply_mask( - scores, /*mask=*/nullptr, /*bias=*/nullptr, params.scale_softmax, + FLASH_NAMESPACE::apply_mask( + scores, /*mask=*/nullptr, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), @@ -965,8 +940,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); } } } else { - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } + // #pragma unroll + // for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax; } // Convert acc_dq from fp32 to fp16 Tensor rdQ = FLASH_NAMESPACE::convert_type(acc_dq); Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom, AtomNum), MMA_M, MMA_K) From 2c35c89f05510d70d427cb53f6f77adda21b00ca Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 4 Nov 2025 16:11:23 +0800 Subject: [PATCH 11/11] Drops redundant softmax unscale Removes the unused reverse scaling parameter from the forward configuration to avoid stale values when softcap toggles. --- csrc/flash_dmattn/flash_api.cpp | 2 -- csrc/flash_dmattn/src/flash.h | 1 - 2 files changed, 3 deletions(-) diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index fa5a9b1..4a67ec1 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -132,13 +132,11 @@ void set_params_fprop( params.softcap = softmax_scale / softcap; params.scale_softmax = softcap; params.scale_softmax_log2 = softcap * M_LOG2E; - params.unscale_softmax = 1.0f / softmax_scale; } else{ // Remove potential NaN params.softcap = 0.0; params.scale_softmax = softmax_scale; params.scale_softmax_log2 = softmax_scale * M_LOG2E; - params.unscale_softmax = 1.0f / softmax_scale; } params.is_causal = is_causal; diff --git a/csrc/flash_dmattn/src/flash.h b/csrc/flash_dmattn/src/flash.h index 1a2740a..a1c9bf1 100644 --- a/csrc/flash_dmattn/src/flash.h +++ b/csrc/flash_dmattn/src/flash.h @@ -101,7 +101,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; - float unscale_softmax; float softcap; // array of length b+1 holding starting offset of each sequence.