-
Notifications
You must be signed in to change notification settings - Fork 994
Implement NPU_ENV handling in init.py to support megatron in NPU. #6661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,11 @@ | |||||||||
| from swift.llm import git_clone_github | ||||||||||
| from swift.utils import (get_logger, is_flash_attn_3_available, is_megatron_available, safe_ddp_context, split_list, | ||||||||||
| subprocess_run) | ||||||||||
| try: | ||||||||||
| from mindspeed.megatron_adaptor import repatch | ||||||||||
| NPU_ENV = True | ||||||||||
| except ImportError: | ||||||||||
| NPU_ENV = False | ||||||||||
|
|
||||||||||
| logger = get_logger() | ||||||||||
|
|
||||||||||
|
|
@@ -673,6 +678,18 @@ def _apply_rotary_pos_emb_thd( | |||||||||
|
|
||||||||||
|
|
||||||||||
| def _patch_megatron(): | ||||||||||
| if NPU_ENV: | ||||||||||
| override_transformer_config = { | ||||||||||
| "recompute_granularity": None, | ||||||||||
| "recompute_modules": ["core_attn"], | ||||||||||
| "recompute_method": None, | ||||||||||
| "recompute_num_layers": None | ||||||||||
| } | ||||||||||
|
Comment on lines
+682
to
+687
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The For example: # At the top of the file, e.g., after the logger is defined.
_NPU_OVERRIDE_TRANSFORMER_CONFIG = {
"recompute_granularity": None,
"recompute_modules": ["core_attn"],
"recompute_method": None,
"recompute_num_layers": None,
}
# ...
def _patch_megatron():
if NPU_ENV:
if repatch is not None:
repatch(_NPU_OVERRIDE_TRANSFORMER_CONFIG)
# ... |
||||||||||
| if repatch is not None: | ||||||||||
| repatch(override_transformer_config) | ||||||||||
| else: | ||||||||||
| print('repatch没有正确导入') | ||||||||||
|
Comment on lines
+690
to
+691
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to use the logger for warnings or errors instead of
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello. Please use English comments. 😊 |
||||||||||
|
|
||||||||||
| os.environ.pop('VLLM_USE_MODELSCOPE', None) | ||||||||||
| logging_level = logging.root.level | ||||||||||
| _patch_flash_attn() | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加一个判断吧,安全一些
from transformers.utils import is_torch_npu_available