diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 0779d50..377dfb2 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -241,16 +241,12 @@ def forward( dt_states = self.dt_proj( value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) ) - dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) - attn_bias = dt_states[:, :, None, :].expand( - -1, -1, hidden_states.shape[1], -1 - ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] + attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) attention_interface: Callable = eager_attention_forward if flash_dynamic_mask_attention_forward is not None: attention_interface = flash_dynamic_mask_attention_forward - attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None # attention_mask: batch, num_kv_heads, query_len, key_len attn_output, attn_weights = attention_interface( self, query_states, @@ -414,7 +410,7 @@ def _init_weights(self, module): super()._init_weights(module) if isinstance(module, DogeAttention): if hasattr(module, "A"): - module.A.data.zero_() + module.A.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, DogeCDMoE): if hasattr(module, "router_gate"): module.router_gate.weight.data.zero_() diff --git a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py index 2c361bf..cae5b3e 100644 --- a/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py +++ b/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py @@ -93,13 +93,14 @@ def _flash_dynamic_mask_attention_forward( min_dtype ) - if keep_window_size is not None: - if key_length > keep_window_size: - topk_values, topk_indices = torch.topk( - attention_bias, keep_window_size, dim=-1, largest=True, sorted=False - ) - attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) - attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + if keep_window_size is not None and key_length > keep_window_size: + topk_values, topk_indices = torch.topk( + attention_bias, keep_window_size, dim=-1, largest=True, sorted=False + ) + attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype) + else: + attention_mask = None out = flash_fn( query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal