Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,16 @@ class LmdeployArguments:
vision_batch_size: int = 1 # max_batch_size in VisionConfig

def get_lmdeploy_engine_kwargs(self):
return {
kwargs = {
'tp': self.tp,
'session_len': self.session_len,
'cache_max_entry_count': self.cache_max_entry_count,
'quant_policy': self.quant_policy,
'vision_batch_size': self.vision_batch_size
}
if dist.is_initialized():
kwargs.update({'devices': [dist.get_rank()]})
return kwargs


@dataclass
Expand Down Expand Up @@ -82,7 +85,7 @@ def get_vllm_engine_kwargs(self):
adapters = self.adapters
if hasattr(self, 'adapter_mapping'):
adapters = adapters + list(self.adapter_mapping.values())
return {
kwargs = {
'gpu_memory_utilization': self.gpu_memory_utilization,
'tensor_parallel_size': self.tensor_parallel_size,
'pipeline_parallel_size': self.pipeline_parallel_size,
Expand All @@ -96,6 +99,9 @@ def get_vllm_engine_kwargs(self):
'max_loras': max(len(adapters), 1),
'enable_prefix_caching': self.enable_prefix_caching,
}
if dist.is_initialized():
kwargs.update({'device': dist.get_rank()})
return kwargs


@dataclass
Expand Down
13 changes: 1 addition & 12 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset
from swift.plugin import InferStats, MeanMetric, compute_rouge_bleu
from swift.utils import JsonlWriter, get_logger, is_master, read_from_jsonl
from ..utils import patch_vllm
from .infer_engine import AdapterRequest, PtEngine
from .protocol import RequestConfig
from .utils import InferCliState
Expand Down Expand Up @@ -57,7 +56,6 @@ def get_infer_engine(args: InferArguments, **kwargs):
'torch_dtype': args.torch_dtype,
})
infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend
context = nullcontext()
if infer_backend == 'pt':
from .infer_engine import PtEngine
infer_engine_cls = PtEngine
Expand All @@ -68,20 +66,11 @@ def get_infer_engine(args: InferArguments, **kwargs):
from .infer_engine import VllmEngine
infer_engine_cls = VllmEngine
kwargs.update(args.get_vllm_engine_kwargs())
if dist.is_initialized():
assert args.tensor_parallel_size == 1 and args.pipeline_parallel_size == 1, (
'not support tensor_parallel_size > 1 or pipeline_parallel_size > 1.')
context = patch_vllm()
kwargs.update({'device': dist.get_rank()})
else:
from .infer_engine import LmdeployEngine
infer_engine_cls = LmdeployEngine
kwargs.update(args.get_lmdeploy_engine_kwargs())
if dist.is_initialized():
assert args.tp == 1, 'not support tp > 1.'
kwargs.update({'device': [dist.get_rank()]})
with context:
return infer_engine_cls(**kwargs)
return infer_engine_cls(**kwargs)

