-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor CUDA interface for improved usability #80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,13 +21,13 @@ | |
|
|
||
| # Import the compiled CUDA extension | ||
| try: | ||
| import flash_dmattn_cuda # type: ignore[import] | ||
| print("✅ Successfully imported flash_dmattn_cuda") | ||
| from flash_dmattn.flash_dmattn_interface import flash_dmattn_func | ||
| print("✅ Successfully imported flash_dmattn interface") | ||
| except ImportError as e: | ||
| print(f"❌ Failed to import flash_dmattn_cuda: {e}") | ||
| print(f"❌ Failed to import flash_dmattn interface: {e}") | ||
| print("Please make sure the package is properly installed with: pip install .") | ||
| # Don't exit here, just warn | ||
| flash_dmattn_cuda = None | ||
| flash_dmattn_func = None | ||
|
|
||
| # Import the Triton implementation | ||
| try: | ||
|
|
@@ -229,6 +229,8 @@ def dynamic_mask_attention_cuda( | |
| Returns: | ||
| attn_outputs: [batch_size, query_len, num_heads, head_dim] | ||
| """ | ||
| if flash_dmattn_func is None: | ||
| raise RuntimeError("flash_dmattn_func not available") | ||
|
|
||
| # Calculate zoh_states | ||
| zoh_states = calculate_zoh_states(value_states, dt_proj, A) | ||
|
|
@@ -243,35 +245,26 @@ def dynamic_mask_attention_cuda( | |
|
|
||
| # Ensure correct data types and memory layout for CUDA function | ||
| # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format | ||
| query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] | ||
| key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] | ||
| value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] | ||
| zoh_states = zoh_states[:, :, None, :].expand( | ||
| -1, -1, query_states.shape[1], -1 | ||
| ).contiguous() # [batch, num_kv_heads, query_len, key_len] | ||
| attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len] | ||
| attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] | ||
| query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim] | ||
| key_states = key_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | ||
| value_states = value_states.transpose(1, 2) # [batch, key_len, num_kv_heads, head_dim] | ||
|
||
|
|
||
| # Call the CUDA implementation using the mha_fwd function signature | ||
| out_tensor = None # Let the function allocate the output tensor | ||
| result = flash_dmattn_cuda.fwd( # type: ignore | ||
| query_states, # q: [batch, seqlen_q, num_heads, head_dim] | ||
| key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] | ||
| value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] | ||
| attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] | ||
| attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] | ||
| out_tensor, # out: None to auto-allocate | ||
| 0.0, # p_dropout | ||
| scaling, # softmax_scale | ||
| is_causal, # is_causal | ||
| keep_window_size, # keep_window_size | ||
| 0.0, # softcap | ||
| return_softmax, # return_softmax | ||
| None # gen (generator) | ||
| # Call the new flash_dmattn_func interface | ||
| attn_outputs = flash_dmattn_func( | ||
| query_states, # [batch, query_len, num_heads, head_dim] | ||
| key_states, # [batch, key_len, num_kv_heads, head_dim] | ||
| value_states, # [batch, key_len, num_kv_heads, head_dim] | ||
| attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] | ||
| attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] | ||
| dropout_p=0.0, | ||
| softmax_scale=scaling, | ||
| is_causal=is_causal, | ||
| softcap=0.0, | ||
| deterministic=True, | ||
| return_attn_probs=return_softmax | ||
| ) | ||
|
|
||
| attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] | ||
| return attn_outputs | ||
| return attn_outputs # [batch, query_len, num_heads, head_dim] | ||
|
|
||
|
|
||
| def dynamic_mask_attention_triton( | ||
|
|
@@ -514,6 +507,11 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): | |
| print("🔬 Testing Forward Pass Equivalence: Python Prototype vs CUDA Implementation 🔬") | ||
| print("🚀" + "=" * 76 + "🚀") | ||
|
|
||
| # Check if CUDA implementation is available | ||
| if flash_dmattn_func is None: | ||
| print("❌ CUDA implementation not available, skipping test.") | ||
| return False | ||
|
|
||
| # Set random seed for reproducibility | ||
| torch.manual_seed(0) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing .contiguous() calls may cause performance issues if the tensors are not contiguous in memory. CUDA kernels typically require contiguous tensors for optimal performance. Consider adding .contiguous() back or verify that the new interface handles non-contiguous tensors efficiently.