Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Renames dt_states parameter to zoh_states for better semantic meaning.

Updates variable names from attn_mask to attn_bias to more accurately
reflect the additive bias nature of the operation.

Improves code organization by moving attn_mask creation logic and
adding explicit handling for sequences shorter than keep_window_size.

Enhances documentation with clearer parameter descriptions and
adds return type specification.

Renames dt_states parameter to zoh_states for better semantic meaning.

Updates variable names from attn_mask to attn_bias to more accurately
reflect the additive bias nature of the operation.

Improves code organization by moving attn_mask creation logic and
adding explicit handling for sequences shorter than keep_window_size.

Enhances documentation with clearer parameter descriptions and
adds return type specification.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the dynamic mask function in the benchmark_grad.py file to improve code clarity and semantics. The changes focus on renaming variables for better meaning and improving code organization.

  • Renames dt_states parameter to zoh_states for clearer semantic meaning
  • Updates variable names from attn_mask to attn_bias to better reflect additive bias operations
  • Improves code organization by moving mask creation logic and adding explicit handling for short sequences

Comment on lines +90 to +92
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Moving the q_vec assignment before k_vecs and v_vecs creates an inconsistent ordering. The original placement after k_vecs and v_vecs was more logical as it groups the vector extractions together before their usage.

Suggested change
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
k_vecs = key_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, h_idx, non_mask_indices, :] # [keep_window_size, head_dim]
q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]

Copilot uses AI. Check for mistakes.
keep_window_size (`int`): The window size of tokens that are not dynamically masked, and dynamic masking is only performed when the sequence length exceeds this value.
attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
hidden_states: Input hidden states to determine dtype minimum value
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter description lacks explanation of what 'zoh_states' represents functionally. Consider adding a brief description of its purpose in the dynamic masking process.

Suggested change
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
zoh_states: Binary tensor indicating zones of high attention, of shape
(batch_size, num_kv_heads, key_sequence_length). Used to guide the
dynamic masking process by identifying key positions to retain.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 7849829 into main Jul 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants