-
Notifications
You must be signed in to change notification settings - Fork 39
Add CUDA-integrated flash attention interface #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implements comprehensive PyTorch interface for flash attention operations with CUDA backend support. Provides multiple function variants including packed QKV, packed KV, and variable-length sequence handling for efficient attention computation. Includes custom autograd functions with proper forward and backward pass implementations, fake tensor registration for torch.compile compatibility, and support for features like dropout, causal masking, and softmax scaling. Enables optimized attention computation with automatic head dimension padding and device capability-aware block size selection.
Moves version declaration to top of module for better organization. Adds conditional imports for CUDA flash attention functions with proper error handling when CUDA backend is available but interface module fails to import. Renames the auto-selection function to better distinguish it from the imported CUDA implementation and updates the CUDA backend logic to use the imported function instead of raising NotImplementedError. Expands exports to include all flash attention function variants for better API accessibility.
Updates tensor dimension extraction to correctly handle variable-length attention sequences by deriving batch size from cumulative sequence lengths rather than tensor shapes. Removes unused head_size variable assignments and corrects mask/bias tensor dimensions to use max_seqlen parameters instead of individual sequence lengths. Simplifies backward pass parameter lists by removing deprecated alibi_slopes and window_size arguments while consolidating causal flag naming.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a comprehensive CUDA-integrated flash attention interface for PyTorch, enhancing the library with efficient attention computation capabilities. The changes add multiple function variants and custom autograd implementations while improving error handling for CUDA imports.
- Adds complete flash attention CUDA interface with multiple variants (qkv-packed, kv-packed, varlen) and custom autograd functions
- Integrates CUDA functions into the main package with proper fallback handling
- Reorganizes version placement and improves import structure for better API accessibility
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_interface.py | New comprehensive CUDA interface with flash attention implementations and autograd functions |
| flash_dmattn/init.py | Updated imports to include CUDA functions with fallback handling and reorganized version definition |
Comments suppressed due to low confidence (1)
flash_dmattn/init.py:110
- The variable name 'flash_dmattn_func' is ambiguous in this context as it refers to the imported function, but the same name is used for the auto function being defined. Consider using a more specific name like 'cuda_flash_dmattn_func' to avoid confusion.
if flash_dmattn_func is None:
Introduce a comprehensive PyTorch interface for flash attention operations with CUDA support, enhancing efficiency with multiple function variants and custom autograd functions. Improve organization and error handling for CUDA imports, ensuring better API accessibility.