diff --git a/csrc/flash_dmattn/flash_api.cpp b/csrc/flash_dmattn/flash_api.cpp index 0782b90..4a67ec1 100644 --- a/csrc/flash_dmattn/flash_api.cpp +++ b/csrc/flash_dmattn/flash_api.cpp @@ -872,14 +872,12 @@ mha_bwd( const int num_heads_k = k.size(2); int num_heads_mask = has_mask ? mask.size(1) : 1; int num_heads_bias = has_bias ? bias.size(1) : 1; - int batch_size_mask = has_mask ? mask.size(0) : batch_size; - int batch_size_bias = has_bias ? bias.size(0) : batch_size; - int seqlen_q_mask = has_mask ? mask.size(2) : seqlen_q; - int seqlen_q_bias = has_bias ? bias.size(2) : seqlen_q; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + int batch_size_dbias = has_bias ? bias.size(0) : batch_size; + int seqlen_q_dbias = has_bias ? bias.size(2) : seqlen_q; TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); @@ -945,7 +943,7 @@ mha_bwd( TORCH_CHECK(dbias.size(2) == seqlen_q || dbias.size(2) == 1, "Query length dimension in dbias must be 1 or equal to seqlen_q"); TORCH_CHECK(dbias.size(3) == seqlen_k_rounded, "Key length dimension in dbias must be seqlen_k_rounded"); } else { - dbias = torch::empty({batch_size_bias, num_heads_bias, seqlen_q_bias, seqlen_k_rounded}, opts); + dbias = torch::empty({batch_size_dbias, num_heads_bias, seqlen_q_dbias, seqlen_k_rounded}, opts); } } else { dbias = torch::empty({0}, opts); @@ -977,8 +975,8 @@ mha_bwd( ? torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts) : dv; dbias_expanded = has_bias - ? (num_heads_bias != num_heads || batch_size_bias == 1 || seqlen_q_bias == 1) // MQA / GQA or dbias has different batch size or seqlen_q - ? (seqlen_q_bias == 1) + ? (num_heads_bias != num_heads || batch_size_dbias == 1 || seqlen_q_dbias == 1) // MQA / GQA or dbias has different batch size or seqlen_q + ? (seqlen_q_dbias == 1) ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts) : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) : dbias @@ -1033,15 +1031,15 @@ mha_bwd( } // For MQA/GQA or dbias has different batch size or seqlen_q, we need to sum dbias across the groups, batch and seqlen_q if (has_bias) { - if (num_heads_bias != num_heads && batch_size_bias == batch_size && seqlen_q_bias == seqlen_q) { + if (num_heads_bias != num_heads && batch_size_dbias == batch_size && seqlen_q_dbias == seqlen_q) { at::sum_out(dbias, at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } else { - if (seqlen_q_bias == 1) { + if (seqlen_q_dbias == 1) { dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, 1, seqlen_k_rounded}), {2}); } else { dbias_expanded = at::sum(at::reshape(dbias_expanded, {batch_size, num_heads_bias, num_heads / num_heads_bias, seqlen_q, seqlen_k_rounded}), {2}); } - if (batch_size_bias == 1) { + if (batch_size_dbias == 1) { dbias_expanded = at::sum(dbias_expanded, {0}, true); } dbias.copy_(dbias_expanded); diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 1a0f6e0..d350f71 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -218,7 +218,7 @@ def forward( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) # original formula is exp(A * softplus(delta V)), but for numerical stability, it is changed to A * softplus(delta V) - attn_bias = self.A * F.softplus(dt_states).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype) + attn_bias = (self.A * F.softplus(dt_states)).transpose(-1, -2).unsqueeze(-2).to(hidden_states.dtype) attention_interface: Callable = flash_dynamic_mask_attention_forward @@ -230,6 +230,7 @@ def forward( attention_mask=attention_mask, attention_bias=attn_bias, scale=self.scaling, + window_size=self.window_size, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous()