Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/modeling/modeling_doge.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,12 @@ def prepare_dynamic_mask(
)
attn_bias = attn_bias.masked_fill(attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype)
if attn_bias.shape[-1] > keep_window_size:
topk_indices = torch.topk(
topk_values, topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
)
valid_topk = topk_values != min_dtype
attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device)
attn_mask = attn_mask.scatter(-1, topk_indices, 1.0)
attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype))
Comment on lines 360 to +361
Copy link

Copilot AI Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scatter operation may produce incorrect results when valid_topk contains False values. Scattering 0.0 values (from valid_topk.to(dtype) where valid_topk is False) will explicitly set those positions to 0.0, but the attn_mask is already initialized with zeros. This could mask positions that should remain unmasked if there are fewer than keep_window_size valid values. Consider only scattering indices where valid_topk is True.

Copilot uses AI. Check for mistakes.
attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype)
else:
attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)
Expand Down