Skip to content
28 changes: 14 additions & 14 deletions benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,20 +232,20 @@ def dynamic_mask_attention_cuda(

# Call the CUDA implementation using the mha_fwd function signature
out_tensor = None # Let the function allocate the output tensor
result = flash_dma_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask
active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
result = flash_dma_cuda.fwd( # type: ignore
query_states, # q: [batch, seqlen_q, num_heads, head_dim]
key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim]
value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim]
attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k]
active_mask, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k]
out_tensor, # out: None to auto-allocate
0.0, # p_dropout
scaling, # softmax_scale
is_causal, # is_causal
keep_window_size, # keep_window_size
0.0, # softcap
return_softmax, # return_softmax
None # gen (generator)
)

attn_outputs = result[0] # [batch, query_len, num_heads, head_dim]
Expand Down
40 changes: 20 additions & 20 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void set_params_fprop(
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor zoh,
const at::Tensor active_mask,
const at::Tensor attn_mask,
const at::Tensor attn_bias,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
Expand All @@ -65,32 +65,32 @@ void set_params_fprop(
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.zoh_ptr = zoh.data_ptr();
params.active_mask_ptr = active_mask.data_ptr();
params.attn_mask_ptr = attn_mask.data_ptr();
params.attn_bias_ptr = attn_bias.data_ptr();
params.o_ptr = out.data_ptr();

// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.zoh_row_stride = zoh.stride(-2);
params.active_mask_row_stride = active_mask.stride(-2);
params.attn_mask_row_stride = attn_mask.stride(-2);
params.attn_bias_row_stride = attn_bias.stride(-2);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.zoh_head_stride = zoh.stride(-3);
params.active_mask_head_stride = active_mask.stride(-3);
params.attn_mask_head_stride = attn_mask.stride(-3);
params.attn_bias_head_stride = attn_bias.stride(-3);
params.o_head_stride = out.stride(-2);
params.zoh_col_stride = zoh.stride(-1);
params.active_mask_col_stride = active_mask.stride(-1);
params.attn_mask_col_stride = attn_mask.stride(-1);
params.attn_bias_col_stride = attn_bias.stride(-1);

if (cu_seqlens_q_d == nullptr) {
params.q_batch_stride = q.stride(0);
params.k_batch_stride = k.stride(0);
params.v_batch_stride = v.stride(0);
params.zoh_batch_stride = zoh.stride(0);
params.active_mask_batch_stride = active_mask.stride(0);
params.attn_mask_batch_stride = attn_mask.stride(0);
params.attn_bias_batch_stride = attn_bias.stride(0);
params.o_batch_stride = out.stride(0);
if (seqlenq_ngroups_swapped) {
params.q_batch_stride *= seqlen_q;
Expand Down Expand Up @@ -271,8 +271,8 @@ mha_fwd(
at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
const at::Tensor &zoh, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &active_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &attn_mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &attn_bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const float p_dropout,
const float softmax_scale,
Expand All @@ -295,10 +295,10 @@ mha_fwd(
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(zoh.dtype() == q_dtype, "zoh must have the same dtype as inputs");
TORCH_CHECK(active_mask.dtype() == q_dtype, "active_mask must have the same dtype as inputs");
TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs");
TORCH_CHECK(attn_bias.dtype() == q_dtype, "attn_bias must have the same dtype as inputs");

CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask);
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(attn_mask); CHECK_DEVICE(attn_bias);

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand Down Expand Up @@ -335,8 +335,8 @@ mha_fwd(
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(attn_mask, batch_size, num_heads_k, seqlen_q, seqlen_k);
CHECK_SHAPE(attn_bias, batch_size, num_heads_k, seqlen_q, seqlen_k);

at::Tensor out;
if (out_.has_value()) {
Expand Down Expand Up @@ -379,7 +379,7 @@ mha_fwd(
num_heads, num_heads_k,
head_size, head_size_rounded,
keep_window_size,
q, k, v, zoh, active_mask, out,
q, k, v, attn_mask, attn_bias, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
Expand Down
6 changes: 3 additions & 3 deletions csrc/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ struct BlockInfo {
}

template <typename index_t>
__forceinline__ __device__ index_t zoh_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb
) const {
__forceinline__ __device__ index_t attn_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const {
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
return offset;
}

template <typename index_t>
__forceinline__ __device__ index_t active_mask_offset(const index_t batch_stride, int row_stride, const int col_stride, const int bidb) const {
__forceinline__ __device__ index_t attn_bias_offset(const index_t batch_stride, const int row_stride, const int col_stride, const int bidb
) const {
index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
sum_s_k == -1 ? offset += leftpad_k * col_stride : offset += uint32_t(sum_s_k + leftpad_k) * col_stride;
return offset;
Expand Down
37 changes: 22 additions & 15 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,34 @@ struct QKV_params {

////////////////////////////////////////////////////////////////////////////////////////////////////

struct ZOH_params {
void *__restrict__ zoh_ptr; // ZOH states tensor [batch_size, num_kv_heads, query_len, key_len]
void * __restrict__ active_mask_ptr; // Active mask tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the zero-hold states and active mask tensors.
index_t zoh_batch_stride; // Stride between batches of ZOH states
index_t active_mask_batch_stride; // Stride between batches of active mask
index_t zoh_head_stride; // Stride between heads of ZOH states
index_t active_mask_head_stride; // Stride between heads of active mask
index_t zoh_row_stride; // Stride between rows of ZOH states
index_t active_mask_row_stride; // Stride between rows of active mask
index_t zoh_col_stride; // Stride between columns of ZOH states
index_t active_mask_col_stride; // Stride between columns of active mask
struct Mask_params {
void * __restrict__ attn_mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the attention mask tensors.
index_t attn_mask_batch_stride; // Stride between batches of attention mask
index_t attn_mask_head_stride; // Stride between heads of attention mask
index_t attn_mask_row_stride; // Stride between rows of attention mask
index_t attn_mask_col_stride; // Stride between columns of attention mask

// The keep window size.
int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k)
int keep_window_size; // Number of tokens to keep in top-k
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Bias_params {
void *__restrict__ attn_bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the attention bias tensor.
index_t attn_bias_batch_stride; // Stride between batches of attention bias
index_t attn_bias_head_stride; // Stride between heads of attention bias
index_t attn_bias_row_stride; // Stride between rows of attention bias
index_t attn_bias_col_stride; // Stride between columns of attention bias
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public QKV_params, public ZOH_params {
struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params {

// The O matrix (output).
void * __restrict__ o_ptr;
Expand Down
Loading