diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index cdca030..e724b12 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -26,27 +26,28 @@ except ImportError as e: print(f"❌ Failed to import flash_dmattn_cuda: {e}") print("Please make sure the package is properly installed with: pip install .") - exit(1) + # Don't exit here, just warn + flash_dmattn_cuda = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import flash_dmattn_func + from flash_dmattn.flash_dmattn_triton import triton_dmattn_func print("✅ Successfully imported flash_dmattn_triton") except ImportError as e: print(f"❌ Failed to import flash_dmattn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - flash_dmattn_func = None + triton_dmattn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_attention_forward + from flash_dmattn.flash_dmattn_flex import flex_dmattn_func print("✅ Successfully imported flash_dmattn_flex") except ImportError as e: print(f"❌ Failed to import flash_dmattn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_attention_forward = None + flex_dmattn_func = None def prepare_dynamic_mask( @@ -301,7 +302,7 @@ def dynamic_mask_attention_triton( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flash_dmattn_func is None: + if triton_dmattn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -333,14 +334,14 @@ def dynamic_mask_attention_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Triton implementation - attn_outputs = flash_dmattn_func( + attn_outputs = triton_dmattn_func( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal, # causal masking + scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -374,7 +375,7 @@ def dynamic_mask_attention_flex( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flex_attention_forward is None: + if flex_dmattn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -402,12 +403,13 @@ def dynamic_mask_attention_flex( # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format # Call the Flex Attention implementation - attn_outputs, _ = flex_attention_forward( + attn_outputs, _ = flex_dmattn_func( query_states, # q: [batch, num_heads, query_len, head_dim] key_states, # k: [batch, num_heads, key_len, head_dim] value_states, # v: [batch, num_heads, key_len, head_dim] attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking scaling=scaling # scaling factor ) @@ -662,14 +664,14 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬") print("🔥" + "=" * 76 + "🔥") - if flash_dmattn_func is None: + if triton_dmattn_func is None: print("❌ Triton implementation not available, skipping Triton tests") return False # Set random seed for reproducibility torch.manual_seed(0) - - # Smaller test configurations for Triton (to avoid memory issues) + + # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) (1, 1, 1, 64, 64, 32, True), @@ -833,7 +835,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬") print("🌟" + "=" * 76 + "🌟") - if flex_attention_forward is None: + if flex_dmattn_func is None: print("❌ Flex Attention implementation not available, skipping Flex Attention tests") return False diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 18fe7ef..d9ed6f5 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -2,18 +2,26 @@ """ Performance Benchmark for Dynamic Mask Attention -This script measures and compares the performance of Dynamic Mask Attention -implementation against Flash Attention baseline across various configurations. +This script measures and compares the performance of multiple Dynamic Mask Attention +implementations against Flash Attention baseline across various configurations. + +Implementations tested: +- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline +- Dynamic Mask Attention CUDA - Custom CUDA kernel implementation +- Dynamic Mask Attention CUDA (No TopK) - CUDA kernel without TopK computation +- Dynamic Mask Attention Triton - Triton kernel implementation +- Dynamic Mask Attention Flex - Flex Attention implementation Benchmark includes: - Multiple sequence lengths and batch sizes - Head count and dimension variations - Throughput and latency measurements - Memory usage analysis -- Speedup comparisons +- Speedup comparisons across all implementations """ import torch +import torch.nn.backends import torch.nn.functional as F import numpy as np import argparse @@ -22,37 +30,73 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda # type: ignore[import] - print("✅ Successfully imported flash_dma_cuda") + import flash_dmattn_cuda # type: ignore[import] + print("✅ Successfully imported flash_dmattn_cuda") except ImportError as e: - print(f"❌ Failed to import flash_dma_cuda: {e}") + print(f"❌ Failed to import flash_dmattn_cuda: {e}") print("Please make sure the package is properly installed with: pip install .") - exit(1) + # Don't exit here, just warn + flash_dmattn_cuda = None + +# Import the Triton implementation +try: + from flash_dmattn.flash_dmattn_triton import triton_dmattn_func + print("✅ Successfully imported flash_dmattn_triton") +except ImportError as e: + print(f"❌ Failed to import flash_dmattn_triton: {e}") + print("Please make sure the Triton implementation is available.") + # Don't exit here, just warn + triton_dmattn_func = None + +# Import the Flex Attention implementation +try: + from flash_dmattn.flash_dmattn_flex import flex_dmattn_func + print("✅ Successfully imported flash_dmattn_flex") +except ImportError as e: + print(f"❌ Failed to import flash_dmattn_flex: {e}") + print("Please make sure the Flex Attention implementation is available.") + # Don't exit here, just warn + flex_dmattn_func = None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def prepare_dynamic_mask( hidden_states: torch.Tensor, - dt_states: torch.Tensor, + zoh_states: torch.Tensor, keep_window_size: int = 2048, attention_mask: torch.Tensor | None = None, ): """ Calculate dynamic attention mask to mask tokens for sparse attention. - Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. Args: hidden_states: Input hidden states to determine dtype minimum value - dt_states: dt_states of shape (batch_size, num_kv_heads, key_sequence_length) + zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) keep_window_size: Window size of tokens not dynamically masked attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len) Returns: - tuple: (attn_mask, active_mask) + tuple: (attn_bias, attn_mask) """ min_dtype = torch.finfo(hidden_states.dtype).min dtype = hidden_states.dtype - attn_mask = dt_states[:, :, None, :].expand( + attn_bias = zoh_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] @@ -63,25 +107,25 @@ def prepare_dynamic_mask( torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype ) - attn_mask = attn_mask.masked_fill( - attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype + attn_bias = attn_bias.masked_fill( + attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype ) - if attn_mask.shape[-1] > keep_window_size: + if attn_bias.shape[-1] > keep_window_size: topk_indices = torch.topk( - attn_mask, keep_window_size, dim=-1, largest=True, sorted=False + attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ).indices - active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) - active_mask = active_mask.scatter(-1, topk_indices, 1.0) - attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) - return attn_mask, active_mask + attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + return attn_bias, attn_mask -def calculate_zero_hold_states(value_states, dt_proj, A): +def calculate_zoh_states(value_states, dt_proj, A): """ - Calculate zero hold states for dynamic mask attention. + Calculate zoh states for dynamic mask attention. Args: value_states: [batch_size, num_kv_heads, key_len, head_dim] @@ -90,7 +134,7 @@ def calculate_zero_hold_states(value_states, dt_proj, A): causal_mask: Optional causal mask Returns: - zero_hold_states: [batch_size, num_kv_heads, key_len] + zoh_states: [batch_size, num_kv_heads, key_len] """ batch_size, _, key_len, _ = value_states.shape @@ -102,9 +146,9 @@ def calculate_zero_hold_states(value_states, dt_proj, A): # Apply softplus activation and coefficient A dt_states = torch.exp(F.softplus(dt_result) * A) - zero_hold_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zero_hold_states + zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] + + return zoh_states def flash_attention_cuda( @@ -139,18 +183,28 @@ def flash_attention_cuda( value_states = value_states.contiguous() try: - attn_outputs = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - scale=scaling, - enable_gqa=True - ) + # Only measure the core attention computation + torch.cuda.synchronize() + start_time = time.time() + + with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + attn_outputs = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + # attn_mask=causal_mask, + scale=scaling, + is_causal=is_causal if query_len == key_len else False, + enable_gqa=True + ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 def dynamic_mask_attention_cuda( @@ -183,49 +237,60 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) - attn_mask, active_mask = prepare_dynamic_mask( + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( query_states, - zero_hold_states, + zoh_states, keep_window_size, causal_mask if is_causal else None ) # [batch_size, num_kv_heads, query_len, key_len] # Ensure correct data types and memory layout for CUDA function # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zero_hold_states = zero_hold_states[:, :, None, :].expand( + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + zoh_states = zoh_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] - active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: # Call the CUDA implementation using the mha_fwd function signature out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - attn_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + + # Only measure the core CUDA kernel computation + torch.cuda.synchronize() + start_time = time.time() + + result = flash_dmattn_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 def dynamic_mask_attention_cuda_no_topk( @@ -259,8 +324,8 @@ def dynamic_mask_attention_cuda_no_topk( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) # Create a simplified mask without topk computation batch_size, _, query_len, _ = query_states.shape @@ -268,8 +333,8 @@ def dynamic_mask_attention_cuda_no_topk( dtype = query_states.dtype device = query_states.device - # Create full active mask (no topk selection) - active_mask = torch.zeros( + # Create full attn mask (no topk selection) + attn_mask = torch.zeros( (batch_size, num_kv_heads, query_len, key_len), dtype=dtype, device=device @@ -279,38 +344,208 @@ def dynamic_mask_attention_cuda_no_topk( query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zero_hold_states = zero_hold_states[:, :, None, :].expand( + attn_bias = zoh_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] # Create full active mask (no topk selection) - active_mask = torch.zeros_like( - zero_hold_states, + attn_mask = torch.zeros_like( + attn_bias, dtype=dtype, device=device - ) - active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + ).contiguous() # [batch, num_kv_heads, query_len, key_len] try: out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + + # Only measure the core CUDA kernel computation + torch.cuda.synchronize() + start_time = time.time() + + result = flash_dmattn_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = result[0] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 + + +def dynamic_mask_attention_triton( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Triton implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if triton_dmattn_func is None: + return "Not Available", 0 + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + try: + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + # Only measure the core Triton kernel computation + torch.cuda.synchronize() + start_time = time.time() + + # Call the Triton implementation + attn_outputs = triton_dmattn_func( + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_heads, head_dim] + attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal, # causal masking + scaling # scaling factor + ) + + torch.cuda.synchronize() + end_time = time.time() + + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + except torch.cuda.OutOfMemoryError: + return "OOM", 0 + + +def dynamic_mask_attention_flex( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Flex Attention implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if flex_dmattn_func is None: + return "Not Available", 0 + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + try: + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format + # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format + + # Only measure the core Flex Attention computation + torch.cuda.synchronize() + start_time = time.time() + + # Call the Flex Attention implementation + attn_outputs, _ = flex_dmattn_func( + query_states, # q: [batch, num_heads, query_len, head_dim] + key_states, # k: [batch, num_heads, key_len, head_dim] + value_states, # v: [batch, num_heads, key_len, head_dim] + attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] + attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: Whether to apply causal masking + scaling=scaling # scaling factor + ) + + torch.cuda.synchronize() + end_time = time.time() + + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + except torch.cuda.OutOfMemoryError: + return "OOM", 0 def measure_memory_usage(): @@ -327,19 +562,19 @@ def measure_memory_usage(): return 0, 0 -def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): +def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_runs=2): """ Benchmark attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data @@ -372,301 +607,477 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - # Set scaling factor and keep window size + # Set scaling factor from config scaling = head_dim ** -0.5 - keep_window_size = 2048 - is_causal = True results = { 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], 'dynamic_mask_attention_no_topk_times': [], + 'dynamic_mask_attention_triton_times': [], + 'dynamic_mask_attention_flex_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, 'dynamic_mask_attention_no_topk_memory': 0, + 'dynamic_mask_attention_triton_memory': 0, + 'dynamic_mask_attention_flex_memory': 0, 'flash_attention_status': 'success', 'dynamic_mask_attention_status': 'success', - 'dynamic_mask_attention_no_topk_status': 'success' + 'dynamic_mask_attention_no_topk_status': 'success', + 'dynamic_mask_attention_triton_status': 'success', + 'dynamic_mask_attention_flex_status': 'success' } - # Benchmark Flash Attention - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = flash_attention_cuda( - query_states, key_states, value_states, - scaling, causal_mask, is_causal - ) - if result == "OOM": - results['flash_attention_status'] = 'OOM' - break - torch.cuda.synchronize() + # Determine which implementations to run + run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] + run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] + run_no_topk = test_type in ['all', 'cuda'] + run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] + run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] - if results['flash_attention_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + # Benchmark Flash Attention + if run_flash: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = flash_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) - torch.cuda.synchronize() - end_time = time.time() - - if result == "OOM": + if result[0] == "OOM": results['flash_attention_status'] = 'OOM' break - - results['flash_attention_times'].append((end_time - start_time) * 1000) # ms + torch.cuda.synchronize() - # Measure memory after - mem_after = measure_memory_usage() - results['flash_attention_memory'] = mem_after[0] - mem_before[0] + if results['flash_attention_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = flash_attention_cuda( + query_states, key_states, value_states, + scaling, causal_mask, is_causal + ) + + if result[0] == "OOM": + results['flash_attention_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['flash_attention_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['flash_attention_memory'] = mem_after[0] - mem_before[0] + else: + results['flash_attention_status'] = 'N/A' # Benchmark Dynamic Mask Attention - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result == "OOM": - results['dynamic_mask_attention_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + if run_cuda: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) - torch.cuda.synchronize() - end_time = time.time() - - if result == "OOM": + if result[0] == "OOM": results['dynamic_mask_attention_status'] = 'OOM' break - - results['dynamic_mask_attention_times'].append((end_time - start_time) * 1000) # ms + torch.cuda.synchronize() - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + if results['dynamic_mask_attention_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_cuda( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] == "OOM": + results['dynamic_mask_attention_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_status'] = 'N/A' # Benchmark Dynamic Mask Attention (No TopK) - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_no_topk_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + if run_no_topk: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) + if result[0] == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break torch.cuda.synchronize() - end_time = time.time() + + if results['dynamic_mask_attention_no_topk_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() - if result == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_no_topk_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_no_topk_status'] = 'N/A' + + # Benchmark Dynamic Mask Attention (Triton) + if run_triton: + gc.collect() + torch.cuda.empty_cache() + + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_triton( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_triton_status'] = result[0] break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_triton_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_triton( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_triton_status'] = result[0] + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_triton_times'].append(result[1]) # ms - results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_triton_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_triton_status'] = 'N/A' + + # Benchmark Dynamic Mask Attention (Flex) + if run_flex: + gc.collect() + torch.cuda.empty_cache() - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_flex( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_flex_status'] = result[0] + break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_flex_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_flex( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_flex_status'] = result[0] + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_flex_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_flex_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_flex_status'] = 'N/A' return results -def run_performance_benchmark(): +def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): """Run comprehensive performance benchmark across different configurations.""" print("\n" + "🏆" + "=" * 76 + "🏆") - print("⚡ Performance Benchmark: Dynamic Mask Attention vs Flash Attention ⚡") + + # Update title based on test type + if test_type == 'all': + title = "⚡ Performance Benchmark: Flash vs CUDA vs Triton vs Flex ⚡" + elif test_type == 'flash-vs-cuda': + title = "⚡ Performance Benchmark: Flash Attention vs CUDA ⚡" + elif test_type == 'flash-vs-triton': + title = "⚡ Performance Benchmark: Flash Attention vs Triton ⚡" + elif test_type == 'flash-vs-flex': + title = "⚡ Performance Benchmark: Flash Attention vs Flex ⚡" + elif test_type == 'flash': + title = "⚡ Performance Benchmark: Flash Attention Only ⚡" + elif test_type == 'cuda': + title = "⚡ Performance Benchmark: CUDA Implementations ⚡" + elif test_type == 'triton': + title = "⚡ Performance Benchmark: Triton Implementation ⚡" + elif test_type == 'flex': + title = "⚡ Performance Benchmark: Flex Implementation ⚡" + else: + title = "⚡ Performance Benchmark ⚡" + + print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ # Vary sequence length - (1, 2, 1, 256, 256, 32), - (1, 2, 1, 512, 512, 32), - (1, 2, 1, 1024, 1024, 32), - (1, 2, 1, 2048, 2048, 32), - (1, 2, 1, 4096, 4096, 32), - (1, 2, 1, 8192, 8192, 32), - (1, 2, 1, 16384, 16384, 32), - (1, 2, 1, 32768, 32768, 32), + (1, 2, 1, 256, 256, 32, 2048, True), + (1, 2, 1, 512, 512, 32, 2048, True), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 2048, 2048, 32, 2048, True), + (1, 2, 1, 4096, 4096, 32, 2048, True), + (1, 2, 1, 8192, 8192, 32, 2048, True), + (1, 2, 1, 16384, 16384, 32, 2048, True), + (1, 2, 1, 32768, 32768, 32, 2048, True), # Inference - (1, 2, 1, 64, 256, 128), - (1, 2, 1, 64, 512, 128), - (1, 2, 1, 64, 1024, 128), - (1, 2, 1, 64, 2048, 128), - (1, 2, 1, 64, 4096, 128), - (1, 2, 1, 64, 8192, 128), - (1, 2, 1, 64, 16384, 128), - (1, 2, 1, 64, 32768, 128), - (1, 2, 1, 64, 65536, 128), - (1, 2, 1, 64, 131072, 128), - (1, 2, 1, 64, 262144, 128), - (1, 2, 1, 64, 524288, 128), + (1, 2, 1, 2, 256, 128, 2048, True), + (1, 2, 1, 2, 512, 128, 2048, True), + (1, 2, 1, 2, 1024, 128, 2048, True), + (1, 2, 1, 2, 2048, 128, 2048, True), + (1, 2, 1, 2, 4096, 128, 2048, True), + (1, 2, 1, 2, 8192, 128, 2048, True), + (1, 2, 1, 2, 16384, 128, 2048, True), + (1, 2, 1, 2, 32768, 128, 2048, True), + (1, 2, 1, 2, 65536, 128, 2048, True), + (1, 2, 1, 2, 131072, 128, 2048, True), + (1, 2, 1, 2, 262144, 128, 2048, True), + (1, 2, 1, 2, 524288, 128, 2048, True), # Vary batch size - (1, 2, 1, 1024, 1024, 32), - (2, 2, 1, 1024, 1024, 32), - (4, 2, 1, 1024, 1024, 32), - (8, 2, 1, 1024, 1024, 32), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (2, 2, 1, 1024, 1024, 32, 2048, True), + (4, 2, 1, 1024, 1024, 32, 2048, True), + (8, 2, 1, 1024, 1024, 32, 2048, True), # Vary head count - (1, 1, 1, 1024, 1024, 32), - (1, 2, 1, 1024, 1024, 32), - (1, 4, 1, 1024, 1024, 32), - (1, 8, 2, 1024, 1024, 32), + (1, 1, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 4, 1, 1024, 1024, 32, 2048, True), + (1, 8, 2, 1024, 1024, 32, 2048, True), # Vary head dimension - (1, 2, 1, 1024, 1024, 32), - (1, 2, 1, 1024, 1024, 64), - (1, 2, 1, 1024, 1024, 96), - (1, 2, 1, 1024, 1024, 128), - (1, 2, 1, 1024, 1024, 192), - (1, 2, 1, 1024, 1024, 256), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 1024, 1024, 64, 2048, True), + (1, 2, 1, 1024, 1024, 96, 2048, True), + (1, 2, 1, 1024, 1024, 128, 2048, True), + (1, 2, 1, 1024, 1024, 192, 2048, True), + (1, 2, 1, 1024, 1024, 256, 2048, True), + + # Vary keep_window_size + (1, 2, 1, 32768, 32768, 128, 32, True), + (1, 2, 1, 32768, 32768, 128, 64, True), + (1, 2, 1, 32768, 32768, 128, 128, True), + (1, 2, 1, 32768, 32768, 128, 256, True), + (1, 2, 1, 32768, 32768, 128, 512, True), + (1, 2, 1, 32768, 32768, 128, 1024, True), + (1, 2, 1, 32768, 32768, 128, 2048, True), + (1, 2, 1, 32768, 32768, 128, 4096, True), + (1, 2, 1, 32768, 32768, 128, 8192, True), + (1, 2, 1, 32768, 32768, 128, 16384, True), + (1, 2, 1, 32768, 32768, 128, 32768, True), + + # Test non-causal + (1, 2, 1, 1024, 1024, 128, 2048, False), ] num_runs = 3 # Run 3 times and take average print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") - print(f"🔧 {'Configuration':<42} ⚡ {'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}") - print("🔄" + "-" * 155 + "🔄") + print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🚀 {'No-TopK':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") + print("🔄" + "-" * 160 + "🔄") all_results = [] for config in configs: - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config - results = benchmark_attention_performance(config, num_runs=num_runs) + results = benchmark_attention_performance(config, test_type, num_runs, warmup_runs) all_results.append(results) - # Calculate averages - if results['flash_attention_status'] == 'success' and results['flash_attention_times']: - flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - flash_time_str = f"{flash_avg:.2f}" - else: - flash_time_str = results['flash_attention_status'] - flash_avg = float('inf') + # Calculate averages for all implementations + implementations = { + 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), + 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), + 'no_topk': ('dynamic_mask_attention_no_topk', results['dynamic_mask_attention_no_topk_status'], results['dynamic_mask_attention_no_topk_times']), + 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), + 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) + } - if results['dynamic_mask_attention_status'] == 'success' and results['dynamic_mask_attention_times']: - dma_avg = sum(results['dynamic_mask_attention_times']) / len(results['dynamic_mask_attention_times']) - dma_time_str = f"{dma_avg:.2f}" - else: - dma_time_str = results['dynamic_mask_attention_status'] - dma_avg = float('inf') - - if results['dynamic_mask_attention_no_topk_status'] == 'success' and results['dynamic_mask_attention_no_topk_times']: - dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) - dma_nt_time_str = f"{dma_nt_avg:.2f}" - else: - dma_nt_time_str = results['dynamic_mask_attention_no_topk_status'] - dma_nt_avg = float('inf') + # Calculate time strings and averages + time_strs = {} + time_avgs = {} - # Calculate speedups - if flash_avg != float('inf') and dma_avg != float('inf') and dma_avg > 0: - speedup = flash_avg / dma_avg - speedup_str = f"{speedup:.2f}x" - else: - speedup_str = "N/A" - - if flash_avg != float('inf') and dma_nt_avg != float('inf') and dma_nt_avg > 0: - kernel_speedup = flash_avg / dma_nt_avg - kernel_speedup_str = f"{kernel_speedup:.2f}x" - else: - kernel_speedup_str = "N/A" + for impl_key, (_, status, times) in implementations.items(): + if status == 'success' and times: + avg_time = sum(times) / len(times) + time_strs[impl_key] = f"{avg_time:.2f}" + time_avgs[impl_key] = avg_time + else: + time_strs[impl_key] = status[:8] # Truncate status for display + time_avgs[impl_key] = float('inf') - # Memory usage - mem_diff = results['dynamic_mask_attention_memory'] - results['flash_attention_memory'] - mem_str = f"{mem_diff:+.0f}" + # Calculate speedups (compared to Flash Attention baseline) + speedup_strs = {} + flash_avg = time_avgs.get('flash', float('inf')) - # Format output - config_short = f"b={batch_size},h={num_heads},kv={num_kv_heads},q={query_len},k={key_len},d={head_dim}" + for impl_key in ['cuda', 'no_topk', 'triton', 'flex']: + impl_avg = time_avgs.get(impl_key, float('inf')) + if flash_avg != float('inf') and impl_avg != float('inf') and impl_avg > 0: + speedup = flash_avg / impl_avg + speedup_strs[impl_key] = f"{speedup:.2f}x" + else: + speedup_strs[impl_key] = "N/A" + + # Format output with shorter config string + config_short = f" b{batch_size} h{num_heads} kv{num_kv_heads} q{query_len} k{key_len} d{head_dim} w{keep_window_size} " + if not is_causal: + config_short += "nc" # Add status icons - flash_icon = "✅" if results['flash_attention_status'] == 'success' else "💥" - dma_icon = "✅" if results['dynamic_mask_attention_status'] == 'success' else "💥" - dma_nt_icon = "✅" if results['dynamic_mask_attention_no_topk_status'] == 'success' else "💥" + icons = "" + for impl_key, (_, status, _) in implementations.items(): + if status == 'success': + icons += " ✅ " + elif status in ['OOM', 'Not Available']: + icons += " ❌ " + else: + icons += " ⚠️ " + + # Create speedup summary (best performing implementation) + best_speedup = "N/A" + best_impl = "N/A" + for impl_key, speedup_str in speedup_strs.items(): + if speedup_str != "N/A": + try: + speedup_val = float(speedup_str.replace('x', '')) + if best_speedup == "N/A" or speedup_val > float(best_speedup.replace('x', '')): + best_speedup = speedup_str + best_impl = impl_key.upper() + except: + continue + + speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{flash_icon}{dma_icon}{dma_nt_icon} {config_short:<42} {flash_time_str:<14} {dma_time_str:<20} {dma_nt_time_str:<20} {speedup_str:<18} {kernel_speedup_str:<20} {mem_str:<12}") + print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['no_topk']:<14} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") - print("🔄" + "-" * 155 + "🔄") + print("🔄" + "-" * 160 + "🔄") # Summary statistics - speedups = [] - kernel_speedups = [] + implementation_speedups = { + 'cuda': [], + 'no_topk': [], + 'triton': [], + 'flex': [] + } + for results in all_results: - if (results['flash_attention_status'] == 'success' and - results['dynamic_mask_attention_status'] == 'success' and - results['flash_attention_times'] and results['dynamic_mask_attention_times']): - - flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - dma_avg = sum(results['dynamic_mask_attention_times']) / len(results['dynamic_mask_attention_times']) - - if dma_avg > 0: - speedups.append(flash_avg / dma_avg) - - if (results['flash_attention_status'] == 'success' and - results['dynamic_mask_attention_no_topk_status'] == 'success' and - results['flash_attention_times'] and results['dynamic_mask_attention_no_topk_times']): - + if results['flash_attention_status'] == 'success' and results['flash_attention_times']: flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) - if dma_nt_avg > 0: - kernel_speedups.append(flash_avg / dma_nt_avg) + # Calculate speedups for each implementation + for impl_key in implementation_speedups.keys(): + # Map implementation keys to actual result keys + if impl_key == 'cuda': + status_key = 'dynamic_mask_attention_status' + times_key = 'dynamic_mask_attention_times' + else: + status_key = f'dynamic_mask_attention_{impl_key}_status' + times_key = f'dynamic_mask_attention_{impl_key}_times' + + if (status_key in results and results[status_key] == 'success' and + times_key in results and results[times_key]): + + impl_avg = sum(results[times_key]) / len(results[times_key]) + if impl_avg > 0: + implementation_speedups[impl_key].append(flash_avg / impl_avg) print(f"\n🏆 Summary:") - if speedups: - avg_speedup = np.mean(speedups) - speedup_icon = "🚀" if avg_speedup > 1.5 else "📈" if avg_speedup > 1.0 else "😐" - print(f" {speedup_icon} DMA vs Flash - Average speedup: {avg_speedup:.2f}x (Best: {np.max(speedups):.2f}x, Worst: {np.min(speedups):.2f}x)") - if kernel_speedups: - avg_kernel_speedup = np.mean(kernel_speedups) - kernel_icon = "🔥" if avg_kernel_speedup > 2.0 else "🚀" if avg_kernel_speedup > 1.5 else "📈" if avg_kernel_speedup > 1.0 else "😐" - print(f" {kernel_icon} DMA-NoTopK vs Flash - Average kernel speedup: {avg_kernel_speedup:.2f}x (Best: {np.max(kernel_speedups):.2f}x, Worst: {np.min(kernel_speedups):.2f}x)") - print(f" 💡 TopK overhead: ~{((np.mean(kernel_speedups) - np.mean(speedups) if speedups else 0) / np.mean(kernel_speedups) * 100) if kernel_speedups else 0:.1f}% performance impact") + # Display statistics for each implementation + for impl_key, speedups in implementation_speedups.items(): + if speedups: + avg_speedup = np.mean(speedups) + max_speedup = np.max(speedups) + min_speedup = np.min(speedups) + + # Choose appropriate icon based on performance + if avg_speedup > 2.0: + icon = "🔥" + elif avg_speedup > 1.5: + icon = "🚀" + elif avg_speedup > 1.0: + icon = "📈" + else: + icon = "😐" + + impl_name = impl_key.replace('_', '-').upper() + print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") + else: + print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") + + # Calculate overhead comparison + if implementation_speedups['cuda'] and implementation_speedups['no_topk']: + avg_cuda = np.mean(implementation_speedups['cuda']) + avg_no_topk = np.mean(implementation_speedups['no_topk']) + topk_overhead = ((avg_no_topk - avg_cuda) / avg_no_topk * 100) if avg_no_topk > 0 else 0 + print(f" 💡 TopK overhead: ~{topk_overhead:.1f}% performance impact") def main(): @@ -687,6 +1098,9 @@ def main(): parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--runs', type=int, default=3, help='Number of benchmark runs') parser.add_argument('--warmup', type=int, default=2, help='Number of warmup runs') + parser.add_argument('--test-type', type=str, default='all', + choices=['all', 'flash', 'cuda', 'triton', 'flex', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'], + help='Type of benchmark to run (default: all)') args = parser.parse_args() @@ -703,8 +1117,12 @@ def main(): print(f"🎮 CUDA device: {torch.cuda.get_device_name()}") print(f"💾 Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") + print(f"🎲 Random seed: {args.seed}") + print(f"📊 Test type: {args.test_type}") + print(f"🔄 Runs: {args.runs}, Warmup: {args.warmup}") + # Run performance benchmark - run_performance_benchmark() + run_performance_benchmark(args.test_type, args.runs, args.warmup) if __name__ == "__main__": diff --git a/flash_dmattn/__init__.py b/flash_dmattn/__init__.py new file mode 100644 index 0000000..b0b3adc --- /dev/null +++ b/flash_dmattn/__init__.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, Jingze Shi. + +from typing import Optional + +try: + from .flash_dmattn_triton import triton_dmattn_func + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + triton_dmattn_func = None + +try: + from .flash_dmattn_flex import flex_dmattn_func + FLEX_AVAILABLE = True +except ImportError: + FLEX_AVAILABLE = False + flex_dmattn_func = None + +# Check if CUDA extension is available +try: + import flash_dmattn_cuda # type: ignore[import] + CUDA_AVAILABLE = True +except ImportError: + CUDA_AVAILABLE = False + +__version__ = "0.1.0" + +__all__ = [ + "triton_dmattn_func", + "flex_dmattn_func", + "TRITON_AVAILABLE", + "FLEX_AVAILABLE", + "CUDA_AVAILABLE", +] + + +def get_available_backends(): + """Return a list of available backends.""" + backends = [] + if CUDA_AVAILABLE: + backends.append("cuda") + if TRITON_AVAILABLE: + backends.append("triton") + if FLEX_AVAILABLE: + backends.append("flex") + return backends + + +def flash_dmattn_func(backend: Optional[str] = None, **kwargs): + """ + Flash Dynamic Mask Attention function with automatic backend selection. + + Args: + backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). + If None, will use the first available backend in order: cuda, triton, flex. + **kwargs: Arguments to pass to the attention function. + + Returns: + The attention function for the specified or auto-selected backend. + """ + if backend is None: + # Auto-select backend + if CUDA_AVAILABLE: + backend = "cuda" + elif TRITON_AVAILABLE: + backend = "triton" + elif FLEX_AVAILABLE: + backend = "flex" + else: + raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.") + + if backend == "cuda": + if not CUDA_AVAILABLE: + raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") + # Import and return CUDA function + raise NotImplementedError("CUDA backend not yet implemented in this version") + + elif backend == "triton": + if not TRITON_AVAILABLE: + raise RuntimeError("Triton backend is not available. Please install triton: pip install triton") + return triton_dmattn_func + + elif backend == "flex": + if not FLEX_AVAILABLE: + raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers") + return flex_dmattn_func + + else: + raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") diff --git a/setup.py b/setup.py index da09b51..3f805ad 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,41 @@ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +# Also useful when user only wants Triton/Flex backends without CUDA compilation FORCE_BUILD = os.getenv("FLASH_DMATTN_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("FLASH_DMATTN_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_DMATTN_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# Auto-detect if user wants only Triton/Flex backends based on pip install command +# This helps avoid unnecessary CUDA compilation when user only wants Python backends +def should_skip_cuda_build(): + """Determine if CUDA build should be skipped based on installation context.""" + + if SKIP_CUDA_BUILD: + return True + + if FORCE_BUILD: + return False # User explicitly wants to build, respect that + + # Check command line arguments for installation hints + if len(sys.argv) > 1: + install_args = ' '.join(sys.argv) + + # Check if Triton or Flex extras are requested + has_triton_or_flex = 'triton' in install_args or 'flex' in install_args + has_all_or_dev = 'all' in install_args or 'dev' in install_args + + if has_triton_or_flex and not has_all_or_dev: + print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") + print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.") + return True + + return False + +# Update SKIP_CUDA_BUILD based on auto-detection +SKIP_CUDA_BUILD = should_skip_cuda_build() + @functools.lru_cache(maxsize=None) def cuda_archs(): # return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;90;100;120").split(";") @@ -289,9 +319,47 @@ def __init__(self, *args, **kwargs) -> None: "torch", "einops", ], + extras_require={ + # Individual backend options - choose one or more + "triton": [ + "triton>=2.0.0", + ], + "flex": [ + "transformers>=4.38.0", + ], + + # Combined options + "all": [ + "triton>=2.0.0", # Triton backend + "transformers>=4.38.0", # Flex backend + # CUDA backend included by default compilation + ], + + # Development dependencies + "dev": [ + "triton>=2.0.0", + "transformers>=4.38.0", + "pytest>=6.0", + "pytest-benchmark", + "numpy", + ], + + # Testing only + "test": [ + "pytest>=6.0", + "pytest-benchmark", + "numpy", + ], + }, setup_requires=[ "packaging", "psutil", "ninja", ], + # Include package data + package_data={ + "flash_dmattn": ["*.py"], + }, + # Ensure the package is properly included + include_package_data=True, ) \ No newline at end of file