From 104c56c152ae30141529d27d9191454c44dc1f07 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 18:58:39 +0800 Subject: [PATCH 1/5] Adds optional backend selection and smart CUDA build skipping Introduces extras_require configuration to allow users to install specific backends (triton, flex, all, dev, test) without requiring full CUDA compilation. Implements auto-detection logic that skips CUDA build when users explicitly request only Triton or Flex backends, reducing installation time and complexity for users who don't need CUDA acceleration. Maintains backward compatibility while providing more granular control over dependencies and build processes. --- setup.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/setup.py b/setup.py index da09b51..f224028 100644 --- a/setup.py +++ b/setup.py @@ -39,11 +39,42 @@ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +# Also useful when user only wants Triton/Flex backends without CUDA compilation FORCE_BUILD = os.getenv("FLASH_DMATTN_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("FLASH_DMATTN_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_DMATTN_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# Auto-detect if user wants only Triton/Flex backends based on pip install command +# This helps avoid unnecessary CUDA compilation when user only wants Python backends +def should_skip_cuda_build(): + """Determine if CUDA build should be skipped based on installation context.""" + + if SKIP_CUDA_BUILD: + return True + + if FORCE_BUILD: + return False # User explicitly wants to build, respect that + + # Check command line arguments for installation hints + if len(sys.argv) > 1: + install_args = ' '.join(sys.argv) + + # If user specifically asks for triton or flex only (not all/dev), skip CUDA + has_triton_or_flex = '.[triton]' in install_args or '.[flex]' in install_args or '[triton]' in install_args or '[flex]' in install_args or 'triton,' in install_args or ',flex' in install_args + has_all_or_dev = '[all]' in install_args or '[dev]' in install_args + has_plain_install = install_args.endswith('flash_dmattn') or install_args.endswith('.') + + if has_triton_or_flex and not has_all_or_dev and not has_plain_install: + print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") + print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.") + return True + + return False + +# Update SKIP_CUDA_BUILD based on auto-detection +SKIP_CUDA_BUILD = should_skip_cuda_build() + @functools.lru_cache(maxsize=None) def cuda_archs(): # return os.getenv("FLASH_DMATTN_CUDA_ARCHS", "80;90;100;120").split(";") @@ -289,9 +320,47 @@ def __init__(self, *args, **kwargs) -> None: "torch", "einops", ], + extras_require={ + # Individual backend options - choose one or more + "triton": [ + "triton>=2.0.0", + ], + "flex": [ + "transformers>=4.38.0", + ], + + # Combined options + "all": [ + "triton>=2.0.0", # Triton backend + "transformers>=4.38.0", # Flex backend + # CUDA backend included by default compilation + ], + + # Development dependencies + "dev": [ + "triton>=2.0.0", + "transformers>=4.38.0", + "pytest>=6.0", + "pytest-benchmark", + "numpy", + ], + + # Testing only + "test": [ + "pytest>=6.0", + "pytest-benchmark", + "numpy", + ], + }, setup_requires=[ "packaging", "psutil", "ninja", ], + # Include package data + package_data={ + "flash_dmattn": ["*.py"], + }, + # Ensure the package is properly included + include_package_data=True, ) \ No newline at end of file From dd322e7a895261d79f5020d0efe4f8984e3707ec Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 18:59:05 +0800 Subject: [PATCH 2/5] Adds dynamic mask attention package with backend support Initializes the flash_dmattn package with automatic backend selection between CUDA, Triton, and Flex implementations. Provides graceful fallback mechanism that prioritizes CUDA for performance, then Triton and Flex as alternatives. Includes runtime availability checks and clear error messages for missing dependencies. Enables users to explicitly specify backends or rely on automatic selection based on available installations. --- flash_dmattn/__init__.py | 89 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 flash_dmattn/__init__.py diff --git a/flash_dmattn/__init__.py b/flash_dmattn/__init__.py new file mode 100644 index 0000000..b0b3adc --- /dev/null +++ b/flash_dmattn/__init__.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, Jingze Shi. + +from typing import Optional + +try: + from .flash_dmattn_triton import triton_dmattn_func + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + triton_dmattn_func = None + +try: + from .flash_dmattn_flex import flex_dmattn_func + FLEX_AVAILABLE = True +except ImportError: + FLEX_AVAILABLE = False + flex_dmattn_func = None + +# Check if CUDA extension is available +try: + import flash_dmattn_cuda # type: ignore[import] + CUDA_AVAILABLE = True +except ImportError: + CUDA_AVAILABLE = False + +__version__ = "0.1.0" + +__all__ = [ + "triton_dmattn_func", + "flex_dmattn_func", + "TRITON_AVAILABLE", + "FLEX_AVAILABLE", + "CUDA_AVAILABLE", +] + + +def get_available_backends(): + """Return a list of available backends.""" + backends = [] + if CUDA_AVAILABLE: + backends.append("cuda") + if TRITON_AVAILABLE: + backends.append("triton") + if FLEX_AVAILABLE: + backends.append("flex") + return backends + + +def flash_dmattn_func(backend: Optional[str] = None, **kwargs): + """ + Flash Dynamic Mask Attention function with automatic backend selection. + + Args: + backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). + If None, will use the first available backend in order: cuda, triton, flex. + **kwargs: Arguments to pass to the attention function. + + Returns: + The attention function for the specified or auto-selected backend. + """ + if backend is None: + # Auto-select backend + if CUDA_AVAILABLE: + backend = "cuda" + elif TRITON_AVAILABLE: + backend = "triton" + elif FLEX_AVAILABLE: + backend = "flex" + else: + raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.") + + if backend == "cuda": + if not CUDA_AVAILABLE: + raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") + # Import and return CUDA function + raise NotImplementedError("CUDA backend not yet implemented in this version") + + elif backend == "triton": + if not TRITON_AVAILABLE: + raise RuntimeError("Triton backend is not available. Please install triton: pip install triton") + return triton_dmattn_func + + elif backend == "flex": + if not FLEX_AVAILABLE: + raise RuntimeError("Flex backend is not available. Please install transformers: pip install transformers") + return flex_dmattn_func + + else: + raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") From 5d0ba119f04703bd3afc30754918495a4638bb01 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 18:59:19 +0800 Subject: [PATCH 3/5] Standardizes function naming across implementations Renames imported functions to use consistent naming convention with 'dmattn_func' suffix across Triton and Flex Attention implementations. Updates function call parameters to use positional arguments instead of keyword arguments for cleaner code. Removes hard exit on CUDA import failure to allow graceful degradation when some implementations are unavailable. --- benchmarks/benchmark_forward_equivalence.py | 36 +++++++++++---------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index cdca030..e724b12 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -26,27 +26,28 @@ except ImportError as e: print(f"❌ Failed to import flash_dmattn_cuda: {e}") print("Please make sure the package is properly installed with: pip install .") - exit(1) + # Don't exit here, just warn + flash_dmattn_cuda = None # Import the Triton implementation try: - from flash_dmattn.flash_dmattn_triton import flash_dmattn_func + from flash_dmattn.flash_dmattn_triton import triton_dmattn_func print("✅ Successfully imported flash_dmattn_triton") except ImportError as e: print(f"❌ Failed to import flash_dmattn_triton: {e}") print("Please make sure the Triton implementation is available.") # Don't exit here, just warn - flash_dmattn_func = None + triton_dmattn_func = None # Import the Flex Attention implementation try: - from flash_dmattn.flash_dmattn_flex import flex_attention_forward + from flash_dmattn.flash_dmattn_flex import flex_dmattn_func print("✅ Successfully imported flash_dmattn_flex") except ImportError as e: print(f"❌ Failed to import flash_dmattn_flex: {e}") print("Please make sure the Flex Attention implementation is available.") # Don't exit here, just warn - flex_attention_forward = None + flex_dmattn_func = None def prepare_dynamic_mask( @@ -301,7 +302,7 @@ def dynamic_mask_attention_triton( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flash_dmattn_func is None: + if triton_dmattn_func is None: raise RuntimeError("Triton implementation not available") _, num_heads, _, _ = query_states.shape @@ -333,14 +334,14 @@ def dynamic_mask_attention_triton( attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] # Call the Triton implementation - attn_outputs = flash_dmattn_func( + attn_outputs = triton_dmattn_func( query_states, # q: [batch, seqlen_q, num_heads, head_dim] key_states, # k: [batch, seqlen_k, num_heads, head_dim] value_states, # v: [batch, seqlen_k, num_heads, head_dim] - mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] - bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] - causal=is_causal, # causal masking - softmax_scale=scaling # scaling factor + attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal, # causal masking + scaling # scaling factor ) return attn_outputs # [batch, query_len, num_heads, head_dim] @@ -374,7 +375,7 @@ def dynamic_mask_attention_flex( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - if flex_attention_forward is None: + if flex_dmattn_func is None: raise RuntimeError("Flex Attention implementation not available") _, num_heads, _, _ = query_states.shape @@ -402,12 +403,13 @@ def dynamic_mask_attention_flex( # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format # Call the Flex Attention implementation - attn_outputs, _ = flex_attention_forward( + attn_outputs, _ = flex_dmattn_func( query_states, # q: [batch, num_heads, query_len, head_dim] key_states, # k: [batch, num_heads, key_len, head_dim] value_states, # v: [batch, num_heads, key_len, head_dim] attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: whether to apply causal masking scaling=scaling # scaling factor ) @@ -662,14 +664,14 @@ def test_triton_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Triton 🔬") print("🔥" + "=" * 76 + "🔥") - if flash_dmattn_func is None: + if triton_dmattn_func is None: print("❌ Triton implementation not available, skipping Triton tests") return False # Set random seed for reproducibility torch.manual_seed(0) - - # Smaller test configurations for Triton (to avoid memory issues) + + # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) (1, 1, 1, 64, 64, 32, True), @@ -833,7 +835,7 @@ def test_flex_forward_equivalence(accuracy_threshold=0.95): print("🔬 Testing Forward Pass Equivalence: Python vs Flex Attention 🔬") print("🌟" + "=" * 76 + "🌟") - if flex_attention_forward is None: + if flex_dmattn_func is None: print("❌ Flex Attention implementation not available, skipping Flex Attention tests") return False From bb0505f9985ac922b068e478907856e4b4006b04 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 20:10:29 +0800 Subject: [PATCH 4/5] Adds support for multiple attention implementations Expands benchmark to compare Flash Attention against CUDA, Triton, and Flex implementations of Dynamic Mask Attention. Introduces modular testing framework allowing selective benchmarking of specific implementations or head-to-head comparisons. Updates function signatures to return timing measurements directly from kernels for more accurate performance metrics. Renames variables for clarity (dt_states to zoh_states, active_mask to attn_mask) and adds comprehensive error handling for missing implementations. Includes new test configurations for window size variations and non-causal attention patterns. Provides detailed performance analysis with implementation-specific speedup calculations and overhead comparisons between different approaches. --- benchmarks/benchmark_forward_performance.py | 1008 +++++++++++++------ 1 file changed, 713 insertions(+), 295 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 18fe7ef..d9ed6f5 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -2,18 +2,26 @@ """ Performance Benchmark for Dynamic Mask Attention -This script measures and compares the performance of Dynamic Mask Attention -implementation against Flash Attention baseline across various configurations. +This script measures and compares the performance of multiple Dynamic Mask Attention +implementations against Flash Attention baseline across various configurations. + +Implementations tested: +- Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline +- Dynamic Mask Attention CUDA - Custom CUDA kernel implementation +- Dynamic Mask Attention CUDA (No TopK) - CUDA kernel without TopK computation +- Dynamic Mask Attention Triton - Triton kernel implementation +- Dynamic Mask Attention Flex - Flex Attention implementation Benchmark includes: - Multiple sequence lengths and batch sizes - Head count and dimension variations - Throughput and latency measurements - Memory usage analysis -- Speedup comparisons +- Speedup comparisons across all implementations """ import torch +import torch.nn.backends import torch.nn.functional as F import numpy as np import argparse @@ -22,37 +30,73 @@ # Import the compiled CUDA extension try: - import flash_dma_cuda # type: ignore[import] - print("✅ Successfully imported flash_dma_cuda") + import flash_dmattn_cuda # type: ignore[import] + print("✅ Successfully imported flash_dmattn_cuda") except ImportError as e: - print(f"❌ Failed to import flash_dma_cuda: {e}") + print(f"❌ Failed to import flash_dmattn_cuda: {e}") print("Please make sure the package is properly installed with: pip install .") - exit(1) + # Don't exit here, just warn + flash_dmattn_cuda = None + +# Import the Triton implementation +try: + from flash_dmattn.flash_dmattn_triton import triton_dmattn_func + print("✅ Successfully imported flash_dmattn_triton") +except ImportError as e: + print(f"❌ Failed to import flash_dmattn_triton: {e}") + print("Please make sure the Triton implementation is available.") + # Don't exit here, just warn + triton_dmattn_func = None + +# Import the Flex Attention implementation +try: + from flash_dmattn.flash_dmattn_flex import flex_dmattn_func + print("✅ Successfully imported flash_dmattn_flex") +except ImportError as e: + print(f"❌ Failed to import flash_dmattn_flex: {e}") + print("Please make sure the Flex Attention implementation is available.") + # Don't exit here, just warn + flex_dmattn_func = None + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + Transform from (batch, num_key_value_heads, seqlen, head_dim) + to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def prepare_dynamic_mask( hidden_states: torch.Tensor, - dt_states: torch.Tensor, + zoh_states: torch.Tensor, keep_window_size: int = 2048, attention_mask: torch.Tensor | None = None, ): """ Calculate dynamic attention mask to mask tokens for sparse attention. - Combine `dt_states` with `attention_mask` to generate the final `attn_mask`. + Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. Args: hidden_states: Input hidden states to determine dtype minimum value - dt_states: dt_states of shape (batch_size, num_kv_heads, key_sequence_length) + zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) keep_window_size: Window size of tokens not dynamically masked attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len) Returns: - tuple: (attn_mask, active_mask) + tuple: (attn_bias, attn_mask) """ min_dtype = torch.finfo(hidden_states.dtype).min dtype = hidden_states.dtype - attn_mask = dt_states[:, :, None, :].expand( + attn_bias = zoh_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] @@ -63,25 +107,25 @@ def prepare_dynamic_mask( torch.tensor(0.0, device=attention_mask.device, dtype=dtype), min_dtype ) - attn_mask = attn_mask.masked_fill( - attention_mask[:, :, :, : attn_mask.shape[-1]] != 0, min_dtype + attn_bias = attn_bias.masked_fill( + attention_mask[:, :, :, : attn_bias.shape[-1]] != 0, min_dtype ) - if attn_mask.shape[-1] > keep_window_size: + if attn_bias.shape[-1] > keep_window_size: topk_indices = torch.topk( - attn_mask, keep_window_size, dim=-1, largest=True, sorted=False + attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ).indices - active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) - active_mask = active_mask.scatter(-1, topk_indices, 1.0) - attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) + attn_mask = torch.zeros_like(attn_bias, dtype=dtype, device=attn_bias.device) + attn_mask = attn_mask.scatter(-1, topk_indices, 1.0) + attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) else: - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) - return attn_mask, active_mask + attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) + return attn_bias, attn_mask -def calculate_zero_hold_states(value_states, dt_proj, A): +def calculate_zoh_states(value_states, dt_proj, A): """ - Calculate zero hold states for dynamic mask attention. + Calculate zoh states for dynamic mask attention. Args: value_states: [batch_size, num_kv_heads, key_len, head_dim] @@ -90,7 +134,7 @@ def calculate_zero_hold_states(value_states, dt_proj, A): causal_mask: Optional causal mask Returns: - zero_hold_states: [batch_size, num_kv_heads, key_len] + zoh_states: [batch_size, num_kv_heads, key_len] """ batch_size, _, key_len, _ = value_states.shape @@ -102,9 +146,9 @@ def calculate_zero_hold_states(value_states, dt_proj, A): # Apply softplus activation and coefficient A dt_states = torch.exp(F.softplus(dt_result) * A) - zero_hold_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] - - return zero_hold_states + zoh_states = dt_states.transpose(-1, -2) # [batch_size, num_kv_heads, key_len] + + return zoh_states def flash_attention_cuda( @@ -139,18 +183,28 @@ def flash_attention_cuda( value_states = value_states.contiguous() try: - attn_outputs = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - scale=scaling, - enable_gqa=True - ) + # Only measure the core attention computation + torch.cuda.synchronize() + start_time = time.time() + + with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): + attn_outputs = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + # attn_mask=causal_mask, + scale=scaling, + is_causal=is_causal if query_len == key_len else False, + enable_gqa=True + ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 def dynamic_mask_attention_cuda( @@ -183,49 +237,60 @@ def dynamic_mask_attention_cuda( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) - attn_mask, active_mask = prepare_dynamic_mask( + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( query_states, - zero_hold_states, + zoh_states, keep_window_size, causal_mask if is_causal else None ) # [batch_size, num_kv_heads, query_len, key_len] # Ensure correct data types and memory layout for CUDA function # CUDA function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zero_hold_states = zero_hold_states[:, :, None, :].expand( + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + zoh_states = zoh_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] - active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + ).contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_bias = attn_bias.contiguous() # [batch, num_kv_heads, query_len, key_len] + attn_mask = attn_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] try: # Call the CUDA implementation using the mha_fwd function signature out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - attn_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + + # Only measure the core CUDA kernel computation + torch.cuda.synchronize() + start_time = time.time() + + result = flash_dmattn_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = result[0] # [batch, query_len, num_heads, head_dim] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 def dynamic_mask_attention_cuda_no_topk( @@ -259,8 +324,8 @@ def dynamic_mask_attention_cuda_no_topk( Returns: attn_outputs: [batch_size, query_len, num_heads, head_dim] """ - # Calculate zero_hold_states - zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A) + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) # Create a simplified mask without topk computation batch_size, _, query_len, _ = query_states.shape @@ -268,8 +333,8 @@ def dynamic_mask_attention_cuda_no_topk( dtype = query_states.dtype device = query_states.device - # Create full active mask (no topk selection) - active_mask = torch.zeros( + # Create full attn mask (no topk selection) + attn_mask = torch.zeros( (batch_size, num_kv_heads, query_len, key_len), dtype=dtype, device=device @@ -279,38 +344,208 @@ def dynamic_mask_attention_cuda_no_topk( query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - zero_hold_states = zero_hold_states[:, :, None, :].expand( + attn_bias = zoh_states[:, :, None, :].expand( -1, -1, query_states.shape[1], -1 ).contiguous() # [batch, num_kv_heads, query_len, key_len] # Create full active mask (no topk selection) - active_mask = torch.zeros_like( - zero_hold_states, + attn_mask = torch.zeros_like( + attn_bias, dtype=dtype, device=device - ) - active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + ).contiguous() # [batch, num_kv_heads, query_len, key_len] try: out_tensor = None # Let the function allocate the output tensor - result = flash_dma_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - zero_hold_states, # zoh: [batch, num_kv_heads, seqlen_q, seqlen_k] - processed attention mask - active_mask, # active_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) + + # Only measure the core CUDA kernel computation + torch.cuda.synchronize() + start_time = time.time() + + result = flash_dmattn_cuda.fwd( # type: ignore + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] + attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] + attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] + out_tensor, # out: None to auto-allocate + 0.0, # p_dropout + scaling, # softmax_scale + is_causal, # is_causal + keep_window_size, # keep_window_size + 0.0, # softcap + return_softmax, # return_softmax + None # gen (generator) ) + + torch.cuda.synchronize() + end_time = time.time() + attn_outputs = result[0] - return attn_outputs + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms except torch.cuda.OutOfMemoryError: - return "OOM" + return "OOM", 0 + + +def dynamic_mask_attention_triton( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Triton implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if triton_dmattn_func is None: + return "Not Available", 0 + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + try: + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + # Only measure the core Triton kernel computation + torch.cuda.synchronize() + start_time = time.time() + + # Call the Triton implementation + attn_outputs = triton_dmattn_func( + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_heads, head_dim] + attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + is_causal, # causal masking + scaling # scaling factor + ) + + torch.cuda.synchronize() + end_time = time.time() + + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + except torch.cuda.OutOfMemoryError: + return "OOM", 0 + + +def dynamic_mask_attention_flex( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Flex Attention implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if flex_dmattn_func is None: + return "Not Available", 0 + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + try: + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format + # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format + + # Only measure the core Flex Attention computation + torch.cuda.synchronize() + start_time = time.time() + + # Call the Flex Attention implementation + attn_outputs, _ = flex_dmattn_func( + query_states, # q: [batch, num_heads, query_len, head_dim] + key_states, # k: [batch, num_heads, key_len, head_dim] + value_states, # v: [batch, num_heads, key_len, head_dim] + attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] + attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] + is_causal=is_causal, # is_causal: Whether to apply causal masking + scaling=scaling # scaling factor + ) + + torch.cuda.synchronize() + end_time = time.time() + + return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms + except torch.cuda.OutOfMemoryError: + return "OOM", 0 def measure_memory_usage(): @@ -327,19 +562,19 @@ def measure_memory_usage(): return 0, 0 -def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): +def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_runs=2): """ Benchmark attention performance for a given configuration. Args: - config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim) + config: Tuple of (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) num_runs: Number of benchmark runs warmup_runs: Number of warmup runs Returns: dict: Performance metrics """ - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Create random input data @@ -372,301 +607,477 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - # Set scaling factor and keep window size + # Set scaling factor from config scaling = head_dim ** -0.5 - keep_window_size = 2048 - is_causal = True results = { 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], 'dynamic_mask_attention_no_topk_times': [], + 'dynamic_mask_attention_triton_times': [], + 'dynamic_mask_attention_flex_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, 'dynamic_mask_attention_no_topk_memory': 0, + 'dynamic_mask_attention_triton_memory': 0, + 'dynamic_mask_attention_flex_memory': 0, 'flash_attention_status': 'success', 'dynamic_mask_attention_status': 'success', - 'dynamic_mask_attention_no_topk_status': 'success' + 'dynamic_mask_attention_no_topk_status': 'success', + 'dynamic_mask_attention_triton_status': 'success', + 'dynamic_mask_attention_flex_status': 'success' } - # Benchmark Flash Attention - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = flash_attention_cuda( - query_states, key_states, value_states, - scaling, causal_mask, is_causal - ) - if result == "OOM": - results['flash_attention_status'] = 'OOM' - break - torch.cuda.synchronize() + # Determine which implementations to run + run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] + run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] + run_no_topk = test_type in ['all', 'cuda'] + run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] + run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] - if results['flash_attention_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + # Benchmark Flash Attention + if run_flash: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = flash_attention_cuda( query_states, key_states, value_states, scaling, causal_mask, is_causal ) - torch.cuda.synchronize() - end_time = time.time() - - if result == "OOM": + if result[0] == "OOM": results['flash_attention_status'] = 'OOM' break - - results['flash_attention_times'].append((end_time - start_time) * 1000) # ms + torch.cuda.synchronize() - # Measure memory after - mem_after = measure_memory_usage() - results['flash_attention_memory'] = mem_after[0] - mem_before[0] + if results['flash_attention_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = flash_attention_cuda( + query_states, key_states, value_states, + scaling, causal_mask, is_causal + ) + + if result[0] == "OOM": + results['flash_attention_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['flash_attention_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['flash_attention_memory'] = mem_after[0] - mem_before[0] + else: + results['flash_attention_status'] = 'N/A' # Benchmark Dynamic Mask Attention - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result == "OOM": - results['dynamic_mask_attention_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + if run_cuda: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = dynamic_mask_attention_cuda( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) - torch.cuda.synchronize() - end_time = time.time() - - if result == "OOM": + if result[0] == "OOM": results['dynamic_mask_attention_status'] = 'OOM' break - - results['dynamic_mask_attention_times'].append((end_time - start_time) * 1000) # ms + torch.cuda.synchronize() - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + if results['dynamic_mask_attention_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_cuda( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] == "OOM": + results['dynamic_mask_attention_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_status'] = 'N/A' # Benchmark Dynamic Mask Attention (No TopK) - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_no_topk_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() + if run_no_topk: + gc.collect() + torch.cuda.empty_cache() - # Actual benchmark runs - for _ in range(num_runs): - start_time = time.time() + # Warmup runs + for _ in range(warmup_runs): result = dynamic_mask_attention_cuda_no_topk( query_states, key_states, value_states, dt_proj, A, scaling, causal_mask, keep_window_size, is_causal ) + if result[0] == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break torch.cuda.synchronize() - end_time = time.time() + + if results['dynamic_mask_attention_no_topk_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() - if result == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_no_topk_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_no_topk_status'] = 'N/A' + + # Benchmark Dynamic Mask Attention (Triton) + if run_triton: + gc.collect() + torch.cuda.empty_cache() + + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_triton( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_triton_status'] = result[0] break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_triton_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_triton( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_triton_status'] = result[0] + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_triton_times'].append(result[1]) # ms - results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_triton_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_triton_status'] = 'N/A' + + # Benchmark Dynamic Mask Attention (Flex) + if run_flex: + gc.collect() + torch.cuda.empty_cache() - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_flex( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_flex_status'] = result[0] + break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_flex_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + result = dynamic_mask_attention_flex( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + + if result[0] in ["OOM", "Not Available"]: + results['dynamic_mask_attention_flex_status'] = result[0] + break + + # Use the timing from the function instead of measuring here + results['dynamic_mask_attention_flex_times'].append(result[1]) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_flex_memory'] = mem_after[0] - mem_before[0] + else: + results['dynamic_mask_attention_flex_status'] = 'N/A' return results -def run_performance_benchmark(): +def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): """Run comprehensive performance benchmark across different configurations.""" print("\n" + "🏆" + "=" * 76 + "🏆") - print("⚡ Performance Benchmark: Dynamic Mask Attention vs Flash Attention ⚡") + + # Update title based on test type + if test_type == 'all': + title = "⚡ Performance Benchmark: Flash vs CUDA vs Triton vs Flex ⚡" + elif test_type == 'flash-vs-cuda': + title = "⚡ Performance Benchmark: Flash Attention vs CUDA ⚡" + elif test_type == 'flash-vs-triton': + title = "⚡ Performance Benchmark: Flash Attention vs Triton ⚡" + elif test_type == 'flash-vs-flex': + title = "⚡ Performance Benchmark: Flash Attention vs Flex ⚡" + elif test_type == 'flash': + title = "⚡ Performance Benchmark: Flash Attention Only ⚡" + elif test_type == 'cuda': + title = "⚡ Performance Benchmark: CUDA Implementations ⚡" + elif test_type == 'triton': + title = "⚡ Performance Benchmark: Triton Implementation ⚡" + elif test_type == 'flex': + title = "⚡ Performance Benchmark: Flex Implementation ⚡" + else: + title = "⚡ Performance Benchmark ⚡" + + print(title) print("🏆" + "=" * 76 + "🏆") - # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim) + # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ # Vary sequence length - (1, 2, 1, 256, 256, 32), - (1, 2, 1, 512, 512, 32), - (1, 2, 1, 1024, 1024, 32), - (1, 2, 1, 2048, 2048, 32), - (1, 2, 1, 4096, 4096, 32), - (1, 2, 1, 8192, 8192, 32), - (1, 2, 1, 16384, 16384, 32), - (1, 2, 1, 32768, 32768, 32), + (1, 2, 1, 256, 256, 32, 2048, True), + (1, 2, 1, 512, 512, 32, 2048, True), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 2048, 2048, 32, 2048, True), + (1, 2, 1, 4096, 4096, 32, 2048, True), + (1, 2, 1, 8192, 8192, 32, 2048, True), + (1, 2, 1, 16384, 16384, 32, 2048, True), + (1, 2, 1, 32768, 32768, 32, 2048, True), # Inference - (1, 2, 1, 64, 256, 128), - (1, 2, 1, 64, 512, 128), - (1, 2, 1, 64, 1024, 128), - (1, 2, 1, 64, 2048, 128), - (1, 2, 1, 64, 4096, 128), - (1, 2, 1, 64, 8192, 128), - (1, 2, 1, 64, 16384, 128), - (1, 2, 1, 64, 32768, 128), - (1, 2, 1, 64, 65536, 128), - (1, 2, 1, 64, 131072, 128), - (1, 2, 1, 64, 262144, 128), - (1, 2, 1, 64, 524288, 128), + (1, 2, 1, 2, 256, 128, 2048, True), + (1, 2, 1, 2, 512, 128, 2048, True), + (1, 2, 1, 2, 1024, 128, 2048, True), + (1, 2, 1, 2, 2048, 128, 2048, True), + (1, 2, 1, 2, 4096, 128, 2048, True), + (1, 2, 1, 2, 8192, 128, 2048, True), + (1, 2, 1, 2, 16384, 128, 2048, True), + (1, 2, 1, 2, 32768, 128, 2048, True), + (1, 2, 1, 2, 65536, 128, 2048, True), + (1, 2, 1, 2, 131072, 128, 2048, True), + (1, 2, 1, 2, 262144, 128, 2048, True), + (1, 2, 1, 2, 524288, 128, 2048, True), # Vary batch size - (1, 2, 1, 1024, 1024, 32), - (2, 2, 1, 1024, 1024, 32), - (4, 2, 1, 1024, 1024, 32), - (8, 2, 1, 1024, 1024, 32), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (2, 2, 1, 1024, 1024, 32, 2048, True), + (4, 2, 1, 1024, 1024, 32, 2048, True), + (8, 2, 1, 1024, 1024, 32, 2048, True), # Vary head count - (1, 1, 1, 1024, 1024, 32), - (1, 2, 1, 1024, 1024, 32), - (1, 4, 1, 1024, 1024, 32), - (1, 8, 2, 1024, 1024, 32), + (1, 1, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 4, 1, 1024, 1024, 32, 2048, True), + (1, 8, 2, 1024, 1024, 32, 2048, True), # Vary head dimension - (1, 2, 1, 1024, 1024, 32), - (1, 2, 1, 1024, 1024, 64), - (1, 2, 1, 1024, 1024, 96), - (1, 2, 1, 1024, 1024, 128), - (1, 2, 1, 1024, 1024, 192), - (1, 2, 1, 1024, 1024, 256), + (1, 2, 1, 1024, 1024, 32, 2048, True), + (1, 2, 1, 1024, 1024, 64, 2048, True), + (1, 2, 1, 1024, 1024, 96, 2048, True), + (1, 2, 1, 1024, 1024, 128, 2048, True), + (1, 2, 1, 1024, 1024, 192, 2048, True), + (1, 2, 1, 1024, 1024, 256, 2048, True), + + # Vary keep_window_size + (1, 2, 1, 32768, 32768, 128, 32, True), + (1, 2, 1, 32768, 32768, 128, 64, True), + (1, 2, 1, 32768, 32768, 128, 128, True), + (1, 2, 1, 32768, 32768, 128, 256, True), + (1, 2, 1, 32768, 32768, 128, 512, True), + (1, 2, 1, 32768, 32768, 128, 1024, True), + (1, 2, 1, 32768, 32768, 128, 2048, True), + (1, 2, 1, 32768, 32768, 128, 4096, True), + (1, 2, 1, 32768, 32768, 128, 8192, True), + (1, 2, 1, 32768, 32768, 128, 16384, True), + (1, 2, 1, 32768, 32768, 128, 32768, True), + + # Test non-causal + (1, 2, 1, 1024, 1024, 128, 2048, False), ] num_runs = 3 # Run 3 times and take average print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") - print(f"🔧 {'Configuration':<42} ⚡ {'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}") - print("🔄" + "-" * 155 + "🔄") + print(f"🔧 {'Configuration':<60} ⚡ {'Flash':<10} 🚀 {'CUDA':<10} 🚀 {'No-TopK':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} 📈 {'Speedup':<15}") + print("🔄" + "-" * 160 + "🔄") all_results = [] for config in configs: - batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim = config + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal = config - results = benchmark_attention_performance(config, num_runs=num_runs) + results = benchmark_attention_performance(config, test_type, num_runs, warmup_runs) all_results.append(results) - # Calculate averages - if results['flash_attention_status'] == 'success' and results['flash_attention_times']: - flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - flash_time_str = f"{flash_avg:.2f}" - else: - flash_time_str = results['flash_attention_status'] - flash_avg = float('inf') + # Calculate averages for all implementations + implementations = { + 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), + 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), + 'no_topk': ('dynamic_mask_attention_no_topk', results['dynamic_mask_attention_no_topk_status'], results['dynamic_mask_attention_no_topk_times']), + 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), + 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) + } - if results['dynamic_mask_attention_status'] == 'success' and results['dynamic_mask_attention_times']: - dma_avg = sum(results['dynamic_mask_attention_times']) / len(results['dynamic_mask_attention_times']) - dma_time_str = f"{dma_avg:.2f}" - else: - dma_time_str = results['dynamic_mask_attention_status'] - dma_avg = float('inf') - - if results['dynamic_mask_attention_no_topk_status'] == 'success' and results['dynamic_mask_attention_no_topk_times']: - dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) - dma_nt_time_str = f"{dma_nt_avg:.2f}" - else: - dma_nt_time_str = results['dynamic_mask_attention_no_topk_status'] - dma_nt_avg = float('inf') + # Calculate time strings and averages + time_strs = {} + time_avgs = {} - # Calculate speedups - if flash_avg != float('inf') and dma_avg != float('inf') and dma_avg > 0: - speedup = flash_avg / dma_avg - speedup_str = f"{speedup:.2f}x" - else: - speedup_str = "N/A" - - if flash_avg != float('inf') and dma_nt_avg != float('inf') and dma_nt_avg > 0: - kernel_speedup = flash_avg / dma_nt_avg - kernel_speedup_str = f"{kernel_speedup:.2f}x" - else: - kernel_speedup_str = "N/A" + for impl_key, (_, status, times) in implementations.items(): + if status == 'success' and times: + avg_time = sum(times) / len(times) + time_strs[impl_key] = f"{avg_time:.2f}" + time_avgs[impl_key] = avg_time + else: + time_strs[impl_key] = status[:8] # Truncate status for display + time_avgs[impl_key] = float('inf') - # Memory usage - mem_diff = results['dynamic_mask_attention_memory'] - results['flash_attention_memory'] - mem_str = f"{mem_diff:+.0f}" + # Calculate speedups (compared to Flash Attention baseline) + speedup_strs = {} + flash_avg = time_avgs.get('flash', float('inf')) - # Format output - config_short = f"b={batch_size},h={num_heads},kv={num_kv_heads},q={query_len},k={key_len},d={head_dim}" + for impl_key in ['cuda', 'no_topk', 'triton', 'flex']: + impl_avg = time_avgs.get(impl_key, float('inf')) + if flash_avg != float('inf') and impl_avg != float('inf') and impl_avg > 0: + speedup = flash_avg / impl_avg + speedup_strs[impl_key] = f"{speedup:.2f}x" + else: + speedup_strs[impl_key] = "N/A" + + # Format output with shorter config string + config_short = f" b{batch_size} h{num_heads} kv{num_kv_heads} q{query_len} k{key_len} d{head_dim} w{keep_window_size} " + if not is_causal: + config_short += "nc" # Add status icons - flash_icon = "✅" if results['flash_attention_status'] == 'success' else "💥" - dma_icon = "✅" if results['dynamic_mask_attention_status'] == 'success' else "💥" - dma_nt_icon = "✅" if results['dynamic_mask_attention_no_topk_status'] == 'success' else "💥" + icons = "" + for impl_key, (_, status, _) in implementations.items(): + if status == 'success': + icons += " ✅ " + elif status in ['OOM', 'Not Available']: + icons += " ❌ " + else: + icons += " ⚠️ " + + # Create speedup summary (best performing implementation) + best_speedup = "N/A" + best_impl = "N/A" + for impl_key, speedup_str in speedup_strs.items(): + if speedup_str != "N/A": + try: + speedup_val = float(speedup_str.replace('x', '')) + if best_speedup == "N/A" or speedup_val > float(best_speedup.replace('x', '')): + best_speedup = speedup_str + best_impl = impl_key.upper() + except: + continue + + speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{flash_icon}{dma_icon}{dma_nt_icon} {config_short:<42} {flash_time_str:<14} {dma_time_str:<20} {dma_nt_time_str:<20} {speedup_str:<18} {kernel_speedup_str:<20} {mem_str:<12}") + print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['no_topk']:<14} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") - print("🔄" + "-" * 155 + "🔄") + print("🔄" + "-" * 160 + "🔄") # Summary statistics - speedups = [] - kernel_speedups = [] + implementation_speedups = { + 'cuda': [], + 'no_topk': [], + 'triton': [], + 'flex': [] + } + for results in all_results: - if (results['flash_attention_status'] == 'success' and - results['dynamic_mask_attention_status'] == 'success' and - results['flash_attention_times'] and results['dynamic_mask_attention_times']): - - flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - dma_avg = sum(results['dynamic_mask_attention_times']) / len(results['dynamic_mask_attention_times']) - - if dma_avg > 0: - speedups.append(flash_avg / dma_avg) - - if (results['flash_attention_status'] == 'success' and - results['dynamic_mask_attention_no_topk_status'] == 'success' and - results['flash_attention_times'] and results['dynamic_mask_attention_no_topk_times']): - + if results['flash_attention_status'] == 'success' and results['flash_attention_times']: flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) - dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) - if dma_nt_avg > 0: - kernel_speedups.append(flash_avg / dma_nt_avg) + # Calculate speedups for each implementation + for impl_key in implementation_speedups.keys(): + # Map implementation keys to actual result keys + if impl_key == 'cuda': + status_key = 'dynamic_mask_attention_status' + times_key = 'dynamic_mask_attention_times' + else: + status_key = f'dynamic_mask_attention_{impl_key}_status' + times_key = f'dynamic_mask_attention_{impl_key}_times' + + if (status_key in results and results[status_key] == 'success' and + times_key in results and results[times_key]): + + impl_avg = sum(results[times_key]) / len(results[times_key]) + if impl_avg > 0: + implementation_speedups[impl_key].append(flash_avg / impl_avg) print(f"\n🏆 Summary:") - if speedups: - avg_speedup = np.mean(speedups) - speedup_icon = "🚀" if avg_speedup > 1.5 else "📈" if avg_speedup > 1.0 else "😐" - print(f" {speedup_icon} DMA vs Flash - Average speedup: {avg_speedup:.2f}x (Best: {np.max(speedups):.2f}x, Worst: {np.min(speedups):.2f}x)") - if kernel_speedups: - avg_kernel_speedup = np.mean(kernel_speedups) - kernel_icon = "🔥" if avg_kernel_speedup > 2.0 else "🚀" if avg_kernel_speedup > 1.5 else "📈" if avg_kernel_speedup > 1.0 else "😐" - print(f" {kernel_icon} DMA-NoTopK vs Flash - Average kernel speedup: {avg_kernel_speedup:.2f}x (Best: {np.max(kernel_speedups):.2f}x, Worst: {np.min(kernel_speedups):.2f}x)") - print(f" 💡 TopK overhead: ~{((np.mean(kernel_speedups) - np.mean(speedups) if speedups else 0) / np.mean(kernel_speedups) * 100) if kernel_speedups else 0:.1f}% performance impact") + # Display statistics for each implementation + for impl_key, speedups in implementation_speedups.items(): + if speedups: + avg_speedup = np.mean(speedups) + max_speedup = np.max(speedups) + min_speedup = np.min(speedups) + + # Choose appropriate icon based on performance + if avg_speedup > 2.0: + icon = "🔥" + elif avg_speedup > 1.5: + icon = "🚀" + elif avg_speedup > 1.0: + icon = "📈" + else: + icon = "😐" + + impl_name = impl_key.replace('_', '-').upper() + print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") + else: + print(f" ❌ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") + + # Calculate overhead comparison + if implementation_speedups['cuda'] and implementation_speedups['no_topk']: + avg_cuda = np.mean(implementation_speedups['cuda']) + avg_no_topk = np.mean(implementation_speedups['no_topk']) + topk_overhead = ((avg_no_topk - avg_cuda) / avg_no_topk * 100) if avg_no_topk > 0 else 0 + print(f" 💡 TopK overhead: ~{topk_overhead:.1f}% performance impact") def main(): @@ -687,6 +1098,9 @@ def main(): parser.add_argument('--seed', type=int, default=42, help='Random seed') parser.add_argument('--runs', type=int, default=3, help='Number of benchmark runs') parser.add_argument('--warmup', type=int, default=2, help='Number of warmup runs') + parser.add_argument('--test-type', type=str, default='all', + choices=['all', 'flash', 'cuda', 'triton', 'flex', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'], + help='Type of benchmark to run (default: all)') args = parser.parse_args() @@ -703,8 +1117,12 @@ def main(): print(f"🎮 CUDA device: {torch.cuda.get_device_name()}") print(f"💾 Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") + print(f"🎲 Random seed: {args.seed}") + print(f"📊 Test type: {args.test_type}") + print(f"🔄 Runs: {args.runs}, Warmup: {args.warmup}") + # Run performance benchmark - run_performance_benchmark() + run_performance_benchmark(args.test_type, args.runs, args.warmup) if __name__ == "__main__": From 00160b5b058d3f087cd1038b172fd61eaffac28d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 7 Jul 2025 20:15:25 +0800 Subject: [PATCH 5/5] Simplifies CUDA build skip logic for extras Streamlines the detection of Triton/Flex-only installations by using simpler string matching instead of complex pattern checking. Removes unnecessary check for plain installations since the core logic focuses on whether specific extras are requested without all/dev variants. --- setup.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index f224028..3f805ad 100644 --- a/setup.py +++ b/setup.py @@ -60,12 +60,11 @@ def should_skip_cuda_build(): if len(sys.argv) > 1: install_args = ' '.join(sys.argv) - # If user specifically asks for triton or flex only (not all/dev), skip CUDA - has_triton_or_flex = '.[triton]' in install_args or '.[flex]' in install_args or '[triton]' in install_args or '[flex]' in install_args or 'triton,' in install_args or ',flex' in install_args - has_all_or_dev = '[all]' in install_args or '[dev]' in install_args - has_plain_install = install_args.endswith('flash_dmattn') or install_args.endswith('.') - - if has_triton_or_flex and not has_all_or_dev and not has_plain_install: + # Check if Triton or Flex extras are requested + has_triton_or_flex = 'triton' in install_args or 'flex' in install_args + has_all_or_dev = 'all' in install_args or 'dev' in install_args + + if has_triton_or_flex and not has_all_or_dev: print("Detected Triton/Flex-only installation. Skipping CUDA compilation.") print("Set FLASH_DMATTN_FORCE_BUILD=TRUE to force CUDA compilation.") return True