From 37171b5595b38767d8b86e0f514dc4370ccf6de5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 21:55:25 +0800 Subject: [PATCH 1/3] Improves code clarity and test coverage Adds type ignore comment to suppress import warnings for CUDA extension Renames parameter from attn_mask to zero_hold_states for better semantic clarity Updates test case to use more realistic sequence length configuration --- benchmarks/benchmark_forward_equivalence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index bc10bec..387c845 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -20,7 +20,7 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda + import flash_dma_cuda # type: ignore print("✅ Successfully imported flash_dma_cuda") except ImportError as e: print(f"❌ Failed to import flash_dma_cuda: {e}") @@ -236,7 +236,7 @@ def dynamic_mask_attention_cuda( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout @@ -377,7 +377,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 64, 64, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 2, 256, 128, True), + (1, 2, 1, 511, 512, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") From 33c67a508a1d04ef74e10dc4bfe0cd469d76bb58 Mon Sep 17 00:00:00 2001 From: Jingze Shi Date: Mon, 30 Jun 2025 21:58:38 +0800 Subject: [PATCH 2/3] Update benchmarks/benchmark_forward_equivalence.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- benchmarks/benchmark_forward_equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 387c845..39e64e0 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -20,7 +20,7 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda # type: ignore + import flash_dma_cuda # type: ignore[import] print("✅ Successfully imported flash_dma_cuda") except ImportError as e: print(f"❌ Failed to import flash_dma_cuda: {e}") From 8d8a7fed7fc956bb060017ab18892cfb062bf66a Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 30 Jun 2025 21:59:19 +0800 Subject: [PATCH 3/3] Adds type ignore comment for CUDA import Suppresses mypy import error for the flash_dma_cuda module to prevent type checking failures when the CUDA extension is not available or not properly configured in the development environment. --- benchmarks/benchmark_forward_performance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index edb9e65..18fe7ef 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -22,7 +22,7 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda + import flash_dma_cuda # type: ignore[import] print("✅ Successfully imported flash_dma_cuda") except ImportError as e: print(f"❌ Failed to import flash_dma_cuda: {e}")