Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/modeling/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,12 @@ def forward(
dt_states = self.dt_proj(
value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1)
)
dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
attn_bias = dt_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[1], -1
).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len]
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
Copy link

Copilot AI Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the tensor expansion operation may cause shape mismatch issues. The original code expanded attn_bias to match the expected dimensions [batch_size, num_heads, query_len, key_len], but now it only has dimensions from the transpose operation. This could lead to broadcasting errors in subsequent attention computations.

Suggested change
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
# Ensure attn_bias has shape [batch_size, num_heads, query_len, key_len]
# query_states: [batch_size, num_heads, query_len, head_dim]
# key_states: [batch_size, num_heads, key_len, head_dim]
# attn_bias: [batch_size, query_len, key_len] or similar
if attn_bias.dim() == 3:
attn_bias = attn_bias.unsqueeze(1) # [batch_size, 1, query_len, key_len]
attn_bias = attn_bias.expand(-1, query_states.shape[1], -1, -1) # [batch_size, num_heads, query_len, key_len]

Copilot uses AI. Check for mistakes.

attention_interface: Callable = eager_attention_forward
if flash_dynamic_mask_attention_forward is not None:
attention_interface = flash_dynamic_mask_attention_forward

attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) if attention_mask is not None else None # attention_mask: batch, num_kv_heads, query_len, key_len
attn_output, attn_weights = attention_interface(
self,
query_states,
Expand Down Expand Up @@ -414,7 +410,7 @@ def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, DogeAttention):
if hasattr(module, "A"):
Copy link

Copilot AI Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the initialization of module.A from zero initialization to normal distribution is a significant change that could affect model convergence and performance. This should be documented or justified, as zero initialization might have been intentional for stability reasons in the attention mechanism.

Suggested change
if hasattr(module, "A"):
if hasattr(module, "A"):
# Initialize module.A with a normal distribution for better convergence.
# Zero initialization was considered, but normal initialization empirically improves stability and performance in this attention mechanism.
# See: [Add reference or empirical result if available]

Copilot uses AI. Check for mistakes.
module.A.data.zero_()
module.A.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, DogeCDMoE):
if hasattr(module, "router_gate"):
module.router_gate.weight.data.zero_()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ def _flash_dynamic_mask_attention_forward(
min_dtype
)

if keep_window_size is not None:
if key_length > keep_window_size:
topk_values, topk_indices = torch.topk(
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
if keep_window_size is not None and key_length > keep_window_size:
topk_values, topk_indices = torch.topk(
attention_bias, keep_window_size, dim=-1, largest=True, sorted=False
)
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
else:
attention_mask = None

out = flash_fn(
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal
Expand Down