diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index bc10bec..39e64e0 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -20,7 +20,7 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda + import flash_dma_cuda # type: ignore[import] print("✅ Successfully imported flash_dma_cuda") except ImportError as e: print(f"❌ Failed to import flash_dma_cuda: {e}") @@ -236,7 +236,7 @@ def dynamic_mask_attention_cuda( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask + zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] out_tensor, # out: None to auto-allocate 0.0, # p_dropout @@ -377,7 +377,7 @@ def test_forward_equivalence(accuracy_threshold=0.95): (1, 2, 1, 64, 64, 128, True), (1, 2, 1, 128, 128, 128, True), (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 2, 256, 128, True), + (1, 2, 1, 511, 512, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index edb9e65..18fe7ef 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -22,7 +22,7 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda + import flash_dma_cuda # type: ignore[import] print("✅ Successfully imported flash_dma_cuda") except ImportError as e: print(f"❌ Failed to import flash_dma_cuda: {e}")