Skip to content

batched_mm is slow on cpu#43438

Merged
vasqu merged 9 commits intomainfrom
grouped-mm-cpu
Jan 27, 2026
Merged

batched_mm is slow on cpu#43438
vasqu merged 9 commits intomainfrom
grouped-mm-cpu

Conversation

@IlyasMoutawwakil
Copy link
Member

What does this PR do?

Fixes # (issue)
cpu

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the doc to add justification!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TY let's explicit in comments or doc with justification maybe?

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar problem as we had with test_torch_compile_for_training - can you take a look at test_generate_compile_model_forward_fullgraph

Forcing batched_mm or changing the type (although we do compare outputs, so not sure if it would introduce flakiness) should solve it

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for clarifying!

| `"grouped_mm"` | Orders tokens by selected experts and uses `torch._grouped_mm` to project all tokens in a single grouped GEMMF (Requires PyTorch 2.9+). |

`batched_mm` is fastest for very small inputs and compilation speeds it up further. `grouped_mm` performs best for larger inputs.
On GPU:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it'd be cleaner to add two separate columns to the table for GPU and CPU, and then you can add the relevant comments for each implementation. makes it easier to quickly scan as well!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aah makes sense ! hope it won't get crowded when rendered

IlyasMoutawwakil and others added 4 commits January 26, 2026 09:23
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@IlyasMoutawwakil
Copy link
Member Author

thanks @stevhliu I updated the table and left one note about the decode-stage optimization on gpu.

@vasqu I switched to bf16 on cpu+grouped_mm+compile, imo it's better to test the grouped_mm on cpu here because it's what a user will get by default, switching to batched_mm will pass the tests but won't catch errors in the default cpu path. wdyt ?

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ok let's move to bf16 but gotta keep an eye out if it does indeed produce flakiness / failing tests

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last nit, otherwise lgtm! 😄

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=43438&sha=1d01c7

@vasqu vasqu merged commit a99a913 into main Jan 27, 2026
21 of 26 checks passed
@vasqu vasqu deleted the grouped-mm-cpu branch January 27, 2026 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants