Npu patcher refactor#9223
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors NPU-specific monkey patches into a structured npu_patch package and introduces a new command-line argument, --enable_npu_model_patch, to control model-level patches. The update includes specific compatibility and performance patches for Qwen series models (Qwen2, Qwen3, Qwen3.5, and MoE variants) on Ascend NPU. Feedback focuses on improving library compatibility by moving top-level imports of newer Transformers models into the patching logic, optimizing token counting with torch.bincount, and reducing host-device synchronization in MoE forward passes.
| from transformers.models.qwen3 import modeling_qwen3 | ||
| from transformers.models.qwen3_moe import modeling_qwen3_moe | ||
| from transformers.models.qwen3_vl_moe import modeling_qwen3_vl_moe |
There was a problem hiding this comment.
These top-level imports from transformers.models.qwen3* will cause an ImportError if the user is using an older version of the transformers library (e.g., < 4.48). Since these models are relatively new, it is safer to use the import_optional_module helper inside the apply_patch function, similar to how qwen3_5 is handled at line 516.
| tokens_per_expert = torch.histc( | ||
| router_indices.to(torch.float), bins=self.num_experts, min=0, max=self.num_experts).to(torch.int64) |
There was a problem hiding this comment.
| cpu_group_list = group_list.to('cpu', non_blocking=False) | ||
| cpu_group_list = [0] + cpu_group_list.tolist() | ||
| split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)] |
There was a problem hiding this comment.
This logic for calculating split_size is redundant and involves an unnecessary host-device synchronization. tokens_per_experts (calculated at line 324) already contains the counts needed for split_size. You can simply convert it to a list. Note that any .tolist() or .item() call on a tensor causes a synchronization point which can impact performance in the forward pass.
| cpu_group_list = group_list.to('cpu', non_blocking=False) | |
| cpu_group_list = [0] + cpu_group_list.tolist() | |
| split_size = [cpu_group_list[i + 1] - cpu_group_list[i] for i in range(len(cpu_group_list) - 1)] | |
| split_size = tokens_per_experts.tolist() |
There was a problem hiding this comment.
Pull request overview
This PR refactors Ascend NPU monkey-patch logic into a dedicated swift/model/npu_patch/ package while preserving the existing import entrypoint (swift/model/npu_patcher.py) and introducing a startup flag intended to disable only model-level NPU patches for debugging.
Changes:
- Moved patch implementations into modular files (
env.py,fsdp.py,mindspeed.py,model.py,utils.py) underswift/model/npu_patch/. - Kept backwards compatibility by retaining
swift/model/npu_patcher.pyas the import entry and applying patches on import. - Added a new CLI argument
--enable_npu_model_patch(documented in EN/ZH) intended to skip only model-level patches.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| swift/model/npu_patcher.py | Compatibility entrypoint that imports the new package and applies patches on import |
| swift/model/npu_patch/init.py | Central patch application + new argv-based switch for model patches |
| swift/model/npu_patch/env.py | Sets default HCCL_CONNECT_TIMEOUT on NPU |
| swift/model/npu_patch/fsdp.py | Refactors Accelerate FSDP2 fp32-cast patch into its own module |
| swift/model/npu_patch/mindspeed.py | Extracts MindSpeed TE CP compatibility patch |
| swift/model/npu_patch/model.py | Consolidates model-family-specific NPU patches (Qwen2/3/3.5 + MoE variants) |
| swift/model/npu_patch/utils.py | Shared helpers for optional imports + patch-map application |
| swift/arguments/base_args/base_args.py | Adds enable_npu_model_patch argument to BaseArguments |
| docs/source_en/Instruction/Command-line-parameters.md | Documents enable_npu_model_patch (EN) |
| docs/source_en/BestPractices/NPU-support.md | Adds best-practice section describing the model patch switch (EN) |
| docs/source/Instruction/Command-line-parameters.md | Documents enable_npu_model_patch (ZH) |
| docs/source/BestPractices/NPU-support.md | Adds best-practice section describing the model patch switch (ZH) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for i, arg in enumerate(sys.argv): | ||
| if arg in _ENABLE_NPU_MODEL_PATCH_ARGS: | ||
| if i + 1 >= len(sys.argv) or sys.argv[i + 1].startswith('--'): | ||
| raise ValueError('--enable_npu_model_patch requires a value: true or false.') | ||
| return _parse_model_patch_enabled(sys.argv[i + 1]) | ||
| if any(arg.startswith(f'{name}=') for name in _ENABLE_NPU_MODEL_PATCH_ARGS): | ||
| value = arg.split('=', 1)[1] | ||
| return _parse_model_patch_enabled(value) | ||
| return True | ||
|
|
|
thanks! |
PR type
PR information
Summary
This PR refactors the NPU patch logic into
swift/model/npu_patch/, keepsswift/model/npu_patcher.pyas the compatible import entry, and adds a switch to disable only model-level NPU patches when debugging.Changes
model.pyfor easier maintenance.swift.modelon NPU still applies patches by default.--enable_npu_model_patch falseto skip only model-related patches.