diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 8dbebc7..f1b3f88 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -213,6 +213,82 @@ def dynamic_mask_attention_cuda( return "OOM" +def dynamic_mask_attention_cuda_no_topk( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + dt_proj: torch.Tensor, + A: torch.Tensor, + scaling: float, + causal_mask: torch.Tensor, + keep_window_size=2048, + is_causal=True, + return_softmax=False +): + """ + CUDA implementation of dynamic mask attention without topk computation. + This version skips the topk calculation for more accurate kernel performance testing. + + Args: + query_states: [batch_size, num_heads, query_len, head_dim] + key_states: [batch_size, num_kv_heads, key_len, head_dim] + value_states: [batch_size, num_kv_heads, key_len, head_dim] + dt_proj: [num_kv_heads, num_kv_heads * head_dim] + A: [num_kv_heads] + scaling: Attention scaling factor + causal_mask: Causal attention mask + keep_window_size: Number of tokens to keep in attention window + is_causal: Whether to apply causal masking + return_softmax: Whether to return softmax weights + + Returns: + attn_outputs: [batch_size, query_len, num_heads, head_dim] + """ + # Calculate zero_hold_states + zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) + + # Create a simplified mask without topk computation + batch_size, num_heads, query_len, head_dim = query_states.shape + _, num_kv_heads, key_len, _ = key_states.shape + dtype = query_states.dtype + device = query_states.device + + # Create full active mask (no topk selection) + active_mask = torch.zeros( + (batch_size, num_kv_heads, query_len, key_len), + dtype=dtype, + device=device + ) + + # Ensure correct data types and memory layout + query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] + key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] + zero_hold_states = zero_hold_states[:, :, None, :].expand( + -1, -1, query_states.shape[1], -1 + ).contiguous() # [batch, num_kv_heads, query_len, key_len] + active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] + + try: + result = apply_dynamic_mask_attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + zoh_states=zero_hold_states, + active_mask=active_mask, + scale=scaling, + keep_window_size=0, + is_causal=is_causal, + return_softmax=return_softmax + ) + + # Convert result back to original data type + attn_outputs = result[0] + return attn_outputs + except torch.cuda.OutOfMemoryError: + return "OOM" + + def measure_memory_usage(): """ Measure current GPU memory usage. @@ -280,10 +356,13 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): 'config': config, 'flash_attention_times': [], 'dynamic_mask_attention_times': [], + 'dynamic_mask_attention_no_topk_times': [], 'flash_attention_memory': 0, 'dynamic_mask_attention_memory': 0, + 'dynamic_mask_attention_no_topk_memory': 0, 'flash_attention_status': 'success', - 'dynamic_mask_attention_status': 'success' + 'dynamic_mask_attention_status': 'success', + 'dynamic_mask_attention_no_topk_status': 'success' } # Benchmark Flash Attention @@ -366,6 +445,47 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): mem_after = measure_memory_usage() results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] + # Benchmark Dynamic Mask Attention (No TopK) + gc.collect() + torch.cuda.empty_cache() + + # Warmup runs + for _ in range(warmup_runs): + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, True + ) + if result == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + torch.cuda.synchronize() + + if results['dynamic_mask_attention_no_topk_status'] == 'success': + # Measure memory before benchmark + mem_before = measure_memory_usage() + + # Actual benchmark runs + for _ in range(num_runs): + start_time = time.time() + result = dynamic_mask_attention_cuda_no_topk( + query_states, key_states, value_states, + dt_proj, A, scaling, causal_mask, + keep_window_size, True + ) + torch.cuda.synchronize() + end_time = time.time() + + if result == "OOM": + results['dynamic_mask_attention_no_topk_status'] = 'OOM' + break + + results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms + + # Measure memory after + mem_after = measure_memory_usage() + results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] + return results @@ -386,6 +506,16 @@ def run_performance_benchmark(): (1, 2, 1, 8192, 8192, 32), (1, 2, 1, 16384, 16384, 32), (1, 2, 1, 32768, 32768, 32), + + # Inference + (1, 2, 1, 64, 256, 32), + (1, 2, 1, 64, 512, 32), + (1, 2, 1, 64, 1024, 32), + (1, 2, 1, 64, 2048, 32), + (1, 2, 1, 64, 4096, 32), + (1, 2, 1, 64, 8192, 32), + (1, 2, 1, 64, 16384, 32), + (1, 2, 1, 64, 32768, 32), # Vary batch size (1, 2, 1, 1024, 1024, 32), @@ -408,8 +538,8 @@ def run_performance_benchmark(): num_runs = 3 # Run 3 times and take average print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") - print(f"šŸ”§ {'Configuration':<45}⚔ {'Flash Attn (ms)':<18}šŸš€ {'DMA (ms)':<17}šŸ“ˆ {'Speedup':<12}šŸ’¾ {'Memory (MB)':<12}") - print("šŸ”„" + "-" * 117 + "šŸ”„") + print(f"šŸ”§ {'Configuration':<42} ⚔ {'Flash (ms)':<12} šŸš€ {'DMA (ms)':<12} šŸš€ {'DMA-Skip-All (ms)':<22} šŸ“ˆ {'Speedup':<12} šŸ“ˆ {'Skip-All-Speedup':<20} šŸ’¾ {'Memory':<10}") + print("šŸ”„" + "-" * 155 + "šŸ”„") all_results = [] @@ -433,31 +563,46 @@ def run_performance_benchmark(): else: dma_time_str = results['dynamic_mask_attention_status'] dma_avg = float('inf') + + if results['dynamic_mask_attention_no_topk_status'] == 'success' and results['dynamic_mask_attention_no_topk_times']: + dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) + dma_nt_time_str = f"{dma_nt_avg:.2f}" + else: + dma_nt_time_str = results['dynamic_mask_attention_no_topk_status'] + dma_nt_avg = float('inf') - # Calculate speedup + # Calculate speedups if flash_avg != float('inf') and dma_avg != float('inf') and dma_avg > 0: speedup = flash_avg / dma_avg speedup_str = f"{speedup:.2f}x" else: speedup_str = "N/A" + + if flash_avg != float('inf') and dma_nt_avg != float('inf') and dma_nt_avg > 0: + kernel_speedup = flash_avg / dma_nt_avg + kernel_speedup_str = f"{kernel_speedup:.2f}x" + else: + kernel_speedup_str = "N/A" # Memory usage mem_diff = results['dynamic_mask_attention_memory'] - results['flash_attention_memory'] - mem_str = f"{mem_diff:+.1f}" + mem_str = f"{mem_diff:+.0f}" # Format output - config_short = f"b={batch_size}, h={num_heads}, kv={num_kv_heads}, q={query_len}, k={key_len}, d={head_dim}" + config_short = f"b={batch_size},h={num_heads},kv={num_kv_heads},q={query_len},k={key_len},d={head_dim}" # Add status icons flash_icon = "āœ…" if results['flash_attention_status'] == 'success' else "šŸ’„" dma_icon = "āœ…" if results['dynamic_mask_attention_status'] == 'success' else "šŸ’„" + dma_nt_icon = "āœ…" if results['dynamic_mask_attention_no_topk_status'] == 'success' else "šŸ’„" - print(f"{flash_icon}{dma_icon} {config_short:<47}{flash_time_str:<20}{dma_time_str:<20}{speedup_str:<15}{mem_str:<15}") + print(f"{flash_icon}{dma_icon}{dma_nt_icon} {config_short:<42} {flash_time_str:<14} {dma_time_str:<20} {dma_nt_time_str:<20} {speedup_str:<18} {kernel_speedup_str:<20} {mem_str:<12}") - print("šŸ”„" + "-" * 117 + "šŸ”„") + print("šŸ”„" + "-" * 155 + "šŸ”„") # Summary statistics speedups = [] + kernel_speedups = [] for results in all_results: if (results['flash_attention_status'] == 'success' and results['dynamic_mask_attention_status'] == 'success' and @@ -468,15 +613,28 @@ def run_performance_benchmark(): if dma_avg > 0: speedups.append(flash_avg / dma_avg) + + if (results['flash_attention_status'] == 'success' and + results['dynamic_mask_attention_no_topk_status'] == 'success' and + results['flash_attention_times'] and results['dynamic_mask_attention_no_topk_times']): + + flash_avg = sum(results['flash_attention_times']) / len(results['flash_attention_times']) + dma_nt_avg = sum(results['dynamic_mask_attention_no_topk_times']) / len(results['dynamic_mask_attention_no_topk_times']) + + if dma_nt_avg > 0: + kernel_speedups.append(flash_avg / dma_nt_avg) + print(f"\nšŸ† Summary:") if speedups: avg_speedup = np.mean(speedups) - print(f"\nšŸ† Summary:") speedup_icon = "šŸš€" if avg_speedup > 1.5 else "šŸ“ˆ" if avg_speedup > 1.0 else "😐" - print(f" {speedup_icon} Average speedup: {avg_speedup:.2f}x") - print(f" ⭐ Best speedup: {np.max(speedups):.2f}x") - print(f" šŸ“‰ Worst speedup: {np.min(speedups):.2f}x") - print(f" šŸ“Š Speedup std: {np.std(speedups):.2f}") + print(f" {speedup_icon} DMA vs Flash - Average speedup: {avg_speedup:.2f}x (Best: {np.max(speedups):.2f}x, Worst: {np.min(speedups):.2f}x)") + + if kernel_speedups: + avg_kernel_speedup = np.mean(kernel_speedups) + kernel_icon = "šŸ”„" if avg_kernel_speedup > 2.0 else "šŸš€" if avg_kernel_speedup > 1.5 else "šŸ“ˆ" if avg_kernel_speedup > 1.0 else "😐" + print(f" {kernel_icon} DMA-NoTopK vs Flash - Average kernel speedup: {avg_kernel_speedup:.2f}x (Best: {np.max(kernel_speedups):.2f}x, Worst: {np.min(kernel_speedups):.2f}x)") + print(f" šŸ’” TopK overhead: ~{((np.mean(kernel_speedups) - np.mean(speedups) if speedups else 0) / np.mean(kernel_speedups) * 100) if kernel_speedups else 0:.1f}% performance impact") def main(): @@ -518,4 +676,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() \ No newline at end of file