From 330f8f5ce76552cceeba294bbdb8974808059933 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 29 Aug 2025 16:23:03 +0800 Subject: [PATCH] Fixes attention mask for invalid topk values Ensures that attention mask correctly handles cases where topk values equal the minimum dtype value by checking validity before scattering. Previously scattered a constant 1.0 for all topk indices regardless of their validity, which could incorrectly mask valid attention scores. --- examples/modeling/modeling_doge.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/modeling/modeling_doge.py b/examples/modeling/modeling_doge.py index 8810539..e67cca2 100644 --- a/examples/modeling/modeling_doge.py +++ b/examples/modeling/modeling_doge.py @@ -353,11 +353,12 @@ def prepare_dynamic_mask( ) attn_bias = attn_bias.masked_fill(attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype) if attn_bias.shape[-1] > keep_window_size: - topk_indices = torch.topk( + topk_values, topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False - ).indices + ) + valid_topk = topk_values != min_dtype attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) - attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_mask = attn_mask.scatter(-1, topk_indices, valid_topk.to(dtype)) attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device)