From 69f0a88d95094f093a85b622d2e8faf146af4acc Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 21 Jun 2025 11:38:40 +0800 Subject: [PATCH 1/2] Adds no-topk variant for kernel performance analysis Introduces a new benchmark function that skips topk computation to isolate pure kernel performance from selection overhead. This enables more accurate measurement of the core attention mechanism by removing the topk bottleneck. Expands benchmark configurations to include inference scenarios with varying key lengths and updates output formatting to display both standard and no-topk performance metrics with speedup comparisons. Provides insights into topk overhead impact on overall performance. --- benchmarks/benchmark_forward_performance.py | 184 ++++++++++++++++++-- 1 file changed, 171 insertions(+), 13 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 8dbebc7..f1f4d5d 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -213,6 +213,82 @@ def dynamic_mask_attention_cuda( return "OOM" +def dynamic_mask_attention_cuda_no_topk( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, + return_softmax=False +): + """ + CUDA implementation of dynamic mask attention without topk computation. + This version skips the topk calculation for more accurate kernel performance testing. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + return_softmax: Whether to return softmax weights + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + # Calculate zero_hold_states + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + + # Create a simplified mask without topk computation + batch_size, num_heads, query_len, head_dim = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + dtype = query_states.dtype + device = query_states.device + + # Create full active mask (no topk selection) + active_mask = torch.zeros( + (batch_size, num_kv_heads, query_len, key_len), + dtype=dtype, + device=device + ) + + # Ensure correct data types and memory layout + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + zero_hold_states = zero_hold_states[:, :, None, :].expand( + -1, -1, query_states.shape[1], -1 + ).contiguous() # [batch, num_kv_heads, query_len, key_len] + active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + + try: + result = apply_dynamic_mask_attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + zoh_states=zero_hold_states, + active_mask=active_mask, + scale=scaling, + keep_window_size=key_len, # Use full key length since no topk + is_causal=is_causal, + return_softmax=return_softmax + ) + + # Convert result back to original data type + attn_outputs = result[0] + return attn_outputs + except torch.cuda.OutOfMemoryError: + return "OOM" + + def measure_memory_usage(): """ Measure current GPU memory usage. @@ -280,10 +356,13 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], + 'dynamic_mask_attention_no_topk_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, + 'dynamic_mask_attention_no_topk_memory': 0, 'flash_attention_status': 'success', - 'dynamic_mask_attention_status': 'success' + 'dynamic_mask_attention_status': 'success', + 'dynamic_mask_attention_no_topk_status': 'success' } # Benchmark Flash Attention @@ -366,6 +445,47 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): mem_after = measure_memory_usage() results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + # Benchmark Dynamic Mask Attention (No TopK) + gc.collect() + torch.cuda.empty_cache() + + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, True + ) + if result == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_no_topk_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + start_time = time.time() + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, True + ) + torch.cuda.synchronize() + end_time = time.time() + + if result == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + + results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + return results @@ -386,6 +506,16 @@ def run_performance_benchmark(): (1, 2, 1, 8192, 8192, 32), (1, 2, 1, 16384, 16384, 32), (1, 2, 1, 32768, 32768, 32), + + # Inference + (1, 2, 1, 64, 256, 32), + (1, 2, 1, 64, 512, 32), + (1, 2, 1, 64, 1024, 32), + (1, 2, 1, 64, 2048, 32), + (1, 2, 1, 64, 4096, 32), + (1, 2, 1, 64, 8192, 32), + (1, 2, 1, 64, 16384, 32), + (1, 2, 1, 64, 32768, 32), # Vary batch size (1, 2, 1, 1024, 1024, 32), @@ -408,8 +538,8 @@ def run_performance_benchmark(): num_runs = 3 # Run 3 times and take average print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") - print(f"šŸ”§ {'Configuration':<45}⚔ {'Flash Attn (ms)':<18}šŸš€ {'DMA (ms)':<17}šŸ“ˆ {'Speedup':<12}šŸ’¾ {'Memory (MB)':<12}") - print("šŸ”„" + "-" * 117 + "šŸ”„") + print(f"šŸ”§ {'Configuration':<42} ⚔ {'Flash (ms)':<12} šŸš€ {'DMA (ms)':<12} šŸš€ {'DMA-Skip-All (ms)':<22} šŸ“ˆ {'Speedup':<12} šŸ“ˆ {'Skip-All-Speedup':<20} šŸ’¾ {'Memory':<10}") + print("šŸ”„" + "-" * 155 + "šŸ”„") all_results = [] @@ -433,31 +563,46 @@ def run_performance_benchmark(): else: dma_time_str = results['dynamic_mask_attention_status'] dma_avg = float('inf') + + if results['dynamic_mask_attention_no_topk_status'] == 'success' and results['dynamic_mask_attention_no_topk_times']: + dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) + dma_nt_time_str = f"{dma_nt_avg:.2f}" + else: + dma_nt_time_str = results['dynamic_mask_attention_no_topk_status'] + dma_nt_avg = float('inf') - # Calculate speedup + # Calculate speedups if flash_avg != float('inf') and dma_avg != float('inf') and dma_avg > 0: speedup = flash_avg / dma_avg speedup_str = f"{speedup:.2f}x" else: speedup_str = "N/A" + + if flash_avg != float('inf') and dma_nt_avg != float('inf') and dma_nt_avg > 0: + kernel_speedup = flash_avg / dma_nt_avg + kernel_speedup_str = f"{kernel_speedup:.2f}x" + else: + kernel_speedup_str = "N/A" # Memory usage mem_diff = results['dynamic_mask_attention_memory'] - results['flash_attention_memory'] - mem_str = f"{mem_diff:+.1f}" + mem_str = f"{mem_diff:+.0f}" # Format output - config_short = f"b={batch_size}, h={num_heads}, kv={num_kv_heads}, q={query_len}, k={key_len}, d={head_dim}" + config_short = f"b={batch_size},h={num_heads},kv={num_kv_heads},q={query_len},k={key_len},d={head_dim}" # Add status icons flash_icon = "āœ…" if results['flash_attention_status'] == 'success' else "šŸ’„" dma_icon = "āœ…" if results['dynamic_mask_attention_status'] == 'success' else "šŸ’„" + dma_nt_icon = "āœ…" if results['dynamic_mask_attention_no_topk_status'] == 'success' else "šŸ’„" - print(f"{flash_icon}{dma_icon} {config_short:<47}{flash_time_str:<20}{dma_time_str:<20}{speedup_str:<15}{mem_str:<15}") + print(f"{flash_icon}{dma_icon}{dma_nt_icon} {config_short:<42} {flash_time_str:<14} {dma_time_str:<20} {dma_nt_time_str:<20} {speedup_str:<18} {kernel_speedup_str:<20} {mem_str:<12}") - print("šŸ”„" + "-" * 117 + "šŸ”„") + print("šŸ”„" + "-" * 155 + "šŸ”„") # Summary statistics speedups = [] + kernel_speedups = [] for results in all_results: if (results['flash_attention_status'] == 'success' and results['dynamic_mask_attention_status'] == 'success' and @@ -468,15 +613,28 @@ def run_performance_benchmark(): if dma_avg > 0: speedups.append(flash_avg / dma_avg) + + if (results['flash_attention_status'] == 'success' and + results['dynamic_mask_attention_no_topk_status'] == 'success' and + results['flash_attention_times'] and results['dynamic_mask_attention_no_topk_times']): + + flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) + dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) + + if dma_nt_avg > 0: + kernel_speedups.append(flash_avg / dma_nt_avg) + print(f"\nšŸ† Summary:") if speedups: avg_speedup = np.mean(speedups) - print(f"\nšŸ† Summary:") speedup_icon = "šŸš€" if avg_speedup > 1.5 else "šŸ“ˆ" if avg_speedup > 1.0 else "😐" - print(f" {speedup_icon} Average speedup: {avg_speedup:.2f}x") - print(f" ⭐ Best speedup: {np.max(speedups):.2f}x") - print(f" šŸ“‰ Worst speedup: {np.min(speedups):.2f}x") - print(f" šŸ“Š Speedup std: {np.std(speedups):.2f}") + print(f" {speedup_icon} DMA vs Flash - Average speedup: {avg_speedup:.2f}x (Best: {np.max(speedups):.2f}x, Worst: {np.min(speedups):.2f}x)") + + if kernel_speedups: + avg_kernel_speedup = np.mean(kernel_speedups) + kernel_icon = "šŸ”„" if avg_kernel_speedup > 2.0 else "šŸš€" if avg_kernel_speedup > 1.5 else "šŸ“ˆ" if avg_kernel_speedup > 1.0 else "😐" + print(f" {kernel_icon} DMA-NoTopK vs Flash - Average kernel speedup: {avg_kernel_speedup:.2f}x (Best: {np.max(kernel_speedups):.2f}x, Worst: {np.min(kernel_speedups):.2f}x)") + print(f" šŸ’” TopK overhead: ~{((np.mean(kernel_speedups) - np.mean(speedups) if speedups else 0) / np.mean(kernel_speedups) * 100) if kernel_speedups else 0:.1f}% performance impact") def main(): From 568d4b8fd420f9a9548be1f539885f94bee7da9a Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 21 Jun 2025 11:49:35 +0800 Subject: [PATCH 2/2] Fixes window size parameter in dynamic mask attention Changes keep_window_size from key_len to 0 when topk is disabled, preventing potential memory issues or incorrect attention calculations. Also removes trailing whitespace from main function call. --- benchmarks/benchmark_forward_performance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index f1f4d5d..f1b3f88 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -277,7 +277,7 @@ def dynamic_mask_attention_cuda_no_topk( zoh_states=zero_hold_states, active_mask=active_mask, scale=scaling, - keep_window_size=key_len, # Use full key length since no topk + keep_window_size=0, is_causal=is_causal, return_softmax=return_softmax ) @@ -676,4 +676,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() \ No newline at end of file