diff --git a/README.md b/README.md index 7e2423f..ad27db2 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,16 @@ Flash-DMA is a high-performance attention implementation that integrates Flash A ## Key Features -- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse patterns. -- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix. -- **CUDA Deep Optimization**: Utilizes custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead. -- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy. -- **Learnable Bias**: Built-in learnable attention bias and its gradient path dbias, eliminating the need for additional external operators. -- **Fusion-Friendly Training**: Both forward and backward passes support block-level zero-mask skipping, further reducing computation in sparse scenarios. +### ๐ŸŽฏ Core Kernel Advantages +- **4D Mask & Bias Support**: Native support for `(batch_size, num_kv_heads, query_len, key_len)` shaped attention mask and attention bias tensors +- **Intelligent Computation Skipping**: Block-level automatic skipping mechanism based on masks, completely bypassing computation and memory access for zero-mask blocks +- **Complete Gradient Support**: Built-in full gradient computation path for attention bias, supporting end-to-end training + +### ๐Ÿš€ Performance & Efficiency +- **Dynamic Sparse Attention**: Dynamically selects the most relevant keys for each query, reducing computational complexity from $O(N^2)$ to $O(N \cdot w)$ where $w \ll N$, supporting trainable sparse structures +- **Memory Efficiency**: Maintains Flash Attention's $O(N)$ memory complexity without instantiating the full attention matrix +- **CUDA Deep Optimization**: Custom CUDA kernels with shared memory aliasing, pipelined prefetching, and block skipping for high throughput and low memory access overhead +- **Extremely Long Context Support**: Handles 128K+ token sequences efficiently through dynamic mask windowing while preserving accuracy ## Performance @@ -145,74 +149,104 @@ MAX_JOBS=4 pip install . --no-build-isolation ## Quick Start +### Basic Usage + ```python import torch from flash_dmattn import flash_dmattn_func_auto import math # Setup -batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 +batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64 +keep_window_size = 128 device = torch.device('cuda') dtype = torch.bfloat16 +min_dtype = torch.finfo(dtype).min # dtype minimum value # Input tensors -query = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -key = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -value = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) +query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) +value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) # Create mask and bias for sparse attention -attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) -attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) +attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) -# Apply dynamic masking (keep top-k for long sequences) -keep_window_size = 2048 +# Generate sparse mask based on bias if seq_len > keep_window_size: # Select top-k most important keys for each query - topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, - largest=True, sorted=False).indices - attention_mask.zero_() - attention_mask.scatter(-1, topk_indices, 1.0) - -# Select backend + topk_values, topk_indices = torch.topk( + attention_bias, keep_window_size, dim=-1, + largest=True, sorted=False + ) + # Generate valid top-k mask + valid_topk = (topk_values != min_dtype).to(dtype) + attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) + attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) + +# Select FDMA kernel flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") # Run Flash Dynamic Mask Attention output = flash_dmattn_func( - q=query, - k=key, - v=value, + query=query, + key=key, + value=value, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=1.0/math.sqrt(head_dim), ) -print(f"Output shape: {output.shape}") # [2, 4096, 16, 128] +print(f"Output shape: {output.shape}") # [1, 256, 2, 64] ``` +### Gradient Computation Example -## How It Works +```python +# Enable gradient computation +query.requires_grad_(True) +key.requires_grad_(True) +value.requires_grad_(True) +attention_bias.requires_grad_(True) -Flash-DMA combines two complementary techniques: +# Forward pass +output = flash_dmattn_func( + query=query, key=key, value=value, + attn_mask=attention_mask, + attn_bias=attention_bias, + is_causal=True, + scale=1.0/math.sqrt(head_dim) +) + +# Backward pass +loss = output.sum() +loss.backward() + +print(f"Query gradient shape: {query.grad.shape}") +print(f"Key gradient shape: {key.grad.shape}") +print(f"Value gradient shape: {value.grad.shape}") +print(f"Bias gradient shape: {attention_bias.grad.shape}") +``` -- **Dynamic Mask Attention**: Computes relevance scores for keys and selects only the most important ones for attention computation -- **Flash Attention**: Processes attention in blocks to reduce memory usage and HBM access -### The Integration Approach +## How It Works + +Flash-DMA integrates the efficient memory access patterns of Flash Attention with the sparse computation capabilities of dynamic mask attention to achieve an efficient attention mechanism. -The integration happens at the CUDA kernel level with several key components: +### Core Technology Integration -- **ZOH States**: Pre-computed importance scores for key selection -- **Active Masks**: Binary masks indicating which keys should be considered for each query -- **Sparse Skipping**: Custom CUDA kernels for efficient sparse attention computation -- **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency +- **๐ŸŽฏ Native 4D Mask & Bias Support**: Kernels directly process `(batch_size, num_kv_heads, query_len, key_len)` shaped tensors +- **โšก Block-level Intelligent Skipping**: Unified OR-reduction skipping logic based on masks, completely avoiding computation and memory access for zero blocks +- **๐Ÿ”„ Complete Gradient Chain**: Built-in attention bias gradient computation (dbias) supporting end-to-end differentiable training -This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences. +### Key Optimization Strategies + +1. **Unified Skip Logic**: Forward and backward passes use the same block-level skip decisions +2. **Memory Access Optimization**: K/V data loaded only when `OR(mask_block) == true` +3. **Gradient Path Completeness**: dbias gradient computation fully fused in backward kernels +4. **Shared Memory Reuse**: sMask โ†” sP, sBias โ†” sdS intelligent aliasing ## Documentation @@ -229,7 +263,7 @@ This creates a hybrid attention mechanism that achieves both memory and computat ```bash # Clone with submodules -git clone --recursive https://github.com/SmallDoges/flash-dmattn.git +git clone https://github.com/SmallDoges/flash-dmattn.git cd flash-dmattn # Build in development mode @@ -297,8 +331,8 @@ Tests backward pass implementation and gradient equivalence. **Compilation Errors** ```bash # Ensure CUDA_HOME is set correctly -echo $CUDA_HOME # Linux/Mac -echo $env:CUDA_HOME # Windows PowerShell +echo $CUDA_HOME # Linux/Mac +echo $env:CUDA_HOME # Windows PowerShell # Check CUDA toolkit version nvcc --version diff --git a/README_zh.md b/README_zh.md index 47edcb5..f8264b6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -17,12 +17,16 @@ Flash-DMA ๆ˜ฏไธ€ไธช้ซ˜ๆ€ง่ƒฝ็š„ๆณจๆ„ๅŠ›ๅฎž็Žฐ๏ผŒๅฐ† Flash Attention ็š„ๅ†…ๅญ˜ ## ไธป่ฆ็‰นๆ€ง -- **ๅŠจๆ€็จ€็–ๆณจๆ„ๅŠ›**: ไธบๆฏไธชๆŸฅ่ฏขๅŠจๆ€้€‰ๆ‹ฉๆœ€้‡่ฆ็š„้”ฎ๏ผŒๅฐ†่ฎก็ฎ—ๅคๆ‚ๅบฆไปŽ $O(N^2)$ ้™ไฝŽๅˆฐ $O(N \cdot w)$๏ผŒๅ…ถไธญ $w \ll N$๏ผŒๆ”ฏๆŒๅฏ่ฎญ็ปƒ็š„็จ€็–็ป“ๆž„ใ€‚ -- **ๅ†…ๅญ˜ๆ•ˆ็އ**: ไฟๆŒ Flash Attention ็š„ $O(N)$ ๅ†…ๅญ˜ๅคๆ‚ๅบฆ๏ผŒๆ— ้œ€ๅฎžไพ‹ๅŒ–ๅฎŒๆ•ด็š„ๆณจๆ„ๅŠ›็Ÿฉ้˜ตใ€‚ -- **CUDA ๆทฑๅบฆไผ˜ๅŒ–**๏ผšไฝฟ็”จ่‡ชๅฎšไน‰ CUDA Kernel, ๅซๅ…ฑไบซๅ†…ๅญ˜ๅˆซๅใ€ๆตๆฐด็บฟ้ข„ๅ–ใ€ๆŒ‰ๅ—่ทณ่ฟ‡, ๅฎž็Žฐ้ซ˜ๅžๅไธŽไฝŽ่ฎฟๅญ˜ๅผ€้”€ใ€‚ -- **่ถ…้•ฟไธŠไธ‹ๆ–‡ๆ”ฏๆŒ**๏ผš้€š่ฟ‡ๅŠจๆ€ๆŽฉ็ ็ช—ๅฃ่ฃๅ‰ช๏ผŒๅœจไฟๆŒ็ฒพๅบฆ็š„ๅ‰ๆไธ‹ๆ”ฏๆ’‘ 128K+ ไปค็‰Œ็บงๅˆซ็š„ไธŠไธ‹ๆ–‡ๅค„็†ใ€‚ -- **ๅฏๅญฆไน ๅ็ฝฎ**๏ผšๅ†…็ฝฎๅฏๅญฆไน  attention bias ๅŠๅ…ถๆขฏๅบฆๅๅ‘่ทฏๅพ„ dbias๏ผŒๆ— ้œ€้ขๅค–ๅค–้ƒจ็ฎ—ๅญใ€‚ -- **่žๅˆๅผ่ฎญ็ปƒๅ‹ๅฅฝ**๏ผšๆญฃๅ‘ไธŽๅๅ‘่ฟ‡็จ‹ๅ‡ๆ”ฏๆŒ block ็บงๅ…จ้›ถๆŽฉ็ ่ทณ่ฟ‡๏ผŒๅœจ็จ€็–ๅœบๆ™ฏ่ฟ›ไธ€ๆญฅ้™ไฝŽ่ฎก็ฎ—ๅผ€้”€ใ€‚ +### ๐ŸŽฏ ๆ ธๅฟƒๅ†…ๆ ธไผ˜ๅŠฟ +- **4D Mask & Bias ๆ”ฏๆŒ**: ๅŽŸ็”Ÿๆ”ฏๆŒ `(batch_size, num_kv_heads, query_len, key_len)` ๅฝข็Šถ็š„ attention_mask ๅ’Œ attention_bias ๅผ ้‡ +- **ๆ™บ่ƒฝ่ฎก็ฎ—่ทณ่ฟ‡**: ๅŸบไบŽ attention_mask ็š„ block-level ่‡ชๅŠจ่ทณ่ฟ‡ๆœบๅˆถ๏ผŒๅฎŒๅ…จ่ทณ่ฟ‡ๅ…จ้›ถ mask ๅŒบๅ—็š„่ฎก็ฎ—ๅ’Œๅ†…ๅญ˜่ฎฟ้—ฎ +- **ๅฎŒๆ•ดๆขฏๅบฆๆ”ฏๆŒ**: ๅ†…็ฝฎ attention_bias ็š„ๅฎŒๆ•ดๆขฏๅบฆ่ฎก็ฎ—่ทฏๅพ„๏ผŒๆ”ฏๆŒ็ซฏๅˆฐ็ซฏ่ฎญ็ปƒ + +### ๐Ÿš€ ๆ€ง่ƒฝไธŽๆ•ˆ็އ +- **ๅŠจๆ€็จ€็–ๆณจๆ„ๅŠ›**: ไธบๆฏไธชๆŸฅ่ฏขๅŠจๆ€้€‰ๆ‹ฉๆœ€้‡่ฆ็š„้”ฎ๏ผŒๅฐ†่ฎก็ฎ—ๅคๆ‚ๅบฆไปŽ $O(N^2)$ ้™ไฝŽๅˆฐ $O(N \cdot w)$๏ผŒๅ…ถไธญ $w \ll N$๏ผŒ ๆ”ฏๆŒๅฏ่ฎญ็ปƒ็š„็จ€็–็ป“ๆž„ +- **ๅ†…ๅญ˜ๆ•ˆ็އ**: ไฟๆŒ Flash Attention ็š„ $O(N)$ ๅ†…ๅญ˜ๅคๆ‚ๅบฆ๏ผŒๆ— ้œ€ๅฎžไพ‹ๅŒ–ๅฎŒๆ•ด็š„ๆณจๆ„ๅŠ›็Ÿฉ้˜ต +- **CUDA ๆทฑๅบฆไผ˜ๅŒ–**: ่‡ชๅฎšไน‰ CUDA ๅ†…ๆ ธ๏ผŒๅซๅ…ฑไบซๅ†…ๅญ˜ๅˆซๅใ€ๆตๆฐด็บฟ้ข„ๅ–ใ€ๆŒ‰ๅ—่ทณ่ฟ‡๏ผŒๅฎž็Žฐ้ซ˜ๅžๅไธŽไฝŽ่ฎฟๅญ˜ๅผ€้”€ +- **่ถ…้•ฟไธŠไธ‹ๆ–‡ๆ”ฏๆŒ**: ้€š่ฟ‡ๅŠจๆ€ๆŽฉ็ ็ช—ๅฃ่ฃๅ‰ช๏ผŒๅœจไฟๆŒ็ฒพๅบฆ็š„ๅ‰ๆไธ‹ๆ”ฏๆ’‘ 128K+ ไปค็‰Œ็บงๅˆซ็š„ไธŠไธ‹ๆ–‡ๅค„็† ## ๆ€ง่ƒฝ @@ -145,43 +149,46 @@ MAX_JOBS=4 pip install . --no-build-isolation ## ๅฟซ้€Ÿๅผ€ๅง‹ +### ๅŸบๆœฌ็”จๆณ• + ```python import torch from flash_dmattn import flash_dmattn_func_auto import math # ่ฎพ็ฝฎ -batch_size, seq_len, num_heads, head_dim = 2, 4096, 16, 128 +batch_size, seq_len, num_heads, num_kv_heads, head_dim = 1, 256, 2, 1, 64 +keep_window_size = 128 device = torch.device('cuda') dtype = torch.bfloat16 +min_dtype = torch.finfo(dtype).min # dtype ็š„ๆœ€ๅฐๅ€ผ # ่พ“ๅ…ฅๅผ ้‡ -query = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -key = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) -value = torch.randn(batch_size, seq_len, num_heads, head_dim, - device=device, dtype=dtype) - -# ไธบ็จ€็–ๆณจๆ„ๅŠ›ๅˆ›ๅปบๆŽฉ็ ๅ’Œๅ็ฝฎ -attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) -attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) - -# ๅบ”็”จๅŠจๆ€ๆŽฉ็ ๏ผˆไธบ้•ฟๅบๅˆ—ไฟ็•™ top-k๏ผ‰ -keep_window_size = 2048 +query = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) +key = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) +value = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, device=device, dtype=dtype) + +# ไธบ็จ€็–ๆณจๆ„ๅŠ›ๅˆ›ๅปบ mask ๅ’Œ bias +attention_mask = torch.ones(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) +attention_bias = torch.randn(batch_size, num_kv_heads, seq_len, seq_len, device=device, dtype=dtype) + +# ๅŸบไบŽ bias ็”Ÿๆˆ็จ€็– mask if seq_len > keep_window_size: # ไธบๆฏไธชๆŸฅ่ฏข้€‰ๆ‹ฉ top-k ๆœ€้‡่ฆ็š„้”ฎ - topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, - largest=True, sorted=False).indices - attention_mask.zero_() - attention_mask.scatter(-1, topk_indices, 1.0) - -# ้€‰ๆ‹ฉๅŽ็ซฏ + topk_values, topk_indices = torch.topk( + attention_bias, keep_window_size, dim=-1, + largest=True, sorted=False + ) + # ็”Ÿๆˆๆœ‰ๆ•ˆ็š„ top-k mask + valid_topk = (topk_values != min_dtype).to(dtype) + attention_mask = torch.zeros_like(attention_bias, dtype=dtype, device=attention_bias.device) + attention_mask = attention_mask.scatter(-1, topk_indices, valid_topk) + attention_bias = attention_bias.masked_fill(attention_mask == 0.0, min_dtype) + +# ้€‰ๆ‹ฉ FDMA ๅ†…ๆ ธ flash_dmattn_func = flash_dmattn_func_auto(backend="cuda") -# ่ฟ่กŒ Flash ๅŠจๆ€ๆŽฉ็ ๆณจๆ„ๅŠ› +# ่ฟ่กŒ FDMA output = flash_dmattn_func( query=query, key=key, @@ -192,27 +199,54 @@ output = flash_dmattn_func( scale=1.0/math.sqrt(head_dim), ) -print(f"่พ“ๅ‡บๅฝข็Šถ: {output.shape}") # [2, 4096, 16, 128] +print(f"่พ“ๅ‡บๅฝข็Šถ: {output.shape}") # [1, 256, 2, 64] ``` +### ๆขฏๅบฆ่ฎก็ฎ—็คบไพ‹ -## ๅทฅไฝœๅŽŸ็† +```python +# ๅผ€ๅฏๆขฏๅบฆ่ฎก็ฎ— +query.requires_grad_(True) +key.requires_grad_(True) +value.requires_grad_(True) +attention_bias.requires_grad_(True) -Flash-DMA ็ป“ๅˆไบ†ไธค็งไบ’่กฅ็š„ๆŠ€ๆœฏ๏ผš +# ๅ‰ๅ‘ไผ ๆ’ญ +output = flash_dmattn_func( + query=query, key=key, value=value, + attn_mask=attention_mask, + attn_bias=attention_bias, + is_causal=True, + scale=1.0/math.sqrt(head_dim) +) -- **ๅŠจๆ€ๆŽฉ็ ๆณจๆ„ๅŠ›**: ่ฎก็ฎ—้”ฎ็š„็›ธๅ…ณๆ€งๅˆ†ๆ•ฐ๏ผŒๅนถไป…้€‰ๆ‹ฉๆœ€้‡่ฆ็š„้”ฎ่ฟ›่กŒๆณจๆ„ๅŠ›่ฎก็ฎ— -- **Flash Attention**: ๅˆ†ๅ—ๅค„็†ๆณจๆ„ๅŠ›ไปฅๅ‡ๅฐ‘ๅ†…ๅญ˜ไฝฟ็”จๅ’Œ HBM ่ฎฟ้—ฎ +# ๅๅ‘ไผ ๆ’ญ +loss = output.sum() +loss.backward() -### ้›†ๆˆๆ–นๆณ• +print(f"Query ๆขฏๅบฆๅฝข็Šถ: {query.grad.shape}") +print(f"Key ๆขฏๅบฆๅฝข็Šถ: {key.grad.shape}") +print(f"Value ๆขฏๅบฆๅฝข็Šถ: {value.grad.shape}") +print(f"Bias ๆขฏๅบฆๅฝข็Šถ: {attention_bias.grad.shape}") +``` -้›†ๆˆๅ‘็”Ÿๅœจ CUDA ๅ†…ๆ ธๅฑ‚้ข๏ผŒๅ…ทๆœ‰ๅ‡ ไธชๅ…ณ้”ฎ็ป„ไปถ๏ผš -- **ZOH ็Šถๆ€**: ้ข„่ฎก็ฎ—็š„้”ฎ้€‰ๆ‹ฉ้‡่ฆๆ€งๅˆ†ๆ•ฐ -- **ๆดป่ทƒๆŽฉ็ **: ๆŒ‡็คบๆฏไธชๆŸฅ่ฏขๅบ”่€ƒ่™‘ๅ“ชไบ›้”ฎ็š„ไบŒ่ฟ›ๅˆถๆŽฉ็  -- **็จ€็–่ทณ่ฟ‡**: ้ซ˜ๆ•ˆ็จ€็–ๆณจๆ„ๅŠ›่ฎก็ฎ—็š„่‡ชๅฎšไน‰ CUDA ๅ†…ๆ ธ -- **ๅˆ†ๅ—ๅค„็†**: ไฟๆŒ Flash Attention ็š„ๅˆ†ๅ—ๆ–นๆณ•ไปฅๆ้ซ˜ๅ†…ๅญ˜ๆ•ˆ็އ +## ๅทฅไฝœๅŽŸ็† + +Flash-DMA ้€š่ฟ‡ๅฐ† Flash Attention ็š„้ซ˜ๆ•ˆๅ†…ๅญ˜่ฎฟ้—ฎๆจกๅผไธŽๅŠจๆ€ๆŽฉ็ ๆณจๆ„ๅŠ›็š„็จ€็–่ฎก็ฎ—่ƒฝๅŠ›็›ธ็ป“ๅˆ๏ผŒๅฎž็Žฐไบ†้ซ˜ๆ•ˆ็š„ๆณจๆ„ๅŠ›ๆœบๅˆถใ€‚ -่ฟ™ๅˆ›ๅปบไบ†ไธ€็งๆททๅˆๆณจๆ„ๅŠ›ๆœบๅˆถ๏ผŒไธบ้•ฟๅบๅˆ—ๅฎž็Žฐไบ†ๅ†…ๅญ˜ๅ’Œ่ฎก็ฎ—ๆ•ˆ็އใ€‚ +### ๆ ธๅฟƒๆŠ€ๆœฏ่žๅˆ + +- **๐ŸŽฏ 4D Mask & Bias ๅŽŸ็”Ÿๆ”ฏๆŒ**: ๅ†…ๆ ธ็›ดๆŽฅๅค„็† `(batch_size, num_kv_heads, query_len, key_len)` ๅฝข็Šถ็š„ๅผ ้‡ +- **โšก Block-level ๆ™บ่ƒฝ่ทณ่ฟ‡**: ๅŸบไบŽ mask ็š„็ปŸไธ€ OR-reduction ่ทณ่ฟ‡้€ป่พ‘๏ผŒๅฎŒๅ…จ้ฟๅ…ๅ…จ้›ถๅŒบๅ—็š„่ฎก็ฎ—ๅ’Œๅ†…ๅญ˜่ฎฟ้—ฎ +- **๐Ÿ”„ ๅฎŒๆ•ดๆขฏๅบฆ้“พ่ทฏ**: ๅ†…็ฝฎ attention bias ๆขฏๅบฆ่ฎก็ฎ—๏ผŒๆ”ฏๆŒ็ซฏๅˆฐ็ซฏๅฏๅพฎๅˆ†่ฎญ็ปƒ + +### ๅ…ณ้”ฎไผ˜ๅŒ–็ญ–็•ฅ + +1. **็ปŸไธ€่ทณ่ฟ‡้€ป่พ‘**: ๅ‰ๅ‘ๅ’Œๅๅ‘่ฟ‡็จ‹ไฝฟ็”จ็›ธๅŒ็š„ block-level ่ทณ่ฟ‡ๅ†ณ็ญ– +2. **ๅ†…ๅญ˜่ฎฟ้—ฎไผ˜ๅŒ–**: ๅชๆœ‰ๅฝ“ `OR(mask_block) == true` ๆ—ถๆ‰ๅŠ ่ฝฝ K/V ๆ•ฐๆฎ +3. **ๆขฏๅบฆ่ทฏๅพ„ๅฎŒๆ•ดๆ€ง**: dbias ๆขฏๅบฆ่ฎก็ฎ—ๅฎŒๅ…จ่žๅˆๅœจๅๅ‘ๅ†…ๆ ธไธญ +4. **ๅ…ฑไบซๅ†…ๅญ˜ๅค็”จ**: sMask โ†” sP, sBias โ†” sdS ๆ™บ่ƒฝๅˆซๅๅŒ– ## ๆ–‡ๆกฃ @@ -229,7 +263,7 @@ Flash-DMA ็ป“ๅˆไบ†ไธค็งไบ’่กฅ็š„ๆŠ€ๆœฏ๏ผš ```bash # ๅ…‹้š†ๅŒ…ๅซๅญๆจกๅ— -git clone --recursive https://github.com/SmallDoges/flash-dmattn.git +git clone https://github.com/SmallDoges/flash-dmattn.git cd flash-dmattn # ๅœจๅผ€ๅ‘ๆจกๅผไธ‹ๆž„ๅปบ @@ -296,8 +330,8 @@ python benchmarks/grad_equivalence.py **็ผ–่ฏ‘้”™่ฏฏ** ```bash # ็กฎไฟ CUDA_HOME ่ฎพ็ฝฎๆญฃ็กฎ -echo $CUDA_HOME # Linux/Mac -echo $env:CUDA_HOME # Windows PowerShell +echo $CUDA_HOME # Linux/Mac +echo $env:CUDA_HOME # Windows PowerShell # ๆฃ€ๆŸฅ CUDA ๅทฅๅ…ทๅŒ…็‰ˆๆœฌ nvcc --version diff --git a/docs/api_reference.md b/docs/api_reference.md index b1a1307..b99eea4 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -1,5 +1,6 @@ # Flash Dynamic Mask Attention API Reference + ## Overview Flash Dynamic Mask Attention is a high-performance attention implementation that combines the memory efficiency of Flash Attention with the sparse compute benefits of Dynamic Mask Attention. It supports CUDA, Triton, and Flex Attention backends and dynamic masking for very long sequences. @@ -7,17 +8,19 @@ Flash Dynamic Mask Attention is a high-performance attention implementation that Interfaces provided: - High-level: simple entry point with automatic backend selection - Backend-specific: direct access to CUDA, Triton, and Flex implementations -- Packed variants: optimized paths for QKV-packed and KV-packed inputs -- Variable length: support for batches with different sequence lengths +- Transformers Integration: seamless integration with HuggingFace Transformers models + ## Table of Contents 1. [Installation](#installation) 2. [High-Level Interface](#high-level-interface) 3. [Core Functions](#core-functions) -4. [Packed Variants](#packed-variants) -5. [Variable Length Functions](#variable-length-functions) -6. [Backend Selection](#backend-selection) +4. [Transformers Integration](#transformers-integration) +5. [Backend Selection](#backend-selection) +6. [Common Issues and Solutions](#common-issues-and-solutions) +7. [Summary](#summary) + ## Installation @@ -25,19 +28,19 @@ Interfaces provided: - Python: 3.8+ - PyTorch: 2.0.0+ with CUDA -- CUDA: 11.8+ -- NVIDIA GPU: Compute Capability 8.0+ -- Dependencies: `packaging`, `torch` +- CUDA: 11.8+ for CUDA backend +- NVIDIA GPU: Compute Capability 8.0+ for CUDA backend +- Optional: `triton` for Triton backend, `transformers` for Flex backend and integrations ### Install from Source ```bash git clone https://github.com/SmallDoges/flash-dmattn.git cd flash-dmattn -git submodule update --init --recursive -pip install -e . +MAX_JOBS=4 pip install . --no-build-isolation ``` + ## High-Level Interface ### Automatic Backend Selection @@ -45,21 +48,22 @@ pip install -e . Note: `flash_dmattn_func_auto` returns a callable attention function, not the attention output. ```python -from flash_dmattn import flash_dmattn_func_auto, get_available_backends +from flash_dmattn import get_available_backends, flash_dmattn_func_auto # Check available backends backends = get_available_backends() print(f"Available backends: {backends}") # Auto-select (priority: cuda > triton > flex) -attn = flash_dmattn_func_auto() -output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) +dmattn_func = flash_dmattn_func_auto() +output = dmattn_func(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) # Force a specific backend -attn = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" -output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) +dmattn_func = flash_dmattn_func_auto(backend="cuda") # or "triton", "flex" +output = dmattn_func(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=None) ``` + ## Core Functions ### flash_dmattn_func (CUDA backend) @@ -77,6 +81,7 @@ def flash_dmattn_func( is_causal: Optional[bool] = None, # causal mask softcap: Optional[float] = None, # CUDA-only deterministic: Optional[bool] = None, # CUDA-only + return_attn_probs: Optional[bool] = None, # CUDA-only, for testing ) -> torch.Tensor ``` @@ -89,75 +94,193 @@ def flash_dmattn_func( - attn_bias: (B, H, Q, K). Added to scores before softmax. None to disable - scale: score scaling; default 1/sqrt(D) - is_causal: apply lower-triangular mask -- softcap, deterministic: only effective on the CUDA backend; ignored on others +- softcap, deterministic, return_attn_probs: only effective on the CUDA backend; ignored on others #### Returns - output: (B, Q, H, D) -## Packed Variants (CUDA backend) - -### flash_dmattn_qkvpacked_func +### triton_dmattn_func (Triton backend) -Optimized function for QKV-packed input. +Triton-based implementation that provides good performance without requiring custom CUDA kernels. ```python -def flash_dmattn_qkvpacked_func( - qkv: torch.Tensor, # (batch, seqlen, 3, num_heads, head_dim) - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, # CUDA-only - deterministic: Optional[bool] = None, # CUDA-only +def triton_dmattn_func( + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + is_causal: bool = False, # causal mask + scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) ) -> torch.Tensor ``` -### flash_dmattn_kvpacked_func +### flex_dmattn_func (Flex Attention backend) -Optimized function for KV-packed input. +Flex Attention-based implementation using PyTorch's native flex attention with dynamic masking support. ```python -def flash_dmattn_kvpacked_func( - q: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) - kv: torch.Tensor, # (batch, seqlen_k, 2, num_kv_heads, head_dim) - attn_mask: Optional[torch.Tensor] = None, - attn_bias: Optional[torch.Tensor] = None, - scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, # CUDA-only - deterministic: Optional[bool] = None, # CUDA-only +def flex_dmattn_func( + query: torch.Tensor, # (batch, seqlen_q, num_heads, head_dim) + key: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + value: torch.Tensor, # (batch, seqlen_k, num_heads, head_dim) + attn_mask: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + attn_bias: Optional[torch.Tensor] = None, # (batch, num_heads, seqlen_q, seqlen_k) + is_causal: Optional[bool] = None, # causal mask + scale: Optional[float] = None, # score scaling, defaults to 1/sqrt(head_dim) ) -> torch.Tensor ``` -## Variable Length Functions (CUDA backend) -### flash_dmattn_varlen_func +## Transformers Integration + +Integration function for HuggingFace Transformers models that provides seamless flash dynamic mask attention support. + +### flash_dynamic_mask_attention_forward -Variable length attention for batches with mixed sequence lengths. ```python -def flash_dmattn_varlen_func( - query: torch.Tensor, # (total_q, H, D) or (B, Q, H, D) - key: torch.Tensor, # same layout as query - value: torch.Tensor, # same layout as query - attn_mask: Optional[torch.Tensor] = None, # (B, H, Q, K) - attn_bias: Optional[torch.Tensor] = None, # (B, H, Q, K) - cu_seqlens_q: torch.Tensor = None, # (B+1,) - cu_seqlens_k: torch.Tensor = None, # (B+1,) - max_seqlen_q: int = None, - max_seqlen_k: int = None, - scale: Optional[float] = None, - is_causal: Optional[bool] = None, - softcap: Optional[float] = None, # CUDA-only - deterministic: Optional[bool] = None, # CUDA-only - block_table: Optional[torch.Tensor] = None, # experimental: paged attention -) -> torch.Tensor +from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + +def flash_dynamic_mask_attention_forward( + module: torch.nn.Module, # The attention module + query: torch.Tensor, # (batch_size, num_heads, query_len, head_dim) + key: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) + value: torch.Tensor, # (batch_size, num_kv_heads, key_len, head_dim) + attention_mask: Optional[torch.Tensor], # (batch_size, num_kv_heads, query_len, key_len) + attention_bias: Optional[torch.Tensor], # (batch_size, num_kv_heads, query_len, key_len) + scaling: Optional[float] = None, # score scaling + softcap: Optional[float] = None, # softcap value + **kwargs, +) -> tuple[torch.Tensor, None] +``` + +#### Parameters + +- module: The attention module instance +- query: Query tensor with head-first layout (B, H, Q, D) +- key: Key tensor with head-first layout (B, H_kv, K, D) +- value: Value tensor with head-first layout (B, H_kv, K, D) +- attention_mask: Boolean attention mask +- attention_bias: Attention bias to add to scores +- scaling: Score scaling factor +- softcap: Softcap value for attention scores +- **kwargs: Additional arguments including: + - is_causal: Whether to apply causal mask + - keep_window_size: Size of window to keep + - layer_idx: Layer index for logging + - implementation: Implementation to use ("flash_dmattn" or None) + +#### Returns + +- tuple[torch.Tensor, None]: Output tensor (B, Q, H, D) and None for compatibility + +### Usage with Transformers + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Callable, tuple +from transformers.cache_utils import Cache +from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + +class DynamicMaskAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.keep_window_size = config.keep_window_size + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + # Dynamic mask for the QK^T attention weights matrix + self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.dt_proj = nn.Linear( + config.num_key_value_heads * self.head_dim, config.num_key_value_heads, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Sampling dt_states from value_states to generate attention bias + dt_states = self.dt_proj( + value_states.transpose(1, 2).reshape(value_states.shape[0], value_states.shape[-2], -1) + ) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + attn_bias = dt_states[:, :, None, :].expand( + -1, -1, hidden_states.shape[1], -1 + ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] + + # Choose attention implementation: fallback to eager if flash_dmattn is not available + attention_interface: Callable = eager_attention_forward + if flash_dynamic_mask_attention_forward is not None: + attention_interface = flash_dynamic_mask_attention_forward + + # Expand attention mask to match the expected shape + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, attn_bias.shape[1], -1, -1) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attn_bias, + scale=self.scaling, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights ``` -- cu_seqlens_q/k: cumulative sequence lengths for query/key -- max_seqlen_q/k: max sequence lengths per batch -- block_table: experimental support for paged attention +This example shows: +- **Dynamic attention bias generation**: Using learnable parameters to create attention bias +- **Flexible backend selection**: Graceful fallback to standard attention when flash_dmattn is unavailable +- **Proper tensor reshaping**: Converting between different tensor layouts as needed +- **Integration with caching**: Support for key-value caching in generation scenarios + ## Backend Selection @@ -170,11 +293,37 @@ print(get_available_backends()) # e.g., ["cuda", "triton", "flex"] print(CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE) ``` +### Available Functions + +The library exports the following functions: + +```python +from flash_dmattn import ( + # High-level interface + get_available_backends, # Get list of available backends + flash_dmattn_func_auto, # Automatic backend selection + + + # Backend-specific functions + flash_dmattn_func, # CUDA backend (if available) + triton_dmattn_func, # Triton backend (if available) + flex_dmattn_func, # Flex Attention backend (if available) + + # Backend availability flags + CUDA_AVAILABLE, + TRITON_AVAILABLE, + FLEX_AVAILABLE, +) + +# Transformers integration +from flash_dmattn.integrations.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward +``` + ### Backend-Specific Functions ```python # Direct access to specific backends -from flash_dmattn import flash_dmattn_func # CUDA backend (requires compiled extension) +from flash_dmattn import flash_dmattn_func # CUDA backend from flash_dmattn import triton_dmattn_func # Triton backend from flash_dmattn import flex_dmattn_func # Flex Attention backend @@ -182,161 +331,37 @@ from flash_dmattn import flex_dmattn_func # Flex Attention backend # query/key/value: (B, L{q/k}, H, D) # attn_mask/attn_bias: (B, H, Lq, Lk) # is_causal: bool, scale: Optional[float] +output = flash_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) output = triton_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) output = flex_dmattn_func(q, k, v, attn_mask=mask, attn_bias=bias, is_causal=True, scale=None) ``` Notes: -- Triton returns only the attention output tensor. -- Flex currently uses causal masking and score_mod with bias; provided attn_mask is not applied in the kernel at the moment (subject to change in future versions). - -### Data Types and Memory Layout - -- dtypes: `torch.float16`, `torch.bfloat16` (bf16 recommended for stability) -- device: CUDA tensors only -- memory: last dimension must be contiguous (`stride(-1) == 1`); call `.contiguous()` if needed - -## Basic Usage Examples - -Prefer the high-level automatic interface for cross-backend portability. - -### Standard Attention - -```python -import torch -from flash_dmattn import flash_dmattn_func_auto - -B, L, H, D = 2, 4096, 12, 128 -device = torch.device('cuda') -dtype = torch.bfloat16 +- All backends support the same unified interface for seamless switching +- Flex backend currently uses causal masking and score_mod with bias; provided attn_mask is not applied in the kernel at the moment, subject to change in future versions +- CUDA backend supports additional parameters like softcap, deterministic, and return_attn_probs -q = torch.randn(B, L, H, D, device=device, dtype=dtype) -k = torch.randn(B, L, H, D, device=device, dtype=dtype) -v = torch.randn(B, L, H, D, device=device, dtype=dtype) +### When to Use Each Backend -attn = flash_dmattn_func_auto() -output = attn(q, k, v, is_causal=True) -print(output.shape) # [2, 4096, 12, 128] -``` - -### Dynamic Mask Attention - -```python -import torch, math -from flash_dmattn import flash_dmattn_func_auto - -B, H, L = 2, 12, 4096 -keep_window_size = 1024 -device = torch.device('cuda') -dtype = torch.bfloat16 - -q = torch.randn(B, L, H, 128, device=device, dtype=dtype) -k = torch.randn(B, L, H, 128, device=device, dtype=dtype) -v = torch.randn(B, L, H, 128, device=device, dtype=dtype) - -attention_bias = torch.randn(B, H, L, L, device=device, dtype=dtype) -attention_mask = torch.zeros_like(attention_bias) - -if L > keep_window_size: - topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, largest=True).indices - attention_mask.scatter_(-1, topk_indices, 1.0) -else: - attention_mask.fill_(1.0) - -attn = flash_dmattn_func_auto() -output = attn(q, k, v, attn_mask=attention_mask, attn_bias=attention_bias, is_causal=True, scale=1.0/math.sqrt(128)) -``` - -### Grouped-Query Attention (GQA) - -```python -import torch -from flash_dmattn import flash_dmattn_func_auto - -B, L, H, H_kv, D = 2, 2048, 32, 8, 128 -device = torch.device('cuda') -dtype = torch.bfloat16 - -q = torch.randn(B, L, H, D, device=device, dtype=dtype) -k = torch.randn(B, L, H_kv, D, device=device, dtype=dtype) -v = torch.randn(B, L, H_kv, D, device=device, dtype=dtype) - -attn_mask = torch.ones(B, H, L, L, device=device, dtype=dtype) - -attn = flash_dmattn_func_auto() -output = attn(q, k, v, attn_mask=attn_mask, is_causal=True) -``` - -### Variable Length Sequences (CUDA backend) +**CUDA Backend:** +- โœ… Training workloads requiring full gradient support +- โœ… Production inference requiring maximum performance +- โœ… Applications needing deterministic behavior +- โŒ Avoid if you cannot build custom CUDA extensions -```python -import torch -from flash_dmattn import flash_dmattn_varlen_func - -B = 3 -seq_lens = [512, 1024, 768] -T = sum(seq_lens) -H, D = 16, 64 -device = torch.device('cuda') -dtype = torch.bfloat16 - -q = torch.randn(T, H, D, device=device, dtype=dtype) -k = torch.randn(T, H, D, device=device, dtype=dtype) -v = torch.randn(T, H, D, device=device, dtype=dtype) - -cu = torch.tensor([0] + seq_lens, device=device, dtype=torch.int32).cumsum(0) - -output = flash_dmattn_varlen_func( - q=q, k=k, v=v, - cu_seqlens_q=cu, cu_seqlens_k=cu, - max_seqlen_q=max(seq_lens), max_seqlen_k=max(seq_lens), - is_causal=True -) -``` +**Triton Backend:** +- โœ… Training workloads when CUDA extension is not available +- โœ… Development and prototyping +- โœ… Cross-platform compatibility needs +- โœ… Good balance of performance and ease of installation -## Performance Optimization - -### Memory Efficiency - -```python -# Gradient checkpointing for long sequences -import torch.utils.checkpoint as checkpoint -from flash_dmattn import flash_dmattn_func_auto - -attn = flash_dmattn_func_auto() - -def attention_checkpoint(q, k, v, *args, **kwargs): - return checkpoint.checkpoint(lambda *a, **kw: attn(*a, **kw), q, k, v, *args, **kwargs) - -# Process very long sequences in chunks -def chunked_attention(q, k, v, chunk_size=8192, **kwargs): - L = q.shape[1] - outs = [] - for i in range(0, L, chunk_size): - outs.append(attn(q[:, i:i+chunk_size], k, v, **kwargs)) - return torch.cat(outs, dim=1) -``` +**Flex Backend:** +- โœ… Inference-only applications +- โœ… Research with latest PyTorch features +- โœ… Quick experimentation without custom builds +- โŒ Avoid for training due to limited backward support +- โŒ Avoid when strict attention mask compliance is required -### Backend Selection for Performance - -```python -import torch -from flash_dmattn import flash_dmattn_func_auto - -backends = ["cuda", "triton", "flex"] -for backend in backends: - try: - attn = flash_dmattn_func_auto(backend=backend) - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - _ = attn(q, k, v, is_causal=True) - end.record() - torch.cuda.synchronize() - print(f"{backend}: {start.elapsed_time(end):.2f} ms") - except RuntimeError as e: - print(f"{backend}: not available - {e}") -``` ## Common Issues and Solutions @@ -354,9 +379,15 @@ except ImportError as e: ### Performance Issues 1. Slow execution: ensure all tensors are on the same GPU and last dim is contiguous; use head dims multiple of 8; prefer CUDA backend when available -2. High memory: use gradient checkpointing; chunk long sequences; use varlen for mixed-length batches +2. High memory: use gradient checkpointing; chunk long sequences; consider Triton or Flex backends for very long sequences 3. Numerical stability: prefer bfloat16; check mask/bias for NaN/Inf; monitor gradient norms +### Transformers Integration Issues + +1. Model compatibility: ensure your model supports custom attention implementations +2. Shape mismatches: check that tensor layouts match expected formats +3. Gradient flow: verify that gradients flow properly through the custom attention function + ### Debugging ```python @@ -383,7 +414,42 @@ print_memory_stats() attn = flash_dmattn_func_auto() output = attn(q, k, v) print_memory_stats() - -torch.cuda.empty_cache() ``` +## Summary + +Flash Dynamic Mask Attention provides a unified interface for high-performance attention computation with the following key features: + +- **Multiple Backends**: CUDA for best performance, Triton for good compatibility, and Flex Attention for native PyTorch support +- **Automatic Backend Selection**: Seamless fallback between available backends +- **Dynamic Masking**: Efficient sparse attention with arbitrary attention masks +- **GQA Support**: Grouped-query attention for efficient inference +- **Transformers Integration**: Direct integration with HuggingFace models +- **Memory Efficiency**: Optimized memory usage for very long sequences + +Choose the backend that best fits your needs: +- **CUDA**: For maximum performance and full feature support, especially for training +- **Triton**: For good performance without custom CUDA compilation, supports both training and inference +- **Flex**: For inference scenarios and compatibility with latest PyTorch features, but limited backward support for training yet + +### Backend Comparison + +| Feature | CUDA | Triton | Flex | +|---------|------|--------|------| +| Performance | Highest | Good | Good | +| Memory Efficiency | Best | Good | Good | +| Build Requirements | Custom CUDA extension | triton package | transformers package | +| GQA Support | โœ… | โœ… | โœ… | +| Attention Mask | โœ… | โœ… | โš ๏ธ | +| Attention Bias | โœ… | โœ… | โœ… | +| Causal Mask | โœ… | โœ… | โœ… | +| Softcap | โœ… | โŒ | โŒ | +| Deterministic | โœ… | โŒ | โŒ | +| Return Attention Probs | โœ… | โŒ | โŒ | +| Backward Support | โœ… | โœ… | โš ๏ธ | + +Notes: +- โœ… = Fully supported +- โš ๏ธ = Limited support or workarounds needed +- โŒ = Not supported + diff --git a/docs/integration.md b/docs/integration.md index 80351ab..feb5ffa 100644 --- a/docs/integration.md +++ b/docs/integration.md @@ -4,7 +4,7 @@ This document describes the integration of Dynamic Mask Attention into the Flash Attention framework. The integration enables efficient sparse attention computation by combining Flash Attention's memory-efficient approach with dynamic masking capabilities for handling extremely long sequences. -The integration implements a two-stage approach: Python frontend pre-computes Zero-Order Hold states and Active Mask tensors, while the CUDA backend performs sparse attention computation using these pre-computed masks. +The integration implements a unified sparse computation approach with block-level skip logic: Python frontend pre-computes Attention Mask and Attention Bias tensors, while the CUDA backend performs block-level skip decisions and sparse attention computation for both forward and backward passes. ## Table of Contents @@ -20,32 +20,22 @@ The integration implements a two-stage approach: Python frontend pre-computes Ze ### High-Level Design -The Dynamic Mask Attention integration follows a two-phase approach: +The Dynamic Mask Attention integration implements a unified sparse computation approach with block-level skip logic for both forward and backward passes: -1. **Dynamic Mask Computation**: Python frontend pre-computes ZOH states and Active Mask tensors -2. **Sparse Attention Execution**: CUDA backend performs sparse attention computation using the pre-computed masks +1. **Dynamic Mask Computation**: Python frontend pre-computes Attention Mask and Attention Bias tensors +2. **Unified Sparse Execution**: CUDA backend performs block-level skip decisions for both forward and backward passes +3. **Memory Optimization**: Smart shared memory aliasing and barrier synchronization -``` -Python Frontend CUDA Backend -โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” -โ”‚ dt_states = exp(A * softplusโ”‚ โ”‚ Global Memory Loading โ”‚ -โ”‚ (V @ dt_proj^T)) โ”‚โ”€โ”€โ”€โ”€โ”‚ โ”œโ”€ ZOH States โ”‚ -โ”‚ โ”‚ โ”‚ โ”œโ”€ Active Mask โ”‚ -โ”‚ prepare_dynamic_mask() โ”‚ โ”‚ โ””โ”€ Q, K, V Tensors โ”‚ -โ”‚ โ”œโ”€ ZOH States Generation โ”‚ โ”‚ โ”‚ -โ”‚ โ”œโ”€ Active Mask via TopK โ”‚ โ”‚ Sparse Attention Computation โ”‚ -โ”‚ โ””โ”€ Dynamic Bias Calculation โ”‚ โ”‚ โ”œโ”€ Sparse Q*K^T GEMM โ”‚ -โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚ โ”œโ”€ Masked Softmax with ZOH โ”‚ - โ”‚ โ””โ”€ Sparse Score*V GEMM โ”‚ - โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ -``` ### Key Components -- **ZOH States**: Dynamic attention bias values `(batch, num_heads, query_len, key_len)` derived from value states and learned projections -- **Active Mask**: Binary mask `(batch, num_heads, query_len, key_len)` indicating which positions should be computed (1.0) or skipped (0.0) -- **Sparse GEMM**: Optimized matrix multiplication that only computes non-masked regions -- **Dynamic Masking**: Integration of ZOH bias and active mask into attention score computation +- **Attention Mask**: Binary mask `(batch, num_kv_heads, query_len, key_len)` indicating which positions should be computed (1.0) or skipped (0.0) +- **Attention Bias**: Dynamic attention bias values `(batch, num_kv_heads, query_len, key_len)` applied to attention scores before softmax +- **Block-level Skip Logic**: Unified OR-reduction over (BlockM ร— BlockN) tiles to determine if computation should be performed +- **LSE Caching**: Log-sum-exp values cached during forward pass for numerically stable backward recomputation +- **Shared Memory Aliasing**: Smart memory reuse with explicit barrier synchronization +- **Complete Gradient Chain**: Full gradient computation pipeline with sparse skip capability +- **Memory Optimization**: Reduced shared memory footprint enabling larger tile sizes and higher occupancy ## Core Modifications @@ -55,197 +45,402 @@ Python Frontend CUDA Backend **Changes Made**: ```cpp -struct ZOH_params { - void *__restrict__ zoh_ptr; // ZOH states pointer - void *__restrict__ active_mask_ptr; // Active mask pointer - index_t zoh_batch_stride; // Batch stride for ZOH states - index_t active_mask_batch_stride; // Batch stride for active mask - index_t zoh_head_stride; // Head stride for ZOH states - index_t active_mask_head_stride; // Head stride for active mask - index_t zoh_row_stride; // Row stride for ZOH states - index_t active_mask_row_stride; // Row stride for active mask - int keep_window_size; // Sparsity control parameter +struct QKV_params { + // The QKV matrices. + void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim] + void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim] + void *__restrict__ v_ptr; // Value tensor [batch_size, num_kv_heads, key_len, head_dim] + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride, k_batch_stride, v_batch_stride; + index_t q_row_stride, k_row_stride, v_row_stride; + index_t q_head_stride, k_head_stride, v_head_stride; + + // The number of heads. + int h, h_k; + int h_h_k_ratio; // precompute h / h_k +}; + +struct Mask_params { + void * __restrict__ mask_ptr; // Attention mask tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention mask tensors. + index_t mask_batch_stride; // Stride between batches of attention mask + index_t mask_head_stride; // Stride between heads of attention mask + index_t mask_row_stride; // Stride between rows of attention mask +}; + +struct Bias_params { + void *__restrict__ bias_ptr; // Attention bias tensor [batch_size, num_kv_heads, query_len, key_len] + + // The stride of the attention bias tensor. + index_t bias_batch_stride; // Stride between batches of attention bias + index_t bias_head_stride; // Stride between heads of attention bias + index_t bias_row_stride; // Stride between rows of attention bias }; -struct Flash_fwd_params : public QKV_params, public ZOH_params { - // Inherits both QKV and ZOH parameters through multiple inheritance - // Enables unified parameter passing to CUDA kernels +struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the K_new and V_new matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + int num_splits; // For split-KV version + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; ``` **Rationale**: -- **Multiple Inheritance Design**: Cleanly separates QKV parameters from ZOH parameters while maintaining unified access +- **Multiple Inheritance Design**: Cleanly separates QKV parameters from Mask/Bias parameters while maintaining unified access - **Comprehensive Stride Information**: Provides all necessary stride information for efficient tensor indexing in CUDA kernels - **Memory Layout Optimization**: Enables optimal memory access patterns for both regular and sparse tensors ### 2. Kernel Traits and Memory Layout (`kernel_traits.h`) -**Purpose**: Define shared memory layouts and copy operations optimized for dynamic masking tensors. +**Purpose**: Define kernel characteristics and memory layouts optimized for dynamic masking operations, supporting both SM75 and SM80+ architectures. **Changes Made**: ```cpp template struct Flash_kernel_traits { - // ...existing Flash Attention traits... - - // ZOH States shared memory layout - matches attention score layout - using SmemLayoutZOH = decltype(make_layout( - make_shape(Int{}, Int{}), - make_stride(Int{}, _1{}) - )); - - // Active Mask shared memory layout - row-major for efficient indexing - using SmemLayoutActiveMask = decltype(make_layout( - make_shape(Int{}, Int{}), - make_stride(Int{}, _1{}) - )); + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr int kHeadDim = kHeadDim_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kNWarps = kNWarps_; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static constexpr bool Has_cp_async = true; + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom + >; +#else + static constexpr bool Has_cp_async = false; + using MMA_Atom_Arch = MMA_Atom; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif - // Optimized copy atoms for ZOH and Active Mask data movement - using SmemCopyAtomZOH = Copy_Atom; - using SmemCopyAtomActiveMask = Copy_Atom; - - // Shared memory size calculations including masking tensors - static constexpr int kSmemSizeZOH = kBlockM * kBlockN * sizeof(elem_type); - static constexpr int kSmemSizeActiveMask = kBlockM * kBlockN * sizeof(elem_type); + // Specialized traits for mask and bias operations + using SmemCopyAtomMask = SmemCopyAtom; + using SmemCopyAtomBias = SmemCopyAtom; }; ``` **Rationale**: -- **Layout Consistency**: ZOH states use the same layout as attention scores for efficient fusion -- **Memory Access Optimization**: Copy atoms leverage GPU's specialized load/store units for maximum bandwidth -- **Shared Memory Management**: Explicit size calculations ensure proper memory allocation +- **Architecture Adaptation**: Automatically selects optimal MMA atoms and copy operations based on GPU architecture +- **Type Safety**: Template-based design ensures type consistency across mask, bias, and attention operations +- **Performance Optimization**: Leverages specialized load/store instructions (LDSM) for maximum memory bandwidth ### 3. Block Information Extension (`block_info.h`) -**Purpose**: Calculate memory offsets for ZOH states and active masks within thread blocks, enabling efficient global memory access. +**Purpose**: Calculate memory offsets for attention bias and attention masks within thread blocks, enabling efficient global memory access. **Changes Made**: ```cpp -template +template struct BlockInfo { - // ...existing Flash Attention block info... - - index_t zoh_offset; // Global memory offset for ZOH states - index_t active_mask_offset; // Global memory offset for active mask - template - __device__ BlockInfo(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - // ...existing initialization... - - // Calculate ZOH states offset: [batch][head][query_start_row][0] - zoh_offset = bidb * params.zoh_batch_stride + - bidh * params.zoh_head_stride + - m_block * kBlockM * params.zoh_row_stride; - - // Calculate Active Mask offset: [batch][head][query_start_row][0] - active_mask_offset = bidb * params.active_mask_batch_stride + - bidh * params.active_mask_head_stride + - m_block * kBlockM * params.active_mask_row_stride; + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : + (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : + seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; + } + + template + __forceinline__ __device__ index_t mask_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); + return offset; + } + + template + __forceinline__ __device__ index_t bias_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + index_t offset = sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + sum_s_k == -1 ? offset += leftpad_k : offset += uint32_t(sum_s_k + leftpad_k); + return offset; } + + const int sum_s_q, sum_s_k; + const int actual_seqlen_q; + const int leftpad_k; + const int seqlen_k_cache; + const int actual_seqlen_k; }; ``` **Rationale**: -- **Unified Offset Calculation**: Encapsulates complex address arithmetic in a single location -- **Block-Aware Indexing**: Accounts for thread block positioning within the global attention matrix -- **Type Safety**: Template-based design ensures compile-time optimization and type checking +- **Unified Offset Calculation**: Provides dedicated methods for calculating mask and bias tensor offsets +- **Variable Length Support**: Handles both fixed and variable length sequences through template specialization +- **Memory Access Optimization**: Encapsulates complex address arithmetic for efficient global memory access ### 4. Memory Copy Operations (`utils.h`) -**Purpose**: Implement efficient memory copy operations for loading ZOH states and active masks from global to shared memory. +**Purpose**: Implement efficient tensor operations and layout conversions optimized for Flash Attention's memory hierarchy. **Changes Made**: ```cpp -template -__forceinline__ __device__ void copy_ZOH( - Tensor0 &tSgZOH, // Global ZOH tensor view - Tensor1 &tSsZOH, // Shared ZOH tensor view - Tensor2 &tSrZOH, // Register ZOH tensor view - Tensor3 &tSgAM, // Global Active Mask tensor view - Tensor4 &tSsAM, // Shared Active Mask tensor view - TiledMma tiled_mma, // MMA tile configuration - TiledCopy smem_tiled_copy_ZOH, // Tiled copy for ZOH - ThrCopy smem_thr_copy_ZOH // Thread copy for ZOH +namespace FLASH_NAMESPACE { + +// Convert accumulator layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +// Type conversion utilities for different precisions +template +__forceinline__ __device__ T convert_type(float x) { + return T(x); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ cutlass::bfloat16_t convert_type(float x) { + return cutlass::bfloat16_t(x); +} +#endif + +// Warp-level reduction operations +template +__forceinline__ __device__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = THREADS / 2; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask); + } + return x; +} + +// GEMM operations with register and shared memory variants +template < + bool A_in_regs=false, bool B_in_regs=false, + typename Tensor0, typename Tensor1, typename Tensor2, + typename Tensor3, typename Tensor4, + typename TiledMma, typename TiledCopyA, typename TiledCopyB, + typename ThrCopyA, typename ThrCopyB +> +__forceinline__ __device__ void gemm( + Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, + Tensor3 &tCsA, Tensor4 &tCsB, + TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B ) { - // Copy ZOH states: Global Memory -> Shared Memory - copy(smem_tiled_copy_ZOH, tSgZOH, tSsZOH); - - // Copy Active Mask: Global Memory -> Shared Memory - copy(smem_tiled_copy_ZOH, tSgAM, tSsAM); - - // Synchronize to ensure all data is loaded before computation - __syncthreads(); + if constexpr (!A_in_regs) { + copy(smem_tiled_copy_A, tCsA, tCrA); + } + if constexpr (!B_in_regs) { + copy(smem_tiled_copy_B, tCsB, tCrB); + } - // Copy to registers for computation: Shared Memory -> Registers - copy(smem_thr_copy_ZOH, tSsZOH, tSrZOH); - copy(smem_thr_copy_ZOH, tSsAM, tSrAM); + // Perform matrix multiplication + gemm(tiled_mma, acc, tCrA, tCrB, acc); } + +} // namespace FLASH_NAMESPACE ``` **Rationale**: -- **Multi-Level Memory Hierarchy**: Efficiently manages data movement through global -> shared -> register memory levels -- **Coalesced Access Patterns**: Leverages CUTLASS copy operations for optimal memory bandwidth utilization -- **Synchronization Management**: Proper thread synchronization ensures data consistency across the thread block +- **Layout Conversion**: Efficient transformation between MMA and row-column layouts for easier tensor manipulation +- **Multi-Precision Support**: Proper type conversion utilities for FP16 and BF16 operations +- **Memory Hierarchy Management**: Flexible GEMM operations supporting different data residency patterns +- **Performance Optimization**: Warp-level reductions and vectorized operations for maximum throughput ### 5. Dynamic Masking Logic (`mask.h`) -**Purpose**: Implement the core dynamic masking functionality that applies ZOH states and active masks during attention computation. +**Purpose**: Implement the core dynamic masking functionality that applies attention bias and attention masks during attention computation. **Changes Made**: ```cpp +template +__forceinline__ __device__ void apply_mask( + TensorType &tensor, + MaskType &mask, + BiasType &bias, + const float scale_softmax, + const int col_idx_offset_, + const int max_seqlen_k, + const int row_idx_offset, + const int max_seqlen_q, + const int warp_row_stride +) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(TensorType::rank == 2, "Only support 2D Tensor"); + static_assert(MaskType::rank == 2, "Only support 2D Mask"); + static_assert(BiasType::rank == 2, "Only support 2D Bias"); + + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = Causal_mask ? + std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : + max_seqlen_k; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); + } + } + } + } +} + template -struct DynamicMask { +struct Mask { const int max_seqlen_k, max_seqlen_q; - const int keep_window_size; - template + __forceinline__ __device__ Mask( + const int max_seqlen_k, + const int max_seqlen_q + ) // Constructor + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + template __forceinline__ __device__ void apply_mask( - TensorType &tensor_, // Attention scores (MMA=4, MMA_M, MMA_N) - ZOHType &tSrZOH, // ZOH states in registers - ActiveMaskType &tSrAM, // Active mask in registers - const float scale_softmax, // Attention scaling factor - const int col_idx_offset_, // Column index offset for this thread block - const int row_idx_offset, // Row index offset for this thread block - const int warp_row_stride // Row stride within warp + TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N) + MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N) + BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N) + const float scale_softmax, // Scale for softmax + const int col_idx_offset_, // Column index offset + const int row_idx_offset, // Row index offset + const int warp_row_stride // Warp row stride ) { - // Convert MMA layout to row-column layout for easier indexing - Tensor tensor = make_tensor(tensor_.data(), convert_layout_acc_rowcol(tensor_.layout())); - Tensor zoh = make_tensor(tSrZOH.data(), convert_layout_acc_rowcol(tSrZOH.layout())); - Tensor active_mask = make_tensor(tSrAM.data(), convert_layout_acc_rowcol(tSrAM.layout())); + // Reshape tensors from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tensor_.layout())); + Tensor mask = make_tensor(tSrMask.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrMask.layout())); + Tensor bias = make_tensor(tSrBias.data(), FLASH_NAMESPACE::convert_layout_acc_rowcol(tSrBias.layout())); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; - // Apply causal masking if enabled const int col_idx_limit = Causal_mask ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k; - #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j * 2; - - if (col_idx < col_idx_limit && row_idx < max_seqlen_q && col_idx < max_seqlen_k) { - // Check if this position should be computed (active mask = 1.0) - if (active_mask(i, mi, j, nj) == 0.0f) { - // Masked position: set to -infinity - tensor(i, mi, j, nj) = -INFINITY; - } else { - // Active position: apply scaling and add ZOH bias - tensor(i, mi, j, nj) = tensor(i, mi, j, nj) * scale_softmax + zoh(i, mi, j, nj); - } - } else { - // Out of bounds: always mask - tensor(i, mi, j, nj) = -INFINITY; - } + const int col_idx = col_idx_base + j; + auto coord = make_coord(make_coord(i, mi), make_coord(j, nj)); + // Apply scaling and bias or masking + tensor(coord) = (col_idx >= col_idx_limit) || (mask(coord) == 0.0f) + ? -INFINITY + : tensor(coord) * scale_softmax + bias(coord); } } } @@ -257,73 +452,77 @@ struct DynamicMask { **Rationale**: - **Register-Level Operations**: All masking operations performed in registers for maximum efficiency - **Unified Masking Logic**: Combines causal masking, boundary checking, and dynamic masking in a single pass +- **Layout Conversion**: Properly handles MMA tensor layout conversion for efficient indexing - **Numerical Stability**: Proper handling of infinity values for masked positions ensures stable softmax computation -### 6. Sparse Matrix Operations (`utils.h`) +### 6. Backward Pass Integration (`flash_bwd_kernel.h`) -**Purpose**: Implement sparse GEMM operations that utilize active masks to skip computation for masked regions, significantly reducing computational overhead. +**Purpose**: Extend backward pass computation to support dynamic masking with proper gradient computation for masked positions. **Changes Made**: ```cpp -template -__forceinline__ __device__ void sparse_gemm( - Tensor0 &acc, // Output accumulator tensor - Tensor1 &tCrA, // A matrix in registers (Query) - Tensor2 &tCrB, // B matrix in registers (Key/Value) - Tensor3 &tCsA, // A matrix in shared memory - Tensor4 &tCsB, // B matrix in shared memory - Tensor5 &active_mask, // Sparsity mask in registers - TiledMma tiled_mma, // MMA tile configuration - TiledCopyA smem_tiled_copy_A, // Copy configuration for A - TiledCopyB smem_tiled_copy_B, // Copy configuration for B - ThrCopyA smem_thr_copy_A, // Thread copy for A - ThrCopyB smem_thr_copy_B // Thread copy for B -) { - // Load data based on sparsity pattern - if constexpr (!A_in_regs) { - copy(smem_tiled_copy_A, tCsA, tCrA); - } - if constexpr (!B_in_regs) { - copy(smem_tiled_copy_B, tCsB, tCrB); - } - - // Perform sparse matrix multiplication - // Only compute where active_mask indicates active positions - sparse_gemm_impl(tiled_mma, acc, tCrA, tCrB, active_mask); -} +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV and dBias matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + void *__restrict__ dbias_ptr; + + // To accumulate dQ, dK, dV + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + index_t dbias_batch_stride; + index_t dbias_head_stride; + index_t dbias_row_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; -template -__forceinline__ __device__ void sparse_gemm_rs( - Tensor0 &acc, // Accumulator (attention scores) - Tensor1 &tCrA, // Query in registers - Tensor2 &tCrB, // Key in registers - Tensor3 &tCsA, // Query in shared memory - Tensor4 &tCsB, // Key in shared memory - Tensor5 &active_mask, // Active mask for sparsity - TiledMma tiled_mma, - TiledCopyA smem_tiled_copy_A, - TiledCopyB smem_tiled_copy_B, - ThrCopyA smem_thr_copy_A, - ThrCopyB smem_thr_copy_B -) { - // Row-major sparse GEMM variant optimized for Q*K^T computation - // Utilizes active mask to determine which K vectors to process +template +inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { + // Backward pass computation with dynamic masking support + // Includes proper gradient computation through masked attention scores + // Maintains numerical stability for masked positions + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Initialize block information and tensor views + const BlockInfo binfo(params, bidb); + + // Set up gradient computation with masking awareness + // Load bias and mask gradients when computing dBias + // Apply masking logic consistently with forward pass } ``` **Rationale**: -- **Computational Efficiency**: Skips matrix multiplication for masked regions, reducing FLOPs proportional to sparsity -- **Memory Bandwidth Optimization**: Avoids loading unnecessary data for masked positions -- **Flexible Sparsity Support**: Supports different sparsity patterns through the active mask tensor -- **Register/Shared Memory Optimization**: Provides variants for different data residency scenarios +- **Gradient Consistency**: Ensures gradients are computed consistently with forward pass masking logic +- **Memory Layout Preservation**: Maintains the same memory layout and stride patterns as forward pass +- **Numerical Stability**: Proper handling of gradients at masked positions to prevent NaN propagation ### 7. Attention Kernel Modifications (`flash_fwd_kernel.h`) @@ -331,90 +530,162 @@ __forceinline__ __device__ void sparse_gemm_rs( **Changes Made**: ```cpp -template -inline __device__ void compute_attn_1rowblock( - const Params ¶ms, - const int bidb, - const int bidh, - const int m_block -) { - // Initialize block information with ZOH and active mask offsets - const BlockInfo binfo(params, bidb, bidh, m_block); +template +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; - // Set up tensor views for ZOH states and active masks - Tensor mZOH = make_tensor(make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset), - make_shape(binfo.actual_seqlen_q, params.seqlen_k), - make_stride(params.zoh_row_stride, _1{})); + // Initialize block information + const BlockInfo binfo(params, bidb); - Tensor mActiveMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset), - make_shape(binfo.actual_seqlen_q, params.seqlen_k), - make_stride(params.active_mask_row_stride, _1{})); - - // Main computation loop over key/value blocks + // Set up tensor views for Q, K, V matrices + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, Int{}), + make_stride(params.q_row_stride, _1{})); + + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, Int{}), + make_stride(params.k_row_stride, _1{})); + + // Set up mask and bias tensor views if available + Tensor mMask, mBias; + if (params.mask_ptr != nullptr) { + mMask = make_tensor(make_gmem_ptr(reinterpret_cast(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.mask_row_stride, _1{})); + } + + if (params.bias_ptr != nullptr) { + mBias = make_tensor(make_gmem_ptr(reinterpret_cast(params.bias_ptr) + binfo.bias_offset(params.bias_batch_stride, params.bias_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, binfo.actual_seqlen_k), + make_stride(params.bias_row_stride, _1{})); + } + + // Main computation loop with dynamic masking integration for (int n_block = n_block_min; n_block < n_block_max; ++n_block) { - // Load ZOH states and active masks for this block - copy_ZOH(tSgZOH, tSsZOH, tSrZOH, tSgActiveMask, tSsActiveMask, - tiled_mma, smem_tiled_copy_ZOH, smem_thr_copy_ZOH); - - // Perform sparse Q*K^T computation - sparse_gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tSrActiveMask, - tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); - - // Apply dynamic masking (ZOH bias + active mask) - DynamicMask dynamic_mask(params.seqlen_k, params.seqlen_q, params.keep_window_size); - dynamic_mask.apply_mask(acc_s, tSrZOH, tSrActiveMask, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM, kBlockM); - - // Continue with softmax and attention*V computation - softmax.template softmax(acc_s); - - // Sparse attention*V computation - sparse_gemm_rs(acc_o, acc_s, tSrV, tSsS, tSsV, tSrActiveMask, - tiled_mma, smem_tiled_copy_S, smem_tiled_copy_V, - smem_thr_copy_S, smem_thr_copy_V); + // Standard Flash Attention computation: Q*K^T + gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, + smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + // Apply dynamic masking if mask/bias tensors are provided + if (params.mask_ptr != nullptr || params.bias_ptr != nullptr) { + Mask mask(params.seqlen_k, params.seqlen_q); + mask.apply_mask(acc_s, tSrMask, tSrBias, params.scale_softmax, + n_block * Kernel_traits::kBlockN, m_block * Kernel_traits::kBlockM, + Kernel_traits::kBlockM); + } + + // Continue with softmax computation + softmax.template softmax_rescale_o( + acc_s, acc_o, params.scale_softmax_log2 + ); + + // Attention * V computation + gemm(acc_o, acc_s, tSrV, acc_s, tSsV, tiled_mma, + smem_tiled_copy_S, smem_tiled_copy_V, + smem_thr_copy_S, smem_thr_copy_V); } } - -template -inline __device__ void compute_attn_1rowblock_splitkv( - const Params ¶ms, - const int bidb, - const int bidh, - const int m_block, - const int n_split_idx, - const int num_n_splits -) { - // Split-K variant with dynamic masking support - // Handles distributed computation across multiple thread blocks - // Maintains sparsity patterns across splits -} ``` **Rationale**: -- **Seamless Integration**: Dynamic masking logic integrated into existing Flash Attention computation flow +- **Seamless Integration**: Dynamic masking logic integrated into existing Flash Attention computation flow without affecting core performance - **Memory Efficiency Preservation**: Maintains Flash Attention's tiling and shared memory optimization strategies -- **Split-K Support**: Extends dynamic masking to split-K attention variants for very long sequences -- **Template Specialization**: Compile-time optimization through template parameters +- **Conditional Execution**: Only applies masking operations when mask/bias tensors are actually provided +- **Template Specialization**: Compile-time optimization eliminates runtime branching for better performance ### 8. Launch Template Updates (`flash_fwd_launch_template.h`) -**Purpose**: Update kernel launch functions to properly configure and validate dynamic masking parameters, ensuring correct shared memory allocation and kernel selection. +**Purpose**: Update kernel launch templates to support dynamic masking functionality with proper template instantiation and dispatch logic. **Changes Made**: ```cpp -template +// Determine if the architecture supports FLASH and define parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define unsupported architecture error handling +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashDynamicMaskAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Kernel definition macro for cleaner code +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { + #if defined(ARCH_SUPPORTS_FLASH) + FLASH_NAMESPACE::compute_attn_splitkv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - // Calculate shared memory requirements including ZOH and active mask tensors - const size_t smem_size = Kernel_traits::kSmemSize + - Kernel_traits::kSmemSizeZOH + - Kernel_traits::kSmemSizeActiveMask; + constexpr size_t smem_size = Kernel_traits::kSmemSize; - // Validate that shared memory requirements don't exceed device limits - TORCH_CHECK(smem_size <= 48 * 1024, "Shared memory requirement exceeds device limit"); + // Handle different precision types and head dimensions + BOOL_SWITCH(params.is_bf16, Is_Bf16, [&] { + using elem_type = std::conditional_t; + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.seqlen_k % Kernel_traits::kBlockN == 0, Is_even_N, [&] { + BOOL_SWITCH(params.d == kHeadDim, Is_even_K, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + auto kernel = &flash_fwd_kernel; + // Launch kernel with appropriate grid and block dimensions + kernel<<>>(params); + }); + }); + }); + }); + }); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// Template instantiations for different configurations +template +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream); +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream); +``` + +**Rationale**: +- **Template Dispatch**: Efficient compile-time branching based on runtime parameters for optimal performance +- **Architecture Support**: Proper handling of different GPU architectures with appropriate error messages +- **Memory Management**: Correct shared memory allocation based on kernel requirements +- **Type Safety**: Strong typing through template parameters ensures correctness across different precisions + +**Purpose**: Update kernel launch functions to properly configure and validate dynamic masking parameters, ensuring correct shared memory allocation and kernel selection. + +**Changes Made**: +```cpp +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + // Calculate shared memory requirements + constexpr size_t smem_size = Kernel_traits::kSmemSize; // Set up grid dimensions const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; @@ -431,10 +702,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMN, [&] { BOOL_SWITCH(is_even_K, IsEvenK, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmax, [&] { - auto kernel = &flash_fwd_kernel; - // Configure dynamic shared memory + // Configure dynamic shared memory if needed if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -476,29 +747,25 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { void set_params_fprop( Flash_fwd_params ¶ms, // ... existing parameters ... - const at::Tensor zoh, // ZOH states tensor - const at::Tensor active_mask, // Active mask tensor - const size_t keep_window_size, // Sparsity control parameter + const at::Tensor mask, // Attention mask tensor + const at::Tensor bias, // Attention bias tensor // ... other parameters ... ) { // Reset parameters and set basic properties params = {}; params.is_bf16 = q.dtype() == torch::kBFloat16; - // Set ZOH states pointers and strides - params.zoh_ptr = zoh.data_ptr(); - params.zoh_batch_stride = zoh.stride(-4); // [batch, head, query, key] - params.zoh_head_stride = zoh.stride(-3); - params.zoh_row_stride = zoh.stride(-2); - - // Set Active Mask pointers and strides - params.active_mask_ptr = active_mask.data_ptr(); - params.active_mask_batch_stride = active_mask.stride(-4); - params.active_mask_head_stride = active_mask.stride(-3); - params.active_mask_row_stride = active_mask.stride(-2); + // Set attention mask pointers and strides + params.mask_ptr = mask.data_ptr(); + params.mask_batch_stride = mask.stride(-4); + params.mask_head_stride = mask.stride(-3); + params.mask_row_stride = mask.stride(-2); - // Set sparsity control parameter - params.keep_window_size = keep_window_size; + // Set attention bias pointers and strides + params.bias_ptr = bias.data_ptr(); + params.bias_batch_stride = bias.stride(-4); + params.bias_head_stride = bias.stride(-3); + params.bias_row_stride = bias.stride(-2); // ... existing parameter setup ... } @@ -507,22 +774,19 @@ std::vector mha_fwd( at::Tensor &q, // Query tensor const at::Tensor &k, // Key tensor const at::Tensor &v, // Value tensor - const at::Tensor &zoh, // ZOH states tensor - const at::Tensor &active_mask, // Active mask tensor + const at::Tensor &mask, // Attention mask tensor + const at::Tensor &bias, // Attention bias tensor std::optional &out_, // Optional output tensor - const float p_dropout, const float softmax_scale, bool is_causal, - const int keep_window_size, // Sparsity control const float softcap, - const bool return_softmax, - std::optional gen_ + const bool return_softmax ) { // Comprehensive input validation CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); - CHECK_DEVICE(zoh); CHECK_DEVICE(active_mask); + CHECK_DEVICE(mask); CHECK_DEVICE(bias); CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v); - CHECK_CONTIGUOUS(zoh); CHECK_CONTIGUOUS(active_mask); + CHECK_CONTIGUOUS(mask); CHECK_CONTIGUOUS(bias); // Validate tensor shapes auto batch_size = q.size(0); @@ -532,31 +796,27 @@ std::vector mha_fwd( auto seqlen_k = k.size(1); auto num_heads_k = k.size(2); - CHECK_SHAPE(zoh, batch_size, num_heads_k, seqlen_q, seqlen_k); - CHECK_SHAPE(active_mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(mask, batch_size, num_heads_k, seqlen_q, seqlen_k); + CHECK_SHAPE(bias, batch_size, num_heads_k, seqlen_q, seqlen_k); // Validate data types consistency TORCH_CHECK(q.dtype() == k.dtype() && k.dtype() == v.dtype(), "All QKV tensors must have the same dtype"); - TORCH_CHECK(zoh.dtype() == q.dtype(), - "ZOH states must have the same dtype as QKV tensors"); - TORCH_CHECK(active_mask.dtype() == q.dtype(), - "Active mask must have the same dtype as QKV tensors"); - - // Validate sparsity parameter - TORCH_CHECK(keep_window_size > 0 && keep_window_size <= seqlen_k, - "keep_window_size must be positive and <= seqlen_k"); + TORCH_CHECK(mask.dtype() == q.dtype(), + "Attention mask must have the same dtype as QKV tensors"); + TORCH_CHECK(bias.dtype() == q.dtype(), + "Attention bias must have the same dtype as QKV tensors"); // Set up parameters and launch computation Flash_fwd_params params; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, /* ... */, - q, k, v, zoh, active_mask, /* ... */, keep_window_size, /* ... */); + q, k, v, mask, bias, /* ... */); // Launch kernel with appropriate configuration run_mha_fwd(params, at::cuda::getCurrentCUDAStream()); // Return results - return {out, softmax_lse, /* ... */}; + return {out, softmax_lse}; } // Python binding @@ -564,15 +824,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashDynamicMaskAttention"; m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass with dynamic masking", py::arg("q"), py::arg("k"), py::arg("v"), - py::arg("zoh"), py::arg("active_mask"), // New required arguments + py::arg("mask"), py::arg("bias"), // Updated arguments py::arg("out") = py::none(), - py::arg("p_dropout") = 0.0f, py::arg("softmax_scale") = 0.0f, py::arg("is_causal") = false, - py::arg("keep_window_size") = 2048, // New sparsity control py::arg("softcap") = 0.0f, - py::arg("return_softmax") = false, - py::arg("gen") = py::none()); + py::arg("return_softmax") = false); } ``` @@ -585,166 +842,456 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ## Implementation Details -### Python Frontend: Dynamic Mask Generation +### C++ API Interface (`flash_api.cpp`) -The Python frontend is responsible for computing the ZOH states and active masks before passing them to the CUDA backend: +The core C++ API provides the following main functions for Dynamic Mask Attention: -```python -def prepare_dynamic_mask( - hidden_states: torch.Tensor, - dt_states: torch.Tensor, - keep_window_size: int = 2048, - attention_mask: torch.Tensor = None, -): - """ - Core DMA function that generates dynamic attention masks for sparse computation. - - Process: - 1. Expand dt_states to match attention matrix dimensions - 2. Apply optional causal/padding masks - 3. Use TopK selection to identify most important positions - 4. Generate binary active mask for CUDA computation - """ - min_dtype = torch.finfo(hidden_states.dtype).min - dtype = hidden_states.dtype +```cpp +namespace FLASH_NAMESPACE { + +std::vector mha_fwd( + at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const float softmax_scale, + bool is_causal, + const float softcap, + const bool return_softmax +); + +std::vector mha_varlen_fwd( + at::Tensor &q, // total_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // total_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // total_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &mask, // total_q x num_heads_k x max_seqlen_k + const at::Tensor &bias, // total_q x num_heads_k x max_seqlen_k + std::optional &out_, // total_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &cu_seqlens_q, // batch_size + 1 + const at::Tensor &cu_seqlens_k, // batch_size + 1 + std::optional &seqused_k, + std::optional &leftpad_k, + const int max_seqlen_q, + const int max_seqlen_k, + const float softmax_scale, + bool is_causal, + const float softcap, + const bool return_softmax +); + +std::vector mha_bwd( + const at::Tensor &dout, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) + const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k + const at::Tensor &out, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) + const at::Tensor &softmax_lse, // batch_size x num_heads x seqlen_q + std::optional &dq_, + std::optional &dk_, + std::optional &dv_, + std::optional &dbias_, + const float softmax_scale, + bool is_causal, + const float softcap, + bool deterministic, + std::optional gen_ +); + +} // namespace FLASH_NAMESPACE +``` + +### Parameter Setup and Validation + +The implementation includes comprehensive parameter validation and setup: + +```cpp +void set_params_fprop( + Flash_fwd_params ¶ms, + const size_t b, const size_t seqlen_q, const size_t seqlen_k, + const size_t seqlen_q_rounded, const size_t seqlen_k_rounded, + const size_t h, const size_t h_k, const size_t d, const size_t d_rounded, + const at::Tensor q, const at::Tensor k, const at::Tensor v, + const at::Tensor mask, const at::Tensor bias, at::Tensor out, + void *cu_seqlens_q_d, void *cu_seqlens_k_d, void *seqused_k, + void *p_d, void *softmax_lse_d, float softmax_scale, bool is_causal, + const float softcap, bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false +) { + // Reset parameters + params = {}; + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set tensor pointers + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.mask_ptr = mask.data_ptr(); + params.bias_ptr = bias.data_ptr(); + params.o_ptr = out.data_ptr(); + + // Set stride information (all strides are in elements, not bytes) + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.mask_row_stride = mask.stride(-2); + params.bias_row_stride = bias.stride(-2); + params.o_row_stride = out.stride(-3); - # Expand dt_states: [batch, num_heads, key_len] -> [batch, num_heads, query_len, key_len] - attn_mask = dt_states[:, :, None, :].expand(-1, -1, hidden_states.shape[2], -1) + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.mask_head_stride = mask.stride(-3); + params.bias_head_stride = bias.stride(-3); + params.o_head_stride = out.stride(-2); + + // Set batch stride information + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.mask_batch_stride = mask.stride(0); + params.bias_batch_stride = bias.stride(0); + params.o_batch_stride = out.stride(0); + } - # Apply causal/padding masks by setting masked positions to -inf - if attention_mask is not None: - if attention_mask.dtype == torch.bool: - attention_mask = torch.where(attention_mask, 0.0, min_dtype) - attn_mask = attn_mask.masked_fill(attention_mask != 0, min_dtype) + // Set sequence length and dimension parameters + params.b = b; params.h = h; params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; params.d_rounded = d_rounded; - # Only apply when sequence length exceeds window size - if attn_mask.shape[-1] > keep_window_size: - active_mask = torch.zeros_like(attn_mask, dtype=dtype, device=attn_mask.device) - # TopK selection identifies most important positions for each query - topk_indices = torch.topk(attn_mask, keep_window_size, dim=-1, - largest=True, sorted=False).indices - # Create binary mask: 1.0 for active positions, 0.0 for masked - active_mask = active_mask.scatter(-1, topk_indices, 1.0) - # Set non-selected positions to -inf in attention mask - attn_mask = attn_mask.masked_fill(active_mask == 0.0, min_dtype) - else: - # If sequence length is within window size, all positions are active - active_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) - return attn_mask, active_mask + // Set scaling and control parameters + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + params.softcap = softcap; + params.is_causal = is_causal; + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} ``` -### CUDA Backend: Sparse Attention Computation +### Python Binding and Interface -The CUDA backend implements three key stages of sparse attention: +The C++ functions are exposed to Python through PyBind11: -#### Stage 1: Memory Loading and Tensor Setup ```cpp -// Set up tensor views for ZOH states and active masks -Tensor mZOH = make_tensor( - make_gmem_ptr(reinterpret_cast(params.zoh_ptr) + binfo.zoh_offset), - make_shape(binfo.actual_seqlen_q, params.seqlen_k), - make_stride(params.zoh_row_stride, _1{}) -); +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashDynamicMaskAttention"; + m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass"); + m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass with variable length"); + m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass"); + m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass with variable length"); +} +``` -Tensor mActiveMask = make_tensor( - make_gmem_ptr(reinterpret_cast(params.active_mask_ptr) + binfo.active_mask_offset), - make_shape(binfo.actual_seqlen_q, params.seqlen_k), - make_stride(params.active_mask_row_stride, _1{}) -); +### Python Frontend Integration Example + +Dynamic Mask Attention can be integrated into transformer models as follows: -// Load data through memory hierarchy: Global -> Shared -> Registers -copy_ZOH(tSgZOH, tSsZOH, tSrZOH, tSgActiveMask, tSsActiveMask, - tiled_mma, smem_tiled_copy_ZOH, smem_thr_copy_ZOH); +```python +import torch +import torch.nn as nn +import flash_dmattn_cuda as flash_dmattn + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = 1.0 / math.sqrt(self.head_dim) + + # Standard attention projections + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, attention_mask=None, attention_bias=None): + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Prepare mask and bias tensors with proper shapes + if attention_mask is None: + attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + if attention_bias is None: + attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + # Call Flash Dynamic Mask Attention + output, _ = flash_dmattn.fwd( + query_states, key_states, value_states, + attention_mask, attention_bias, + None, # out + self.scaling, # softmax_scale + False, # is_causal + 0.0, # softcap + False # return_softmax + ) + + # Output projection + output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + return self.o_proj(output) +``` + + # Call attention implementation + attn_output, attn_weights = flash_dynamic_mask_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attn_bias, + scaling=self.scaling, + ) + + return attn_output, attn_weights ``` -#### Stage 2: Sparse Q*K^T Computation -```cpp -// Sparse GEMM that skips computation for masked positions -sparse_gemm_rs(acc_s, tSrQ, tSrK, tSsQ, tSsK, tSrActiveMask, - tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, - smem_thr_copy_Q, smem_thr_copy_K); +The attention bias generation process: + +1. **Value-based Dynamic States**: + ```python + dt_states = self.dt_proj(value_states_flattened) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + ``` + +2. **Bias Expansion**: + ```python + attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) + ``` + +3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` + + +### CUDA Backend: Sparse Attention Computation -// Apply dynamic masking: scaling + ZOH bias + masking -DynamicMask dynamic_mask(params.seqlen_k, params.seqlen_q, params.keep_window_size); -dynamic_mask.apply_mask(acc_s, tSrZOH, tSrActiveMask, params.scale_softmax, - n_block * kBlockN, m_block * kBlockM, kBlockM); +The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: + +```python +def _flash_dynamic_mask_attention_forward( + query_states, key_states, value_states, + attention_mask, attention_bias, + query_length, key_length, + is_causal, softmax_scale=None, softcap=None, + target_dtype=None, implementation=None, **kwargs +): + dtype = query_states.dtype + min_dtype = torch.finfo(dtype).min + batch_size, _, num_kv_heads, _ = key_states.shape + + # Initialize attention bias if not provided + if attention_bias is None: + attention_bias = torch.zeros( + (batch_size, num_kv_heads, query_length, key_length), + dtype=dtype, device=query_states.device + ) + + # Apply attention mask to bias + if attention_mask is not None: + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) + attention_mask = attention_mask.to(dtype) + + # Call Flash Attention with dynamic masking + out = flash_dmattn_func( + query_states, key_states, value_states, + attn_mask=attention_mask, attn_bias=attention_bias, + scale=softmax_scale, is_causal=is_causal + ) + + return out[0] if isinstance(out, tuple) else out ``` -#### Stage 3: Softmax and Sparse Attention*V +The backend processing stages: + +1. **Bias Initialization**: Create zero bias tensor if not provided +2. **Mask Application**: Apply boolean attention mask to bias tensor +3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns + +#### Updated Forward Algorithm + +The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: + ```cpp -// Online softmax computation (unchanged from Flash Attention) -softmax.template online_softmax(acc_s); +// Forward pass with unified skip logic +for m_block in M_tiles: + load Q_tile + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Block-level skip decision + if !any_active: + advance_pointers() // Skip computation, advance to next tile + continue + + // Only execute for active tiles + load K_tile, V_tile // Load data only when needed + S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM + S_masked = apply_mask(S, mask_block) // Apply dynamic masking + P = softmax(S_masked, LSE_cache) // Softmax with LSE caching + O_partial += P @ V_tile // Sparse Score*V GEMM +write O +``` -// Sparse attention*V computation -sparse_gemm(acc_o, acc_s, tSrV, tSsS, tSsV, tSrActiveMask, - tiled_mma, smem_tiled_copy_S, smem_tiled_copy_V, - smem_thr_copy_S, smem_thr_copy_V); +Key improvements: +- **Block-level Skip Logic**: OR-reduction over entire (BlockM ร— BlockN) tile determines if computation is needed +- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation +- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles + +#### Updated Backward Algorithm + +The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: + +```cpp +// Backward pass with unified skip logic +for m_block in reversed(M_tiles): + load Q_tile, dO_tile + init accum_dQ + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Same skip decision as forward + if !any_active: + advance_pointers_zero_side_outputs() // Skip computation, zero side outputs + continue + + // Only execute for active tiles + load K_tile, V_tile + + # Recompute (identical to forward for active tiles) + S = Q_tile @ K_tile^T + bias_block + P = softmax(S, LSE_cache) // Use cached LSE for stability + + # Gradient computation chain (5 GEMMs) + dV += P^T @ dO_tile // Accumulate dV + dP = dO_tile @ V_tile^T // Compute dP + dS = g(P, dP) // dS = (dP - (P โŠ™ dP).sum(axis)) * P + dQ += dS @ K_tile // Accumulate dQ + dK += dS^T @ Q_tile // Accumulate dK + write dQ, accumulate dK, dV ``` -## Sparse Computation Strategy +Key features: +- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision +- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation +- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness +- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation + +#### Skip Logic Correctness + +The mathematical correctness of the skip logic relies on the following principles: + +1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: + ``` + O_contribution = P @ V = 0 @ V = 0 + ``` + +2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: + ``` + P = 0 โŸน dS = 0 โŸน dQ = dK = dV = 0 (from this tile) + ``` + +3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. + +### Sparse Computation Strategy + +### Block-level Skip Logic + +The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: + +1. **Tile-level Active Detection**: + ```cpp + any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active + ``` + +2. **Skip Decision**: Binary branch based on tile activity: + ```cpp + if (!any_active) { + advance_pointers(); // Forward: skip all computation + advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs + continue; + } + ``` + +3. **Computational Benefits**: + - Skip entire K/V loads for inactive tiles + - Eliminate all 5 GEMMs in backward pass for inactive tiles + - Reduce memory bandwidth and arithmetic operations proportional to sparsity ### Sparsity Pattern Recognition The Dynamic Mask Attention implements structured sparsity based on learned importance scores: -1. **ZOH State Computation**: `dt_states = exp(A * softplus(V @ dt_proj^T))` - - Learned projection matrix `dt_proj` maps value features to importance scores - - Coefficient `A` controls the dynamic range of importance values - - Exponential activation ensures positive importance scores +1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors + - Learned projection matrices map value features to importance scores + - Coefficient parameters control the dynamic range of importance values + - Activation functions ensure appropriate bias magnitude + +2. **Binary Attention Mask**: + - 1.0 for positions that should be computed + - 0.0 for positions that should be skipped + +### Performance Model (Updated) + +For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: -2. **TopK Selection**: For sequences longer than `keep_window_size`: - - Select top-K most important positions per query token - - K = `keep_window_size` (typically 512-2048) - - Maintains fixed computational complexity regardless of sequence length +$$ +\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} +$$ -3. **Binary Active Mask**: - - 1.0 for positions selected by TopK (compute) - - 0.0 for positions not selected (skip computation) +Where: +- $p$: fraction of active tiles +- $\varepsilon$: skip branching overhead +- $\eta$: efficiency of early memory load exit +- $\text{LoadOverhead}$: relative cost of K/V loading vs computation -### Sparse GEMM Implementation +Upper bound as $\varepsilon, \eta \to 0$: $1/p$ -The sparse GEMM operations leverage the active mask to skip computation: +### Shared Memory Aliasing + +The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: + +1. **sMask โ†” sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption +2. **sBias โ†” sdS Aliasing**: Bias shared memory region is reused for gradient computations dS +3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage ```cpp -template -__forceinline__ __device__ void sparse_gemm_impl( - TiledMma tiled_mma, - AccType &acc, - AType &tCrA, - BType &tCrB, - MaskType &active_mask -) { - // Convert layouts for efficient indexing - auto acc_rowcol = make_tensor(acc.data(), convert_layout_acc_rowcol(acc.layout())); - auto mask_rowcol = make_tensor(active_mask.data(), convert_layout_acc_rowcol(active_mask.layout())); - - #pragma unroll - for (int mi = 0; mi < size<0, 1>(acc_rowcol); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1, 1>(acc_rowcol); ++ni) { - // Check if this position should be computed - if (mask_rowcol(0, mi, 0, ni) != 0.0f) { - // Perform computation only for active positions - gemm(tiled_mma, acc(_, mi, _, ni), tCrA(_, mi, _), tCrB(_, _, ni)); - } - // Skip computation for masked positions (acc remains unchanged) - } - } -} +// Example aliasing pattern +load mask -> sMask +any_active = or_reduce(sMask) +if any_active: + compute S + __syncthreads() // ensure mask fully consumed + softmax -> write P into aliased region (sP) // reuse sMask region as sP + ... +__syncthreads() // ensure dS consumed +// reuse sBias region as sdS in next iteration ``` ### Memory Efficiency Optimizations -1. **Shared Memory Reuse**: ZOH states and active masks share copy infrastructure with Q/K/V tensors -2. **Register Allocation**: Critical masking operations performed in registers to minimize memory traffic -3. **Coalesced Access**: Memory access patterns optimized for GPU memory hierarchy -4. **Template Specialization**: Compile-time optimization eliminates runtime branching +1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask โ†” sP, sBias โ†” sdS) with explicit barrier synchronization +2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles +3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory +5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy +6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead ## Memory Layout ### Tensor Memory Organization -The Dynamic Mask Attention extends Flash Attention's memory layout to include ZOH states and active masks: +The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: ``` Global Memory Layout: @@ -752,8 +1299,8 @@ Global Memory Layout: โ”‚ Q: [batch, seqlen_q, num_heads, head_dim] โ”‚ โ”‚ K: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ โ”‚ V: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ -โ”‚ ZOH: [batch, num_heads_k, seqlen_q, seqlen_k] โ”‚ -โ”‚ AM: [batch, num_heads_k, seqlen_q, seqlen_k] โ”‚ +โ”‚ Mask: [batch, num_heads_k, seqlen_q, seqlen_k] โ”‚ +โ”‚ Bias: [batch, num_heads_k, seqlen_q, seqlen_k] โ”‚ โ”‚ Output: [batch, seqlen_q, num_heads, head_dim] โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ @@ -761,32 +1308,32 @@ Shared Memory Layout (per thread block): โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ Q Tile: [kBlockM, head_dim] โ”‚ K Tile: [kBlockN, head_dim] โ”‚ โ”‚ V Tile: [kBlockN, head_dim] โ”‚ S Tile: [kBlockM, kBlockN] โ”‚ -โ”‚ ZOH Tile: [kBlockM, kBlockN] โ”‚ AM Tile: [kBlockM, kBlockN] โ”‚ +โ”‚ AM Tile: [kBlockM, kBlockN] โ”‚ Bias Tile: [kBlockM, kBlockN] โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ Register Memory (per thread): โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚ Q Frag: [MMA_M, head_dim/N] โ”‚ K Frag: [MMA_N, head_dim/N] โ”‚ โ”‚ V Frag: [MMA_N, head_dim/N] โ”‚ S Frag: [MMA_M, MMA_N] โ”‚ -โ”‚ ZOH Frag: [MMA_M, MMA_N] โ”‚ AM Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ AM Frag: [MMA_M, MMA_N] โ”‚ Bias Frag: [MMA_M, MMA_N] โ”‚ โ”‚ Acc Frag: [MMA_M, head_dim/N] โ”‚ โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ ``` ### Memory Access Patterns -#### ZOH States and Active Mask Loading +#### Attention Mask and Attention Bias Loading ```cpp // Global to Shared Memory (coalesced access) -Tensor tSgZOH = local_partition(mZOH, smem_tiled_copy_ZOH, thread_idx); -Tensor tSsZOH = local_partition(sZOH, smem_tiled_copy_ZOH, thread_idx); +Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); +Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); // Each thread loads a contiguous chunk to maximize memory bandwidth -copy(smem_tiled_copy_ZOH, tSgZOH, tSsZOH); +copy(smem_tiled_copy_Bias, tSgBias, tSsBias); // Shared to Register Memory (bank-conflict-free) -Tensor tSrZOH = local_partition(sZOH, smem_thr_copy_ZOH, thread_idx); -copy(smem_thr_copy_ZOH, tSsZOH, tSrZOH); +Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); +copy(smem_thr_copy_Bias, tSsBias, tSrBias); ``` #### Memory Layout Transformations @@ -806,33 +1353,47 @@ auto convert_layout_acc_rowcol = [](auto layout) { ### Shared Memory Optimization #### Bank Conflict Avoidance -- ZOH states and active masks use the same copy patterns as Q/K/V to avoid bank conflicts +- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts - Padding added when necessary to ensure 128-bit aligned access - Thread block size chosen to maximize occupancy while maintaining memory efficiency #### Memory Coalescing ```cpp // Example: Loading 128-bit aligned chunks for optimal bandwidth -using SmemCopyAtomZOH = Copy_Atom; // 128-bit loads -using SmemCopyAtomActiveMask = Copy_Atom; +using SmemCopyAtomBias = Copy_Atom; // 128-bit loads +using SmemCopyAtomAttnMask = Copy_Atom; ``` ## Performance Considerations ### Memory Efficiency -- **Reduced Memory Bandwidth**: Sparse computation reduces memory traffic -- **Optimized Layouts**: Tensor layouts optimized for GPU memory hierarchy -- **Shared Memory Reuse**: Efficient use of limited shared memory resources +- **Shared Memory Aliasing**: Smart memory reuse (sMask โ†” sP, sBias โ†” sdS) reduces footprint by ~30% +- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles +- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy ### Computational Efficiency -- **Sparse GEMM**: Skips computation for masked regions -- **Fused Operations**: Masking integrated into existing computation kernels -- **Warp-Level Optimization**: Optimized for GPU warp execution model +- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping +- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles +- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads +- **Warp-Level Optimization**: Operations optimized for GPU warp execution model ### Scalability -- **Long Sequence Support**: Efficient handling of sequences > 32K tokens -- **Configurable Sparsity**: `keep_window_size` parameter controls sparsity level -- **Multi-Head Support**: Efficient handling of multiple attention heads +- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences +- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns +- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies + +### Performance Model + +Expected speedup for various sparsity levels: +- **50% sparsity**: ~1.8x speedup +- **75% sparsity**: ~3.2x speedup +- **90% sparsity**: ~6.5x speedup + +Performance factors: +- Skip overhead typically <5% of dense computation time +- Memory bandwidth reduction scales linearly with sparsity +- Shared memory aliasing enables 20-30% larger tile sizes ## API Changes @@ -840,18 +1401,13 @@ using SmemCopyAtomActiveMask = Copy_Atom; The Dynamic Mask Attention integration introduces new required parameters to the forward pass: -- **`zoh`** (`torch.Tensor`): ZOH states tensor of shape `(batch, num_heads_k, seqlen_q, seqlen_k)` - - Contains dynamic attention bias values derived from value states - - Must have the same dtype and device as Q/K/V tensors - -- **`active_mask`** (`torch.Tensor`): Active mask tensor of shape `(batch, num_heads_k, seqlen_q, seqlen_k)` +- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed - Determines the sparsity pattern for computational efficiency -- **`keep_window_size`** (`int`): Sparsity control parameter - - Maximum number of key positions to attend to per query token - - Controls the computational complexity and memory usage - - Typical values: 512-2048 for long sequences +- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Contains dynamic attention bias values applied to attention scores before softmax + - Must have the same dtype and device as Q/K/V tensors ### Updated Function Signature @@ -860,141 +1416,1457 @@ def fwd( q: torch.Tensor, # Query tensor k: torch.Tensor, # Key tensor v: torch.Tensor, # Value tensor - zoh: torch.Tensor, # ZOH states (NEW) - active_mask: torch.Tensor, # Active mask (NEW) + attn_mask: torch.Tensor, # Attention mask (REQUIRED) + attn_bias: torch.Tensor, # Attention bias (REQUIRED) out: Optional[torch.Tensor] = None, # Pre-allocated output - p_dropout: float = 0.0, # Dropout probability softmax_scale: float = None, # Attention scaling is_causal: bool = False, # Causal masking - keep_window_size: int = 2048, # Sparsity control (NEW) softcap: float = 0.0, # Soft capping return_softmax: bool = False, # Return attention weights - gen: Optional[torch.Generator] = None # Random generator ) -> List[torch.Tensor] ``` ### Backward Compatibility -**Breaking Change Notice**: The integration requires ZOH states and active mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. +**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. **Migration Path**: Users need to: -1. Implement ZOH state computation using the `prepare_dynamic_mask` function -2. Update function calls to include the new required parameters -3. Choose appropriate `keep_window_size` values based on their use case +1. Add attention mask and bias generation logic to attention modules +2. Implement appropriate mask and bias computation within the attention forward pass +3. Ensure proper tensor shapes and dtypes for mask and bias tensors ### Complete Usage Example ```python import torch -import torch.nn.functional as F -import flash_dma - -# Setup -batch_size, seqlen_q, seqlen_k = 2, 4096, 4096 -num_heads, head_dim = 12, 128 -device, dtype = 'cuda', torch.bfloat16 - -# Input tensors -q = torch.randn(batch_size, seqlen_q, num_heads, head_dim, device=device, dtype=dtype) -k = torch.randn(batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype) -v = torch.randn(batch_size, seqlen_k, num_heads, head_dim, device=device, dtype=dtype) - -# Dynamic Mask Attention requires additional parameters -dt_proj = torch.randn(num_heads, num_heads * head_dim, device=device, dtype=dtype) -A = torch.randn(num_heads, device=device, dtype=dtype) - -# Step 1: Compute ZOH states -dt_states = torch.matmul( - v.transpose(-2, -3).reshape(batch_size, seqlen_k, -1), - dt_proj.T -) -dt_states = torch.exp(A * F.softplus(dt_states)).transpose(-1, -2) - -# Step 2: Generate dynamic masks -zoh_states, active_mask = flash_dma.prepare_dynamic_mask( - q, dt_states, keep_window_size=2048, attention_mask=None -) - -# Step 3: Run Dynamic Mask Attention -output = flash_dma.fwd( - q, k, v, zoh_states, active_mask, - keep_window_size=2048, - softmax_scale=1.0 / (head_dim ** 0.5), - is_causal=False -) - -print(f"Output shape: {output[0].shape}") # [batch_size, seqlen_q, num_heads, head_dim] +import torch.nn as nn +import flash_dmattn_cuda as flash_dmattn + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = 1.0 / math.sqrt(self.head_dim) + + # Standard attention projections + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, attention_mask=None, attention_bias=None): + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Prepare mask and bias tensors with proper shapes + if attention_mask is None: + attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + if attention_bias is None: + attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + # Call Flash Dynamic Mask Attention + output, _ = flash_dmattn.fwd( + query_states, key_states, value_states, + attention_mask, attention_bias, + None, # out + self.scaling, # softmax_scale + False, # is_causal + 0.0, # softcap + False # return_softmax + ) + + # Output projection + output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + return self.o_proj(output) +``` + + # Call attention implementation + attn_output, attn_weights = flash_dynamic_mask_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attn_bias, + scaling=self.scaling, + ) + + return attn_output, attn_weights ``` -### Integration with Existing Codebases +The attention bias generation process: -For users migrating from Flash Attention, the typical changes required are: +1. **Value-based Dynamic States**: + ```python + dt_states = self.dt_proj(value_states_flattened) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + ``` -```python -# Before (Flash Attention) -output = flash_attn.flash_attn_func(q, k, v, dropout_p=0.1, softmax_scale=scale, causal=True) +2. **Bias Expansion**: + ```python + attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) + ``` -# After (Dynamic Mask Attention) -# 1. Add ZOH computation -dt_states = compute_dt_states(v, dt_proj, A) -zoh_states, active_mask = prepare_dynamic_mask(q, dt_states, keep_window_size=2048) +3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` -# 2. Update function call -output = flash_dma.fwd(q, k, v, zoh_states, active_mask, - p_dropout=0.1, softmax_scale=scale, is_causal=True, - keep_window_size=2048) -``` -## Future Enhancements +### CUDA Backend: Sparse Attention Computation -### Planned Improvements +The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: -1. **Backward Pass Integration**: Complete gradient computation support for training Dynamic Mask Attention models - - Sparse gradient computation for ZOH states - - Efficient gradient propagation through active masks - - Memory-optimized backward kernels +```python +def _flash_dynamic_mask_attention_forward( + query_states, key_states, value_states, + attention_mask, attention_bias, + query_length, key_length, + is_causal, softmax_scale=None, softcap=None, + target_dtype=None, implementation=None, **kwargs +): + dtype = query_states.dtype + min_dtype = torch.finfo(dtype).min + batch_size, _, num_kv_heads, _ = key_states.shape + + # Initialize attention bias if not provided + if attention_bias is None: + attention_bias = torch.zeros( + (batch_size, num_kv_heads, query_length, key_length), + dtype=dtype, device=query_states.device + ) + + # Apply attention mask to bias + if attention_mask is not None: + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) + attention_mask = attention_mask.to(dtype) + + # Call Flash Attention with dynamic masking + out = flash_dmattn_func( + query_states, key_states, value_states, + attn_mask=attention_mask, attn_bias=attention_bias, + scale=softmax_scale, is_causal=is_causal + ) + + return out[0] if isinstance(out, tuple) else out +``` + +The backend processing stages: + +1. **Bias Initialization**: Create zero bias tensor if not provided +2. **Mask Application**: Apply boolean attention mask to bias tensor +3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns + +#### Updated Forward Algorithm + +The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: + +```cpp +// Forward pass with unified skip logic +for m_block in M_tiles: + load Q_tile + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Block-level skip decision + if !any_active: + advance_pointers() // Skip computation, advance to next tile + continue + + // Only execute for active tiles + load K_tile, V_tile // Load data only when needed + S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM + S_masked = apply_mask(S, mask_block) // Apply dynamic masking + P = softmax(S_masked, LSE_cache) // Softmax with LSE caching + O_partial += P @ V_tile // Sparse Score*V GEMM +write O +``` + +Key improvements: +- **Block-level Skip Logic**: OR-reduction over entire (BlockM ร— BlockN) tile determines if computation is needed +- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation +- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles + +#### Updated Backward Algorithm + +The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: + +```cpp +// Backward pass with unified skip logic +for m_block in reversed(M_tiles): + load Q_tile, dO_tile + init accum_dQ + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Same skip decision as forward + if !any_active: + advance_pointers_zero_side_outputs() // Skip computation, zero side outputs + continue + + // Only execute for active tiles + load K_tile, V_tile + + # Recompute (identical to forward for active tiles) + S = Q_tile @ K_tile^T + bias_block + P = softmax(S, LSE_cache) // Use cached LSE for stability + + # Gradient computation chain (5 GEMMs) + dV += P^T @ dO_tile // Accumulate dV + dP = dO_tile @ V_tile^T // Compute dP + dS = g(P, dP) // dS = (dP - (P โŠ™ dP).sum(axis)) * P + dQ += dS @ K_tile // Accumulate dQ + dK += dS^T @ Q_tile // Accumulate dK + write dQ, accumulate dK, dV +``` + +Key features: +- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision +- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation +- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness +- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation + +#### Skip Logic Correctness + +The mathematical correctness of the skip logic relies on the following principles: + +1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: + ``` + O_contribution = P @ V = 0 @ V = 0 + ``` + +2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: + ``` + P = 0 โŸน dS = 0 โŸน dQ = dK = dV = 0 (from this tile) + ``` + +3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. + +### Sparse Computation Strategy + +### Block-level Skip Logic + +The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: + +1. **Tile-level Active Detection**: + ```cpp + any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active + ``` + +2. **Skip Decision**: Binary branch based on tile activity: + ```cpp + if (!any_active) { + advance_pointers(); // Forward: skip all computation + advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs + continue; + } + ``` + +3. **Computational Benefits**: + - Skip entire K/V loads for inactive tiles + - Eliminate all 5 GEMMs in backward pass for inactive tiles + - Reduce memory bandwidth and arithmetic operations proportional to sparsity + +### Sparsity Pattern Recognition + +The Dynamic Mask Attention implements structured sparsity based on learned importance scores: + +1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors + - Learned projection matrices map value features to importance scores + - Coefficient parameters control the dynamic range of importance values + - Activation functions ensure appropriate bias magnitude + +2. **Binary Attention Mask**: + - 1.0 for positions that should be computed + - 0.0 for positions that should be skipped + +### Performance Model (Updated) + +For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: + +$$ +\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} +$$ + +Where: +- $p$: fraction of active tiles +- $\varepsilon$: skip branching overhead +- $\eta$: efficiency of early memory load exit +- $\text{LoadOverhead}$: relative cost of K/V loading vs computation + +Upper bound as $\varepsilon, \eta \to 0$: $1/p$ + +### Shared Memory Aliasing + +The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: + +1. **sMask โ†” sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption +2. **sBias โ†” sdS Aliasing**: Bias shared memory region is reused for gradient computations dS +3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage + +```cpp +// Example aliasing pattern +load mask -> sMask +any_active = or_reduce(sMask) +if any_active: + compute S + __syncthreads() // ensure mask fully consumed + softmax -> write P into aliased region (sP) // reuse sMask region as sP + ... +__syncthreads() // ensure dS consumed +// reuse sBias region as sdS in next iteration +``` + +### Memory Efficiency Optimizations + +1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask โ†” sP, sBias โ†” sdS) with explicit barrier synchronization +2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles +3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory +5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy +6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead + +## Memory Layout + +### Tensor Memory Organization + +The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: + +``` +Global Memory Layout: +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ”‚ K: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ V: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Output: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Shared Memory Layout (per thread block): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Tile: [kBlockM, head_dim] โ”‚ K Tile: [kBlockN, head_dim] โ”‚ +โ”‚ V Tile: [kBlockN, head_dim] โ”‚ S Tile: [kBlockM, kBlockN] โ”‚ +โ”‚ AM Tile: [kBlockM, kBlockN] โ”‚ Bias Tile: [kBlockM, kBlockN] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Register Memory (per thread): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Frag: [MMA_M, head_dim/N] โ”‚ K Frag: [MMA_N, head_dim/N] โ”‚ +โ”‚ V Frag: [MMA_N, head_dim/N] โ”‚ S Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ AM Frag: [MMA_M, MMA_N] โ”‚ Bias Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ Acc Frag: [MMA_M, head_dim/N] โ”‚ โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Memory Access Patterns + +#### Attention Mask and Attention Bias Loading +```cpp +// Global to Shared Memory (coalesced access) +Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); +Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); + +// Each thread loads a contiguous chunk to maximize memory bandwidth +copy(smem_tiled_copy_Bias, tSgBias, tSsBias); + +// Shared to Register Memory (bank-conflict-free) +Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); +copy(smem_thr_copy_Bias, tSsBias, tSrBias); +``` + +#### Memory Layout Transformations +```cpp +// Convert MMA accumulator layout to row-column layout for masking +// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) +auto convert_layout_acc_rowcol = [](auto layout) { + return make_layout( + make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), + make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), + make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), + make_stride(Int<1>{}, Int<2>{})) + ); +}; +``` + +### Shared Memory Optimization + +#### Bank Conflict Avoidance +- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts +- Padding added when necessary to ensure 128-bit aligned access +- Thread block size chosen to maximize occupancy while maintaining memory efficiency + +#### Memory Coalescing +```cpp +// Example: Loading 128-bit aligned chunks for optimal bandwidth +using SmemCopyAtomBias = Copy_Atom; // 128-bit loads +using SmemCopyAtomAttnMask = Copy_Atom; +``` + +## Performance Considerations + +### Memory Efficiency +- **Shared Memory Aliasing**: Smart memory reuse (sMask โ†” sP, sBias โ†” sdS) reduces footprint by ~30% +- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles +- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy + +### Computational Efficiency +- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping +- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles +- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads +- **Warp-Level Optimization**: Operations optimized for GPU warp execution model + +### Scalability +- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences +- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns +- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies + +### Performance Model + +Expected speedup for various sparsity levels: +- **50% sparsity**: ~1.8x speedup +- **75% sparsity**: ~3.2x speedup +- **90% sparsity**: ~6.5x speedup + +Performance factors: +- Skip overhead typically <5% of dense computation time +- Memory bandwidth reduction scales linearly with sparsity +- Shared memory aliasing enables 20-30% larger tile sizes + +## API Changes + +### New Required Parameters + +The Dynamic Mask Attention integration introduces new required parameters to the forward pass: + +- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed + - Determines the sparsity pattern for computational efficiency + +- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Contains dynamic attention bias values applied to attention scores before softmax + - Must have the same dtype and device as Q/K/V tensors + +### Updated Function Signature + +```python +def fwd( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + attn_mask: torch.Tensor, # Attention mask (REQUIRED) + attn_bias: torch.Tensor, # Attention bias (REQUIRED) + out: Optional[torch.Tensor] = None, # Pre-allocated output + softmax_scale: float = None, # Attention scaling + is_causal: bool = False, # Causal masking + softcap: float = 0.0, # Soft capping + return_softmax: bool = False, # Return attention weights +) -> List[torch.Tensor] +``` + +### Backward Compatibility + +**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. + +**Migration Path**: Users need to: +1. Add attention mask and bias generation logic to attention modules +2. Implement appropriate mask and bias computation within the attention forward pass +3. Ensure proper tensor shapes and dtypes for mask and bias tensors + +### Complete Usage Example + +```python +import torch +import torch.nn as nn +import flash_dmattn_cuda as flash_dmattn + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = 1.0 / math.sqrt(self.head_dim) + + # Standard attention projections + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, attention_mask=None, attention_bias=None): + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Prepare mask and bias tensors with proper shapes + if attention_mask is None: + attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + if attention_bias is None: + attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + # Call Flash Dynamic Mask Attention + output, _ = flash_dmattn.fwd( + query_states, key_states, value_states, + attention_mask, attention_bias, + None, # out + self.scaling, # softmax_scale + False, # is_causal + 0.0, # softcap + False # return_softmax + ) + + # Output projection + output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) + return self.o_proj(output) +``` + + # Call attention implementation + attn_output, attn_weights = flash_dynamic_mask_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attn_bias, + scaling=self.scaling, + ) + + return attn_output, attn_weights +``` + +The attention bias generation process: + +1. **Value-based Dynamic States**: + ```python + dt_states = self.dt_proj(value_states_flattened) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + ``` + +2. **Bias Expansion**: + ```python + attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) + ``` + +3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` + + +### CUDA Backend: Sparse Attention Computation + +The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: + +```python +def _flash_dynamic_mask_attention_forward( + query_states, key_states, value_states, + attention_mask, attention_bias, + query_length, key_length, + is_causal, softmax_scale=None, softcap=None, + target_dtype=None, implementation=None, **kwargs +): + dtype = query_states.dtype + min_dtype = torch.finfo(dtype).min + batch_size, _, num_kv_heads, _ = key_states.shape + + # Initialize attention bias if not provided + if attention_bias is None: + attention_bias = torch.zeros( + (batch_size, num_kv_heads, query_length, key_length), + dtype=dtype, device=query_states.device + ) + + # Apply attention mask to bias + if attention_mask is not None: + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) + attention_mask = attention_mask.to(dtype) + + # Call Flash Attention with dynamic masking + out = flash_dmattn_func( + query_states, key_states, value_states, + attn_mask=attention_mask, attn_bias=attention_bias, + scale=softmax_scale, is_causal=is_causal + ) + + return out[0] if isinstance(out, tuple) else out +``` + +The backend processing stages: + +1. **Bias Initialization**: Create zero bias tensor if not provided +2. **Mask Application**: Apply boolean attention mask to bias tensor +3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns + +#### Updated Forward Algorithm + +The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: + +```cpp +// Forward pass with unified skip logic +for m_block in M_tiles: + load Q_tile + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Block-level skip decision + if !any_active: + advance_pointers() // Skip computation, advance to next tile + continue + + // Only execute for active tiles + load K_tile, V_tile // Load data only when needed + S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM + S_masked = apply_mask(S, mask_block) // Apply dynamic masking + P = softmax(S_masked, LSE_cache) // Softmax with LSE caching + O_partial += P @ V_tile // Sparse Score*V GEMM +write O +``` + +Key improvements: +- **Block-level Skip Logic**: OR-reduction over entire (BlockM ร— BlockN) tile determines if computation is needed +- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation +- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles + +#### Updated Backward Algorithm + +The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: + +```cpp +// Backward pass with unified skip logic +for m_block in reversed(M_tiles): + load Q_tile, dO_tile + init accum_dQ + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Same skip decision as forward + if !any_active: + advance_pointers_zero_side_outputs() // Skip computation, zero side outputs + continue + + // Only execute for active tiles + load K_tile, V_tile + + # Recompute (identical to forward for active tiles) + S = Q_tile @ K_tile^T + bias_block + P = softmax(S, LSE_cache) // Use cached LSE for stability + + # Gradient computation chain (5 GEMMs) + dV += P^T @ dO_tile // Accumulate dV + dP = dO_tile @ V_tile^T // Compute dP + dS = g(P, dP) // dS = (dP - (P โŠ™ dP).sum(axis)) * P + dQ += dS @ K_tile // Accumulate dQ + dK += dS^T @ Q_tile // Accumulate dK + write dQ, accumulate dK, dV +``` + +Key features: +- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision +- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation +- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness +- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation + +#### Skip Logic Correctness + +The mathematical correctness of the skip logic relies on the following principles: + +1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: + ``` + O_contribution = P @ V = 0 @ V = 0 + ``` + +2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: + ``` + P = 0 โŸน dS = 0 โŸน dQ = dK = dV = 0 (from this tile) + ``` + +3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. + +### Sparse Computation Strategy + +### Block-level Skip Logic + +The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: + +1. **Tile-level Active Detection**: + ```cpp + any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active + ``` + +2. **Skip Decision**: Binary branch based on tile activity: + ```cpp + if (!any_active) { + advance_pointers(); // Forward: skip all computation + advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs + continue; + } + ``` + +3. **Computational Benefits**: + - Skip entire K/V loads for inactive tiles + - Eliminate all 5 GEMMs in backward pass for inactive tiles + - Reduce memory bandwidth and arithmetic operations proportional to sparsity + +### Sparsity Pattern Recognition + +The Dynamic Mask Attention implements structured sparsity based on learned importance scores: + +1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors + - Learned projection matrices map value features to importance scores + - Coefficient parameters control the dynamic range of importance values + - Activation functions ensure appropriate bias magnitude + +2. **Binary Attention Mask**: + - 1.0 for positions that should be computed + - 0.0 for positions that should be skipped + +### Performance Model (Updated) + +For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: + +$$ +\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} +$$ + +Where: +- $p$: fraction of active tiles +- $\varepsilon$: skip branching overhead +- $\eta$: efficiency of early memory load exit +- $\text{LoadOverhead}$: relative cost of K/V loading vs computation + +Upper bound as $\varepsilon, \eta \to 0$: $1/p$ + +### Shared Memory Aliasing + +The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: + +1. **sMask โ†” sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption +2. **sBias โ†” sdS Aliasing**: Bias shared memory region is reused for gradient computations dS +3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage + +```cpp +// Example aliasing pattern +load mask -> sMask +any_active = or_reduce(sMask) +if any_active: + compute S + __syncthreads() // ensure mask fully consumed + softmax -> write P into aliased region (sP) // reuse sMask region as sP + ... +__syncthreads() // ensure dS consumed +// reuse sBias region as sdS in next iteration +``` + +### Memory Efficiency Optimizations + +1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask โ†” sP, sBias โ†” sdS) with explicit barrier synchronization +2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles +3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory +5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy +6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead + +## Memory Layout + +### Tensor Memory Organization + +The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: + +``` +Global Memory Layout: +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ”‚ K: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ V: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Output: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Shared Memory Layout (per thread block): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Tile: [kBlockM, head_dim] โ”‚ K Tile: [kBlockN, head_dim] โ”‚ +โ”‚ V Tile: [kBlockN, head_dim] โ”‚ S Tile: [kBlockM, kBlockN] โ”‚ +โ”‚ AM Tile: [kBlockM, kBlockN] โ”‚ Bias Tile: [kBlockM, kBlockN] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Register Memory (per thread): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Frag: [MMA_M, head_dim/N] โ”‚ K Frag: [MMA_N, head_dim/N] โ”‚ +โ”‚ V Frag: [MMA_N, head_dim/N] โ”‚ S Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ AM Frag: [MMA_M, MMA_N] โ”‚ Bias Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ Acc Frag: [MMA_M, head_dim/N] โ”‚ โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Memory Access Patterns + +#### Attention Mask and Attention Bias Loading +```cpp +// Global to Shared Memory (coalesced access) +Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); +Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); + +// Each thread loads a contiguous chunk to maximize memory bandwidth +copy(smem_tiled_copy_Bias, tSgBias, tSsBias); + +// Shared to Register Memory (bank-conflict-free) +Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); +copy(smem_thr_copy_Bias, tSsBias, tSrBias); +``` + +#### Memory Layout Transformations +```cpp +// Convert MMA accumulator layout to row-column layout for masking +// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) +auto convert_layout_acc_rowcol = [](auto layout) { + return make_layout( + make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), + make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), + make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), + make_stride(Int<1>{}, Int<2>{})) + ); +}; +``` + +### Shared Memory Optimization + +#### Bank Conflict Avoidance +- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts +- Padding added when necessary to ensure 128-bit aligned access +- Thread block size chosen to maximize occupancy while maintaining memory efficiency + +#### Memory Coalescing +```cpp +// Example: Loading 128-bit aligned chunks for optimal bandwidth +using SmemCopyAtomBias = Copy_Atom; // 128-bit loads +using SmemCopyAtomAttnMask = Copy_Atom; +``` + +## Performance Considerations + +### Memory Efficiency +- **Shared Memory Aliasing**: Smart memory reuse (sMask โ†” sP, sBias โ†” sdS) reduces footprint by ~30% +- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles +- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy + +### Computational Efficiency +- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping +- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles +- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads +- **Warp-Level Optimization**: Operations optimized for GPU warp execution model + +### Scalability +- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences +- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns +- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies + +### Performance Model + +Expected speedup for various sparsity levels: +- **50% sparsity**: ~1.8x speedup +- **75% sparsity**: ~3.2x speedup +- **90% sparsity**: ~6.5x speedup + +Performance factors: +- Skip overhead typically <5% of dense computation time +- Memory bandwidth reduction scales linearly with sparsity +- Shared memory aliasing enables 20-30% larger tile sizes + +## API Changes + +### New Required Parameters + +The Dynamic Mask Attention integration introduces new required parameters to the forward pass: + +- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed + - Determines the sparsity pattern for computational efficiency + +- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Contains dynamic attention bias values applied to attention scores before softmax + - Must have the same dtype and device as Q/K/V tensors + +### Updated Function Signature + +```python +def fwd( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + attn_mask: torch.Tensor, # Attention mask (REQUIRED) + attn_bias: torch.Tensor, # Attention bias (REQUIRED) + out: Optional[torch.Tensor] = None, # Pre-allocated output + softmax_scale: float = None, # Attention scaling + is_causal: bool = False, # Causal masking + softcap: float = 0.0, # Soft capping + return_softmax: bool = False, # Return attention weights +) -> List[torch.Tensor] +``` + +### Backward Compatibility + +**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. + +**Migration Path**: Users need to: +1. Add attention mask and bias generation logic to attention modules +2. Implement appropriate mask and bias computation within the attention forward pass +3. Ensure proper tensor shapes and dtypes for mask and bias tensors + +### Complete Usage Example + +```python +import torch +import torch.nn as nn +from flash_dmattn.integration.flash_dynamic_mask_attention import flash_dynamic_mask_attention_forward + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = 1.0 / math.sqrt(self.head_dim) + + # Standard attention projections + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, attention_mask=None, attention_bias=None): + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Prepare mask and bias tensors with proper shapes + if attention_mask is None: + attention_mask = torch.ones((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + if attention_bias is None: + attention_bias = torch.zeros((batch_size, self.num_kv_heads, seq_len, seq_len), + dtype=query_states.dtype, device=query_states.device) + + # Call attention implementation + attn_output, attn_weights = flash_dynamic_mask_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attention_bias, + scaling=self.scaling, + ) + + return attn_output, attn_weights +``` + +The attention bias generation process: + +1. **Value-based Dynamic States**: + ```python + dt_states = self.dt_proj(value_states_flattened) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + ``` + +2. **Bias Expansion**: + ```python + attn_bias = dt_states[:, :, None, :].expand(-1, -1, query_len, -1) + ``` + +3. **Mask Processing**: Done internally in `_flash_dynamic_mask_attention_forward` + + +### CUDA Backend: Sparse Attention Computation + +The CUDA backend implements the sparse attention computation through `_flash_dynamic_mask_attention_forward`: + +```python +def _flash_dynamic_mask_attention_forward( + query_states, key_states, value_states, + attention_mask, attention_bias, + query_length, key_length, + is_causal, softmax_scale=None, softcap=None, + target_dtype=None, implementation=None, **kwargs +): + dtype = query_states.dtype + min_dtype = torch.finfo(dtype).min + batch_size, _, num_kv_heads, _ = key_states.shape + + # Initialize attention bias if not provided + if attention_bias is None: + attention_bias = torch.zeros( + (batch_size, num_kv_heads, query_length, key_length), + dtype=dtype, device=query_states.device + ) + + # Apply attention mask to bias + if attention_mask is not None: + attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype) + attention_mask = attention_mask.to(dtype) + + # Call Flash Attention with dynamic masking + out = flash_dmattn_func( + query_states, key_states, value_states, + attn_mask=attention_mask, attn_bias=attention_bias, + scale=softmax_scale, is_causal=is_causal + ) + + return out[0] if isinstance(out, tuple) else out +``` + +The backend processing stages: + +1. **Bias Initialization**: Create zero bias tensor if not provided +2. **Mask Application**: Apply boolean attention mask to bias tensor +3. **Flash Attention Call**: Execute optimized CUDA kernels with sparse patterns + +#### Forward Algorithm + +The implementation introduces unified block-level skip logic that optimizes computation by skipping entire tiles when they are fully masked: + +```cpp +// Forward pass with unified skip logic +for m_block in M_tiles: + load Q_tile + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Block-level skip decision + if !any_active: + advance_pointers() // Skip computation, advance to next tile + continue + + // Only execute for active tiles + load K_tile, V_tile // Load data only when needed + S = Q_tile @ K_tile^T + bias_block // Sparse Q*K^T GEMM + S_masked = apply_mask(S, mask_block) // Apply dynamic masking + P = softmax(S_masked, LSE_cache) // Softmax with LSE caching + O_partial += P @ V_tile // Sparse Score*V GEMM +write O +``` + +Key improvements: +- **Block-level Skip Logic**: OR-reduction over entire (BlockM ร— BlockN) tile determines if computation is needed +- **Early Skip Decision**: Mask evaluation happens before expensive K/V loading and computation +- **Pointer Management**: Safe pointer advancement ensures correct memory layout for subsequent tiles + +#### Backward Algorithm + +The backward pass also benefits from the unified skip logic, maintaining numerical correctness while significantly reducing computation for sparse patterns: + +```cpp +// Backward pass with unified skip logic +for m_block in reversed(M_tiles): + load Q_tile, dO_tile + init accum_dQ + for n_block in N_tiles_stream: + load mask_block + any_active = OR(mask_block) // Same skip decision as forward + if !any_active: + advance_pointers_zero_side_outputs() // Skip computation, zero side outputs + continue + + // Only execute for active tiles + load K_tile, V_tile + + // Recompute (identical to forward for active tiles) + S = Q_tile @ K_tile^T + bias_block + P = softmax(S, LSE_cache) // Use cached LSE for stability + + // Gradient computation chain (5 GEMMs) + dV += P^T @ dO_tile // Accumulate dV + dP = dO_tile @ V_tile^T // Compute dP + dS = g(P, dP) // dS = (dP - (P โŠ™ dP).sum(axis)) * P + dQ += dS @ K_tile // Accumulate dQ + dK += dS^T @ Q_tile // Accumulate dK + write dQ, accumulate dK, dV +``` + +Key features: +- **Recomputation Strategy**: Forward computation is recomputed only for active tiles to maintain numerical precision +- **LSE Caching**: Uses cached log-sum-exp values from forward pass for stable softmax recomputation +- **Gradient Chain**: All five gradient GEMMs are skipped for fully masked tiles, maintaining mathematical correctness +- **Zero Handling**: Properly handles zero contributions from skipped tiles in accumulation + +#### Skip Logic Correctness + +The mathematical correctness of the skip logic relies on the following principles: + +1. **Forward Skip**: If a tile is entirely masked (active_mask = 0), its contribution to the output is exactly zero: + ``` + O_contribution = P @ V = 0 @ V = 0 + ``` + +2. **Backward Skip**: For fully masked tiles, all intermediate gradients are zero: + ``` + P = 0 โŸน dS = 0 โŸน dQ = dK = dV = 0 (from this tile) + ``` + +3. **LSE Preservation**: Skipped tiles don't contribute to the log-sum-exp, maintaining numerical stability. + +### Sparse Computation Strategy + +### Block-level Skip Logic + +The implementation introduces unified block-level skip logic that operates at the tile granularity rather than individual elements: + +1. **Tile-level Active Detection**: + ```cpp + any_active = OR_reduce(mask_block) // Single bit indicating if any position in tile is active + ``` + +2. **Skip Decision**: Binary branch based on tile activity: + ```cpp + if (!any_active) { + advance_pointers(); // Forward: skip all computation + advance_pointers_zero_outputs(); // Backward: skip computation, zero side outputs + continue; + } + ``` + +3. **Computational Benefits**: + - Skip entire K/V loads for inactive tiles + - Eliminate all 5 GEMMs in backward pass for inactive tiles + - Reduce memory bandwidth and arithmetic operations proportional to sparsity + +### Sparsity Pattern Recognition + +The Dynamic Mask Attention implements structured sparsity based on learned importance scores: + +1. **Attention Bias Computation**: Attention bias values are computed based on dynamic states derived from value tensors + - Learned projection matrices map value features to importance scores + - Coefficient parameters control the dynamic range of importance values + - Activation functions ensure appropriate bias magnitude + +2. **Binary Attention Mask**: + - 1.0 for positions that should be computed + - 0.0 for positions that should be skipped + +### Performance Model + +For block-level sparsity with active tile fraction $p$, skip overhead ratio $\varepsilon$, and early-exit efficiency $\eta$: + +$$ +\text{Speedup} \approx \frac{1}{p + (1-p)(\varepsilon + \eta \cdot \text{LoadOverhead})} +$$ + +Where: +- $p$: fraction of active tiles +- $\varepsilon$: skip branching overhead +- $\eta$: efficiency of early memory load exit +- $\text{LoadOverhead}$: relative cost of K/V loading vs computation + +Upper bound as $\varepsilon, \eta \to 0$: $1/p$ + +### Shared Memory Aliasing + +The implementation introduces smart shared memory aliasing to reduce footprint and enable larger tile sizes: + +1. **sMask โ†” sP Aliasing**: Mask shared memory region is reused for storing softmax probabilities P after mask consumption +2. **sBias โ†” sdS Aliasing**: Bias shared memory region is reused for gradient computations dS +3. **Barrier Synchronization**: Explicit `__syncthreads()` calls ensure safe transitions between aliased usage + +```cpp +// Example aliasing pattern +load mask -> sMask +any_active = or_reduce(sMask) +if any_active: + compute S + __syncthreads() // ensure mask fully consumed + softmax -> write P into aliased region (sP) // reuse sMask region as sP + ... +__syncthreads() // ensure dS consumed +// reuse sBias region as sdS in next iteration +``` + +### Memory Efficiency Optimizations + +1. **Shared Memory Aliasing**: Smart reuse of memory regions (sMask โ†” sP, sBias โ†” sdS) with explicit barrier synchronization +2. **Block-level Skip**: Early exit from computation and memory loading for inactive tiles +3. **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +4. **Register-Optimized Operations**: Critical masking and gradient operations performed in register memory +5. **Coalesced Memory Access**: Optimized access patterns for GPU memory hierarchy +6. **Template Specialization**: Compile-time optimization eliminates runtime branching overhead + +## Memory Layout + +### Tensor Memory Organization + +The Dynamic Mask Attention extends Flash Attention's memory layout to include attention masks and attention bias: + +``` +Global Memory Layout: +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ”‚ K: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ V: [batch, seqlen_k, num_heads_k, head_dim] โ”‚ +โ”‚ AttnMask: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Bias: [batch, num_kv_heads, seqlen_q, seqlen_k] โ”‚ +โ”‚ Output: [batch, seqlen_q, num_heads, head_dim] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Shared Memory Layout (per thread block): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Tile: [kBlockM, head_dim] โ”‚ K Tile: [kBlockN, head_dim] โ”‚ +โ”‚ V Tile: [kBlockN, head_dim] โ”‚ S Tile: [kBlockM, kBlockN] โ”‚ +โ”‚ AM Tile: [kBlockM, kBlockN] โ”‚ Bias Tile: [kBlockM, kBlockN] โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +Register Memory (per thread): +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Q Frag: [MMA_M, head_dim/N] โ”‚ K Frag: [MMA_N, head_dim/N] โ”‚ +โ”‚ V Frag: [MMA_N, head_dim/N] โ”‚ S Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ AM Frag: [MMA_M, MMA_N] โ”‚ Bias Frag: [MMA_M, MMA_N] โ”‚ +โ”‚ Acc Frag: [MMA_M, head_dim/N] โ”‚ โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +``` + +### Memory Access Patterns + +#### Attention Mask and Attention Bias Loading +```cpp +// Global to Shared Memory (coalesced access) +Tensor tSgBias = local_partition(mBias, smem_tiled_copy_Bias, thread_idx); +Tensor tSsBias = local_partition(sBias, smem_tiled_copy_Bias, thread_idx); + +// Each thread loads a contiguous chunk to maximize memory bandwidth +copy(smem_tiled_copy_Bias, tSgBias, tSsBias); + +// Shared to Register Memory (bank-conflict-free) +Tensor tSrBias = local_partition(sBias, smem_thr_copy_Bias, thread_idx); +copy(smem_thr_copy_Bias, tSsBias, tSrBias); +``` + +#### Memory Layout Transformations +```cpp +// Convert MMA accumulator layout to row-column layout for masking +// From: (MMA=4, MMA_M, MMA_N) -> (nrow=(2, MMA_M), ncol=(2, MMA_N)) +auto convert_layout_acc_rowcol = [](auto layout) { + return make_layout( + make_layout(make_shape(Int<2>{}, get<1>(layout.shape())), + make_stride(Int(layout.stride())* 2>{}, get<1>(layout.stride()))), + make_layout(make_shape(Int<2>{}, get<2>(layout.shape())), + make_stride(Int<1>{}, Int<2>{})) + ); +}; +``` + +### Shared Memory Optimization + +#### Bank Conflict Avoidance +- Attention bias and attention masks use the same copy patterns as Q/K/V to avoid bank conflicts +- Padding added when necessary to ensure 128-bit aligned access +- Thread block size chosen to maximize occupancy while maintaining memory efficiency + +#### Memory Coalescing +```cpp +// Example: Loading 128-bit aligned chunks for optimal bandwidth +using SmemCopyAtomBias = Copy_Atom; // 128-bit loads +using SmemCopyAtomAttnMask = Copy_Atom; +``` -2. **Adaptive Sparsity Patterns**: Dynamic adjustment of attention patterns based on input characteristics - - Learned sparsity controllers - - Content-aware mask generation - - Adaptive `keep_window_size` selection +## Performance Considerations + +### Memory Efficiency +- **Shared Memory Aliasing**: Smart memory reuse (sMask โ†” sP, sBias โ†” sdS) reduces footprint by ~30% +- **Block-level Skip**: Early exit eliminates unnecessary memory loads for inactive tiles +- **LSE Caching**: Forward pass LSE values cached and reused in backward pass for numerical stability +- **Coalesced Access**: Optimized tensor layouts for GPU memory hierarchy + +### Computational Efficiency +- **Unified Skip Logic**: Both forward and backward passes benefit from block-level computation skipping +- **5-GEMM Chain Skip**: Complete gradient computation chain skipped for inactive tiles +- **Early Branch Decision**: Mask OR-reduction happens before expensive K/V loads +- **Warp-Level Optimization**: Operations optimized for GPU warp execution model + +### Scalability +- **Block-level Granularity**: Tile-level sparsity more efficient than element-level for long sequences +- **Multi-Head Support**: Efficient handling of multiple attention heads with per-head sparsity patterns +- **Barrier Optimization**: Minimal synchronization overhead through smart aliasing strategies + +### Performance Model -3. **Multi-GPU Distributed Support**: Optimizations for large-scale distributed training - - Efficient tensor parallelism for long sequences - - Communication-optimal attention computation - - Memory-balanced workload distribution +Expected speedup for various sparsity levels: +- **50% sparsity**: ~1.8x speedup +- **75% sparsity**: ~3.2x speedup +- **90% sparsity**: ~6.5x speedup -4. **Advanced Memory Optimizations**: Further reduce memory footprint for extremely long sequences - - Progressive attention computation - - Hierarchical sparsity patterns - - Memory-efficient checkpoint/recomputation strategies +Performance factors: +- Skip overhead typically <5% of dense computation time +- Memory bandwidth reduction scales linearly with sparsity +- Shared memory aliasing enables 20-30% larger tile sizes + +## API Changes + +### New Required Parameters + +The Dynamic Mask Attention integration introduces new required parameters to the forward pass: + +- **`attn_mask`** (`torch.Tensor`): Attention mask tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Binary mask (1.0 = compute, 0.0 = skip) indicating which positions should be processed + - Determines the sparsity pattern for computational efficiency -5. **Hardware-Specific Optimizations**: Leverage newer GPU architectures - - Hopper architecture optimizations - - Sparse Tensor Core utilization - - Advanced memory hierarchy exploitation +- **`attn_bias`** (`torch.Tensor`): Attention bias tensor of shape `(batch, num_kv_heads, seqlen_q, seqlen_k)` + - Contains dynamic attention bias values applied to attention scores before softmax + - Must have the same dtype and device as Q/K/V tensors -### Performance Targets +### Updated Function Signature -- **Sequence Length**: Support up to 1M+ tokens efficiently -- **Memory Reduction**: 50-80% memory savings compared to dense attention -- **Speed**: Maintain or improve upon Flash Attention performance for long sequences -- **Sparsity**: Flexible sparsity ratios from 10% to 90% depending on use case +```python +def fwd( + q: torch.Tensor, # Query tensor + k: torch.Tensor, # Key tensor + v: torch.Tensor, # Value tensor + attn_mask: torch.Tensor, # Attention mask (REQUIRED) + attn_bias: torch.Tensor, # Attention bias (REQUIRED) + out: Optional[torch.Tensor] = None, # Pre-allocated output + softmax_scale: float = None, # Attention scaling + is_causal: bool = False, # Causal masking + softcap: float = 0.0, # Soft capping + return_softmax: bool = False, # Return attention weights +) -> List[torch.Tensor] +``` -## Conclusion +### Backward Compatibility -The Dynamic Mask Attention integration successfully combines Flash Attention's memory efficiency with structured sparsity to enable efficient processing of extremely long sequences. The implementation maintains the core optimization principles of Flash Attention while adding the capability to skip computation for less important token interactions. +**Breaking Change Notice**: The integration requires attention bias and attention mask tensors as mandatory parameters. This is a breaking change from the original Flash Attention API. -Key achievements of this integration: +**Migration Path**: Users need to: +1. Add attention mask and bias generation logic to attention modules +2. Implement appropriate mask and bias computation within the attention forward pass +3. Ensure proper tensor shapes and dtypes for mask and bias tensors -1. **Seamless Integration**: All dynamic masking functionality integrated into Flash Attention's kernel architecture without compromising existing optimizations +### Complete Usage Example -2. **Comprehensive Implementation**: Complete pipeline from Python preprocessing to optimized CUDA kernels with proper memory management +```python +import torch +import torch.nn as nn +import flash_dmattn_cuda as flash_dmattn + +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.scaling = 1.0 / math.sqrt(self.head_dim) + + # Standard attention projections + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False) + + def forward(self, hidden_states, attention_mask=None, attention_bias=None): + batch_size, seq_len, _ = hidden_states.shape + + # Project to Q, K, V + query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) + + # Generate attention bias from value states + dt_states = self.dt_proj( + value_states.transpose(1, 2).reshape(batch_size, seq_len, -1) + ) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1).to(hidden_states.dtype) + + # Prepare attention mask for multi-head + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, self.num_kv_heads, -1, -1) + + # Flash Dynamic Mask Attention + attn_output, _ = flash_dynamic_mask_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=attention_mask, + attention_bias=attention_bias, + scaling=self.scaling, + ) + + # Output projection + attn_output = attn_output.reshape(batch_size, seq_len, -1) + return self.o_proj(attn_output) + +# Usage example +config = type('Config', (), { + 'hidden_size': 768, + 'num_attention_heads': 12, + 'num_key_value_heads': 12, +})() + +attention = DynamicMaskAttention(config) +hidden_states = torch.randn(2, 4096, 768, device='cuda', dtype=torch.bfloat16) +output = attention(hidden_states) +print(f"Output shape: {output.shape}") # [2, 4096, 768] +``` -3. **Flexible Sparsity Control**: Configurable sparsity levels through the `keep_window_size` parameter to balance quality and efficiency +### Integration with Existing Codebases -4. **Robust Validation**: Extensive testing infrastructure ensures numerical equivalence with reference implementations +For users migrating from Flash Attention, the typical changes required are: -5. **Performance Optimization**: Sparse computation patterns reduce both memory usage and computational overhead for long sequences +```python +# Before (Flash Attention) +class StandardAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) + + def forward(self, hidden_states): + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + output = flash_attn_func(q, k, v, dropout_p=0.1, softmax_scale=self.scaling, causal=True) + return self.o_proj(output) -This integration enables practitioners to efficiently handle very long sequences in transformer models while maintaining the numerical stability and optimization benefits that have made Flash Attention the standard for efficient attention computation. \ No newline at end of file +# After (Dynamic Mask Attention) +class DynamicMaskAttention(nn.Module): + def __init__(self, config): + super().__init__() + # Same standard projections + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size) + + # Add dynamic mask parameters + self.A = nn.Parameter(torch.zeros(config.num_key_value_heads)) + self.dt_proj = nn.Linear(config.num_key_value_heads * self.head_dim, config.num_key_value_heads) + self.keep_window_size = config.keep_window_size + + def forward(self, hidden_states): + # Standard Q, K, V projections + query_states = self.q_proj(hidden_states).view(...).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(...).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(...).transpose(1, 2) + + # Generate attention bias from value states + dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(...)) + dt_states = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2) + attention_bias = dt_states[:, :, None, :].expand(-1, -1, seq_len, -1) + + # Use Flash Dynamic Mask Attention + attn_output, _ = flash_dynamic_mask_attention_forward( + self, query_states, key_states, value_states, + attention_mask=attention_mask, attention_bias=attention_bias, + scaling=self.scaling + ) + + return self.o_proj(attn_output.reshape(...)) +``` \ No newline at end of file