Skip to content

Update MOE linear_loop implementation for speedup and matching name in vLLM#1478

Merged
xin3he merged 13 commits intomainfrom
xinhe/2-26
Mar 4, 2026
Merged

Update MOE linear_loop implementation for speedup and matching name in vLLM#1478
xin3he merged 13 commits intomainfrom
xinhe/2-26

Conversation

@xin3he
Copy link
Copy Markdown
Contributor

@xin3he xin3he commented Feb 27, 2026

Description

  1. SpeedUp: use 'meta' Linear instead of create a real Linear for MOE list
  2. Naming: split gate_up_proj into gate_proj and up_proj; change name from experts.down_proj.[idx] toexperts.[idx]. down_proj during saving.

2026-02-27 19:45:30 INFO moe_experts_interface.py L478: [MoE Prep] Before unfuse: 'peak_ram': 1.26GB
2026-02-27 19:45:47 INFO moe_experts_interface.py L491: [MoE Prep] Unfused 40 MOE experts modules
2026-02-27 19:45:47 INFO moe_experts_interface.py L502: [MoE Prep] After unfuse: 'peak_ram': 41.34GB
2026-02-27 19:45:47 INFO replace_modules.py L81: Prepared 40 MOE modules for quantization

Previous time:
2026-02-27 17:25:14 INFO moe_experts_interface.py L474: [MoE Prep] Unfused 'model.language_model.layers.0.mlp.experts': 49542.75 ms
2026-02-27 17:26:03 INFO moe_experts_interface.py L474: [MoE Prep] Unfused 'model.language_model.layers.1.mlp.experts': 47986.69 ms
2026-02-27 17:26:58 INFO moe_experts_interface.py L474: [MoE Prep] Unfused 'model.language_model.layers.2.mlp.experts': 55245.50 ms

Now:
2026-02-27 19:07:44 INFO moe_experts_interface.py L491: [MoE Prep] Unfused 'model.language_model.layers.0.mlp.experts': 243.57 ms
2026-02-27 19:07:44 INFO moe_experts_interface.py L491: [MoE Prep] Unfused 'model.language_model.layers.1.mlp.experts': 206.36 ms
2026-02-27 19:07:44 INFO moe_experts_interface.py L491: [MoE Prep] Unfused 'model.language_model.layers.2.mlp.experts': 189.15 ms

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring
  • Other (please specify):

Related Issues

Fixes or relates to #1464

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.

Copilot AI review requested due to automatic review settings February 27, 2026 12:06
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates AutoRound’s Transformers MoE “linear_loop” pathway to better align naming with vLLM conventions and reduce unfusing overhead, and adjusts backend documentation/metadata for AutoGPTQ backends.

Changes:

  • Split fused gate_up_proj handling into separate gate_proj + up_proj during MoE unfusing and update the linear-loop forward accordingly.
  • Optimize MoE unfusing by creating nn.Linear shells on the meta device and assigning per-expert weight slices.
  • Update AutoGPTQ backend registration/requirements (and reflect priority changes in docs) and remap MoE expert parameter keys during shard saving.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

File Description
docs/step_by_step.md Updates backend table entries (priority/packing format/requirements).
auto_round/modeling/fused_moe/moe_experts_interface.py Refactors MoE projection naming and unfuse implementation; adds memory monitoring logs.
auto_round/inference/backend.py Adjusts AutoGPTQ backend registration and requirements gating.
auto_round/compressors/shard_writer.py Remaps expert parameter keys to {experts}.{idx}.{proj} style and skips meta tensors in finalize.

@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Feb 27, 2026

verified with Qwen3.5-35B-A3B + vLLM

…n vLLM

Signed-off-by: He, Xin3 <xin3.he@intel.com>
@wenhuach21
Copy link
Copy Markdown
Contributor

wenhuach21 commented Feb 28, 2026

verified with Qwen3.5-35B-A3B + vLLM

1 as liang's pr shown, transformers v.5.2.0 transposes the weights. However, I don't see any transformers version control here. Better add it If the code is not compatible with < 5.2.0

2 better test 2 more different model families.

3 test transformers backend as well

Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
@xin3he xin3he requested a review from n1ck-guo March 2, 2026 05:38
@xin3he
Copy link
Copy Markdown
Contributor Author

xin3he commented Mar 2, 2026

Qwen3-VL
5.2.0
custom: 1. AttributeError: 'Qwen3VLMoeTextSparseMoeBlock' object has no attribute 'top_k'; 2, IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
general: work

5.1.0
custom: 1. AttributeError: 'Qwen3VLMoeTextSparseMoeBlock' object has no attribute 'top_k'
general: work

4.57.6
custom: work
general: not applicable.

Qwen3-Next
5.2.0
unfused_moe works well with low memory
General recipe works well with high memory.

xin3he added 8 commits March 2, 2026 13:38
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
Signed-off-by: He, Xin3 <xin3.he@intel.com>
@xin3he xin3he merged commit d150118 into main Mar 4, 2026
29 checks passed
@xin3he xin3he deleted the xinhe/2-26 branch March 4, 2026 06:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants