Skip to content

Add M-tile loop with dispatch capping for Intel Xe2/3-LPG#28250

Merged
guschmue merged 1 commit into
microsoft:mainfrom
jchen10:large_prefill
May 6, 2026
Merged

Add M-tile loop with dispatch capping for Intel Xe2/3-LPG#28250
guschmue merged 1 commit into
microsoft:mainfrom
jchen10:large_prefill

Conversation

@jchen10
Copy link
Copy Markdown
Contributor

@jchen10 jchen10 commented Apr 28, 2026

  • Wrap 8x16x16 MatMulNBits(SubgroupMatrix) kernel body in M-tile loop using uniforms.m_tiles_per_wg for tile assignment per workgroup
  • Cap dispatch_y on Xe2/3-LPG when M > 2k, with occupancy factor 16x
  • Non-Intel or small-M paths pass m_tiles_per_wg=1 (no behavior change)

@jchen10
Copy link
Copy Markdown
Contributor Author

jchen10 commented Apr 28, 2026

We observed a sharp perf drop of prefill for long prompts(>4k) on PTL. This PR can largely alleviate the problem.

@qjia7 @guschmue PTAL

- Wrap 8x16x16 MatMulNBits(SubgroupMatrix) kernel body in M-tile loop
  using uniforms.m_tiles_per_wg for tile assignment per workgroup
- Cap dispatch_y on Xe2/3-LPG when M > 2k, with occupancy factor 16x
- Non-Intel or small-M paths pass m_tiles_per_wg=1 (no behavior change)
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label May 1, 2026
@guschmue guschmue requested a review from Copilot May 1, 2026 16:51
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the WebGPU SubgroupMatrix MatMulNBits 8x16x16 path to reduce dispatch overhead on large-M Intel Xe2/Xe3-LPG devices by having each workgroup process multiple M-tiles sequentially, driven by a new m_tiles_per_wg uniform and a capped dispatch_y.

Changes:

  • Wrap the 8x16x16 WGSL kernel body in an outer M-tile loop controlled by uniforms.m_tiles_per_wg.
  • Add m_tiles_per_wg to the program’s uniform interface and pass it from the CPU side.
  • Cap dispatch_y for large M on Intel Xe2/Xe3-LPG and derive m_tiles_per_wg accordingly.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template Adds an outer M-tile loop and resets accumulators per tile using m_tiles_per_wg.
onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h Extends the uniform variable list with m_tiles_per_wg.
onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc Computes capped dispatch_y on Intel Xe2/Xe3-LPG and passes m_tiles_per_wg to the shader.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@guschmue guschmue merged commit 5f071fb into microsoft:main May 6, 2026
90 of 91 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants