Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions swift/megatron/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from swift.llm import git_clone_github
from swift.utils import (get_logger, is_flash_attn_3_available, is_megatron_available, safe_ddp_context, split_list,
subprocess_run)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加一个判断吧,安全一些

from transformers.utils import is_torch_npu_available

from mindspeed.megatron_adaptor import repatch
NPU_ENV = True
except ImportError:
NPU_ENV = False

logger = get_logger()

Expand Down Expand Up @@ -673,6 +678,18 @@ def _apply_rotary_pos_emb_thd(


def _patch_megatron():
if NPU_ENV:
override_transformer_config = {
"recompute_granularity": None,
"recompute_modules": ["core_attn"],
"recompute_method": None,
"recompute_num_layers": None
}
Comment on lines +682 to +687
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The override_transformer_config dictionary is static. It's more efficient and readable to define it as a constant at the module level, outside of the _patch_megatron function.

For example:

# At the top of the file, e.g., after the logger is defined.
_NPU_OVERRIDE_TRANSFORMER_CONFIG = {
    "recompute_granularity": None,
    "recompute_modules": ["core_attn"],
    "recompute_method": None,
    "recompute_num_layers": None,
}

# ...

def _patch_megatron():
    if NPU_ENV:
        if repatch is not None:
            repatch(_NPU_OVERRIDE_TRANSFORMER_CONFIG)
        # ...

if repatch is not None:
repatch(override_transformer_config)
else:
print('repatch没有正确导入')
Comment on lines +690 to +691
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to use the logger for warnings or errors instead of print(). This ensures consistent logging behavior. Also, using an English message would improve maintainability and consistency with the rest of the codebase.

Suggested change
else:
print('repatch没有正确导入')
else:
logger.warning('`repatch` from `mindspeed.megatron_adaptor` was not imported correctly, skipping NPU patch.')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello. Please use English comments. 😊


os.environ.pop('VLLM_USE_MODELSCOPE', None)
logging_level = logging.root.level
_patch_flash_attn()
Expand Down
Loading