Skip to content

Fix 2pass sdpa on < M2#3099

Merged
awni merged 1 commit intomainfrom
fix_2pass_sdpa
Feb 5, 2026
Merged

Fix 2pass sdpa on < M2#3099
awni merged 1 commit intomainfrom
fix_2pass_sdpa

Conversation

@awni
Copy link
Member

@awni awni commented Feb 5, 2026

Closes ml-explore/mlx-lm#844

Honestly I think this is a lower level issue. For some reason using blocks = 128 when blocks is a function constant makes the kernel support < 1024 threads per thread group. Other values (64, 256, etc) do not have that issue. This is just for bfloat16 on M1 and M2 🤷‍♂️

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

I assume there isn't a perf regression...

@awni
Copy link
Member Author

awni commented Feb 5, 2026

No but at first I tried getting rid of the blocks function constant entirely from the first pass as well and there was a very note-able regression.. which was unexpected.

@awni awni merged commit 99ca62c into main Feb 5, 2026
16 checks passed
@awni awni deleted the fix_2pass_sdpa branch February 5, 2026 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mlx-community/Qwen3-Coder-Next-4bit: garbage output after ~1000 tokens

2 participants