diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index d9ed6f5..a38036b 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -8,7 +8,6 @@ Implementations tested: - Flash Attention (PyTorch SDPA Flash Attention backend) - Baseline - Dynamic Mask Attention CUDA - Custom CUDA kernel implementation -- Dynamic Mask Attention CUDA (No TopK) - CUDA kernel without TopK computation - Dynamic Mask Attention Triton - Triton kernel implementation - Dynamic Mask Attention Flex - Flex Attention implementation @@ -187,16 +186,15 @@ def flash_attention_cuda( torch.cuda.synchronize() start_time = time.time() - with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION]): - attn_outputs = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - # attn_mask=causal_mask, - scale=scaling, - is_causal=is_causal if query_len == key_len else False, - enable_gqa=True - ) + attn_outputs = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + scale=scaling, + # is_causal=is_causal if query_len == key_len else False, + enable_gqa=True + ) torch.cuda.synchronize() end_time = time.time() @@ -293,99 +291,6 @@ def dynamic_mask_attention_cuda( return "OOM", 0 -def dynamic_mask_attention_cuda_no_topk( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - dt_proj: torch.Tensor, - A: torch.Tensor, - scaling: float, - causal_mask: torch.Tensor, - keep_window_size=2048, - is_causal=True, - return_softmax=False -): - """ - CUDA implementation of dynamic mask attention without topk computation. - This version skips the topk calculation for more accurate kernel performance testing. - - Args: - query_states: [batch_size, num_heads, query_len, head_dim] - key_states: [batch_size, num_kv_heads, key_len, head_dim] - value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - scaling: Attention scaling factor - causal_mask: Causal attention mask - keep_window_size: Number of tokens to keep in attention window - is_causal: Whether to apply causal masking - return_softmax: Whether to return softmax weights - - Returns: - attn_outputs: [batch_size, query_len, num_heads, head_dim] - """ - # Calculate zoh_states - zoh_states = calculate_zoh_states(value_states, dt_proj, A) - - # Create a simplified mask without topk computation - batch_size, _, query_len, _ = query_states.shape - _, num_kv_heads, key_len, _ = key_states.shape - dtype = query_states.dtype - device = query_states.device - - # Create full attn mask (no topk selection) - attn_mask = torch.zeros( - (batch_size, num_kv_heads, query_len, key_len), - dtype=dtype, - device=device - ) - - # Ensure correct data types and memory layout - query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] - key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] - attn_bias = zoh_states[:, :, None, :].expand( - -1, -1, query_states.shape[1], -1 - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - # Create full active mask (no topk selection) - attn_mask = torch.zeros_like( - attn_bias, - dtype=dtype, - device=device - ).contiguous() # [batch, num_kv_heads, query_len, key_len] - - try: - out_tensor = None # Let the function allocate the output tensor - - # Only measure the core CUDA kernel computation - torch.cuda.synchronize() - start_time = time.time() - - result = flash_dmattn_cuda.fwd( # type: ignore - query_states, # q: [batch, seqlen_q, num_heads, head_dim] - key_states, # k: [batch, seqlen_k, num_kv_heads, head_dim] - value_states, # v: [batch, seqlen_k, num_kv_heads, head_dim] - attn_mask, # attn_mask: [batch, num_kv_heads, seqlen_q, seqlen_k] - attn_bias, # attn_bias: [batch, num_kv_heads, seqlen_q, seqlen_k] - out_tensor, # out: None to auto-allocate - 0.0, # p_dropout - scaling, # softmax_scale - is_causal, # is_causal - keep_window_size, # keep_window_size - 0.0, # softcap - return_softmax, # return_softmax - None # gen (generator) - ) - - torch.cuda.synchronize() - end_time = time.time() - - attn_outputs = result[0] - return attn_outputs, (end_time - start_time) * 1000 # Return output and time in ms - except torch.cuda.OutOfMemoryError: - return "OOM", 0 - - def dynamic_mask_attention_triton( query_states: torch.Tensor, key_states: torch.Tensor, @@ -614,17 +519,14 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], - 'dynamic_mask_attention_no_topk_times': [], 'dynamic_mask_attention_triton_times': [], 'dynamic_mask_attention_flex_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, - 'dynamic_mask_attention_no_topk_memory': 0, 'dynamic_mask_attention_triton_memory': 0, 'dynamic_mask_attention_flex_memory': 0, 'flash_attention_status': 'success', 'dynamic_mask_attention_status': 'success', - 'dynamic_mask_attention_no_topk_status': 'success', 'dynamic_mask_attention_triton_status': 'success', 'dynamic_mask_attention_flex_status': 'success' } @@ -632,7 +534,6 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ # Determine which implementations to run run_flash = test_type in ['all', 'flash', 'flash-vs-cuda', 'flash-vs-triton', 'flash-vs-flex'] run_cuda = test_type in ['all', 'cuda', 'flash-vs-cuda'] - run_no_topk = test_type in ['all', 'cuda'] run_triton = test_type in ['all', 'triton', 'flash-vs-triton'] run_flex = test_type in ['all', 'flex', 'flash-vs-flex'] @@ -718,48 +619,6 @@ def benchmark_attention_performance(config, test_type='all', num_runs=5, warmup_ else: results['dynamic_mask_attention_status'] = 'N/A' - # Benchmark Dynamic Mask Attention (No TopK) - if run_no_topk: - gc.collect() - torch.cuda.empty_cache() - - # Warmup runs - for _ in range(warmup_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - if result[0] == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - torch.cuda.synchronize() - - if results['dynamic_mask_attention_no_topk_status'] == 'success': - # Measure memory before benchmark - mem_before = measure_memory_usage() - - # Actual benchmark runs - for _ in range(num_runs): - result = dynamic_mask_attention_cuda_no_topk( - query_states, key_states, value_states, - dt_proj, A, scaling, causal_mask, - keep_window_size, is_causal - ) - - if result[0] == "OOM": - results['dynamic_mask_attention_no_topk_status'] = 'OOM' - break - - # Use the timing from the function instead of measuring here - results['dynamic_mask_attention_no_topk_times'].append(result[1]) # ms - - # Measure memory after - mem_after = measure_memory_usage() - results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] - else: - results['dynamic_mask_attention_no_topk_status'] = 'N/A' - # Benchmark Dynamic Mask Attention (Triton) if run_triton: gc.collect() @@ -876,49 +735,49 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): # Test configurations: (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, keep_window_size, is_causal) configs = [ - # Vary sequence length - (1, 2, 1, 256, 256, 32, 2048, True), - (1, 2, 1, 512, 512, 32, 2048, True), - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 2048, 2048, 32, 2048, True), - (1, 2, 1, 4096, 4096, 32, 2048, True), - (1, 2, 1, 8192, 8192, 32, 2048, True), - (1, 2, 1, 16384, 16384, 32, 2048, True), - (1, 2, 1, 32768, 32768, 32, 2048, True), - - # Inference - (1, 2, 1, 2, 256, 128, 2048, True), - (1, 2, 1, 2, 512, 128, 2048, True), - (1, 2, 1, 2, 1024, 128, 2048, True), - (1, 2, 1, 2, 2048, 128, 2048, True), - (1, 2, 1, 2, 4096, 128, 2048, True), - (1, 2, 1, 2, 8192, 128, 2048, True), - (1, 2, 1, 2, 16384, 128, 2048, True), - (1, 2, 1, 2, 32768, 128, 2048, True), - (1, 2, 1, 2, 65536, 128, 2048, True), - (1, 2, 1, 2, 131072, 128, 2048, True), - (1, 2, 1, 2, 262144, 128, 2048, True), - (1, 2, 1, 2, 524288, 128, 2048, True), - - # Vary batch size - (1, 2, 1, 1024, 1024, 32, 2048, True), - (2, 2, 1, 1024, 1024, 32, 2048, True), - (4, 2, 1, 1024, 1024, 32, 2048, True), - (8, 2, 1, 1024, 1024, 32, 2048, True), - - # Vary head count - (1, 1, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 4, 1, 1024, 1024, 32, 2048, True), - (1, 8, 2, 1024, 1024, 32, 2048, True), - - # Vary head dimension - (1, 2, 1, 1024, 1024, 32, 2048, True), - (1, 2, 1, 1024, 1024, 64, 2048, True), - (1, 2, 1, 1024, 1024, 96, 2048, True), - (1, 2, 1, 1024, 1024, 128, 2048, True), - (1, 2, 1, 1024, 1024, 192, 2048, True), - (1, 2, 1, 1024, 1024, 256, 2048, True), + # # Vary sequence length + # (1, 2, 1, 256, 256, 128, 2048, True), + # (1, 2, 1, 512, 512, 128, 2048, True), + # (1, 2, 1, 1024, 1024, 128, 2048, True), + # (1, 2, 1, 2048, 2048, 128, 2048, True), + # (1, 2, 1, 4096, 4096, 128, 2048, True), + # (1, 2, 1, 8192, 8192, 128, 2048, True), + # (1, 2, 1, 16384, 16384, 128, 2048, True), + # (1, 2, 1, 32768, 32768, 128, 2048, True), + + # # Inference + # (1, 2, 1, 2, 256, 128, 2048, True), + # (1, 2, 1, 2, 512, 128, 2048, True), + # (1, 2, 1, 2, 1024, 128, 2048, True), + # (1, 2, 1, 2, 2048, 128, 2048, True), + # (1, 2, 1, 2, 4096, 128, 2048, True), + # (1, 2, 1, 2, 8192, 128, 2048, True), + # (1, 2, 1, 2, 16384, 128, 2048, True), + # (1, 2, 1, 2, 32768, 128, 2048, True), + # (1, 2, 1, 2, 65536, 128, 2048, True), + # (1, 2, 1, 2, 131072, 128, 2048, True), + # (1, 2, 1, 2, 262144, 128, 2048, True), + # (1, 2, 1, 2, 524288, 128, 2048, True), + + # # Vary batch size + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (2, 2, 1, 1024, 1024, 32, 2048, True), + # (4, 2, 1, 1024, 1024, 32, 2048, True), + # (8, 2, 1, 1024, 1024, 32, 2048, True), + + # # Vary head count + # (1, 1, 1, 1024, 1024, 32, 2048, True), + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (1, 4, 1, 1024, 1024, 32, 2048, True), + # (1, 8, 2, 1024, 1024, 32, 2048, True), + + # # Vary head dimension + # (1, 2, 1, 1024, 1024, 32, 2048, True), + # (1, 2, 1, 1024, 1024, 64, 2048, True), + # (1, 2, 1, 1024, 1024, 96, 2048, True), + # (1, 2, 1, 1024, 1024, 128, 2048, True), + # (1, 2, 1, 1024, 1024, 192, 2048, True), + # (1, 2, 1, 1024, 1024, 256, 2048, True), # Vary keep_window_size (1, 2, 1, 32768, 32768, 128, 32, True), @@ -933,15 +792,15 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 32768, 32768, 128, 16384, True), (1, 2, 1, 32768, 32768, 128, 32768, True), - # Test non-causal - (1, 2, 1, 1024, 1024, 128, 2048, False), + # # Test non-causal + # (1, 2, 1, 1024, 1024, 128, 2048, False), ] num_runs = 3 # Run 3 times and take average print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") - print(f"šŸ”§ {'Configuration':<60} ⚔ {'Flash':<10} šŸš€ {'CUDA':<10} šŸš€ {'No-TopK':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} šŸ“ˆ {'Speedup':<15}") - print("šŸ”„" + "-" * 160 + "šŸ”„") + print(f"šŸ”§ {'Configuration':<60} ⚔ {'Flash':<10} šŸš€ {'CUDA':<10} 🌟 {'Triton':<10} 🌟 {'Flex':<15} šŸ“ˆ {'Speedup':<15}") + print("šŸ”„" + "-" * 150 + "šŸ”„") all_results = [] @@ -955,7 +814,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): implementations = { 'flash': ('flash_attention', results['flash_attention_status'], results['flash_attention_times']), 'cuda': ('dynamic_mask_attention', results['dynamic_mask_attention_status'], results['dynamic_mask_attention_times']), - 'no_topk': ('dynamic_mask_attention_no_topk', results['dynamic_mask_attention_no_topk_status'], results['dynamic_mask_attention_no_topk_times']), 'triton': ('dynamic_mask_attention_triton', results['dynamic_mask_attention_triton_status'], results['dynamic_mask_attention_triton_times']), 'flex': ('dynamic_mask_attention_flex', results['dynamic_mask_attention_flex_status'], results['dynamic_mask_attention_flex_times']) } @@ -977,7 +835,7 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_strs = {} flash_avg = time_avgs.get('flash', float('inf')) - for impl_key in ['cuda', 'no_topk', 'triton', 'flex']: + for impl_key in ['cuda', 'triton', 'flex']: impl_avg = time_avgs.get(impl_key, float('inf')) if flash_avg != float('inf') and impl_avg != float('inf') and impl_avg > 0: speedup = flash_avg / impl_avg @@ -1015,14 +873,13 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): speedup_summary = f"{best_impl}:{best_speedup}" if best_speedup != "N/A" else "N/A" - print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['no_topk']:<14} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") + print(f"{icons} {config_short:<48} {time_strs['flash']:<12} {time_strs['cuda']:<12} {time_strs['triton']:<12} {time_strs['flex']:<18} {speedup_summary:<15}") - print("šŸ”„" + "-" * 160 + "šŸ”„") + print("šŸ”„" + "-" * 150 + "šŸ”„") # Summary statistics implementation_speedups = { 'cuda': [], - 'no_topk': [], 'triton': [], 'flex': [] } @@ -1071,13 +928,6 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): print(f" {icon} {impl_name:10} vs Flash - Avg: {avg_speedup:.2f}x (Best: {max_speedup:.2f}x, Worst: {min_speedup:.2f}x)") else: print(f" āŒ {impl_key.replace('_', '-').upper():10} vs Flash - No successful runs") - - # Calculate overhead comparison - if implementation_speedups['cuda'] and implementation_speedups['no_topk']: - avg_cuda = np.mean(implementation_speedups['cuda']) - avg_no_topk = np.mean(implementation_speedups['no_topk']) - topk_overhead = ((avg_no_topk - avg_cuda) / avg_no_topk * 100) if avg_no_topk > 0 else 0 - print(f" šŸ’” TopK overhead: ~{topk_overhead:.1f}% performance impact") def main():