Skip to content

Conversation

petercad
Copy link

@petercad petercad commented Oct 4, 2025

This PR updates FlashAttention to the new copy/MMA atoms.

Changes:

  • Prefill and decode unified into a single implementation, allowing simultaneous K and Q subgroup-level parallelization rather than an either-or.
  • GEMMs and softmax grouped together and the full k loop consolidated into an FMHA mainloop class.
    • This will facilitate further manual pipelining/overlap of GEMM with softmax.
  • Use new copy/MMA atoms and reorders to transparently support arbitrary data types.
  • Automatic copy/MMA operator selection.

Current status: prefill/decode examples almost all working, similar/better performance to old examples.

Known issues:

  • Head size 192 decode config doesn't compile yet -- to be fixed.
  • Strange SYCL compiler behavior/bug with tSrS->tArP reorder. Apparently the compiler believes there is UB somewhere and will omit a large section of the kernel as a result. For the moment, there's a direct copy as a workaround while I pin down the issue. I'm not able to reproduce this behavior with the reorder in isolation.

Additional features (causal masking, variable sequence lengths, etc.) to be added later.

@petercad petercad changed the title [Umbrella commit] Re-implement FlashAttention with new Xe atoms Re-implement FlashAttention with new Xe atoms Oct 4, 2025
@petercad
Copy link
Author

petercad commented Oct 4, 2025

I will break up this large commit into self-contained smaller commits after review is complete.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant