Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,27 @@ Flash-DMA combines two complementary techniques:

The integration happens at the CUDA kernel level with several key components:

- **ZOH States**: Pre-computed importance scores for key selection
- **ZOH States**: Pre-computed importance scores for key selection (computed from Value vectors only)
- **Active Masks**: Binary masks indicating which keys should be considered for each query
- **Sparse Matrix Multiplication**: Custom CUDA kernels for efficient sparse attention computation
- **Block-Based Processing**: Maintains Flash Attention's block-based approach for memory efficiency

This creates a hybrid attention mechanism that achieves both memory and computational efficiency for long sequences.

### Design Characteristics

⚠️ **Important Design Note**: Flash-DMA uses a **query-agnostic** masking approach where the same set of keys is selected for all queries. This design choice prioritizes computational efficiency and works well for tasks with global importance patterns, but may be suboptimal for fine-grained associative recall tasks that require query-specific key selection.

For detailed analysis of this design choice and its implications, see the [Design Choices Documentation](docs/design_choices.md).


## Documentation

πŸ“š **Complete documentation is available in the [docs](docs/) directory:**

- **[API Reference](docs/api_reference.md)** - Complete function documentation and usage examples
- **[Integration Guide](docs/integration.md)** - Detailed technical documentation of the Flash Attention integration
- **[Design Choices](docs/design_choices.md)** - Analysis of query-agnostic masking and its implications for different tasks


## Building from Source
Expand Down
33 changes: 28 additions & 5 deletions benchmarks/forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,34 @@ def prepare_dynamic_mask(
"""
Calculate dynamic attention mask to mask tokens for sparse attention.

⚠️ DESIGN NOTE: This function implements QUERY-AGNOSTIC masking.
The same ZOH-based importance scores are broadcast to ALL queries, meaning:
1. All queries attend to the same set of top-K keys
2. No query-specific key selection is performed
3. Optimization prioritizes efficiency over query-specific precision

This approach works well for tasks with global importance patterns but may be
suboptimal for fine-grained associative recall tasks. See docs/design_choices.md
for detailed analysis and potential alternatives.

Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`.

Args:
hidden_states: Input hidden states to determine dtype minimum value
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
zoh_states: [batch_size, num_kv_heads, key_sequence_length] - Value-based importance scores
keep_window_size: Window size of tokens not dynamically masked
attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len)

Returns:
tuple: (attn_bias, attn_mask)
- attn_bias: [batch_size, num_kv_heads, query_len, key_len] - ZOH scores + masking
- attn_mask: [batch_size, num_kv_heads, query_len, key_len] - Binary active mask
"""
min_dtype = torch.finfo(hidden_states.dtype).min
dtype = hidden_states.dtype

# πŸ” KEY INSIGHT: Broadcasting same importance scores to ALL queries
# Shape transformation: [batch, heads, key_len] -> [batch, heads, query_len, key_len]
attn_bias = zoh_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_kv_heads, query_len, key_len]
Expand Down Expand Up @@ -103,14 +118,22 @@ def calculate_zoh_states(value_states, dt_proj, A):
"""
Calculate zoh states for dynamic mask attention.

⚠️ DESIGN NOTE: This function implements QUERY-AGNOSTIC importance scoring.
ZOH states are computed solely from Value vectors and contain no query-specific information.
The same importance scores will be broadcast to all queries, meaning all queries
will attend to the same set of top-K keys.

This design choice prioritizes computational efficiency over query-specific precision.
See docs/design_choices.md for detailed analysis of implications.

Args:
value_states: [batch_size, num_kv_heads, key_len, head_dim]
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
causal_mask: Optional causal mask
dt_proj: [num_kv_heads, num_kv_heads * head_dim] - Learned projection matrix
A: [num_kv_heads] - Scaling coefficients

Returns:
zoh_states: [batch_size, num_kv_heads, key_len]
zoh_states: [batch_size, num_kv_heads, key_len] - Value-based importance scores
Note: No query dimension - same scores applied to all queries
"""
batch_size, _, key_len, _ = value_states.shape

Expand Down
35 changes: 30 additions & 5 deletions benchmarks/forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,34 @@ def prepare_dynamic_mask(
"""
Calculate dynamic attention mask to mask tokens for sparse attention.

⚠️ DESIGN NOTE: This function implements QUERY-AGNOSTIC masking.
The same ZOH-based importance scores are broadcast to ALL queries, meaning:
1. All queries attend to the same set of top-K keys
2. No query-specific key selection is performed
3. Optimization prioritizes efficiency over query-specific precision

This approach works well for tasks with global importance patterns but may be
suboptimal for fine-grained associative recall tasks. See docs/design_choices.md
for detailed analysis and potential alternatives.

Combine `zoh_states` with `attention_mask` to generate the final `attn_mask`.

Args:
hidden_states: Input hidden states to determine dtype minimum value
zoh_states: zoh_states of shape (batch_size, num_kv_heads, key_sequence_length)
zoh_states: [batch_size, num_kv_heads, key_sequence_length] - Value-based importance scores
keep_window_size: Window size of tokens not dynamically masked
attention_mask: Optional attention mask of shape (batch_size, 1, query_len, key_len)

Returns:
tuple: (attn_bias, attn_mask)
- attn_bias: [batch_size, num_kv_heads, query_len, key_len] - ZOH scores + masking
- attn_mask: [batch_size, num_kv_heads, query_len, key_len] - Binary active mask
"""
min_dtype = torch.finfo(hidden_states.dtype).min
dtype = hidden_states.dtype

# πŸ” KEY INSIGHT: Broadcasting same importance scores to ALL queries
# Shape transformation: [batch, heads, key_len] -> [batch, heads, query_len, key_len]
attn_bias = zoh_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[2], -1
) # [batch_size, num_kv_heads, query_len, key_len]
Expand All @@ -110,6 +125,8 @@ def prepare_dynamic_mask(
)

if attn_bias.shape[-1] > keep_window_size:
# πŸ” CRITICAL: TopK selection produces SAME keys for ALL queries
# This creates uniform attention patterns across all query positions
topk_indices = torch.topk(
attn_bias, keep_window_size, dim=-1, largest=True, sorted=False
).indices
Expand All @@ -125,14 +142,22 @@ def calculate_zoh_states(value_states, dt_proj, A):
"""
Calculate zoh states for dynamic mask attention.

⚠️ DESIGN NOTE: This function implements QUERY-AGNOSTIC importance scoring.
ZOH states are computed solely from Value vectors and contain no query-specific information.
The same importance scores will be broadcast to all queries, meaning all queries
will attend to the same set of top-K keys.

This design choice prioritizes computational efficiency over query-specific precision.
See docs/design_choices.md for detailed analysis of implications.

Args:
value_states: [batch_size, num_kv_heads, key_len, head_dim]
dt_proj: [num_kv_heads, num_kv_heads * head_dim]
A: [num_kv_heads]
causal_mask: Optional causal mask
dt_proj: [num_kv_heads, num_kv_heads * head_dim] - Learned projection matrix
A: [num_kv_heads] - Scaling coefficients

Returns:
zoh_states: [batch_size, num_kv_heads, key_len]
zoh_states: [batch_size, num_kv_heads, key_len] - Value-based importance scores
Note: No query dimension - same scores applied to all queries
"""
batch_size, _, key_len, _ = value_states.shape

Expand Down
195 changes: 195 additions & 0 deletions docs/design_choices.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Flash Dynamic Mask Attention: Design Choices and Trade-offs

## Overview

This document explains key design decisions in Flash Dynamic Mask Attention, particularly regarding the query-agnostic nature of the dynamic masking mechanism and its implications for different types of attention tasks.

## Query-Agnostic Masking Design

### Current Implementation

The Flash Dynamic Mask Attention implementation uses a **query-agnostic** masking strategy:

```python
# 1. ZOH states computed ONLY from Value vectors
def calculate_zoh_states(value_states, dt_proj, A):
"""
ZOH states depend only on Value vectors, not Queries.
Result shape: [batch_size, num_kv_heads, key_len] # No query dimension!
"""
dt_result = torch.matmul(
value_states.transpose(-2, -3).reshape(batch_size, key_len, -1),
dt_proj.T
)
dt_states = torch.exp(F.softplus(dt_result) * A)
return dt_states.transpose(-1, -2)

# 2. Same importance scores broadcast to all queries
def prepare_dynamic_mask(hidden_states, zoh_states, keep_window_size, attention_mask):
"""
The same ZOH-based importance scores are applied to ALL queries.
"""
# Broadcast: [batch, heads, key_len] -> [batch, heads, query_len, key_len]
attn_bias = zoh_states[:, :, None, :].expand(-1, -1, hidden_states.shape[2], -1)

# TopK selection: same keys selected for ALL queries
topk_indices = torch.topk(attn_bias, keep_window_size, dim=-1,
largest=True, sorted=False).indices

# Result: all queries attend to the same top-K keys
active_mask = torch.zeros_like(attn_bias)
active_mask = active_mask.scatter(-1, topk_indices, 1.0)
return attn_bias, active_mask
```

### Key Characteristics

1. **Value-only Computation**: Importance scores are derived solely from Value vectors
2. **Global Broadcasting**: Same importance scores applied to all query positions
3. **Uniform Selection**: All queries attend to the same set of top-K keys
4. **Query Independence**: Mask generation does not consider query content

## Design Rationale

### Computational Efficiency

The query-agnostic design provides significant computational advantages:

```python
# Query-agnostic (current): O(N) complexity for mask generation
zoh_states = compute_importance(V) # Shape: [batch, heads, N]
mask = topk(zoh_states.expand_to_queries()) # Broadcast operation

# Query-aware alternative: O(NΒ²) complexity for mask generation
for each query_i:
importance_i = compute_query_aware_importance(Q[i], V) # Shape: [batch, heads, N]
mask[i] = topk(importance_i) # Separate computation per query
```

**Benefits:**
- **Memory Efficiency**: Single importance computation instead of per-query computation
- **Speed**: O(N) mask generation vs O(NΒ²) for query-aware approaches
- **Simplicity**: Cleaner implementation with fewer edge cases

### When Query-Agnostic Masking Works Well

This design is effective for tasks where:

1. **Global Importance Patterns**: Some keys are inherently more important regardless of the query
2. **Structured Content**: Information is hierarchically organized (e.g., summaries, keywords)
3. **Content-based Retrieval**: Important information is identifiable from content alone

#### Example: Document Summarization
```python
# Document: [title, abstract, section1, section2, ..., references]
# Value-based importance can identify:
# - Title and abstract (always important)
# - Key sentences (high information density)
# - Section headers (structural importance)
# All queries benefit from attending to these globally important positions
```

## Limitations for Associative Recall Tasks

### The Challenge

Associative recall tasks typically require **query-specific** key selection:

```python
# Example: "What did Alice say about the meeting?"
# - Query focuses on: "Alice" + "meeting"
# - Relevant keys: positions mentioning both Alice and meetings
# - Irrelevant keys: positions about Bob, other topics, or Alice discussing other topics

# Current limitation: All queries see the same "important" keys
# even if those keys aren't relevant to the specific query
```

### Specific Limitations

1. **Context Mismatch**: Globally important keys may not be relevant to specific queries
2. **Information Dilution**: Attention spread across non-relevant but "important" positions
3. **Recall Precision**: Harder to precisely locate query-specific information

### Quantitative Example

Consider a document with 4096 tokens and `keep_window_size=512`:

```
Query-Agnostic (Current):
- All queries attend to the same 512 "important" positions
- For "What did Alice say?": only ~50 positions might actually mention Alice
- Efficiency: 50/512 = ~10% relevant attention

Query-Aware (Ideal):
- Each query attends to its own 512 most relevant positions
- For "What did Alice say?": 400+ positions could mention Alice
- Efficiency: 400/512 = ~78% relevant attention
```

## Hybrid Approaches and Future Directions

### Potential Improvements

1. **Query-Conditioned Importance**:
```python
# Compute importance based on query-key interaction
importance = compute_qk_importance(Q, V, dt_proj) # Shape: [batch, heads, query_len, key_len]
```

2. **Multi-Stage Selection**:
```python
# Stage 1: Global filtering (current approach)
global_mask = compute_global_importance(V)

# Stage 2: Query-specific refinement within global selection
refined_mask = compute_query_specific(Q, V, global_mask)
```

3. **Learned Query-Aware Projections**:
```python
# Different projections for different query types
query_type = classify_query(Q)
dt_proj_specific = dt_proj_bank[query_type]
importance = compute_importance(V, dt_proj_specific)
```

## Current Capabilities and Workarounds

### What Still Works

Even with query-agnostic masking, the system can handle some associative recall through:

1. **Learned Global Patterns**: Training can identify generally important positions
2. **Redundant Information**: Multiple positions may contain similar information
3. **Post-Selection Attention**: Standard attention weights can still focus within selected keys

### Practical Strategies

For better associative recall with current implementation:

1. **Larger Window Sizes**: Increase `keep_window_size` to capture more potential targets
2. **Multi-Head Diversity**: Different heads may learn different global importance patterns
3. **Hierarchical Processing**: Use multiple attention layers with different masking strategies

## Conclusion

The query-agnostic design of Flash Dynamic Mask Attention represents a **computational efficiency vs. precision trade-off**:

**Advantages:**
- βœ… Excellent computational efficiency
- βœ… Simple implementation and debugging
- βœ… Effective for tasks with global importance patterns
- βœ… Good baseline performance across diverse tasks

**Limitations:**
- ❌ Suboptimal for fine-grained associative recall
- ❌ May miss query-specific relevant information
- ❌ Less precise attention targeting

This design choice prioritizes **efficiency and generality** over **task-specific optimization**. For applications requiring high-precision associative recall, consider:
1. Using larger window sizes
2. Implementing hybrid approaches
3. Contributing query-aware extensions to the project

The current implementation serves as a strong foundation that balances performance and computational requirements while remaining extensible for future enhancements.
Loading