Add MPS SDPA workarounds for value head dim and bidirectional attention#44591
Closed
moktamd wants to merge 2 commits into
Closed
Add MPS SDPA workarounds for value head dim and bidirectional attention#44591moktamd wants to merge 2 commits into
moktamd wants to merge 2 commits into
Conversation
…on bugs Add _apply_mps_fixes to handle two upstream PyTorch MPS bugs: 1. pytorch/pytorch#176767: pad value tensor when v_head_dim != q_head_dim to avoid corrupted SDPA output (affects DeepSeek models, fixed in 2.12) 2. pytorch/pytorch#174861: force a non-bool attention mask for non-causal attention with non-float32 dtypes to route through sdpa_general_mps instead of the broken sdpa_vector_2pass_mps path (fixed in 2.11) Fixes huggingface#44554 Fixes huggingface#44247
Contributor
Author
|
Apologies, I missed that you were already working on this. Closing in your favor. Good luck with the fix! |
This was referenced Mar 11, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds
_apply_mps_fixesinsdpa_attention.pyto handle two upstream PyTorch MPS bugs:MPS: scaled_dot_product_attention returns wrong output shape when value dim != query/key dim pytorch/pytorch#176767 (fixed in PyTorch 2.12): pads value tensor when
v_head_dim != q_head_dimto avoid corrupted output. Affects DeepSeek models with MQA.[MPS] Out of bounds memory access/corruption and correctness issue in SDPA pytorch/pytorch#174861 (fixed in PyTorch 2.11): forces a non-bool attention mask for non-causal, non-float32 attention to route through
sdpa_general_mpsinstead of brokensdpa_vector_2pass_mps.Both fixes are version-gated and will no-op once the upstream PyTorch fixes are available.
Fixes #44554
Fixes #44247