Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Switches from direct CUDA extension import to standardized function interface for better maintainability and consistency.

Simplifies function call signature by removing manual tensor operations and utilizing cleaner parameter passing through the new interface.

Adds proper null check to handle cases where the function is unavailable.

Switches from direct CUDA extension import to standardized function interface for better maintainability and consistency.

Simplifies function call signature by removing manual tensor operations and utilizing cleaner parameter passing through the new interface.

Adds proper null check to handle cases where the function is unavailable.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Refactors the CUDA implementation in the forward performance benchmark to use a standardized function interface instead of direct CUDA extension import. This improves maintainability and consistency by switching from the low-level flash_dmattn_cuda.fwd call to the higher-level flash_dmattn_func interface.

  • Switches from direct CUDA extension import to standardized function interface
  • Simplifies function call signature by removing manual tensor operations and using cleaner parameter passing
  • Adds proper null check to handle cases where the function is unavailable

attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len]
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zoh_states calculation and usage has been removed but the calculate_zoh_states function is still being called on line 241. This will cause the function to compute zoh_states that are no longer used, and the new flash_dmattn_func interface may not handle the dynamic masking behavior that zoh_states were intended to provide.

Copilot uses AI. Check for mistakes.
Comment on lines +253 to +255
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .contiguous() calls have been removed from the tensor operations. If the new flash_dmattn_func interface expects contiguous tensors, this could cause performance issues or errors. Consider adding .contiguous() calls back if the interface requires them.

Suggested change
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim]
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit ef11ca8 into main Jul 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants