-
Notifications
You must be signed in to change notification settings - Fork 39
Refactors CUDA implementation to use new interface #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Switches from direct CUDA extension import to standardized function interface for better maintainability and consistency. Simplifies function call signature by removing manual tensor operations and utilizing cleaner parameter passing through the new interface. Adds proper null check to handle cases where the function is unavailable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Refactors the CUDA implementation in the forward performance benchmark to use a standardized function interface instead of direct CUDA extension import. This improves maintainability and consistency by switching from the low-level flash_dmattn_cuda.fwd call to the higher-level flash_dmattn_func interface.
- Switches from direct CUDA extension import to standardized function interface
- Simplifies function call signature by removing manual tensor operations and using cleaner parameter passing
- Adds proper null check to handle cases where the function is unavailable
| attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] | ||
| query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] | ||
| key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | ||
| value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The zoh_states calculation and usage has been removed but the calculate_zoh_states function is still being called on line 241. This will cause the function to compute zoh_states that are no longer used, and the new flash_dmattn_func interface may not handle the dynamic masking behavior that zoh_states were intended to provide.
| query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] | ||
| key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | ||
| value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] |
Copilot
AI
Jul 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The .contiguous() calls have been removed from the tensor operations. If the new flash_dmattn_func interface expects contiguous tensors, this could cause performance issues or errors. Consider adding .contiguous() calls back if the interface requires them.
| query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] | |
| key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | |
| value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | |
| query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] | |
| key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] | |
| value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] |
Switches from direct CUDA extension import to standardized function interface for better maintainability and consistency.
Simplifies function call signature by removing manual tensor operations and utilizing cleaner parameter passing through the new interface.
Adds proper null check to handle cases where the function is unavailable.