Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def prepare_dynamic_mask(
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device)
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Copy link

Copilot AI Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing attn_mask from attn_bias may yield an incorrect shape if the two tensors differ. Consider using torch.ones_like(attn_mask) or explicitly matching dimensions to ensure the mask has the intended shape.

Suggested change
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = torch.ones(
attn_bias.shape, dtype=dtype, device=attn_bias.device
)

Copilot uses AI. Check for mistakes.
return attn_bias, attn_mask


Expand Down
14 changes: 7 additions & 7 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,18 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par

struct Flash_bwd_params : public Flash_fwd_params {

// The dO and dQKV and dZeroHold matrices.
// The dO and dQKV and dBias matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
void *__restrict__ dzoh_ptr;
void *__restrict__ dbias_ptr;

// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
void *__restrict__ dzoh_accum_ptr;
void *__restrict__ dbias_accum_ptr;

// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
Expand All @@ -199,10 +199,10 @@ struct Flash_bwd_params : public Flash_fwd_params {
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
index_t dzoh_batch_stride;
index_t dzoh_head_stride;
index_t dzoh_row_stride;
index_t dzoh_col_stride;
index_t dbias_batch_stride;
index_t dbias_head_stride;
index_t dbias_row_stride;
index_t dbias_col_stride;

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
Expand Down
68 changes: 29 additions & 39 deletions csrc/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2024, Tri Dao.
* Copyright (c) 2025, Jingze Shi and Tri Dao.
******************************************************************************/

#pragma once
Expand All @@ -18,8 +18,6 @@
#include "mask.h"
#include "dropout.h"

#include "alibi.h"

namespace FLASH_NAMESPACE {

using namespace cute;
Expand Down Expand Up @@ -77,7 +75,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
Expand All @@ -101,16 +99,19 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;

int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}

const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_mask = binfo.attn_mask_offset(params.attn_mask_batch_stride, params.attn_mask_row_stride, params.attn_mask_col_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.attn_mask_head_stride + (m_block_max - 1) * kBlockM * params.attn_mask_row_stride
+ n_block * kBlockN * params.attn_mask_col_stride;
const index_t row_offset_bias = binfo.attn_bias_offset(params.attn_bias_batch_stride, params.attn_bias_row_stride, params.attn_bias_col_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.attn_bias_head_stride + (m_block_max - 1) * kBlockM * params.attn_bias_row_stride
+ n_block * kBlockN * params.attn_bias_col_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
Expand All @@ -134,6 +135,15 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_mask_row_stride, params.attn_mask_col_stride));
Tensor gBias = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_bias_ptr) + row_offset_bias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.attn_bias_row_stride, params.attn_bias_col_stride));
Tensor gdBias = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dbias_ptr) + row_offset_bias),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.dbias_row_stride, params.dbias_col_stride));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Expand Down Expand Up @@ -313,9 +323,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;

int m_block = m_block_max - 1;
int m_block_min = (!Is_causal && !Is_local)
int m_block_min = (!Is_causal)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
Expand All @@ -328,7 +338,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// This also covers the case where actual_seqlen_q == 0
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
if ((!Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
Expand Down Expand Up @@ -410,8 +420,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// with V (which would be zero), we're fine.

// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
Expand Down Expand Up @@ -449,9 +458,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv);
clear(acc_dk);

const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
FLASH_NAMESPACE::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);

for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s);
Expand Down Expand Up @@ -486,20 +492,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
FLASH_NAMESPACE::calculate_dtanh(scores, dtanh, params.softcap);
}

// Alibi
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
}

// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// so the result would still be correct.
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if (!Is_causal && !Is_local) {
if (!Is_causal) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
FLASH_NAMESPACE::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
Expand All @@ -517,16 +517,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
} else if (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
FLASH_NAMESPACE::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}

}

// if (cute::thread(32, 0)) { print(scores); }
Expand Down Expand Up @@ -794,7 +784,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv(const Params &params) {

// The block index for the batch.
Expand All @@ -808,20 +798,20 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {

const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (n_block_max == 1) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
} else {
// Iterating backward from n_block_max - 1 to 0 might save 1 register
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
}
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// The block index for the batch.
Expand All @@ -831,7 +821,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

Expand Down
31 changes: 29 additions & 2 deletions csrc/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,22 @@ struct Flash_bwd_kernel_traits : public Base {
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));

using SmemLayoutAtomMask = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, _8>,
Stride<_8, _1>>{}));
using SmemLayoutMask = decltype(tile_to_shape(
SmemLayoutAtomMask{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));

using SmemLayoutAtomBias = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, _8>,
Stride<_8, _1>>{}));
using SmemLayoutBias = decltype(tile_to_shape(
SmemLayoutAtomBias{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));

// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
Expand Down Expand Up @@ -306,18 +322,21 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
using SmemCopyAtomBias = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, elem_type>;

// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
static constexpr int kSmemMaskSize = size(SmemLayoutMask{}) * sizeof(Element);
static constexpr int kSmemBiasSize = size(SmemLayoutBias{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
static constexpr int kSmemSize1colblock = kSmemQdOSize
static constexpr int kSmemSize1colblock = kSmemQdOSize + kSmemMaskSize + kSmemBiasSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
Expand Down Expand Up @@ -354,6 +373,14 @@ struct Flash_bwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopyMask = decltype(
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per read
using GmemTiledCopyBias = decltype(
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<64>, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per read
using GmemLayoutAtomdQaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
Expand Down