[Performance] FP8 Grouped and Batched Matmuls#44231
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… library wrappers for better torch compileability
| return is_grouped_mm_available() | ||
| return hasattr(torch.nn.functional, "grouped_mm") or hasattr(torch, "_grouped_mm") |
There was a problem hiding this comment.
not sure why but sometimes is_grouped_mm_available() and other functions using metadata/versions result in compilation failures
There was a problem hiding this comment.
Pull request overview
This PR introduces new FP8 MoE expert implementations (batched and grouped matmuls) intended to significantly speed up fine-grained FP8 expert execution, with torch.compile / CUDA graphs compatibility, and expands tests to cover the different expert implementations.
Changes:
- Add FP8 batched/grouped experts forward paths and CUTLASS/Triton dispatch plumbing in the fine-grained FP8 integration.
- Update MoE grouped_mm availability checks and make grouped_mm token reordering more graph-friendly.
- Parameterize the fine-grained FP8 MoE forward test across
eager,batched_mm, andgrouped_mmimplementations.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
tests/quantization/finegrained_fp8/test_fp8.py |
Runs the FP8 MoE forward smoke test across multiple experts implementations. |
src/transformers/utils/generic.py |
Fixes a docstring typo (“though” → “through”). |
src/transformers/quantizers/quantizer_finegrained_fp8.py |
Updates quantizer to target the new FP8 experts module type and marks it as compileable. |
src/transformers/integrations/moe.py |
Adjusts grouped_mm availability detection and permutation inversion logic for compile/cudagraph friendliness. |
src/transformers/integrations/finegrained_fp8.py |
Major refactor: kernel loading/dispatch changes, new FP8Experts, and new FP8 batched/grouped expert forward implementations. |
| ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward | ||
| ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward |
There was a problem hiding this comment.
Assigning ALL_EXPERTS_FUNCTIONS["batched_mm"] / ["grouped_mm"] here overrides the global experts dispatch used by all @use_experts_implementation models. That will route non-FP8 MoE layers into these FP8-specific functions (which expect FP8 scales / kernel APIs) and will likely crash. Prefer keeping the global mappings intact and dispatching to the FP8 implementations only for FP8 expert modules (e.g., by handling the selection inside FP8Experts.forward, or by using distinct keys and setting config._experts_implementation accordingly for FP8 models only).
| ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward | |
| ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward | |
| ALL_EXPERTS_FUNCTIONS["fp8_batched_mm"] = fp8_batched_mm_experts_forward | |
| ALL_EXPERTS_FUNCTIONS["fp8_grouped_mm"] = fp8_grouped_mm_experts_forward |
There was a problem hiding this comment.
my understanding is that the setter results in file-specific changes and registering results in global changes. @ArthurZucker if you can confirm
There was a problem hiding this comment.
not sure about this, but maybe run some tests
There was a problem hiding this comment.
Yeah I don't remember the scope, registering would work better + having quantization config class change the impl to prefix with FP8 would be better no?
There was a problem hiding this comment.
i created a new interface specifically for fp8, "quantization config class change the impl to prefix with FP8" not sure how to achieve that because we read directly from model.config._experts_implementation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
SunMarc
left a comment
There was a problem hiding this comment.
Thanks a lot, left mostly minor comments
| ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward | ||
| ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward |
There was a problem hiding this comment.
not sure about this, but maybe run some tests
Cyrilvallez
left a comment
There was a problem hiding this comment.
Nice! Did not check the exact mathematics, but trusting you and @SunMarc on this!
| ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward | ||
| ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward |
There was a problem hiding this comment.
Yeah I don't remember the scope, registering would work better + having quantization config class change the impl to prefix with FP8 would be better no?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: finegrained_fp8 |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks for fixing the last bits ! Merging
What does this PR do?
up to 30x faster than current fp8 experts, the kernels are also tailored for full torch.compile and cuda graphs compatibility.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.