Skip to content

Integrate FA4 (Flash Attention for Blackwell) into HF Transformers #42405

@sambhavnoobcoder

Description

@sambhavnoobcoder

Feature request

Transformers currently supports Flash Attention 2 and 3, but not Flash Attention 4. Users with compatible hardware and the latest flash-attn package cannot leverage FA4's improvements . Lets add that as well .

Motivation

Flash Attention 4 is now available in the flash-attn package, bringing significant improvements over FA2/FA3:

  • 30-50x faster compilation through JIT compilation from Python
  • Eliminates binary wheel distribution issues - no more platform-specific wheels needed
  • Optimized for modern GPUs - specifically tuned for Hopper (H100/H200) and Blackwell architectures
  • Cleaner implementation - leverages CuTe's high-level DSL abstractions

Reference: https://x.com/StasBekman/status/1993060880150675700

Your contribution

Add comprehensive FA4 support to transformers:

  1. Detection: Add is_flash_attn_4_available() function to check for FA4 with hardware requirements (SM 8.0+)
  2. Import logic: Handle flash_attn.cute import path
  3. API compatibility: Use runtime introspection to handle API differences between FA4 and FA2/FA3
  4. Auto-selection: Include FA4 in automatic attention implementation selection with highest priority
  5. Registration: Register flash_attention_4 in AttentionInterface
  6. Testing: Add test suite and validation scripts

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions