Skip to content

[Flash Attention] Disable packed sequences with pos ids only during torch compile#41827

Draft
vasqu wants to merge 2 commits intohuggingface:mainfrom
vasqu:fix-fa-pos-ids-compile
Draft

[Flash Attention] Disable packed sequences with pos ids only during torch compile#41827
vasqu wants to merge 2 commits intohuggingface:mainfrom
vasqu:fix-fa-pos-ids-compile

Conversation

@vasqu
Copy link
Copy Markdown
Contributor

@vasqu vasqu commented Oct 23, 2025

Draft, only as a reference as what could be done. It would allow for full graph compile when using no attention mask.

Supported compile:

  • Bsz 1
    • No mask
      • Before: No full graph, recompilations
      • After: Full graph
    • Attn mask
      • Before: No full graph, recompilations
      • After: No full graph, recompilations
    • Pos ids, no mask
      • Before: No full graph, recompilations
      • After: Not supported, silent wrong computations (if packed)
    • Fa kwargs, no mask
      • Before: Full graph
      • After: Full graph
  • Bsz > 1
    • No mask
      • Before: Full graph
      • After: Full graph
    • Attn mask
      • Before: Same as bsz 1
      • After: Same as bsz 1

Tl;dr: core changes are

  • No attn mask: Full graph support vs recompilations and no full graph (bsz == 1)
  • Position ids but no attn mask: Not supported for compile vs recompilations and no full graph

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants