diff --git a/README.md b/README.md index e3f0d35..54a0fa3 100644 --- a/README.md +++ b/README.md @@ -127,13 +127,19 @@ Flash-DMA combines two complementary techniques: The integration happens at the CUDA kernel level with several key components: -- **ZOH States**: Pre-computed importance scores for key selection +- **ZOH States**: Pre-computed importance scores for key selection (computed from Value vectors only) - **Active Masks**: Binary masks indicating which keys should be considered for each query - **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation - **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences. +### Design Characteristics + +āš ļø **Important Design Note**: Flash-DMA uses a **query-agnostic** masking approach where the same set of keys is selected for all queries. This design choice prioritizes computational efficiency and works well for tasks with global importance patterns, but may be suboptimal for fine-grained associative recall tasks that require query-specific key selection. + +For detailed analysis of this design choice and its implications, see the [Design Choices Documentation](docs/design_choices.md). + ## Documentation @@ -141,6 +147,7 @@ This creates a hybrid attention mechanism that achieves both memory and computat - **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples - **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration +- **[Design Choices](docs/design_choices.md)** - Analysis of query-agnostic masking and its implications for different tasks ## Building from Source diff --git a/benchmarks/forward_equivalence.py b/benchmarks/forward_equivalence.py index b3b0883..b948bed 100644 --- a/benchmarks/forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -59,19 +59,34 @@ def prepare_dynamic_mask( """ Calculate dynamic attention mask to mask tokens for sparse attention. + āš ļø DESIGN NOTE: This function implements QUERY-AGNOSTIC masking. + The same ZOH-based importance scores are broadcast to ALL queries, meaning: + 1. All queries attend to the same set of top-K keys + 2. No query-specific key selection is performed + 3. Optimization prioritizes efficiency over query-specific precision + + This approach works well for tasks with global importance patterns but may be + suboptimal for fine-grained associative recall tasks. See docs/design_choices.md + for detailed analysis and potential alternatives. + Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) + zoh_states: [batch_size, num_kv_heads, key_sequence_length] - Value-based importance scores keep_window_size: Window size of tokens not dynamically masked attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len) Returns: tuple: (attn_bias, attn_mask) + - attn_bias: [batch_size, num_kv_heads, query_len, key_len] - ZOH scores + masking + - attn_mask: [batch_size, num_kv_heads, query_len, key_len] - Binary active mask """ min_dtype = torch.finfo(hidden_states.dtype).min dtype = hidden_states.dtype + + # šŸ” KEY INSIGHT: Broadcasting same importance scores to ALL queries + # Shape transformation: [batch, heads, key_len] -> [batch, heads, query_len, key_len] attn_bias = zoh_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] @@ -103,14 +118,22 @@ def calculate_zoh_states(value_states, dt_proj, A): """ Calculate zoh states for dynamic mask attention. + āš ļø DESIGN NOTE: This function implements QUERY-AGNOSTIC importance scoring. + ZOH states are computed solely from Value vectors and contain no query-specific information. + The same importance scores will be broadcast to all queries, meaning all queries + will attend to the same set of top-K keys. + + This design choice prioritizes computational efficiency over query-specific precision. + See docs/design_choices.md for detailed analysis of implications. + Args: value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask + dt_proj: [num_kv_heads, num_kv_heads * head_dim] - Learned projection matrix + A: [num_kv_heads] - Scaling coefficients Returns: - zoh_states: [batch_size, num_kv_heads, key_len] + zoh_states: [batch_size, num_kv_heads, key_len] - Value-based importance scores + Note: No query dimension - same scores applied to all queries """ batch_size, _, key_len, _ = value_states.shape diff --git a/benchmarks/forward_performance.py b/benchmarks/forward_performance.py index 8ba8ec0..407646d 100644 --- a/benchmarks/forward_performance.py +++ b/benchmarks/forward_performance.py @@ -81,19 +81,34 @@ def prepare_dynamic_mask( """ Calculate dynamic attention mask to mask tokens for sparse attention. + āš ļø DESIGN NOTE: This function implements QUERY-AGNOSTIC masking. + The same ZOH-based importance scores are broadcast to ALL queries, meaning: + 1. All queries attend to the same set of top-K keys + 2. No query-specific key selection is performed + 3. Optimization prioritizes efficiency over query-specific precision + + This approach works well for tasks with global importance patterns but may be + suboptimal for fine-grained associative recall tasks. See docs/design_choices.md + for detailed analysis and potential alternatives. + Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`. Args: hidden_states: Input hidden states to determine dtype minimum value - zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length) + zoh_states: [batch_size, num_kv_heads, key_sequence_length] - Value-based importance scores keep_window_size: Window size of tokens not dynamically masked attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len) Returns: tuple: (attn_bias, attn_mask) + - attn_bias: [batch_size, num_kv_heads, query_len, key_len] - ZOH scores + masking + - attn_mask: [batch_size, num_kv_heads, query_len, key_len] - Binary active mask """ min_dtype = torch.finfo(hidden_states.dtype).min dtype = hidden_states.dtype + + # šŸ” KEY INSIGHT: Broadcasting same importance scores to ALL queries + # Shape transformation: [batch, heads, key_len] -> [batch, heads, query_len, key_len] attn_bias = zoh_states[:, :, None, :].expand( -1, -1, hidden_states.shape[2], -1 ) # [batch_size, num_kv_heads, query_len, key_len] @@ -110,6 +125,8 @@ def prepare_dynamic_mask( ) if attn_bias.shape[-1] > keep_window_size: + # šŸ” CRITICAL: TopK selection produces SAME keys for ALL queries + # This creates uniform attention patterns across all query positions topk_indices = torch.topk( attn_bias, keep_window_size, dim=-1, largest=True, sorted=False ).indices @@ -125,14 +142,22 @@ def calculate_zoh_states(value_states, dt_proj, A): """ Calculate zoh states for dynamic mask attention. + āš ļø DESIGN NOTE: This function implements QUERY-AGNOSTIC importance scoring. + ZOH states are computed solely from Value vectors and contain no query-specific information. + The same importance scores will be broadcast to all queries, meaning all queries + will attend to the same set of top-K keys. + + This design choice prioritizes computational efficiency over query-specific precision. + See docs/design_choices.md for detailed analysis of implications. + Args: value_states: [batch_size, num_kv_heads, key_len, head_dim] - dt_proj: [num_kv_heads, num_kv_heads * head_dim] - A: [num_kv_heads] - causal_mask: Optional causal mask + dt_proj: [num_kv_heads, num_kv_heads * head_dim] - Learned projection matrix + A: [num_kv_heads] - Scaling coefficients Returns: - zoh_states: [batch_size, num_kv_heads, key_len] + zoh_states: [batch_size, num_kv_heads, key_len] - Value-based importance scores + Note: No query dimension - same scores applied to all queries """ batch_size, _, key_len, _ = value_states.shape diff --git a/docs/design_choices.md b/docs/design_choices.md new file mode 100644 index 0000000..d3292de --- /dev/null +++ b/docs/design_choices.md @@ -0,0 +1,195 @@ +# Flash Dynamic Mask Attention: Design Choices and Trade-offs + +## Overview + +This document explains key design decisions in Flash Dynamic Mask Attention, particularly regarding the query-agnostic nature of the dynamic masking mechanism and its implications for different types of attention tasks. + +## Query-Agnostic Masking Design + +### Current Implementation + +The Flash Dynamic Mask Attention implementation uses a **query-agnostic** masking strategy: + +```python +# 1. ZOH states computed ONLY from Value vectors +def calculate_zoh_states(value_states, dt_proj, A): + """ + ZOH states depend only on Value vectors, not Queries. + Result shape: [batch_size, num_kv_heads, key_len] # No query dimension! + """ + dt_result = torch.matmul( + value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), + dt_proj.T + ) + dt_states = torch.exp(F.softplus(dt_result) * A) + return dt_states.transpose(-1, -2) + +# 2. Same importance scores broadcast to all queries +def prepare_dynamic_mask(hidden_states, zoh_states, keep_window_size, attention_mask): + """ + The same ZOH-based importance scores are applied to ALL queries. + """ + # Broadcast: [batch, heads, key_len] -> [batch, heads, query_len, key_len] + attn_bias = zoh_states[:, :, None, :].expand(-1, -1, hidden_states.shape[2], -1) + + # TopK selection: same keys selected for ALL queries + topk_indices = torch.topk(attn_bias, keep_window_size, dim=-1, + largest=True, sorted=False).indices + + # Result: all queries attend to the same top-K keys + active_mask = torch.zeros_like(attn_bias) + active_mask = active_mask.scatter(-1, topk_indices, 1.0) + return attn_bias, active_mask +``` + +### Key Characteristics + +1. **Value-only Computation**: Importance scores are derived solely from Value vectors +2. **Global Broadcasting**: Same importance scores applied to all query positions +3. **Uniform Selection**: All queries attend to the same set of top-K keys +4. **Query Independence**: Mask generation does not consider query content + +## Design Rationale + +### Computational Efficiency + +The query-agnostic design provides significant computational advantages: + +```python +# Query-agnostic (current): O(N) complexity for mask generation +zoh_states = compute_importance(V) # Shape: [batch, heads, N] +mask = topk(zoh_states.expand_to_queries()) # Broadcast operation + +# Query-aware alternative: O(N²) complexity for mask generation +for each query_i: + importance_i = compute_query_aware_importance(Q[i], V) # Shape: [batch, heads, N] + mask[i] = topk(importance_i) # Separate computation per query +``` + +**Benefits:** +- **Memory Efficiency**: Single importance computation instead of per-query computation +- **Speed**: O(N) mask generation vs O(N²) for query-aware approaches +- **Simplicity**: Cleaner implementation with fewer edge cases + +### When Query-Agnostic Masking Works Well + +This design is effective for tasks where: + +1. **Global Importance Patterns**: Some keys are inherently more important regardless of the query +2. **Structured Content**: Information is hierarchically organized (e.g., summaries, keywords) +3. **Content-based Retrieval**: Important information is identifiable from content alone + +#### Example: Document Summarization +```python +# Document: [title, abstract, section1, section2, ..., references] +# Value-based importance can identify: +# - Title and abstract (always important) +# - Key sentences (high information density) +# - Section headers (structural importance) +# All queries benefit from attending to these globally important positions +``` + +## Limitations for Associative Recall Tasks + +### The Challenge + +Associative recall tasks typically require **query-specific** key selection: + +```python +# Example: "What did Alice say about the meeting?" +# - Query focuses on: "Alice" + "meeting" +# - Relevant keys: positions mentioning both Alice and meetings +# - Irrelevant keys: positions about Bob, other topics, or Alice discussing other topics + +# Current limitation: All queries see the same "important" keys +# even if those keys aren't relevant to the specific query +``` + +### Specific Limitations + +1. **Context Mismatch**: Globally important keys may not be relevant to specific queries +2. **Information Dilution**: Attention spread across non-relevant but "important" positions +3. **Recall Precision**: Harder to precisely locate query-specific information + +### Quantitative Example + +Consider a document with 4096 tokens and `keep_window_size=512`: + +``` +Query-Agnostic (Current): +- All queries attend to the same 512 "important" positions +- For "What did Alice say?": only ~50 positions might actually mention Alice +- Efficiency: 50/512 = ~10% relevant attention + +Query-Aware (Ideal): +- Each query attends to its own 512 most relevant positions +- For "What did Alice say?": 400+ positions could mention Alice +- Efficiency: 400/512 = ~78% relevant attention +``` + +## Hybrid Approaches and Future Directions + +### Potential Improvements + +1. **Query-Conditioned Importance**: + ```python + # Compute importance based on query-key interaction + importance = compute_qk_importance(Q, V, dt_proj) # Shape: [batch, heads, query_len, key_len] + ``` + +2. **Multi-Stage Selection**: + ```python + # Stage 1: Global filtering (current approach) + global_mask = compute_global_importance(V) + + # Stage 2: Query-specific refinement within global selection + refined_mask = compute_query_specific(Q, V, global_mask) + ``` + +3. **Learned Query-Aware Projections**: + ```python + # Different projections for different query types + query_type = classify_query(Q) + dt_proj_specific = dt_proj_bank[query_type] + importance = compute_importance(V, dt_proj_specific) + ``` + +## Current Capabilities and Workarounds + +### What Still Works + +Even with query-agnostic masking, the system can handle some associative recall through: + +1. **Learned Global Patterns**: Training can identify generally important positions +2. **Redundant Information**: Multiple positions may contain similar information +3. **Post-Selection Attention**: Standard attention weights can still focus within selected keys + +### Practical Strategies + +For better associative recall with current implementation: + +1. **Larger Window Sizes**: Increase `keep_window_size` to capture more potential targets +2. **Multi-Head Diversity**: Different heads may learn different global importance patterns +3. **Hierarchical Processing**: Use multiple attention layers with different masking strategies + +## Conclusion + +The query-agnostic design of Flash Dynamic Mask Attention represents a **computational efficiency vs. precision trade-off**: + +**Advantages:** +- āœ… Excellent computational efficiency +- āœ… Simple implementation and debugging +- āœ… Effective for tasks with global importance patterns +- āœ… Good baseline performance across diverse tasks + +**Limitations:** +- āŒ Suboptimal for fine-grained associative recall +- āŒ May miss query-specific relevant information +- āŒ Less precise attention targeting + +This design choice prioritizes **efficiency and generality** over **task-specific optimization**. For applications requiring high-precision associative recall, consider: +1. Using larger window sizes +2. Implementing hybrid approaches +3. Contributing query-aware extensions to the project + +The current implementation serves as a strong foundation that balances performance and computational requirements while remaining extensible for future enhancements. \ No newline at end of file diff --git a/docs/integration.md b/docs/integration.md index 80351ab..cf346d4 100644 --- a/docs/integration.md +++ b/docs/integration.md @@ -685,21 +685,36 @@ sparse_gemm(acc_o, acc_s, tSrV, tSsS, tSsV, tSrActiveMask, ### Sparsity Pattern Recognition -The Dynamic Mask Attention implements structured sparsity based on learned importance scores: +The Dynamic Mask Attention implements structured sparsity based on learned importance scores using a **query-agnostic** approach: 1. **ZOH State Computation**: `dt_states = exp(A * softplus(V @ dt_proj^T))` + - **Value-only computation**: Importance scores derived solely from Value vectors - Learned projection matrix `dt_proj` maps value features to importance scores - Coefficient `A` controls the dynamic range of importance values - Exponential activation ensures positive importance scores + - **Result shape**: `[batch, num_heads, key_len]` (no query dimension) -2. **TopK Selection**: For sequences longer than `keep_window_size`: - - Select top-K most important positions per query token +2. **Global Broadcasting and TopK Selection**: For sequences longer than `keep_window_size`: + - **Broadcast**: ZOH states expanded to all queries: `[batch, heads, key_len] → [batch, heads, query_len, key_len]` + - **Uniform selection**: Same top-K keys selected for ALL query positions - K = `keep_window_size` (typically 512-2048) - Maintains fixed computational complexity regardless of sequence length + - **Trade-off**: Computational efficiency vs. query-specific precision 3. **Binary Active Mask**: - 1.0 for positions selected by TopK (compute) - 0.0 for positions not selected (skip computation) + - **Same mask applied to all queries** within each attention head + +### Design Implications + +āš ļø **Important**: This query-agnostic design has significant implications: + +- **Suitable for**: Tasks with global importance patterns, document processing, content summarization +- **Limitations**: May be suboptimal for fine-grained associative recall tasks requiring query-specific key selection +- **Trade-off**: Prioritizes computational efficiency and simplicity over task-specific precision + +For detailed analysis of this design choice and its implications, see [Design Choices Documentation](design_choices.md). ### Sparse GEMM Implementation diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..49fb586 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,31 @@ +# Examples + +This directory contains examples and demonstrations for Flash Dynamic Mask Attention. + +## Files + +### `query_agnostic_demo.py` + +A standalone demonstration that illustrates the query-agnostic nature of the dynamic masking mechanism. This script shows: + +- How ZOH states are computed from Value vectors only +- How the same importance scores are broadcast to all queries +- How TopK selection produces the same keys for all queries +- The implications of this design for different types of tasks + +**Run the demo:** +```bash +python examples/query_agnostic_demo.py +``` + +This demo helps understand the trade-offs between computational efficiency and query-specific precision discussed in Issue #117. + +### `modeling/` + +Contains example model implementations showing how to integrate Flash Dynamic Mask Attention into transformer architectures. + +## Related Documentation + +- [Design Choices](../docs/design_choices.md) - Detailed analysis of the query-agnostic design +- [Integration Guide](../docs/integration.md) - Technical implementation details +- [API Reference](../docs/api_reference.md) - Function documentation \ No newline at end of file diff --git a/examples/query_agnostic_demo.py b/examples/query_agnostic_demo.py new file mode 100755 index 0000000..8d04413 --- /dev/null +++ b/examples/query_agnostic_demo.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +""" +Demonstration of Query-Agnostic Masking Behavior in Flash Dynamic Mask Attention + +This script demonstrates how the current implementation applies the same mask +to all queries, showing both the benefits and limitations of this approach. +""" + +import torch +import torch.nn.functional as F + +def calculate_zoh_states(value_states, dt_proj, A): + """Calculate ZOH states from value vectors only (query-agnostic).""" + batch_size, _, key_len, _ = value_states.shape + + # Compute importance scores from Value vectors only + dt_result = torch.matmul( + value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), + dt_proj.T + ) + + dt_states = torch.exp(F.softplus(dt_result) * A) + return dt_states.transpose(-1, -2) + +def prepare_dynamic_mask(query_states, zoh_states, keep_window_size=4): + """Prepare dynamic mask - demonstrates query-agnostic behavior.""" + dtype = query_states.dtype + device = query_states.device + + # Broadcast same ZOH scores to all queries + attn_bias = zoh_states[:, :, None, :].expand(-1, -1, query_states.shape[2], -1) + + # TopK selection: same keys for all queries + if attn_bias.shape[-1] > keep_window_size: + topk_indices = torch.topk(attn_bias, keep_window_size, dim=-1, + largest=True, sorted=False).indices + active_mask = torch.zeros_like(attn_bias, dtype=dtype, device=device) + active_mask = active_mask.scatter(-1, topk_indices, 1.0) + else: + active_mask = torch.ones_like(attn_bias, dtype=dtype, device=device) + + return attn_bias, active_mask, topk_indices + +def main(): + print("=" * 70) + print("Flash Dynamic Mask Attention: Query-Agnostic Behavior Demonstration") + print("=" * 70) + + # Setup simple example + batch_size, num_heads, seq_len, head_dim = 1, 2, 8, 4 + keep_window_size = 4 + device = 'cpu' + + # Create example data with clear patterns + torch.manual_seed(42) + + # Values with clear importance pattern: positions 1, 3, 5, 7 are "important" + value_states = torch.zeros(batch_size, num_heads, seq_len, head_dim) + value_states[:, :, [1, 3, 5, 7], :] = 1.0 # Important positions + value_states[:, :, [0, 2, 4, 6], :] = 0.1 # Less important positions + + # Queries with different "intentions" + query_states = torch.randn(batch_size, num_heads, seq_len, head_dim) + query_states[:, :, 0, :] = torch.tensor([1.0, 0.0, 0.0, 0.0]) # Query 0: looking for pattern A + query_states[:, :, 1, :] = torch.tensor([0.0, 1.0, 0.0, 0.0]) # Query 1: looking for pattern B + query_states[:, :, 2, :] = torch.tensor([0.0, 0.0, 1.0, 0.0]) # Query 2: looking for pattern C + + # Learned parameters (simplified) + dt_proj = torch.ones(num_heads, num_heads * head_dim) * 0.5 + A = torch.ones(num_heads) + + print(f"Sequence length: {seq_len}") + print(f"Keep window size: {keep_window_size}") + print(f"Important value positions: [1, 3, 5, 7]") + print(f"Less important positions: [0, 2, 4, 6]") + print() + + # Calculate ZOH states (value-based importance) + zoh_states = calculate_zoh_states(value_states, dt_proj, A) + print("ZOH States (Value-based importance scores):") + print(f"Shape: {zoh_states.shape}") # [batch, heads, key_len] + print("Head 0:", zoh_states[0, 0].round(decimals=3).tolist()) + print("Head 1:", zoh_states[0, 1].round(decimals=3).tolist()) + print() + + # Generate dynamic mask + attn_bias, active_mask, topk_indices = prepare_dynamic_mask( + query_states, zoh_states, keep_window_size + ) + + print("TopK Selected Keys (same for ALL queries):") + print("Head 0:", topk_indices[0, 0, 0].tolist()) # Same for all queries + print("Head 1:", topk_indices[0, 1, 0].tolist()) # Same for all queries + print() + + print("Active Mask Verification (1.0 = attend, 0.0 = ignore):") + for head in range(num_heads): + print(f"\nHead {head}:") + print("Query positions -> Key positions:") + for query in range(min(3, seq_len)): # Show first 3 queries + mask_row = active_mask[0, head, query].tolist() + attended_keys = [i for i, val in enumerate(mask_row) if val == 1.0] + print(f" Query {query}: attends to keys {attended_keys}") + + print("\n" + "=" * 50) + print("KEY OBSERVATIONS:") + print("=" * 50) + print("1. ZOH states computed from Values only (no Query information)") + print("2. Same TopK keys selected for ALL queries") + print("3. Query intentions (patterns A, B, C) are ignored in key selection") + print("4. Computational efficiency: O(N) mask generation vs O(N²) for query-aware") + print("5. Trade-off: efficiency vs. query-specific precision") + + print("\n" + "=" * 50) + print("IMPLICATIONS FOR ASSOCIATIVE RECALL:") + print("=" * 50) + print("āœ… Works well when:") + print(" - Important information is globally relevant") + print(" - Document has clear hierarchical structure") + print(" - Similar information needs across queries") + + print("\nāŒ Limitations for:") + print(" - Fine-grained query-specific retrieval") + print(" - Tasks requiring different keys per query") + print(" - Precise associative recall ('What did Alice say about X?')") + + print("\nšŸ’” Potential improvements:") + print(" - Larger keep_window_size for more coverage") + print(" - Query-conditioned importance scoring") + print(" - Multi-stage selection (global + query-specific)") + +if __name__ == "__main__": + main() \ No newline at end of file