diff --git a/README.md b/README.md index 139f8ba..64ae1a6 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,7 @@ pip install . ```python import torch -import flash_dma_cuda -import torch.nn.functional as F +from flash_dmattn import flash_dmattn_func import math # Setup @@ -63,19 +62,32 @@ key = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) value = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) -zoh_states = torch.randn(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) -active_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, - device=device, dtype=dtype) - -# Run Flash-DMA -output = flash_dma_cuda.fwd( - q=query, k=key, v=value, - zoh=zoh_states, active_mask=active_mask, + +# Create mask and bias for sparse attention +attention_bias = torch.randn(batch_size, num_heads, seq_len, seq_len, + device=device, dtype=dtype) +attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len, + device=device, dtype=dtype) + +# Apply dynamic masking (keep top-k for long sequences) +keep_window_size = 2048 +if seq_len > keep_window_size: + # Select top-k most important keys for each query + topk_indices = torch.topk(attention_bias, keep_window_size, dim=-1, + largest=True, sorted=False).indices + attention_mask.zero_() + attention_mask.scatter(-1, topk_indices, 1.0) + +# Run Flash Dynamic Mask Attention +output = flash_dmattn_func( + q=query, + k=key, + v=value, + attn_mask=attention_mask, + attn_bias=attention_bias, softmax_scale=1.0/math.sqrt(head_dim), - keep_window_size=keep_window_size, is_causal=True -)[0] +) print(f"Output shape: {output.shape}") # [2, 4096, 12, 128] ``` @@ -189,34 +201,31 @@ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')" ```python # Test basic import try: - import flash_dma_cuda - print("✅ Flash DMA CUDA extension imported successfully") + from flash_dmattn import flash_dmattn_func, get_available_backends + print("✅ Flash Dynamic Mask Attention imported successfully") + print(f"Available backends: {get_available_backends()}") except ImportError as e: print(f"❌ Import failed: {e}") print("Please ensure the package is properly installed with: pip install -e .") ``` **Performance Issues** -- Ensure GPU has compute capability 8.0+ for optimal performance -- Use `torch.bfloat16` for better numerical stability -- Adjust `keep_window_size` based on available GPU memory -- Verify CUDA kernels are being used - -**Memory Issues** ```python # Monitor GPU memory usage -torch.cuda.memory_summary() -torch.cuda.max_memory_allocated() +from flash_dmattn import flash_dmattn_func + +def print_memory_stats(): + if torch.cuda.is_available(): + print(f"GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB") + +print_memory_stats() +output = flash_dmattn_func(q=query, k=key, v=value, is_causal=True) +print_memory_stats() # Clear cache if needed torch.cuda.empty_cache() ``` -**Numerical Issues** -- Use `torch.bfloat16` instead of `torch.float16` for better stability -- Check input tensor ranges for NaN or infinite values -- Validate ZOH states and active mask values are in expected ranges - ## License This project is licensed under the BSD 3-Clause License. See [LICENSE](LICENSE) for details.