From d01b4897fca271a78acf1b6174f249766c76f140 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 21 Sep 2025 17:34:01 +0800 Subject: [PATCH 1/2] Refactor attention mask handling in _flash_dynamic_mask_attention_forward to simplify logic and ensure proper initialization --- ...modeling_flash_dynamic_mask_attention_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 2c361bf..cae5b3e 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -93,13 +93,14 @@ def _flash_dynamic_mask_attention_forward( min_dtype ) - if keep_window_size is not None: - if key_length > keep_window_size: - topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + if keep_window_size is not None and key_length > keep_window_size: + topk_values, topk_indices = torch.topk( + attention_bias, keep_window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + else: + attention_mask = None out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal From 80b25f7906bb4c54acf282b1174636e5ccb5fc3a Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sun, 21 Sep 2025 17:37:12 +0800 Subject: [PATCH 2/2] Simplifies attention bias computation and improves weight initialization Removes unnecessary tensor expansion operations in attention bias calculation to improve memory efficiency and computational performance. Changes weight initialization from zero to normal distribution for better training dynamics and gradient flow. --- examples/modeling/modeling_doge.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 0779d50..377dfb2 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -241,16 +241,12 @@ def forward( dt_states = self.dt_proj( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attn_bias = dt_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[1], -1 - ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] + attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) attention_interface: Callable = eager_attention_forward if flash_dynamic_mask_attention_forward is not None: attention_interface = flash_dynamic_mask_attention_forward - attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None # attention_mask: batch, num_kv_heads, query_len, key_len attn_output, attn_weights = attention_interface( self, query_states, @@ -414,7 +410,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.data.zero_() + module.A.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, DogeCDMoE): if hasattr(module, "router_gate"): module.router_gate.weight.data.zero_()