diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d2b7b7cb95..85408b7ada 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -18,6 +18,11 @@ from swift.llm import git_clone_github from swift.utils import (get_logger, is_flash_attn_3_available, is_megatron_available, safe_ddp_context, split_list, subprocess_run) +try: + from mindspeed.megatron_adaptor import repatch + NPU_ENV = True +except ImportError: + NPU_ENV = False logger = get_logger() @@ -673,6 +678,18 @@ def _apply_rotary_pos_emb_thd( def _patch_megatron(): + if NPU_ENV: + override_transformer_config = { + "recompute_granularity": None, + "recompute_modules": ["core_attn"], + "recompute_method": None, + "recompute_num_layers": None + } + if repatch is not None: + repatch(override_transformer_config) + else: + print('repatch没有正确导入') + os.environ.pop('VLLM_USE_MODELSCOPE', None) logging_level = logging.root.level _patch_flash_attn()