Skip to content

NPU patch FLA#9195

Merged
Jintao-Huang merged 11 commits into
modelscope:mainfrom
hazelduan:npu-patcher
Apr 25, 2026
Merged

NPU patch FLA#9195
Jintao-Huang merged 11 commits into
modelscope:mainfrom
hazelduan:npu-patcher

Conversation

@hazelduan
Copy link
Copy Markdown
Contributor

@hazelduan hazelduan commented Apr 23, 2026

PR type

New Feature -- Add NPU patch FLA to improve performance.

PR information

Add npu patcher for Qwen 3.5 dense model. After patching, a throughput of 1.59x was achieved on the NPU, and the single-step time decreased from 7.945s/it to 5.006s/it. Accuracy comparisons with GPUs were implemented on Transformers 5.2.0, and unpatched single-step steady-state time on the GPU is 6.959 s/it.

Experiment results

npu patch and no patch comparison
截屏2026-04-23 下午4 20 04

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the chunk_gated_delta_rule using Triton kernels and integrates it into the NPU patching logic for Qwen 3.5 models. It adds optimized NPU implementations for RMSNorm and Rotary Position Embeddings, along with a file-locking mechanism to safely handle concurrent Triton compilation on Ascend hardware. Feedback focuses on enhancing the robustness of the gradient calculation by casting to float32, cleaning up unused imports, fixing typographical and style issues in error messages, and ensuring the compilation lock file is user-specific to prevent permission conflicts.

Comment thread swift/model/chunk_gated_delta_rule.py Outdated
Comment thread swift/model/chunk_gated_delta_rule.py Outdated
from typing import Optional

import torch
import copy
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The copy module is imported but not used in this file. Removing unused imports helps maintain code cleanliness and reduces overhead.

Comment thread swift/model/chunk_gated_delta_rule.py Outdated
Comment thread swift/model/npu_patcher.py Outdated
Comment on lines +56 to +72
@wraps(make_launcher_stub)
def _locked_make_npu_launcher_stub(*args, **kwargs):
lock_path = os.environ.get('SWIFT_TRITON_ASCEND_LAUNCHER_LOCK',
_DEFAULT_TRITON_ASCEND_LAUNCHER_LOCK_PATH)
with open(lock_path, 'w') as lock_file:
fcntl.flock(lock_file, fcntl.LOCK_EX)
try:
return make_launcher_stub(*args, **kwargs)
finally:
fcntl.flock(lock_file, fcntl.LOCK_UN)

_locked_make_npu_launcher_stub._swift_compile_lock = True
ascend_driver.make_npu_launcher_stub = _locked_make_npu_launcher_stub
logger.info(
'Patched Ascend Triton launcher compilation with file lock: %s.',
os.environ.get('SWIFT_TRITON_ASCEND_LAUNCHER_LOCK', _DEFAULT_TRITON_ASCEND_LAUNCHER_LOCK_PATH),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The hardcoded lock file path in /tmp can lead to PermissionError in multi-user environments if the file is created by one user and accessed by another. It is better to use a user-specific lock path. Additionally, the logic can be simplified by defining the lock path once and using 'a' mode to avoid unnecessary truncation.

Suggested change
@wraps(make_launcher_stub)
def _locked_make_npu_launcher_stub(*args, **kwargs):
lock_path = os.environ.get('SWIFT_TRITON_ASCEND_LAUNCHER_LOCK',
_DEFAULT_TRITON_ASCEND_LAUNCHER_LOCK_PATH)
with open(lock_path, 'w') as lock_file:
fcntl.flock(lock_file, fcntl.LOCK_EX)
try:
return make_launcher_stub(*args, **kwargs)
finally:
fcntl.flock(lock_file, fcntl.LOCK_UN)
_locked_make_npu_launcher_stub._swift_compile_lock = True
ascend_driver.make_npu_launcher_stub = _locked_make_npu_launcher_stub
logger.info(
'Patched Ascend Triton launcher compilation with file lock: %s.',
os.environ.get('SWIFT_TRITON_ASCEND_LAUNCHER_LOCK', _DEFAULT_TRITON_ASCEND_LAUNCHER_LOCK_PATH),
)
lock_path = os.environ.get('SWIFT_TRITON_ASCEND_LAUNCHER_LOCK')
if not lock_path:
lock_path = _DEFAULT_TRITON_ASCEND_LAUNCHER_LOCK_PATH
if hasattr(os, 'getuid'):
lock_path += f'.{os.getuid()}'
@wraps(make_launcher_stub)
def _locked_make_npu_launcher_stub(*args, **kwargs):
with open(lock_path, 'a') as lock_file:
fcntl.flock(lock_file, fcntl.LOCK_EX)
try:
return make_launcher_stub(*args, **kwargs)
finally:
fcntl.flock(lock_file, fcntl.LOCK_UN)
_locked_make_npu_launcher_stub._swift_compile_lock = True
ascend_driver.make_npu_launcher_stub = _locked_make_npu_launcher_stub
logger.info('Patched Ascend Triton launcher compilation with file lock: %s.', lock_path)

Comment thread swift/model/npu_patcher.py Outdated
_QWEN3_5_PATCH_TABLE = ((modeling_qwen3_5, {
'Qwen3_5RMSNorm': NpuQwen3_5RMSNorm,
'apply_rotary_pos_emb': npu_apply_rotary_pos_emb_qwen3_5,
'Qwen3_5MLP.forward': npu_swiglu_forward,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from transformers import qwen3_5挪到此处,做一次try except, 然后对table进行增补;qwen3-next同理可操作
对flash-linear的patch也在此处进行,做对应的if else区分可以,后期可继续沿用此方法演进,保证transformers版本的兼容问题

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在最新提交中解决

@Jintao-Huang Jintao-Huang merged commit 0f13dce into modelscope:main Apr 25, 2026
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants