Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions csrc/flash_dmattn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,12 @@ mha_bwd(
const int num_heads_k = k.size(2);
int num_heads_mask = has_mask ? mask.size(1) : 1;
int num_heads_bias = has_bias ? bias.size(1) : 1;
int batch_size_mask = has_mask ? mask.size(0) : batch_size;
int batch_size_bias = has_bias ? bias.size(0) : batch_size;
int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q;
int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
int batch_size_dbias = has_bias ? bias.size(0) : batch_size;
int seqlen_q_dbias = has_bias ? bias.size(2) : seqlen_q;

TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
Expand Down Expand Up @@ -945,7 +943,7 @@ mha_bwd(
TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q");
TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded");
} else {
dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts);
dbias = torch::empty({batch_size_dbias, num_heads_bias, seqlen_q_dbias, seqlen_k_rounded}, opts);
}
} else {
dbias = torch::empty({0}, opts);
Expand Down Expand Up @@ -977,8 +975,8 @@ mha_bwd(
? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts)
: dv;
dbias_expanded = has_bias
? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_bias == 1)
? (num_heads_bias != num_heads || batch_size_dbias == 1 || seqlen_q_dbias == 1) // MQA / GQA or dbias has different batch size or seqlen_q
? (seqlen_q_dbias == 1)
? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts)
: torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts)
: dbias
Expand Down Expand Up @@ -1033,15 +1031,15 @@ mha_bwd(
}
// For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q
if (has_bias) {
if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) {
if (num_heads_bias != num_heads && batch_size_dbias == batch_size && seqlen_q_dbias == seqlen_q) {
at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
} else {
if (seqlen_q_bias == 1) {
if (seqlen_q_dbias == 1) {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2});
} else {
dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2});
}
if (batch_size_bias == 1) {
if (batch_size_dbias == 1) {
dbias_expanded = at::sum(dbias_expanded, {0}, true);
}
dbias.copy_(dbias_expanded);
Expand Down
3 changes: 2 additions & 1 deletion examples/modeling/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def forward(
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
)
# original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V)
attn_bias = self.A * F.softplus(dt_states).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)
attn_bias = (self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype)

attention_interface: Callable = flash_dynamic_mask_attention_forward

Expand All @@ -230,6 +230,7 @@ def forward(
attention_mask=attention_mask,
attention_bias=attn_bias,
scale=self.scaling,
window_size=self.window_size,
)

attn_output = attn_output.reshape(*input_shape, -1).contiguous()
Expand Down
Loading