def run(self) -> List[Dict[str, Any]]:
args = self.args
Expand Down
14 changes: 7 additions & 7 deletions swift/llm/infer/infer_engine/lmdeploy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def __init__(
cache_max_entry_count: float = 0.8,
quant_policy: int = 0, # e.g. 4, 8
vision_batch_size: int = 1, # max_batch_size in VisionConfig
device: Optional[List[int]] = None,
devices: Optional[List[int]] = None,
reload_weights: bool = False,
engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0')
if reload_weights:
assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0'
if version_7:
if version_7 and tp == 1:
patch_lmdeploy(reload_weights)
self.processor = get_model_tokenizer(
model_id_or_path,
Expand All @@ -78,7 +78,7 @@ def __init__(
cache_max_entry_count=cache_max_entry_count,
quant_policy=quant_policy,
vision_batch_size=vision_batch_size,
device=device,
devices=devices,
engine_kwargs=engine_kwargs)

self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype
Expand All @@ -102,7 +102,7 @@ def _prepare_engine_kwargs(self,
cache_max_entry_count: float = 0.8,
quant_policy: int = 0,
vision_batch_size: int = 1,
device: Optional[List[int]] = None,
devices: Optional[List[int]] = None,
engine_kwargs: Optional[Dict[str, Any]] = None):
if engine_kwargs is None:
engine_kwargs = {}
Expand All @@ -113,9 +113,9 @@ def _prepare_engine_kwargs(self,
backend_config = TurbomindEngineConfig(**engine_kwargs)
backend_config = autoget_backend_config(self.model_dir, backend_config)
if hasattr(backend_config, 'devices'):
if device is None:
device = [0]
backend_config.devices = device
if devices is None:
devices = [0]
backend_config.devices = devices
self.backend_config = backend_config
logger.info(f'backend_config: {backend_config}')

Expand Down
46 changes: 46 additions & 0 deletions swift/llm/infer/infer_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import re
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from itertools import repeat
from queue import Queue
from typing import List, Optional

import torch
import torch.distributed as dist
from packaging import version
from transformers import GenerationConfig, LogitsProcessor
from transformers.generation.streamers import BaseStreamer

from swift.llm.model.register import fix_do_sample_warning
from swift.utils import get_device
from ..protocol import RequestConfig


Expand Down Expand Up @@ -342,3 +346,45 @@ def _create_model_instance(self, device_id):
TurboMindInstance.__origin_init__ = TurboMindInstance.__init__
TurboMindInstance.__init__ = __init_ins__
TurboMindInstance._create_model_instance = _create_model_instance


@contextmanager
def patch_vllm():
from vllm.distributed.parallel_state import GroupCoordinator
from unittest.mock import patch
world_size_patch = patch('torch.distributed.get_world_size', return_value=1)
profiling_patch = patch(
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)
__origin_init__ = GroupCoordinator.__init__

def __init__(self, group_ranks, *args, **kwargs):
rank = dist.get_rank()
if [rank] not in group_ranks:
group_ranks.append([rank])
return __origin_init__(self, group_ranks, *args, **kwargs)

GroupCoordinator.__init__ = __init__

try:
with world_size_patch, profiling_patch:
yield
finally:
GroupCoordinator.__init__ = __origin_init__


def patch_npu_vllm(vllm_device: str):
if isinstance(vllm_device, int):
vllm_device = get_device(vllm_device)
device_type = vllm_device.split(':')[0]

@contextlib.contextmanager
def new_group_context():
original_new_group = torch.distributed.new_group
try:
torch.distributed.new_group = partial(original_new_group, use_local_synchronization=True)
torch.npu.mem_get_info = partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = original_new_group

return new_group_context() if device_type == 'npu' else contextlib.nullcontext()
10 changes: 7 additions & 3 deletions swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import inspect
import os
from contextlib import nullcontext
from copy import deepcopy
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

Expand All @@ -16,7 +17,7 @@
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
from .infer_engine import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .utils import AdapterRequest, InferStreamer
from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm

try:
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
Expand Down Expand Up @@ -86,8 +87,11 @@ def __init__(
device=device,
engine_kwargs=engine_kwargs,
)

self._prepare_engine()
context, npu_context = nullcontext(), nullcontext()
if tensor_parallel_size == 1 and pipeline_parallel_size == 1:
context, npu_context = patch_vllm(), patch_npu_vllm(self.engine_args.device)
with context, npu_context:
self._prepare_engine()
self._load_generation_config()
self._fix_vllm_bug()
self.patch_remove_log()
Expand Down
44 changes: 0 additions & 44 deletions swift/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import functools
import inspect
import os
import shutil
import tempfile
from contextlib import contextmanager
from types import MethodType
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from modelscope.hub.utils.utils import get_cache_dir
from transformers import FeatureExtractionMixin, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase
Expand Down Expand Up @@ -268,43 +264,3 @@ def get_temporary_cache_files_directory(prefix=None):
TEMP_DIR_POOL[prefix] = TEMP_DIR

return TEMP_DIR.name


@contextmanager
def patch_vllm():
from vllm.distributed.parallel_state import GroupCoordinator
from unittest.mock import patch
world_size_patch = patch('torch.distributed.get_world_size', return_value=1)
profiling_patch = patch(
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)
__origin_init__ = GroupCoordinator.__init__

def __init__(self, group_ranks, *args, **kwargs):
rank = dist.get_rank()
if [rank] not in group_ranks:
group_ranks.append([rank])
return __origin_init__(self, group_ranks, *args, **kwargs)

GroupCoordinator.__init__ = __init__

try:
with world_size_patch, profiling_patch:
yield
finally:
GroupCoordinator.__init__ = __origin_init__


def patch_npu_vllm(vllm_device: str):
device_type = vllm_device.split(':')[0]

@contextlib.contextmanager
def new_group_context():
original_new_group = torch.distributed.new_group
try:
torch.distributed.new_group = functools.partial(original_new_group, use_local_synchronization=True)
torch.npu.mem_get_info = functools.partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = original_new_group

return new_group_context() if device_type == 'npu' else contextlib.nullcontext()
6 changes: 2 additions & 4 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def __init__(self,
'Please install vLLM with `pip install vllm` to use it.')
from swift.llm import VllmEngine
from swift.tuners import Swift
from swift.llm.utils import patch_vllm, patch_npu_vllm
npu_vllm_patch_context = patch_npu_vllm(fast_infer_device[self.local_infer_rank])
with patch_vllm(), npu_vllm_patch_context, Swift.grpo_context(model, self.template.processor):
with Swift.grpo_context(model, self.template.processor):
self.engine = VllmEngine(
model.model_dir,
model.model_info.torch_dtype,
Expand Down Expand Up @@ -204,7 +202,7 @@ def __init__(self,
model.model_dir,
model.model_info.torch_dtype,
model_type=model.model_meta.model_type,
device=[fast_infer_device],
devices=[fast_infer_device],
session_len=args.lmdeploy_session_len,
cache_max_entry_count=args.lmdeploy_cache_max_entry_count,
reload_weights=True)
Expand Down
Loading