From e126d72cd761136556dcf79ee2a7d21776059629 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 30 Jul 2025 21:30:38 +0800 Subject: [PATCH] Refactors dynamic mask function to improve clarity Renames dt_states parameter to zoh_states for better semantic meaning. Updates variable names from attn_mask to attn_bias to more accurately reflect the additive bias nature of the operation. Improves code organization by moving attn_mask creation logic and adding explicit handling for sequences shorter than keep_window_size. Enhances documentation with clearer parameter descriptions and adds return type specification. --- benchmarks/benchmark_grad.py | 61 +++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/benchmarks/benchmark_grad.py b/benchmarks/benchmark_grad.py index 795282d..0ae8f40 100644 --- a/benchmarks/benchmark_grad.py +++ b/benchmarks/benchmark_grad.py @@ -5,40 +5,51 @@ def prepare_dynamic_mask( hidden_states: torch.Tensor, - dt_states: torch.Tensor, + zoh_states: torch.Tensor, keep_window_size: int = 2048, attention_mask: torch.Tensor | None = None, ): """ - The core idea of DMA is to calculate the dynamic attention mask to mask the tokens that should be masked, so as to form sparse attention. + Calculate dynamic attention mask to mask tokens for sparse attention. - Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. Args: - hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision. - dt_states (`torch.Tensor`): dt_states of shape `(batch_size, num_heads, key_sequence_length)`. - keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value. - attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`. + hidden_states: Input hidden states to determine dtype minimum value + zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) + keep_window_size: Window size of tokens not dynamically masked + attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len) + + Returns: + tuple: (attn_bias, attn_mask) """ min_dtype = torch.finfo(hidden_states.dtype).min dtype = hidden_states.dtype - attn_mask = dt_states[:, :, None, :].expand( + attn_bias = zoh_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 - ) # [batch_size, num_heads, query_len, key_len] - active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) + ) # [batch_size, num_kv_heads, query_len, key_len] + if attention_mask is not None: if attention_mask.dtype == torch.bool: attention_mask = torch.where( - attention_mask, torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype + attention_mask, + torch.tensor(0.0, device=attention_mask.device, dtype=dtype), + min_dtype ) - attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype) - if attn_mask.shape[-1] > keep_window_size: + attn_bias = attn_bias.masked_fill( + attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype + ) + + if attn_bias.shape[-1] > keep_window_size: topk_indices = torch.topk( - attn_mask, keep_window_size, dim=-1, largest=True, sorted=False + attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ).indices - active_mask = active_mask.scatter(-1, topk_indices, 1.0) - attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) - return attn_mask, active_mask + attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) + else: + attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + return attn_bias, attn_mask def dynamic_mask_attention_cuda( @@ -60,7 +71,7 @@ def dynamic_mask_attention_cuda( dt_states = torch.matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T) dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2) - attn_mask, _ = prepare_dynamic_mask( + attn_bias, attn_mask = prepare_dynamic_mask( query_states, dt_states, keep_window_size=keep_window_size, attention_mask=causal_mask ) # [batch_size, num_kv_heads, query_len, key_len] @@ -76,16 +87,15 @@ def dynamic_mask_attention_cuda( if len(non_mask_indices) == 0: continue - k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim] - v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim] - - q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim] + q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim] + k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim] + v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim] # QK dot product attn_weight = torch.sum(q_vec.unsqueeze(0) * k_vecs, dim=-1) # Apply scaling and dynamic_mask - attn_weight = attn_weight * scaling + attn_mask[b_idx, h_idx, q_idx, non_mask_indices] + attn_weight = attn_weight * scaling + attn_bias[b_idx, h_idx, q_idx, non_mask_indices] # Softmax attn_weight = F.softmax(attn_weight, dim=-1) @@ -118,12 +128,11 @@ def dynamic_mask_attention_python( dt_states = torch.matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T) dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2) - attn_mask, _ = prepare_dynamic_mask( + attn_bias, _ = prepare_dynamic_mask( query_states, dt_states, keep_window_size=keep_window_size, attention_mask=causal_mask ) # [batch_size, num_kv_heads, query_len, key_len] - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) # [batch_size, num_heads, query_len, key_len] - attn_weights = attn_weights * scaling + attn_mask # Apply scaling and dynamic_mask + attn_weights = attn_weights * scaling + attn_bias # Apply scaling and dynamic_mask attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) # [batch_size, num_heads, query_len, head_dim]