diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index 00dde63..151845d 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -17,6 +17,7 @@ import argparse import time import gc +import sys # Import the compiled CUDA extension try: @@ -27,6 +28,26 @@ print("Please make sure the package is properly installed with: pip install .") exit(1) +# Import the Triton implementation +try: + from flash_dmattn.flash_dmattn_triton import flash_dmattn_func + print("โœ… Successfully imported flash_dmattn_triton") +except ImportError as e: + print(f"โŒ Failed to import flash_dmattn_triton: {e}") + print("Please make sure the Triton implementation is available.") + # Don't exit here, just warn + flash_attn_with_mask = None + +# Import the Flex Attention implementation +try: + from flash_dmattn.flash_dmattn_flex import flex_attention_forward + print("โœ… Successfully imported flash_dmattn_flex") +except ImportError as e: + print(f"โŒ Failed to import flash_dmattn_flex: {e}") + print("Please make sure the Flex Attention implementation is available.") + # Don't exit here, just warn + flex_attention_forward = None + def prepare_dynamic_mask( hidden_states: torch.Tensor, @@ -173,7 +194,7 @@ def dynamic_mask_attention_python( attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization attn_outputs = torch.matmul(attn_weights, value_states) attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim] - + return attn_outputs @@ -252,6 +273,147 @@ def dynamic_mask_attention_cuda( return attn_outputs +def dynamic_mask_attention_triton( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Triton implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if flash_attn_with_mask is None: + raise RuntimeError("Triton implementation not available") + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Triton function expects: q, k, v in [batch, seqlen, num_heads, head_dim] format + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim] + attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k] + + # Call the Triton implementation + attn_outputs = flash_attn_with_mask( + query_states, # q: [batch, seqlen_q, num_heads, head_dim] + key_states, # k: [batch, seqlen_k, num_heads, head_dim] + value_states, # v: [batch, seqlen_k, num_heads, head_dim] + mask=attn_mask, # mask: [batch, num_heads, seqlen_q, seqlen_k] + bias=attn_bias, # bias: [batch, num_heads, seqlen_q, seqlen_k] + causal=is_causal, # causal masking + softmax_scale=scaling # scaling factor + ) + + return attn_outputs # [batch, query_len, num_heads, head_dim] + + +def dynamic_mask_attention_flex( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, +): + """ + Flex Attention implementation of dynamic mask attention. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + if flex_attention_forward is None: + raise RuntimeError("Flex Attention implementation not available") + + _, num_heads, _, _ = query_states.shape + _, num_kv_heads, _, _ = key_states.shape + num_queries_per_kv = num_heads // num_kv_heads + + # Calculate zoh_states + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + + # Use prepare_dynamic_mask to get the processed attention mask + attn_bias, attn_mask = prepare_dynamic_mask( + query_states, + zoh_states, + keep_window_size, + causal_mask if is_causal else None + ) # [batch_size, num_kv_heads, query_len, key_len] + + # Repeat KV for multi-head attention (GQA support) + key_states = repeat_kv(key_states, num_queries_per_kv) + value_states = repeat_kv(value_states, num_queries_per_kv) + attn_mask = repeat_kv(attn_mask, num_queries_per_kv) + attn_bias = repeat_kv(attn_bias, num_queries_per_kv) + + # Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format + # But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format + + # Call the Flex Attention implementation + attn_outputs, _ = flex_attention_forward( + query_states, # q: [batch, num_heads, query_len, head_dim] + key_states, # k: [batch, num_heads, key_len, head_dim] + value_states, # v: [batch, num_heads, key_len, head_dim] + attention_mask=attn_mask, # attention_mask: [batch, num_heads, query_len, key_len] + attention_bias=attn_bias, # attention_bias: [batch, num_heads, query_len, key_len] + scaling=scaling # scaling factor + ) + + return attn_outputs # [batch, query_len, num_heads, head_dim] + + def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95): """ Analyze differences between two implementations. @@ -344,7 +506,7 @@ def analyze_differences(original_result, cuda_result, accuracy_threshold=0.95): return is_close, max_diff, mean_diff -def test_forward_equivalence(accuracy_threshold=0.95): +def test_cuda_forward_equivalence(accuracy_threshold=0.95): """Test forward pass equivalence between Python prototype and CUDA implementation.""" print("\n" + "๐Ÿš€" + "=" * 76 + "๐Ÿš€") print("๐Ÿ”ฌ Testing Forward Pass Equivalence: Python Prototype vs CUDA Implementation ๐Ÿ”ฌ") @@ -494,6 +656,348 @@ def test_forward_equivalence(accuracy_threshold=0.95): return all_passed +def test_triton_forward_equivalence(accuracy_threshold=0.95): + """Test forward pass equivalence between Python and Triton implementations.""" + print("\n" + "๐Ÿ”ฅ" + "=" * 76 + "๐Ÿ”ฅ") + print("๐Ÿ”ฌ Testing Forward Pass Equivalence: Python vs Triton ๐Ÿ”ฌ") + print("๐Ÿ”ฅ" + "=" * 76 + "๐Ÿ”ฅ") + + if flash_attn_with_mask is None: + print("โŒ Triton implementation not available, skipping Triton tests") + return False + + # Set random seed for reproducibility + torch.manual_seed(0) + + # Smaller test configurations for Triton (to avoid memory issues) + test_configs = [ + # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) + (1, 1, 1, 64, 64, 32, True), + (1, 1, 1, 64, 64, 32, False), + (1, 1, 1, 128, 128, 32, True), + (1, 1, 1, 128, 128, 32, False), + (1, 1, 1, 256, 256, 32, True), + (1, 1, 1, 256, 256, 32, False), + (1, 1, 1, 512, 512, 32, True), + (1, 1, 1, 512, 512, 32, False), + (1, 1, 1, 1024, 1024, 32, True), + (1, 1, 1, 1024, 1024, 32, False), + (1, 1, 1, 2048, 2048, 32, True), + (1, 1, 1, 2048, 2048, 32, False), + (1, 1, 1, 4096, 4096, 32, True), + (1, 1, 1, 4096, 4096, 32, False), + (1, 2, 1, 64, 64, 32, True), + (2, 1, 1, 128, 128, 32, True), + (2, 2, 1, 128, 128, 32, True), + (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 512, 512, 128, True), + ] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device_icon = "๐Ÿ”ฅ" if device.type == "cuda" else "๐Ÿ’ป" + print(f"{device_icon} Using device: {device}") + + all_passed = True + + for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config + + # Progress indicator + progress_filled = "โ–ˆ" * (i + 1) + progress_empty = "โ–‘" * (len(test_configs) - i - 1) + progress_bar = f"[{progress_filled}{progress_empty}]" + + print(f"\n๐Ÿงช Test configuration {i+1}/{len(test_configs)} {progress_bar}") + print(f" ๐Ÿ“Š batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") + print(f" ๐Ÿ“ query_len={query_len}, key_len={key_len}, head_dim={head_dim}") + print(f" ๐Ÿ”’ is_causal={is_causal}") + print(f" ๐ŸŽฏ Accuracy threshold: {accuracy_threshold*100:.1f}%") + + # Create random input data + query_states = torch.randn( + batch_size, num_heads, query_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + key_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + value_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + dt_proj = torch.randn( + num_kv_heads, num_kv_heads * head_dim, + device=device, dtype=torch.bfloat16 + ) + A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) + + # Create custom causal mask with cache position + cache_position = torch.arange(0, query_len + 0, device=device) + min_type = torch.finfo(value_states.dtype).min + causal_mask = torch.full( + (query_len, key_len), fill_value=min_type, + device=device, dtype=value_states.dtype + ) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + # Set scaling factor and keep window size + scaling = head_dim ** -0.5 + keep_window_size = 64 + + # Run Python implementation + start_time = time.time() + py_output = dynamic_mask_attention_python( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + torch.cuda.synchronize() + py_time = time.time() - start_time + + # Run Triton implementation + start_time = time.time() + try: + triton_output = dynamic_mask_attention_triton( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + torch.cuda.synchronize() + triton_time = time.time() - start_time + except Exception as e: + print(f"โŒ Triton implementation failed: {e}") + triton_output = None + triton_time = float('inf') + + # Analyze differences + py_output_copy = py_output.clone() + + if triton_output is not None: + triton_output_copy = triton_output.clone() + + print("\n๐Ÿ“Š Python vs Triton comparison:") + triton_vs_py_close, triton_max_diff, triton_mean_diff = analyze_differences(py_output_copy, triton_output_copy, accuracy_threshold) + else: + triton_vs_py_close = False + + # Report performance differences + print(f"\nโšก Performance comparison:") + print(f" ๐Ÿ Python implementation: {py_time*1000:.2f} ms") + if triton_output is not None: + print(f" ๐Ÿ”ฅ Triton implementation: {triton_time*1000:.2f} ms") + + triton_speedup = py_time / triton_time if triton_time > 0 else float('inf') + print(f" ๐Ÿ“ˆ Triton speedup vs Python: {triton_speedup:.2f}x") + + # Update test results + test_passed = triton_vs_py_close if triton_output is not None else False + test_result = "Passed" if test_passed else "Failed" + result_icon = "โœ…" if test_passed else "โŒ" + all_passed = all_passed and test_passed + print(f"\n{result_icon} Overall test result: {test_result}") + + # If test fails with large difference, can exit early + if not test_passed: + if triton_output is not None: + if triton_max_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + + del query_states, key_states, value_states, dt_proj, A, causal_mask, py_output, py_output_copy + if triton_output is not None: + del triton_output, triton_output_copy + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + print("\n" + "๐Ÿ" + "=" * 76 + "๐Ÿ") + summary_icon = "๐ŸŽ‰" if all_passed else "๐Ÿ˜ž" + print(f"{summary_icon} Python vs Triton Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}") + print("๐Ÿ" + "=" * 76 + "๐Ÿ") + + return all_passed + + +def test_flex_forward_equivalence(accuracy_threshold=0.95): + """Test forward pass equivalence between Python and Flex Attention implementations.""" + print("\n" + "๐ŸŒŸ" + "=" * 76 + "๐ŸŒŸ") + print("๐Ÿ”ฌ Testing Forward Pass Equivalence: Python vs Flex Attention ๐Ÿ”ฌ") + print("๐ŸŒŸ" + "=" * 76 + "๐ŸŒŸ") + + if flex_attention_forward is None: + print("โŒ Flex Attention implementation not available, skipping Flex Attention tests") + return False + + # Set random seed for reproducibility + torch.manual_seed(0) + + # Test configurations for Flex Attention + test_configs = [ + # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) + (1, 1, 1, 64, 64, 32, True), + (1, 1, 1, 64, 64, 32, False), + (1, 1, 1, 128, 128, 32, True), + (1, 1, 1, 128, 128, 32, False), + (1, 1, 1, 256, 256, 32, True), + (1, 1, 1, 256, 256, 32, False), + (1, 1, 1, 512, 512, 32, True), + (1, 1, 1, 512, 512, 32, False), + (1, 1, 1, 1024, 1024, 32, True), + (1, 1, 1, 1024, 1024, 32, False), + (1, 1, 1, 2048, 2048, 32, True), + (1, 1, 1, 2048, 2048, 32, False), + (1, 1, 1, 4096, 4096, 32, True), + (1, 1, 1, 4096, 4096, 32, False), + (1, 2, 1, 64, 64, 32, True), + (2, 1, 1, 128, 128, 32, True), + (2, 2, 1, 128, 128, 32, True), + (1, 2, 1, 64, 64, 128, True), + (1, 2, 1, 128, 128, 128, True), + (1, 2, 1, 256, 256, 128, True), + (1, 2, 1, 512, 512, 128, True), + ] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device_icon = "๐Ÿ”ฅ" if device.type == "cuda" else "๐Ÿ’ป" + print(f"{device_icon} Using device: {device}") + + all_passed = True + + for i, config in enumerate(test_configs): + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config + + # Progress indicator + progress_filled = "โ–ˆ" * (i + 1) + progress_empty = "โ–‘" * (len(test_configs) - i - 1) + progress_bar = f"[{progress_filled}{progress_empty}]" + + print(f"\n๐Ÿงช Test configuration {i+1}/{len(test_configs)} {progress_bar}") + print(f" ๐Ÿ“Š batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}") + print(f" ๐Ÿ“ query_len={query_len}, key_len={key_len}, head_dim={head_dim}") + print(f" ๐Ÿ”’ is_causal={is_causal}") + print(f" ๐ŸŽฏ Accuracy threshold: {accuracy_threshold*100:.1f}%") + + # Create random input data + query_states = torch.randn( + batch_size, num_heads, query_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + key_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + value_states = torch.randn( + batch_size, num_kv_heads, key_len, head_dim, + device=device, dtype=torch.bfloat16 + ) + dt_proj = torch.randn( + num_kv_heads, num_kv_heads * head_dim, + device=device, dtype=torch.bfloat16 + ) + A = torch.randn(num_kv_heads, device=device, dtype=torch.bfloat16) + + # Create custom causal mask with cache position + cache_position = torch.arange(0, query_len + 0, device=device) + min_type = torch.finfo(value_states.dtype).min + causal_mask = torch.full( + (query_len, key_len), fill_value=min_type, + device=device, dtype=value_states.dtype + ) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(key_len, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + + # Set scaling factor and keep window size + scaling = head_dim ** -0.5 + keep_window_size = 64 + + # Run Python implementation + start_time = time.time() + py_output = dynamic_mask_attention_python( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + torch.cuda.synchronize() + py_time = time.time() - start_time + + # Run Flex Attention implementation + start_time = time.time() + try: + flex_output = dynamic_mask_attention_flex( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, is_causal + ) + torch.cuda.synchronize() + flex_time = time.time() - start_time + except Exception as e: + print(f"โŒ Flex Attention implementation failed: {e}") + flex_output = None + flex_time = float('inf') + + # Analyze differences + py_output_copy = py_output.clone() + + if flex_output is not None: + flex_output_copy = flex_output.clone() + + print("\n๐Ÿ“Š Python vs Flex Attention comparison:") + flex_vs_py_close, flex_max_diff, flex_mean_diff = analyze_differences(py_output_copy, flex_output_copy, accuracy_threshold) + else: + flex_vs_py_close = False + + # Report performance differences + print(f"\nโšก Performance comparison:") + print(f" ๐Ÿ Python implementation: {py_time*1000:.2f} ms") + if flex_output is not None: + print(f" ๐ŸŒŸ Flex Attention implementation: {flex_time*1000:.2f} ms") + + flex_speedup = py_time / flex_time if flex_time > 0 else float('inf') + print(f" ๐Ÿ“ˆ Flex Attention speedup vs Python: {flex_speedup:.2f}x") + + # Update test results + test_passed = flex_vs_py_close if flex_output is not None else False + test_result = "Passed" if test_passed else "Failed" + result_icon = "โœ…" if test_passed else "โŒ" + all_passed = all_passed and test_passed + print(f"\n{result_icon} Overall test result: {test_result}") + + # If test fails with large difference, can exit early + if not test_passed: + if flex_output is not None: + if flex_max_diff > 1e-2: + print(" โš ๏ธ Difference too large, stopping subsequent tests.") + break + + del query_states, key_states, value_states, dt_proj, A, causal_mask, py_output, py_output_copy + if flex_output is not None: + del flex_output, flex_output_copy + torch.cuda.empty_cache() + gc.collect() + torch.cuda.synchronize() + + print("\n" + "๐Ÿ" + "=" * 76 + "๐Ÿ") + summary_icon = "๐ŸŽ‰" if all_passed else "๐Ÿ˜ž" + print(f"{summary_icon} Python vs Flex Attention Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}") + print("๐Ÿ" + "=" * 76 + "๐Ÿ") + + return all_passed + + def main(): """ Test forward pass equivalence between Python prototype and CUDA implementation @@ -516,7 +1020,7 @@ def main(): parser.add_argument('--accuracy-threshold', type=float, default=0.95, help='Minimum accuracy ratio to pass test (default: 0.95)') parser.add_argument('--test-type', type=str, default='all', - choices=['all', 'fwd'], + choices=['all', 'cuda', 'triton', 'flex'], help='Type of test to run (default: all)') args = parser.parse_args() @@ -543,9 +1047,17 @@ def main(): test_results = {} # Run tests based on user selection - if args.test_type in ['all', 'fwd']: + if args.test_type in ['all', 'cuda']: print("\n" + "๐Ÿ“" + " Starting Standard Forward Pass Tests " + "๐Ÿ“") - test_results['fwd'] = test_forward_equivalence(args.accuracy_threshold) + test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold) + + if args.test_type in ['all', 'triton']: + print("\n" + "๐Ÿ”ฅ" + " Starting Python vs Triton Tests " + "๐Ÿ”ฅ") + test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) + + if args.test_type in ['all', 'flex']: + print("\n" + "๐ŸŒŸ" + " Starting Python vs Flex Attention Tests " + "๐ŸŒŸ") + test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) # Print overall summary @@ -567,7 +1079,6 @@ def main(): print("๐Ÿ†" + "=" * 78 + "๐Ÿ†") # Exit with appropriate code - import sys sys.exit(0 if all_passed else 1) diff --git a/flash_dmattn/flash_dmattn_flex.py b/flash_dmattn/flash_dmattn_flex.py new file mode 100644 index 0000000..c5ce05b --- /dev/null +++ b/flash_dmattn/flash_dmattn_flex.py @@ -0,0 +1,64 @@ +from typing import Optional, Tuple +import torch +from torch.nn.attention.flex_attention import create_block_mask +from transformers.integrations.flex_attention import compile_friendly_flex_attention + + +def flex_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + attention_bias: torch.Tensor, + is_causal: bool = True, + scaling: Optional[float] = None, + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + attn_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_bias = attention_bias[:, :, :, : key.shape[-2]] + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + score = score + attn_bias[batch_idx][head_idx][q_idx][kv_idx] + return score + + def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + # It looks like you're attempting to use a Tensor in some data-dependent control flow. + # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . + # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0 + return q_idx >= kv_idx + + block_mask = create_block_mask( + mask_mod=causal_mask_mod, + B=query.shape[0], + H=None, + Q_LEN=query.shape[2], + KV_LEN=key.shape[2], + device=query.device, + _compile=True, + ) + + kernel_options = { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_DMODEL": 32, + "num_stages": 1, + "num_warps": 8, + } + attn_output, attention_weights = compile_friendly_flex_attention( + query, + key, + value, + score_mod=score_mod, + block_mask=block_mask if is_causal else None, + scale=scaling, + kernel_options=kernel_options, + # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless. + # For simplification, we thus always return it as no additional computations are introduced. + return_lse=True, + training=False, + ) + # lse is returned in float32 + attention_weights = attention_weights.to(value.dtype) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attention_weights \ No newline at end of file diff --git a/flash_dmattn/flash_dmattn_triton.py b/flash_dmattn/flash_dmattn_triton.py new file mode 100644 index 0000000..84d9f6c --- /dev/null +++ b/flash_dmattn/flash_dmattn_triton.py @@ -0,0 +1,399 @@ +import math + +import torch +import triton +import triton.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Mask, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_mb, + stride_mh, + stride_mm, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V, Mask, Bias + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + m_ptrs = ( + Mask + off_b * stride_mb + off_h * stride_mh + (offs_m[:, None] * stride_mm + offs_n[None, :]) + ) + b_ptrs = ( + Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) + + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + # Load mask + if EVEN_M & EVEN_N: + mask = tl.load(m_ptrs + start_n) + else: + mask = tl.load( + m_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0 + ) + + # Check if any element in mask is non-zero + # any_active = tl.sum(mask > 0) > 0 + any_active = True + + # compute acc_s + acc_s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if any_active: + # Load k + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + acc_s += tl.dot(q, tl.trans(k)) + + # Trying to combine the two masks seem to make the result wrong + # Apply sequence length mask + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + acc_s += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + # Apply causal mask + if IS_CAUSAL: + acc_s += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + # Apply dynamic mask + acc_s += tl.where(mask > 0.0, 0, float("-inf")) + + # Load bias + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + + # Apply scaling and bias + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + acc_s = acc_s * softmax_scale + bias + m_ij = tl.maximum(tl.max(acc_s, 1), lse_i) + p = tl.exp(acc_s - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # update output accumulator + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + + # update acc_o + if any_active: + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + acc_o += tl.dot(p.to(v.dtype), v) + + # update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +def _flash_attn_forward(q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + + if mask is not None: + assert mask.shape == (batch, nheads, seqlen_q, seqlen_k), f"mask shape {mask.shape} does not match expected shape {(batch, nheads, seqlen_q, seqlen_k)}" + assert mask.dtype in [torch.float16, torch.bfloat16, torch.float32], "mask must be fp16, bf16, or fp32" + assert mask.is_cuda, "mask must be on CUDA" + if mask.stride(-1) != 1: + mask = mask.contiguous() + else: + # Create a default mask of all ones + mask = torch.ones((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + + if bias is not None: + assert bias.dtype in [q.dtype, torch.float], f"bias dtype {bias.dtype} must match q dtype {q.dtype} or be float" + assert bias.is_cuda, "bias must be on CUDA" + assert bias.dim() == 4, f"bias must be 4D, got {bias.dim()}D" + assert bias.shape == (batch, nheads, seqlen_q, seqlen_k), f"bias shape {bias.shape} must be (batch={batch}, nheads={nheads}, seqlen_q={seqlen_q}, seqlen_k={seqlen_k})" + if bias.stride(-1) != 1: + bias = bias.contiguous() + else: + # Create zero bias if none provided + bias = torch.zeros((batch, nheads, seqlen_q, seqlen_k), device=q.device, dtype=q.dtype) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 64 # Reduced from 128 to 64 to avoid shared memory overflow + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + mask, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + mask.stride(0), + mask.stride(1), + mask.stride(2), + bias.stride(0), + bias.stride(1), + bias.stride(2), + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, mask=None, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + mask: optional, shape (batch, nheads, seqlen_q, seqlen_k), dynamic attention mask + bias: optional, shape must be exactly (batch, nheads, seqlen_q, seqlen_k), attention bias matrix + causal: bool, whether to apply causal masking + softmax_scale: float, scaling factor for attention scores + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, mask=mask, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, mask, bias) + ctx.causal = causal + return o + + # @staticmethod + # def backward(ctx, do): + # q, k, v, o, lse, mask, bias = ctx.saved_tensors + # assert not ctx.needs_input_grad[3], "FlashAttention does not support mask gradient yet" + # assert not ctx.needs_input_grad[4], "FlashAttention does not support bias gradient yet" + # # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + # with torch.inference_mode(): + # dq = torch.empty_like(q) + # dk = torch.empty_like(k) + # dv = torch.empty_like(v) + # _flash_attn_backward( + # do, + # q, + # k, + # v, + # o, + # lse, + # dq, + # dk, + # dv, + # bias=bias, + # causal=ctx.causal, + # softmax_scale=ctx.softmax_scale, + # ) + # return dq, dk, dv, None, None, None, None + + +flash_dmattn_func = FlashAttnFunc.apply