Enable kernels-community/metal-flash-sdpa on MPS#45974
Open
ArthurZucker wants to merge 4 commits into
Open
Conversation
Two small fixes so `attn_implementation="kernels-community/metal-flash-sdpa"` works end-to-end on Apple Silicon (`generate` and `generate_batch`): * `modeling_flash_attention_utils._flash_attention_forward`: the "no padding" branch unconditionally called `flash_fn`, which is `None` for varlen-only kernels (the metal kernel only ships `flash_attn_varlen_func`). Synthesize `cu_seqlens` for the dense batched layout and route through `flash_varlen_fn` in that case. `.contiguous()` before reshape is required: the cached K/V (post-transpose) is non-contiguous and the Metal kernel reads garbage off it during decode, producing nonsense tokens. * `continuous_batching/requests.get_device_and_memory_breakdown`: on MPS, `torch.mps.driver_allocated_memory()` returns bytes currently held by the Metal driver (≈0 right after process start), not the total. Use `recommended_max_memory()` for total and `current_allocated_memory()` for the running allocation. Without this, `infer_num_blocks_and_max_batch_tokens` either returns a negative `num_blocks` or refuses to allocate, so `generate_batch` was unusable on MPS regardless of the chosen attention. Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch): impl time(s) tok/s acc sdpa 149.33 158.4 30/100 kernels-community/metal-flash-sdpa 89.78 256.0 32/100 1.66x speedup, accuracy within noise.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
When `transformers serve` runs on Apple Silicon (`--device auto` or `mps`) with `kernels` installed and no explicit `--attn-implementation` flag, default the attention to `kernels-community/metal-flash-sdpa` instead of plain SDPA. On the 100-sample gsm8k benchmark (Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch) it's a 1.66x throughput improvement (158 -> 256 tok/s) with token-for-token parity for greedy decoding. Users who don't want it can opt out with `--attn-implementation sdpa`. Help text on the `--attn-implementation` flag also now lists the kernels-hub syntax explicitly.
Don't build cu_seqlens on the fly inside the modeling forward — the non-padding `else` branch can stay as a NoneType failure for varlen-only kernels. Callers that need varlen (continuous batching, padding-free training) go through `paged_attention_forward` or the explicit `cu_seq_lens_*` kwarg path, both of which already supply their own cumulative lengths. Companion kernel change: dropping `flash_attn_func` from kernels-community/metal-flash-sdpa for the same reason (PR #3).
The published `main` of `kernels-community/metal-flash-sdpa` predates the MPS dispatch hardening (contiguity, int32 cast, alias clone, MPS encoder flush) that this integration depends on. Pinning to the open PR's HEAD commit so the auto-default actually works end-to-end out of the box. Drop / bump this constant when the upstream PR merges: https://huggingface.co/kernels-community/metal-flash-sdpa/discussions/3
remi-or
approved these changes
May 14, 2026
Collaborator
remi-or
left a comment
There was a problem hiding this comment.
Please apply line length and fix the bug!
| @@ -800,8 +800,36 @@ def _flash_attention_forward( | |||
|
|
|||
| # No padding | |||
| else: | |||
| def _resolve_attn_implementation(cls, attn_implementation: str | None, device: str | int) -> str | None: | ||
| """Auto-select a flash-attention kernel when the user didn't specify one. | ||
|
|
||
| On Apple Silicon (MPS) with ``kernels`` installed, default to |
Collaborator
There was a problem hiding this comment.
I think length line is 120, you should update you claude md!
| # MPS memory reporting (PyTorch 2.0+). `driver_allocated_memory` returns bytes currently held by | ||
| # the Metal driver (≈ 0 right after process start), so use `recommended_max_memory` for total | ||
| # and `current_allocated_memory` for the running allocation instead. | ||
| total_memory = getattr(torch.mps, "recommended_max_memory")() |
Collaborator
There was a problem hiding this comment.
I think you forgot the default here
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Enables
kernels-community/metal-flash-sdpaforgenerate/generate_batchon MPS: synthesize cu_seqlens +.contiguous()in the no-padding branch of_flash_attention_forward, and fix MPS memory accounting in continuous batching.Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16,
generate_batch):sdpakernels-community/metal-flash-sdpa1.66× speedup, accuracy within noise.
Follow-up: push the contiguity/varlen handling into the kernel itself so
modeling_flash_attention_utils.pyno longer needs the fallback branch.