Skip to content

fix: handle 3D input tensors in flashAttention#45

Merged
MankyDanky merged 1 commit intostagingfrom
feat/web-backend
Apr 17, 2026
Merged

fix: handle 3D input tensors in flashAttention#45
MankyDanky merged 1 commit intostagingfrom
feat/web-backend

Conversation

@MankyDanky
Copy link
Copy Markdown
Collaborator

flashAttention was hardcoded to assume 4D [B, H, T, D] input, but models commonly pass 3D [B*H, T, D] after merging batch and head dims. With 3D input, qShape[3] was undefined, causing NaN to propagate through the entire forward pass.

Now reads T and D from the last two dimensions regardless of rank.

flashAttention was hardcoded to assume 4D [B, H, T, D] input, but
models commonly pass 3D [B*H, T, D] after merging batch and head dims.
With 3D input, qShape[3] was undefined, causing NaN to propagate
through the entire forward pass.

Now reads T and D from the last two dimensions regardless of rank.
@MankyDanky MankyDanky merged commit 7e0a561 into staging Apr 17, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant