Skip to content

[Performance] FP8 Grouped and Batched Matmuls#44231

Merged
ArthurZucker merged 27 commits intomainfrom
fp8-grouped-mm
Mar 11, 2026
Merged

[Performance] FP8 Grouped and Batched Matmuls#44231
ArthurZucker merged 27 commits intomainfrom
fp8-grouped-mm

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Feb 23, 2026

What does this PR do?

up to 30x faster than current fp8 experts, the kernels are also tailored for full torch.compile and cuda graphs compatibility.

============================================================
FP8Expert parity: eager / batched_mm / grouped_mm
============================================================

[case 1/5]
device=cuda  batch_size=1  num_tokens=64  total_tokens=64  num_experts=8  hidden=256  intermediate=512  top_k=2
  [eager vs eager_fused]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs batched_mm]  max=0.000000  mean=0.000000  PASS ✓
  [batched_mm vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓

[case 2/5]
device=cuda  batch_size=1  num_tokens=1  total_tokens=1  num_experts=8  hidden=256  intermediate=512  top_k=2
  [eager vs eager_fused]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs batched_mm]  max=0.000000  mean=0.000000  PASS ✓
  [batched_mm vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓

[case 3/5]
device=cuda  batch_size=1  num_tokens=7  total_tokens=7  num_experts=8  hidden=256  intermediate=512  top_k=1
  [eager vs eager_fused]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs batched_mm]  max=0.000000  mean=0.000000  PASS ✓
  [batched_mm vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓

[case 4/5]
device=cuda  batch_size=1  num_tokens=4  total_tokens=4  num_experts=8  hidden=256  intermediate=512  top_k=8
  [eager vs eager_fused]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs batched_mm]  max=0.000000  mean=0.000000  PASS ✓
  [batched_mm vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓

[case 5/5]
device=cuda  batch_size=4  num_tokens=64  total_tokens=256  num_experts=8  hidden=256  intermediate=512  top_k=2
  [eager vs eager_fused]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓
  [eager vs batched_mm]  max=0.000000  mean=0.000000  PASS ✓
  [batched_mm vs grouped_mm]  max=0.000000  mean=0.000000  PASS ✓

============================================================
All parity checks PASSED ✓

============================================================
Benchmark sweep
============================================================

────────────────────────────────────────────────────────────
Benchmark  device=cuda  batch_size=1  tokens=1  total=1  experts=8
           hidden=256  intermediate=512  top_k=2
────────────────────────────────────────────────────────────
  impl                            median (ms)    p10 (ms)    p90 (ms)   speedup
  eager                                 1.477       1.458       1.503  (baseline)
  eager_fused                           1.206       1.193       1.238     1.22x
  grouped_mm                            0.830       0.808       0.858     1.78x
  batched_mm                            0.423       0.418       0.441     3.49x
  eager (compiled)                      1.413       1.396       1.465     1.05x
  grouped_mm (compiled)                 0.138       0.136       0.147    10.67x
  batched_mm (compiled)                 0.138       0.136       0.151    10.71x

────────────────────────────────────────────────────────────
Benchmark  device=cuda  batch_size=1  tokens=8  total=8  experts=8
           hidden=256  intermediate=512  top_k=2
────────────────────────────────────────────────────────────
  impl                            median (ms)    p10 (ms)    p90 (ms)   speedup
  eager                                 4.506       4.430       4.581  (baseline)
  eager_fused                           3.672       3.654       3.700     1.23x
  grouped_mm                            0.809       0.797       0.840     5.57x
  batched_mm                            0.423       0.419       0.444    10.64x
  eager (compiled)                      4.752       4.733       4.775     0.95x
  grouped_mm (compiled)                 0.158       0.155       0.168    28.52x
  batched_mm (compiled)                 0.153       0.151       0.163    29.41x

────────────────────────────────────────────────────────────
Benchmark  device=cuda  batch_size=1  tokens=32  total=32  experts=8
           hidden=256  intermediate=512  top_k=2
────────────────────────────────────────────────────────────
  impl                            median (ms)    p10 (ms)    p90 (ms)   speedup
  eager                                 5.106       5.064       5.142  (baseline)
  eager_fused                           4.065       4.052       4.086     1.26x
  grouped_mm                            0.815       0.807       0.844     6.27x
  batched_mm                            0.433       0.428       0.455    11.80x
  eager (compiled)                      5.346       5.326       5.379     0.95x
  grouped_mm (compiled)                 0.159       0.157       0.167    32.01x
  batched_mm (compiled)                 0.167       0.165       0.177    30.62x

────────────────────────────────────────────────────────────
Benchmark  device=cuda  batch_size=1  tokens=128  total=128  experts=8
           hidden=256  intermediate=512  top_k=2
────────────────────────────────────────────────────────────
  impl                            median (ms)    p10 (ms)    p90 (ms)   speedup
  eager                                 5.236       5.098       5.275  (baseline)
  eager_fused                           4.094       4.073       4.197     1.28x
  grouped_mm                            0.826       0.815       0.854     6.34x
  batched_mm                            0.441       0.436       0.460    11.87x
  eager (compiled)                      5.572       5.545       5.611     0.94x
  grouped_mm (compiled)                 0.183       0.180       0.193    28.65x
  batched_mm (compiled)                 0.284       0.282       0.291    18.42x

────────────────────────────────────────────────────────────
Benchmark  device=cuda  batch_size=1  tokens=512  total=512  experts=8
           hidden=256  intermediate=512  top_k=2
────────────────────────────────────────────────────────────
  impl                            median (ms)    p10 (ms)    p90 (ms)   speedup
  eager                                 5.341       5.305       5.375  (baseline)
  eager_fused                           4.215       4.199       4.233     1.27x
  grouped_mm                            0.822       0.813       0.848     6.50x
  batched_mm                            1.405       1.391       1.416     3.80x
  eager (compiled)                      5.622       5.563       5.674     0.95x
  grouped_mm (compiled)                 0.219       0.217       0.224    24.40x
  batched_mm (compiled)                 0.724       0.712       0.734     7.38x

============================================================
generate

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines -256 to +257
return is_grouped_mm_available()
return hasattr(torch.nn.functional, "grouped_mm") or hasattr(torch, "_grouped_mm")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why but sometimes is_grouped_mm_available() and other functions using metadata/versions result in compilation failures

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces new FP8 MoE expert implementations (batched and grouped matmuls) intended to significantly speed up fine-grained FP8 expert execution, with torch.compile / CUDA graphs compatibility, and expands tests to cover the different expert implementations.

Changes:

  • Add FP8 batched/grouped experts forward paths and CUTLASS/Triton dispatch plumbing in the fine-grained FP8 integration.
  • Update MoE grouped_mm availability checks and make grouped_mm token reordering more graph-friendly.
  • Parameterize the fine-grained FP8 MoE forward test across eager, batched_mm, and grouped_mm implementations.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/quantization/finegrained_fp8/test_fp8.py Runs the FP8 MoE forward smoke test across multiple experts implementations.
src/transformers/utils/generic.py Fixes a docstring typo (“though” → “through”).
src/transformers/quantizers/quantizer_finegrained_fp8.py Updates quantizer to target the new FP8 experts module type and marks it as compileable.
src/transformers/integrations/moe.py Adjusts grouped_mm availability detection and permutation inversion logic for compile/cudagraph friendliness.
src/transformers/integrations/finegrained_fp8.py Major refactor: kernel loading/dispatch changes, new FP8Experts, and new FP8 batched/grouped expert forward implementations.

Comment on lines +399 to +400
ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assigning ALL_EXPERTS_FUNCTIONS["batched_mm"] / ["grouped_mm"] here overrides the global experts dispatch used by all @use_experts_implementation models. That will route non-FP8 MoE layers into these FP8-specific functions (which expect FP8 scales / kernel APIs) and will likely crash. Prefer keeping the global mappings intact and dispatching to the FP8 implementations only for FP8 expert modules (e.g., by handling the selection inside FP8Experts.forward, or by using distinct keys and setting config._experts_implementation accordingly for FP8 models only).

Suggested change
ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["fp8_batched_mm"] = fp8_batched_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["fp8_grouped_mm"] = fp8_grouped_mm_experts_forward

Copilot uses AI. Check for mistakes.
Copy link
Member Author

@IlyasMoutawwakil IlyasMoutawwakil Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is that the setter results in file-specific changes and registering results in global changes. @ArthurZucker if you can confirm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this, but maybe run some tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't remember the scope, registering would work better + having quantization config class change the impl to prefix with FP8 would be better no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i created a new interface specifically for fp8, "quantization config class change the impl to prefix with FP8" not sure how to achieve that because we read directly from model.config._experts_implementation

IlyasMoutawwakil and others added 2 commits March 3, 2026 11:10
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@IlyasMoutawwakil IlyasMoutawwakil marked this pull request as ready for review March 3, 2026 14:34
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, left mostly minor comments

Comment on lines +399 to +400
ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this, but maybe run some tests

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Did not check the exact mathematics, but trusting you and @SunMarc on this!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +399 to +400
ALL_EXPERTS_FUNCTIONS["batched_mm"] = fp8_batched_mm_experts_forward
ALL_EXPERTS_FUNCTIONS["grouped_mm"] = fp8_grouped_mm_experts_forward
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't remember the scope, registering would work better + having quantization config class change the impl to prefix with FP8 would be better no?

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the last bits ! Merging

@SunMarc SunMarc enabled auto-merge March 10, 2026 17:04
@SunMarc SunMarc added this pull request to the merge queue Mar 10, 2026
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 10, 2026
@ArthurZucker ArthurZucker merged commit ff2ba44 into main Mar 11, 2026
29 checks passed
@ArthurZucker ArthurZucker deleted the fp8-grouped-mm branch March 11, 2026 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants