Skip to content

[Enhancement Request] SDFA does not support head dimension of size 192 (capped at 128)  #1604

@s-tlgh

Description

@s-tlgh

Describe the bug
SDFA currently does not support head dimension outside of 64, 96, and 128.

To Reproduce
Fused attention falls back to regular operation when head dimension is not in (64, 96, 128)
https://github.com/ml-explore/mlx/blob/main/mlx/fast.cpp#L644-L645

Request
SDFA to support non-common head dimensions (still multiple of 32)

Desktop (please complete the following information):

  • OS Version: MacOS 15.2
  • MLX Version 0.20.0

Additional context
Awni's suggestion: «generalize it so that any even head dim or maybe multiple of 32 is supported»

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions