From ffabbcb044f3d275c445b6e3e4075b2c48b347ad Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 27 Feb 2025 14:10:02 +0800 Subject: [PATCH 1/5] update --- swift/llm/infer/infer_engine/lmdeploy_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index d20b6db1a7..8d22a6fa9a 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -57,7 +57,7 @@ def __init__( version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0') if reload_weights: assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0' - if version_7: + if version_7 and tp == 1: patch_lmdeploy(reload_weights) self.processor = get_model_tokenizer( model_id_or_path, From a2aa03c0ec19b78d5e8e5b723648e6bc597454c2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 27 Feb 2025 14:30:22 +0800 Subject: [PATCH 2/5] fix lmdeploy tp --- swift/llm/argument/infer_args.py | 10 ++++- swift/llm/infer/infer.py | 12 +----- swift/llm/infer/infer_engine/utils.py | 43 ++++++++++++++++++++ swift/llm/infer/infer_engine/vllm_engine.py | 10 +++-- swift/llm/utils.py | 44 --------------------- swift/trainers/rlhf_trainer/grpo_trainer.py | 4 +- 6 files changed, 60 insertions(+), 63 deletions(-) diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 9d970f39f3..d7579ccf28 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -37,13 +37,16 @@ class LmdeployArguments: vision_batch_size: int = 1 # max_batch_size in VisionConfig def get_lmdeploy_engine_kwargs(self): - return { + kwargs = { 'tp': self.tp, 'session_len': self.session_len, 'cache_max_entry_count': self.cache_max_entry_count, 'quant_policy': self.quant_policy, 'vision_batch_size': self.vision_batch_size } + if dist.is_initialized(): + kwargs.update({'device': dist.get_rank()}) + return kwargs @dataclass @@ -82,7 +85,7 @@ def get_vllm_engine_kwargs(self): adapters = self.adapters if hasattr(self, 'adapter_mapping'): adapters = adapters + list(self.adapter_mapping.values()) - return { + kwargs = { 'gpu_memory_utilization': self.gpu_memory_utilization, 'tensor_parallel_size': self.tensor_parallel_size, 'pipeline_parallel_size': self.pipeline_parallel_size, @@ -96,6 +99,9 @@ def get_vllm_engine_kwargs(self): 'max_loras': max(len(adapters), 1), 'enable_prefix_caching': self.enable_prefix_caching, } + if dist.is_initialized(): + kwargs.update({'device': dist.get_rank()}) + return kwargs @dataclass diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index b8115e6b59..201613cfad 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -9,7 +9,6 @@ from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset from swift.plugin import InferStats, MeanMetric, compute_rouge_bleu from swift.utils import JsonlWriter, get_logger, is_master, read_from_jsonl -from ..utils import patch_vllm from .infer_engine import AdapterRequest, PtEngine from .protocol import RequestConfig from .utils import InferCliState @@ -68,20 +67,11 @@ def get_infer_engine(args: InferArguments, **kwargs): from .infer_engine import VllmEngine infer_engine_cls = VllmEngine kwargs.update(args.get_vllm_engine_kwargs()) - if dist.is_initialized(): - assert args.tensor_parallel_size == 1 and args.pipeline_parallel_size == 1, ( - 'not support tensor_parallel_size > 1 or pipeline_parallel_size > 1.') - context = patch_vllm() - kwargs.update({'device': dist.get_rank()}) else: from .infer_engine import LmdeployEngine infer_engine_cls = LmdeployEngine kwargs.update(args.get_lmdeploy_engine_kwargs()) - if dist.is_initialized(): - assert args.tp == 1, 'not support tp > 1.' - kwargs.update({'device': [dist.get_rank()]}) - with context: - return infer_engine_cls(**kwargs) + return infer_engine_cls(**kwargs) def run(self) -> List[Dict[str, Any]]: args = self.args diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/llm/infer/infer_engine/utils.py index e1637aae32..47d0b9d3f0 100644 --- a/swift/llm/infer/infer_engine/utils.py +++ b/swift/llm/infer/infer_engine/utils.py @@ -5,12 +5,15 @@ import re from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from itertools import repeat from queue import Queue from typing import List, Optional import torch +import torch.distributed as dist from packaging import version from transformers import GenerationConfig, LogitsProcessor from transformers.generation.streamers import BaseStreamer @@ -342,3 +345,43 @@ def _create_model_instance(self, device_id): TurboMindInstance.__origin_init__ = TurboMindInstance.__init__ TurboMindInstance.__init__ = __init_ins__ TurboMindInstance._create_model_instance = _create_model_instance + + +@contextmanager +def patch_vllm(): + from vllm.distributed.parallel_state import GroupCoordinator + from unittest.mock import patch + world_size_patch = patch('torch.distributed.get_world_size', return_value=1) + profiling_patch = patch( + 'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None) + __origin_init__ = GroupCoordinator.__init__ + + def __init__(self, group_ranks, *args, **kwargs): + rank = dist.get_rank() + if [rank] not in group_ranks: + group_ranks.append([rank]) + return __origin_init__(self, group_ranks, *args, **kwargs) + + GroupCoordinator.__init__ = __init__ + + try: + with world_size_patch, profiling_patch: + yield + finally: + GroupCoordinator.__init__ = __origin_init__ + + +def patch_npu_vllm(vllm_device: str): + device_type = vllm_device.split(':')[0] + + @contextlib.contextmanager + def new_group_context(): + original_new_group = torch.distributed.new_group + try: + torch.distributed.new_group = partial(original_new_group, use_local_synchronization=True) + torch.npu.mem_get_info = partial(torch.npu.mem_get_info, device=vllm_device) + yield + finally: + torch.distributed.new_group = original_new_group + + return new_group_context() if device_type == 'npu' else contextlib.nullcontext() diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 2ac300f90d..a8187b8c11 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -2,6 +2,7 @@ import asyncio import inspect import os +from contextlib import nullcontext from copy import deepcopy from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union @@ -16,7 +17,7 @@ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid) from .infer_engine import InferEngine from .patch import patch_auto_config, patch_auto_tokenizer -from .utils import AdapterRequest, InferStreamer +from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm try: # After setting the environment variables, import vllm. This way of writing allows lint to pass. @@ -86,8 +87,11 @@ def __init__( device=device, engine_kwargs=engine_kwargs, ) - - self._prepare_engine() + context, npu_context = nullcontext(), nullcontext() + if tensor_parallel_size == 1 and pipeline_parallel_size == 1: + context, npu_context = patch_vllm(), patch_npu_vllm(self.engine_args.device) + with context, npu_context: + self._prepare_engine() self._load_generation_config() self._fix_vllm_bug() self.patch_remove_log() diff --git a/swift/llm/utils.py b/swift/llm/utils.py index b7f90ea1f9..da4b42be38 100644 --- a/swift/llm/utils.py +++ b/swift/llm/utils.py @@ -1,16 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import contextlib -import functools import inspect import os import shutil import tempfile -from contextlib import contextmanager from types import MethodType from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import torch -import torch.distributed as dist import torch.nn as nn from modelscope.hub.utils.utils import get_cache_dir from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase @@ -268,43 +264,3 @@ def get_temporary_cache_files_directory(prefix=None): TEMP_DIR_POOL[prefix] = TEMP_DIR return TEMP_DIR.name - - -@contextmanager -def patch_vllm(): - from vllm.distributed.parallel_state import GroupCoordinator - from unittest.mock import patch - world_size_patch = patch('torch.distributed.get_world_size', return_value=1) - profiling_patch = patch( - 'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None) - __origin_init__ = GroupCoordinator.__init__ - - def __init__(self, group_ranks, *args, **kwargs): - rank = dist.get_rank() - if [rank] not in group_ranks: - group_ranks.append([rank]) - return __origin_init__(self, group_ranks, *args, **kwargs) - - GroupCoordinator.__init__ = __init__ - - try: - with world_size_patch, profiling_patch: - yield - finally: - GroupCoordinator.__init__ = __origin_init__ - - -def patch_npu_vllm(vllm_device: str): - device_type = vllm_device.split(':')[0] - - @contextlib.contextmanager - def new_group_context(): - original_new_group = torch.distributed.new_group - try: - torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True) - torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device) - yield - finally: - torch.distributed.new_group = original_new_group - - return new_group_context() if device_type == 'npu' else contextlib.nullcontext() diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1a70022d80..3963bc89af 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -169,9 +169,7 @@ def __init__(self, 'Please install vLLM with `pip install vllm` to use it.') from swift.llm import VllmEngine from swift.tuners import Swift - from swift.llm.utils import patch_vllm, patch_npu_vllm - npu_vllm_patch_context = patch_npu_vllm(fast_infer_device[self.local_infer_rank]) - with patch_vllm(), npu_vllm_patch_context, Swift.grpo_context(model, self.template.processor): + with Swift.grpo_context(model, self.template.processor): self.engine = VllmEngine( model.model_dir, model.model_info.torch_dtype, From 25233aaaee429562f72154452f25a2b47e631665 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 27 Feb 2025 14:39:40 +0800 Subject: [PATCH 3/5] fix --- swift/llm/infer/infer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 201613cfad..76c1453a6a 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -56,7 +56,6 @@ def get_infer_engine(args: InferArguments, **kwargs): 'torch_dtype': args.torch_dtype, }) infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend - context = nullcontext() if infer_backend == 'pt': from .infer_engine import PtEngine infer_engine_cls = PtEngine From 8cbfc5d665c23b3e8349cd0f0842b5473496fb93 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 27 Feb 2025 14:49:15 +0800 Subject: [PATCH 4/5] update --- swift/llm/argument/infer_args.py | 2 +- swift/llm/infer/infer_engine/lmdeploy_engine.py | 12 ++++++------ swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index d7579ccf28..4f20e053ba 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -45,7 +45,7 @@ def get_lmdeploy_engine_kwargs(self): 'vision_batch_size': self.vision_batch_size } if dist.is_initialized(): - kwargs.update({'device': dist.get_rank()}) + kwargs.update({'devices': [dist.get_rank()]}) return kwargs diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 8d22a6fa9a..79557ac337 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -50,7 +50,7 @@ def __init__( cache_max_entry_count: float = 0.8, quant_policy: int = 0, # e.g. 4, 8 vision_batch_size: int = 1, # max_batch_size in VisionConfig - device: Optional[List[int]] = None, + devices: Optional[List[int]] = None, reload_weights: bool = False, engine_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -78,7 +78,7 @@ def __init__( cache_max_entry_count=cache_max_entry_count, quant_policy=quant_policy, vision_batch_size=vision_batch_size, - device=device, + devices=devices, engine_kwargs=engine_kwargs) self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype @@ -102,7 +102,7 @@ def _prepare_engine_kwargs(self, cache_max_entry_count: float = 0.8, quant_policy: int = 0, vision_batch_size: int = 1, - device: Optional[List[int]] = None, + devices: Optional[List[int]] = None, engine_kwargs: Optional[Dict[str, Any]] = None): if engine_kwargs is None: engine_kwargs = {} @@ -113,9 +113,9 @@ def _prepare_engine_kwargs(self, backend_config = TurbomindEngineConfig(**engine_kwargs) backend_config = autoget_backend_config(self.model_dir, backend_config) if hasattr(backend_config, 'devices'): - if device is None: - device = [0] - backend_config.devices = device + if devices is None: + devices = [0] + backend_config.devices = devices self.backend_config = backend_config logger.info(f'backend_config: {backend_config}') diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 3963bc89af..1156c02298 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -202,7 +202,7 @@ def __init__(self, model.model_dir, model.model_info.torch_dtype, model_type=model.model_meta.model_type, - device=[fast_infer_device], + devices=[fast_infer_device], session_len=args.lmdeploy_session_len, cache_max_entry_count=args.lmdeploy_cache_max_entry_count, reload_weights=True) From c7bf5da1ba42c488f7c367d90e9c0ba15ce20c3e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 27 Feb 2025 15:00:06 +0800 Subject: [PATCH 5/5] fix --- swift/llm/infer/infer_engine/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/llm/infer/infer_engine/utils.py index 47d0b9d3f0..a448270563 100644 --- a/swift/llm/infer/infer_engine/utils.py +++ b/swift/llm/infer/infer_engine/utils.py @@ -19,6 +19,7 @@ from transformers.generation.streamers import BaseStreamer from swift.llm.model.register import fix_do_sample_warning +from swift.utils import get_device from ..protocol import RequestConfig @@ -372,6 +373,8 @@ def __init__(self, group_ranks, *args, **kwargs): def patch_npu_vllm(vllm_device: str): + if isinstance(vllm_device, int): + vllm_device = get_device(vllm_device) device_type = vllm_device.split(':')[0] @contextlib.contextmanager