-
Notifications
You must be signed in to change notification settings - Fork 41
Adds no-topk variant for kernel performance analysis #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -213,6 +213,82 @@ def dynamic_mask_attention_cuda( | |||||||||
| return "OOM" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def dynamic_mask_attention_cuda_no_topk( | ||||||||||
| query_states: torch.Tensor, | ||||||||||
| key_states: torch.Tensor, | ||||||||||
| value_states: torch.Tensor, | ||||||||||
| dt_proj: torch.Tensor, | ||||||||||
| A: torch.Tensor, | ||||||||||
| scaling: float, | ||||||||||
| causal_mask: torch.Tensor, | ||||||||||
| keep_window_size=2048, | ||||||||||
| is_causal=True, | ||||||||||
| return_softmax=False | ||||||||||
| ): | ||||||||||
| """ | ||||||||||
| CUDA implementation of dynamic mask attention without topk computation. | ||||||||||
| This version skips the topk calculation for more accurate kernel performance testing. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| query_states: [batch_size, num_heads, query_len, head_dim] | ||||||||||
| key_states: [batch_size, num_kv_heads, key_len, head_dim] | ||||||||||
| value_states: [batch_size, num_kv_heads, key_len, head_dim] | ||||||||||
| dt_proj: [num_kv_heads, num_kv_heads * head_dim] | ||||||||||
| A: [num_kv_heads] | ||||||||||
| scaling: Attention scaling factor | ||||||||||
| causal_mask: Causal attention mask | ||||||||||
| keep_window_size: Number of tokens to keep in attention window | ||||||||||
| is_causal: Whether to apply causal masking | ||||||||||
| return_softmax: Whether to return softmax weights | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| attn_outputs: [batch_size, query_len, num_heads, head_dim] | ||||||||||
| """ | ||||||||||
| # Calculate zero_hold_states | ||||||||||
| zero_hold_states = calculate_zero_hold_states(value_states, dt_proj, A, causal_mask) | ||||||||||
|
|
||||||||||
| # Create a simplified mask without topk computation | ||||||||||
| batch_size, num_heads, query_len, head_dim = query_states.shape | ||||||||||
| _, num_kv_heads, key_len, _ = key_states.shape | ||||||||||
| dtype = query_states.dtype | ||||||||||
| device = query_states.device | ||||||||||
|
|
||||||||||
| # Create full active mask (no topk selection) | ||||||||||
| active_mask = torch.zeros( | ||||||||||
| (batch_size, num_kv_heads, query_len, key_len), | ||||||||||
| dtype=dtype, | ||||||||||
| device=device | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Ensure correct data types and memory layout | ||||||||||
| query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim] | ||||||||||
| key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] | ||||||||||
| value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim] | ||||||||||
| zero_hold_states = zero_hold_states[:, :, None, :].expand( | ||||||||||
| -1, -1, query_states.shape[1], -1 | ||||||||||
| ).contiguous() # [batch, num_kv_heads, query_len, key_len] | ||||||||||
| active_mask = active_mask.contiguous() # [batch, num_kv_heads, query_len, key_len] | ||||||||||
|
|
||||||||||
| try: | ||||||||||
| result = apply_dynamic_mask_attention( | ||||||||||
| query_states=query_states, | ||||||||||
| key_states=key_states, | ||||||||||
| value_states=value_states, | ||||||||||
| zoh_states=zero_hold_states, | ||||||||||
| active_mask=active_mask, | ||||||||||
| scale=scaling, | ||||||||||
| keep_window_size=0, | ||||||||||
| is_causal=is_causal, | ||||||||||
| return_softmax=return_softmax | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Convert result back to original data type | ||||||||||
| attn_outputs = result[0] | ||||||||||
| return attn_outputs | ||||||||||
| except torch.cuda.OutOfMemoryError: | ||||||||||
| return "OOM" | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def measure_memory_usage(): | ||||||||||
| """ | ||||||||||
| Measure current GPU memory usage. | ||||||||||
|
|
@@ -280,10 +356,13 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): | |||||||||
| 'config': config, | ||||||||||
| 'flash_attention_times': [], | ||||||||||
| 'dynamic_mask_attention_times': [], | ||||||||||
| 'dynamic_mask_attention_no_topk_times': [], | ||||||||||
| 'flash_attention_memory': 0, | ||||||||||
| 'dynamic_mask_attention_memory': 0, | ||||||||||
| 'dynamic_mask_attention_no_topk_memory': 0, | ||||||||||
| 'flash_attention_status': 'success', | ||||||||||
| 'dynamic_mask_attention_status': 'success' | ||||||||||
| 'dynamic_mask_attention_status': 'success', | ||||||||||
| 'dynamic_mask_attention_no_topk_status': 'success' | ||||||||||
| } | ||||||||||
|
|
||||||||||
| # Benchmark Flash Attention | ||||||||||
|
|
@@ -366,6 +445,47 @@ def benchmark_attention_performance(config, num_runs=5, warmup_runs=2): | |||||||||
| mem_after = measure_memory_usage() | ||||||||||
| results['dynamic_mask_attention_memory'] = mem_after[0] - mem_before[0] | ||||||||||
|
|
||||||||||
| # Benchmark Dynamic Mask Attention (No TopK) | ||||||||||
| gc.collect() | ||||||||||
| torch.cuda.empty_cache() | ||||||||||
|
|
||||||||||
| # Warmup runs | ||||||||||
| for _ in range(warmup_runs): | ||||||||||
|
||||||||||
| result = dynamic_mask_attention_cuda_no_topk( | ||||||||||
| query_states, key_states, value_states, | ||||||||||
| dt_proj, A, scaling, causal_mask, | ||||||||||
| keep_window_size, True | ||||||||||
| ) | ||||||||||
| if result == "OOM": | ||||||||||
| results['dynamic_mask_attention_no_topk_status'] = 'OOM' | ||||||||||
| break | ||||||||||
| torch.cuda.synchronize() | ||||||||||
|
|
||||||||||
| if results['dynamic_mask_attention_no_topk_status'] == 'success': | ||||||||||
| # Measure memory before benchmark | ||||||||||
| mem_before = measure_memory_usage() | ||||||||||
|
|
||||||||||
| # Actual benchmark runs | ||||||||||
| for _ in range(num_runs): | ||||||||||
| start_time = time.time() | ||||||||||
| result = dynamic_mask_attention_cuda_no_topk( | ||||||||||
| query_states, key_states, value_states, | ||||||||||
| dt_proj, A, scaling, causal_mask, | ||||||||||
| keep_window_size, True | ||||||||||
| ) | ||||||||||
| torch.cuda.synchronize() | ||||||||||
| end_time = time.time() | ||||||||||
|
|
||||||||||
| if result == "OOM": | ||||||||||
| results['dynamic_mask_attention_no_topk_status'] = 'OOM' | ||||||||||
| break | ||||||||||
|
|
||||||||||
| results['dynamic_mask_attention_no_topk_times'].append((end_time - start_time) * 1000) # ms | ||||||||||
|
|
||||||||||
| # Measure memory after | ||||||||||
| mem_after = measure_memory_usage() | ||||||||||
| results['dynamic_mask_attention_no_topk_memory'] = mem_after[0] - mem_before[0] | ||||||||||
|
|
||||||||||
| return results | ||||||||||
|
|
||||||||||
|
|
||||||||||
|
|
@@ -386,6 +506,16 @@ def run_performance_benchmark(): | |||||||||
| (1, 2, 1, 8192, 8192, 32), | ||||||||||
| (1, 2, 1, 16384, 16384, 32), | ||||||||||
| (1, 2, 1, 32768, 32768, 32), | ||||||||||
|
|
||||||||||
| # Inference | ||||||||||
| (1, 2, 1, 64, 256, 32), | ||||||||||
| (1, 2, 1, 64, 512, 32), | ||||||||||
| (1, 2, 1, 64, 1024, 32), | ||||||||||
| (1, 2, 1, 64, 2048, 32), | ||||||||||
| (1, 2, 1, 64, 4096, 32), | ||||||||||
| (1, 2, 1, 64, 8192, 32), | ||||||||||
| (1, 2, 1, 64, 16384, 32), | ||||||||||
| (1, 2, 1, 64, 32768, 32), | ||||||||||
|
|
||||||||||
| # Vary batch size | ||||||||||
| (1, 2, 1, 1024, 1024, 32), | ||||||||||
|
|
@@ -408,8 +538,8 @@ def run_performance_benchmark(): | |||||||||
| num_runs = 3 # Run 3 times and take average | ||||||||||
|
|
||||||||||
| print(f"\n📊 Benchmark Results (averaged over {num_runs} runs):") | ||||||||||
| print(f"🔧 {'Configuration':<45}⚡ {'Flash Attn (ms)':<18}🚀 {'DMA (ms)':<17}📈 {'Speedup':<12}💾 {'Memory (MB)':<12}") | ||||||||||
| print("🔄" + "-" * 117 + "🔄") | ||||||||||
| print(f"🔧 {'Configuration':<42} ⚡ {'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}") | ||||||||||
| print("🔄" + "-" * 155 + "🔄") | ||||||||||
|
Comment on lines
+541
to
+542
|
||||||||||
| print(f"🔧 {'Configuration':<42} ⚡ {'Flash (ms)':<12} 🚀 {'DMA (ms)':<12} 🚀 {'DMA-Skip-All (ms)':<22} 📈 {'Speedup':<12} 📈 {'Skip-All-Speedup':<20} 💾 {'Memory':<10}") | |
| print("🔄" + "-" * 155 + "🔄") | |
| print(f"🔧 {'Configuration':<{CONFIGURATION_WIDTH}} ⚡ {'Flash (ms)':<{FLASH_WIDTH}} 🚀 {'DMA (ms)':<{DMA_WIDTH}} 🚀 {'DMA-Skip-All (ms)':<{DMA_SKIP_ALL_WIDTH}} 📈 {'Speedup':<{SPEEDUP_WIDTH}} 📈 {'Skip-All-Speedup':<{SKIP_ALL_SPEEDUP_WIDTH}} 💾 {'Memory':<{MEMORY_WIDTH}}") | |
| print("🔄" + "-" * (CONFIGURATION_WIDTH + FLASH_WIDTH + DMA_WIDTH + DMA_SKIP_ALL_WIDTH + SPEEDUP_WIDTH + SKIP_ALL_SPEEDUP_WIDTH + MEMORY_WIDTH + 20) + "🔄") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new no-topk function largely duplicates logic from
dynamic_mask_attention_cuda. Extract common setup (transposes, mask creation) into shared helpers to reduce duplication and simplify future maintenance.