Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def dynamic_mask_attention_python(
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1))
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
Expand Down
1,002 changes: 542 additions & 460 deletions csrc/flash_dmattn/flash_api.cpp

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions csrc/flash_dmattn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct Mask_params {
index_t mask_batch_stride; // Stride between batches of attention mask
index_t mask_head_stride; // Stride between heads of attention mask
index_t mask_row_stride; // Stride between rows of attention mask

// The number of heads in the mask.
int h_mask;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -61,6 +64,9 @@ struct Bias_params {
index_t bias_batch_stride; // Stride between batches of attention bias
index_t bias_head_stride; // Stride between heads of attention bias
index_t bias_row_stride; // Stride between rows of attention bias

// The number of heads in the bias.
int h_bias;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 4 additions & 2 deletions csrc/flash_dmattn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t row_offset_mask = binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
+ h_idx_mask * params.mask_head_stride + (m_block_max - 1) * kBlockM * params.mask_row_stride + n_block * kBlockN;
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t row_offset_bias = binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
+ h_idx_bias * params.bias_head_stride + (m_block_max - 1) * kBlockM * params.bias_row_stride + n_block * kBlockN;
const index_t row_offset_dbias = binfo.bias_offset(params.dbias_batch_stride, params.dbias_row_stride, bidb)
+ bidh * params.dbias_head_stride + (m_block_max - 1) * kBlockM * params.dbias_row_stride + n_block * kBlockN;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
Expand Down
20 changes: 12 additions & 8 deletions csrc/flash_dmattn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).

const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);

// Global memory tensor configuration
Tensor mQ = make_tensor(
Expand Down Expand Up @@ -170,21 +172,21 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
); // (kBlockN, kHeadDim, nblocksN)
Tensor mMask = make_tensor(
make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_shape(params.h_mask, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.mask_head_stride, params.mask_row_stride, _1{})
);
Tensor gMask = local_tile(
mMask(bidh / params.h_h_k_ratio, _, _),
mMask(h_idx_mask, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Tensor mBias = make_tensor(
make_gmem_ptr(reinterpret_cast<Element*>(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)),
make_shape(params.h_k, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_shape(params.h_bias, binfo.actual_seqlen_q, binfo.actual_seqlen_k),
make_stride(params.bias_head_stride, params.bias_row_stride, _1{})
);
Tensor gBias = local_tile(
mBias(bidh / params.h_h_k_ratio, _, _),
mBias(h_idx_bias, _, _),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_coord(m_block, _)
); // (kBlockM, kBlockN, nblocksN)
Expand Down Expand Up @@ -840,16 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const int h_idx_mask = (params.h_mask == 1) ? 0 : ((params.h_mask == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t col_offset_mask = (block_table == nullptr)
? binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + (n_block_max - 1) * kBlockN
: binfo.q_offset(/*batch_stride=*/index_t(0), params.mask_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
+ h_idx_mask * params.mask_head_stride + m_block * kBlockM * params.mask_row_stride + block_table[block_table_idx] * params.mask_batch_stride + block_table_offset;
const int h_idx_bias = (params.h_bias == 1) ? 0 : ((params.h_bias == params.h_k) ? (bidh / params.h_h_k_ratio) : bidh);
const index_t col_offset_bias = (block_table == nullptr)
? binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + (n_block_max - 1) * kBlockN
: binfo.q_offset(/*batch_stride=*/index_t(0), params.bias_row_stride, bidb_cache)
+ (bidh / params.h_h_k_ratio) * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;
+ h_idx_bias * params.bias_head_stride + m_block * kBlockM * params.bias_row_stride + block_table[block_table_idx] * params.bias_batch_stride + block_table_offset;

// Global memory tensor configuration
Tensor mQ = make_tensor(
Expand Down
8 changes: 6 additions & 2 deletions flash_dmattn/flash_dmattn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,14 @@ def flash_dmattn_func(
key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores.
shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores.
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
If None, no mask is applied.
attn_bias: torch.Tensor, optional. The attention bias float tensor of
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores.
shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores.
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
If None, no bias is applied.
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
scale: float. The scaling of QK^T before applying softmax.
Expand Down
4 changes: 2 additions & 2 deletions flash_dmattn/integrations/flash_dynamic_mask_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward(
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len).
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len).
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
scaling (Optional[float]): The scaling factor for the attention scores.
softcap (Optional[float]): The softcap value for the attention scores.
**kwargs: Additional keyword arguments.
Expand Down
Loading