Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 30 additions & 36 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,12 @@ void set_params_fprop(
const size_t h_k,
const size_t d,
const size_t d_rounded,
const size_t keep_window_size,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor attn_mask,
const at::Tensor attn_bias,
const at::Tensor mask,
const at::Tensor bias,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
Expand All @@ -65,32 +64,32 @@ void set_params_fprop(
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.attn_mask_ptr = attn_mask.data_ptr();
params.attn_bias_ptr = attn_bias.data_ptr();
params.mask_ptr = mask.data_ptr();
params.bias_ptr = bias.data_ptr();
params.o_ptr = out.data_ptr();

// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.attn_mask_row_stride = attn_mask.stride(-2);
params.attn_bias_row_stride = attn_bias.stride(-2);
params.mask_row_stride = mask.stride(-2);
params.bias_row_stride = bias.stride(-2);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.attn_mask_head_stride = attn_mask.stride(-3);
params.attn_bias_head_stride = attn_bias.stride(-3);
params.mask_head_stride = mask.stride(-3);
params.bias_head_stride = bias.stride(-3);
params.o_head_stride = out.stride(-2);
params.attn_mask_col_stride = attn_mask.stride(-1);
params.attn_bias_col_stride = attn_bias.stride(-1);
params.mask_col_stride = mask.stride(-1);
params.bias_col_stride = bias.stride(-1);

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = q.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.attn_mask_batch_stride = attn_mask.stride(0);
params.attn_bias_batch_stride = attn_bias.stride(0);
params.mask_batch_stride = mask.stride(0);
params.bias_batch_stride = bias.stride(0);
params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
Expand Down Expand Up @@ -119,7 +118,6 @@ void set_params_fprop(
params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d;
params.d_rounded = d_rounded;
params.keep_window_size = keep_window_size;

// Set the different scale values.
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
Expand Down Expand Up @@ -271,13 +269,12 @@ mha_fwd(
at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &attn_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const float p_dropout,
const float softmax_scale,
bool is_causal,
const int keep_window_size,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_
Expand All @@ -294,10 +291,10 @@ mha_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs");
TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias);
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand Down Expand Up @@ -334,8 +331,8 @@ mha_fwd(
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(attn_mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(attn_bias, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k);

at::Tensor out;
if (out_.has_value()) {
Expand Down Expand Up @@ -377,8 +374,7 @@ mha_fwd(
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
keep_window_size,
q, k, v, attn_mask, attn_bias, out,
q, k, v, mask, bias, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
Expand Down Expand Up @@ -436,8 +432,8 @@ mha_varlen_fwd(
at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &attn_mask, // total_q x num_heads_k x max_seqlen_k
const at::Tensor &attn_bias, // total_q x num_heads_k x max_seqlen_k
const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k
const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
Expand All @@ -450,7 +446,6 @@ mha_varlen_fwd(
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
const int keep_window_size,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_
Expand All @@ -465,12 +460,12 @@ mha_varlen_fwd(
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs");
TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs");
TORCH_CHECK(mask.dtype() == q_dtype, "mask must have the same dtype as inputs");
TORCH_CHECK(bias.dtype() == q_dtype, "bias must have the same dtype as inputs");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias);
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(mask); CHECK_DEVICE(bias);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);

Expand All @@ -487,8 +482,8 @@ mha_varlen_fwd(
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(attn_bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(mask.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(bias.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

Expand Down Expand Up @@ -533,8 +528,8 @@ mha_varlen_fwd(
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
CHECK_SHAPE(attn_mask, total_q, num_heads_k, max_seqlen_k);
CHECK_SHAPE(attn_bias, total_q, num_heads_k, max_seqlen_k);
CHECK_SHAPE(mask, total_q, num_heads_k, max_seqlen_k);
CHECK_SHAPE(bias, total_q, num_heads_k, max_seqlen_k);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
Expand Down Expand Up @@ -596,8 +591,7 @@ mha_varlen_fwd(
seqlen_q_rounded, seqlen_k_rounded,
num_heads, num_heads_k,
head_size, head_size_rounded,
keep_window_size,
q, k, v, attn_mask, attn_bias, out,
q, k, v, mask, bias, out,
cu_seqlens_q_d,
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
Expand Down
5 changes: 2 additions & 3 deletions csrc/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@ struct BlockInfo {
}

template <typename index_t>
__forceinline__ __device__ index_t attn_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const {
__forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const {
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
return offset;
}

template <typename index_t>
__forceinline__ __device__ index_t attn_bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb
) const {
__forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const index_t col_stride, const int bidb) const {
Comment on lines +39 to +46
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter type for row_stride was changed from int to index_t, but this change is inconsistent with the col_stride parameter which remains int. For consistency, all stride parameters should have the same type.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter type for row_stride and col_stride were changed from int to index_t, but this creates an inconsistency where bidb remains int while other parameters use index_t. Consider using consistent types for all index-related parameters.

Copilot uses AI. Check for mistakes.
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
return offset;
Expand Down
23 changes: 10 additions & 13 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,25 @@ struct QKV_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

struct Mask_params {
void * __restrict__ attn_mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]
void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the attention mask tensors.
index_t attn_mask_batch_stride; // Stride between batches of attention mask
index_t attn_mask_head_stride; // Stride between heads of attention mask
index_t attn_mask_row_stride; // Stride between rows of attention mask
index_t attn_mask_col_stride; // Stride between columns of attention mask

// The keep window size.
int keep_window_size; // Number of tokens to keep in top-k
index_t mask_batch_stride; // Stride between batches of attention mask
index_t mask_head_stride; // Stride between heads of attention mask
index_t mask_row_stride; // Stride between rows of attention mask
index_t mask_col_stride; // Stride between columns of attention mask
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Bias_params {
void *__restrict__ attn_bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]
void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the attention bias tensor.
index_t attn_bias_batch_stride; // Stride between batches of attention bias
index_t attn_bias_head_stride; // Stride between heads of attention bias
index_t attn_bias_row_stride; // Stride between rows of attention bias
index_t attn_bias_col_stride; // Stride between columns of attention bias
index_t bias_batch_stride; // Stride between batches of attention bias
index_t bias_head_stride; // Stride between heads of attention bias
index_t bias_row_stride; // Stride between rows of attention bias
index_t bias_col_stride; // Stride between columns of attention bias
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading