Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 37 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ pip install .

```python
import torch
import flash_dma_cuda
import torch.nn.functional as F
from flash_dmattn import flash_dmattn_func
import math

# Setup
Expand All @@ -63,19 +62,32 @@ key = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
value = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=dtype)
zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)

# Run Flash-DMA
output = flash_dma_cuda.fwd(
q=query, k=key, v=value,
zoh=zoh_states, active_mask=active_mask,

# Create mask and bias for sparse attention
attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)
attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len,
device=device, dtype=dtype)

# Apply dynamic masking (keep top-k for long sequences)
keep_window_size = 2048
if seq_len > keep_window_size:
# Select top-k most important keys for each query
topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1,
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The keep_window_size variable is used but not defined in this context. It was removed from the function call but still referenced in the top-k selection logic.

Copilot uses AI. Check for mistakes.
largest=True, sorted=False).indices
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scatter operation may fail because topk_indices has shape [batch_size, num_heads, seq_len, keep_window_size] but attention_mask expects indices for the last dimension of size seq_len. The indices need to be properly shaped or the scatter operation needs to specify the correct dimensions.

Suggested change
attention_mask.scatter(-1, topk_indices, 1.0)
topk_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
topk_mask.scatter_(-1, topk_indices, True)
attention_mask.masked_fill_(topk_mask, 1.0)

Copilot uses AI. Check for mistakes.

# Run Flash Dynamic Mask Attention
output = flash_dmattn_func(
q=query,
k=key,
v=value,
attn_mask=attention_mask,
attn_bias=attention_bias,
softmax_scale=1.0/math.sqrt(head_dim),
keep_window_size=keep_window_size,
is_causal=True
)[0]
)

print(f"Output shape: {output.shape}") # [2, 4096, 12, 128]
```
Expand Down Expand Up @@ -189,34 +201,31 @@ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
```python
# Test basic import
try:
import flash_dma_cuda
print("✅ Flash DMA CUDA extension imported successfully")
from flash_dmattn import flash_dmattn_func, get_available_backends
print("✅ Flash Dynamic Mask Attention imported successfully")
Comment on lines +204 to +205
Copy link

Copilot AI Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_available_backends function is imported but never used in the example. Consider either using it in the example or removing it from the import to avoid confusion.

Suggested change
from flash_dmattn import flash_dmattn_func, get_available_backends
print("✅ Flash Dynamic Mask Attention imported successfully")
from flash_dmattn import flash_dmattn_func
print("✅ Flash Dynamic Mask Attention imported successfully")
from flash_dmattn import get_available_backends

Copilot uses AI. Check for mistakes.
print(f"Available backends: {get_available_backends()}")
except ImportError as e:
print(f"❌ Import failed: {e}")
print("Please ensure the package is properly installed with: pip install -e .")
```

**Performance Issues**
- Ensure GPU has compute capability 8.0+ for optimal performance
- Use `torch.bfloat16` for better numerical stability
- Adjust `keep_window_size` based on available GPU memory
- Verify CUDA kernels are being used

**Memory Issues**
```python
# Monitor GPU memory usage
torch.cuda.memory_summary()
torch.cuda.max_memory_allocated()
from flash_dmattn import flash_dmattn_func

def print_memory_stats():
if torch.cuda.is_available():
print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

print_memory_stats()
output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True)
print_memory_stats()

# Clear cache if needed
torch.cuda.empty_cache()
```

**Numerical Issues**
- Use `torch.bfloat16` instead of `torch.float16` for better stability
- Check input tensor ranges for NaN or infinite values
- Validate ZOH states and active mask values are in expected ranges

## License

This project is licensed under the BSD 3-Clause License. See [LICENSE](LICENSE) for details.
Expand Down