Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce a comprehensive PyTorch interface for flash attention operations with CUDA support, enhancing efficiency with multiple function variants and custom autograd functions. Improve organization and error handling for CUDA imports, ensuring better API accessibility.

Implements comprehensive PyTorch interface for flash attention operations with CUDA backend support.

Provides multiple function variants including packed QKV, packed KV, and variable-length sequence handling for efficient attention computation.

Includes custom autograd functions with proper forward and backward pass implementations, fake tensor registration for torch.compile compatibility, and support for features like dropout, causal masking, and softmax scaling.

Enables optimized attention computation with automatic head dimension padding and device capability-aware block size selection.
Moves version declaration to top of module for better organization.

Adds conditional imports for CUDA flash attention functions with proper error handling when CUDA backend is available but interface module fails to import.

Renames the auto-selection function to better distinguish it from the imported CUDA implementation and updates the CUDA backend logic to use the imported function instead of raising NotImplementedError.

Expands exports to include all flash attention function variants for better API accessibility.

This comment was marked as outdated.

Updates tensor dimension extraction to correctly handle variable-length
attention sequences by deriving batch size from cumulative sequence lengths
rather than tensor shapes.

Removes unused head_size variable assignments and corrects mask/bias tensor
dimensions to use max_seqlen parameters instead of individual sequence lengths.

Simplifies backward pass parameter lists by removing deprecated alibi_slopes
and window_size arguments while consolidating causal flag naming.
@LoserCheems LoserCheems requested a review from Copilot July 29, 2025 03:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a comprehensive CUDA-integrated flash attention interface for PyTorch, enhancing the library with efficient attention computation capabilities. The changes add multiple function variants and custom autograd implementations while improving error handling for CUDA imports.

  • Adds complete flash attention CUDA interface with multiple variants (qkv-packed, kv-packed, varlen) and custom autograd functions
  • Integrates CUDA functions into the main package with proper fallback handling
  • Reorganizes version placement and improves import structure for better API accessibility

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
flash_dmattn/flash_dmattn_interface.py New comprehensive CUDA interface with flash attention implementations and autograd functions
flash_dmattn/init.py Updated imports to include CUDA functions with fallback handling and reorganized version definition
Comments suppressed due to low confidence (1)

flash_dmattn/init.py:110

  • The variable name 'flash_dmattn_func' is ambiguous in this context as it refers to the imported function, but the same name is used for the auto function being defined. Consider using a more specific name like 'cuda_flash_dmattn_func' to avoid confusion.
        if flash_dmattn_func is None:

@LoserCheems LoserCheems merged commit 431170b into main Jul 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants