From 8efb0eca4055683fd6124060805dced783840a60 Mon Sep 17 00:00:00 2001 From: vx120 <57470515+vx120@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:03:49 +0800 Subject: [PATCH] Implement NPU_ENV handling in init.py to support megatron in NPU. Add NPU_ENV check and repatch configuration for Megatron. --- swift/megatron/init.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d2b7b7cb95..85408b7ada 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -18,6 +18,11 @@ from swift.llm import git_clone_github from swift.utils import (get_logger, is_flash_attn_3_available, is_megatron_available, safe_ddp_context, split_list, subprocess_run) +try: + from mindspeed.megatron_adaptor import repatch + NPU_ENV = True +except ImportError: + NPU_ENV = False logger = get_logger() @@ -673,6 +678,18 @@ def _apply_rotary_pos_emb_thd( def _patch_megatron(): + if NPU_ENV: + override_transformer_config = { + "recompute_granularity": None, + "recompute_modules": ["core_attn"], + "recompute_method": None, + "recompute_num_layers": None + } + if repatch is not None: + repatch(override_transformer_config) + else: + print('repatch没有正确导入') + os.environ.pop('VLLM_USE_MODELSCOPE', None) logging_level = logging.root.level _patch_flash_attn()