Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 13 additions & 92 deletions swift/llm/infer/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import copy, deepcopy
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Optional

import torch
from packaging import version

from swift.llm import InferRequest, Template, VllmEngine, get_model_tokenizer
from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig
from .patch import patch_auto_config, patch_auto_tokenizer
from .utils import AdapterRequest, patch_vllm_memory_leak
from swift.llm import Template, VllmEngine

try:
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600'
import vllm
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLM
except Exception:
raise

Expand Down Expand Up @@ -56,23 +48,15 @@ def __init__(
engine_kwargs: Optional[Dict[str, Any]] = None,
template: Optional[Template] = None,
) -> None:
os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '0')
if engine_kwargs is None:
engine_kwargs = {}
patch_vllm_memory_leak()
self.use_async_engine = use_async_engine
self.processor = get_model_tokenizer(
model_id_or_path,
torch_dtype,
load_model=False,
download_model=True,
assert not use_async_engine # TODO
super().__init__(
model_id_or_path=model_id_or_path,
torch_dtype=torch_dtype,
use_async_engine=use_async_engine,
model_type=model_type,
use_hf=use_hf,
hub_token=hub_token,
revision=revision)[1]
self._post_init(template)

self._prepare_engine_kwargs(
revision=revision,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
Expand All @@ -81,78 +65,15 @@ def __init__(
disable_custom_all_reduce=disable_custom_all_reduce,
enforce_eager=enforce_eager,
limit_mm_per_prompt=limit_mm_per_prompt,
device=device,
seed=seed,
enable_lora=enable_lora,
max_loras=max_loras,
max_lora_rank=max_lora_rank,
enable_prefix_caching=enable_prefix_caching,
device=device,
seed=seed,
distributed_executor_backend=distributed_executor_backend,
enable_sleep_mode=enable_sleep_mode,
distributed_executor_backend=distributed_executor_backend,
quantization=quantization,
**engine_kwargs,
engine_kwargs=engine_kwargs,
template=template,
)
self._prepare_engine()
self._load_generation_config()

def _prepare_engine(self) -> None:
with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config):
engine = LLM(**self.engine_args.__dict__)
self.engine = engine

@property
def inner_model(self):
return self.engine.llm_engine.model_executor.driver_worker.model_runner.model

@property
def inner_model_executor(self):
return self.engine.llm_engine.model_executor

def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
template: Optional[Template] = None,
use_tqdm: Optional[bool] = None,
adapter_request: Optional[AdapterRequest] = None,
) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
request_config = deepcopy(request_config or RequestConfig())
if template is None:
template = self.default_template
template.set_mode('vllm')
batched_inputs, error_list = self._batch_encode(
infer_requests, template=template, strict=getattr(self, 'strict', True))
self.set_default_max_tokens(request_config, batched_inputs)

prompts = []
for inputs in batched_inputs:
llm_inputs = {'prompt_token_ids': inputs['input_ids']}
mm_data = {}
for key in ['images', 'audios', 'videos']:
media_data = inputs.get(key) or []
if media_data:
if version.parse(vllm.__version__) < version.parse('0.6'):
assert len(media_data) == 1, (
f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0')
mm_data = {key.rstrip('s'): media_data[0]}
else:
mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data}
if mm_data:
llm_inputs['multi_modal_data'] = mm_data
prompts.append(llm_inputs)

generation_configs = []
seed = request_config.seed
assert seed >= 0, 'Seed is needed for GRPOVllmEngine.'
for i, _ in enumerate(prompts):
request_config = copy(request_config)
request_config.seed = seed + i
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config, template.template_meta)
generation_configs.append(generation_config)
outputs = self.engine.generate(prompts, generation_configs, use_tqdm=False)
return [
self._create_chat_completion_response(result, template, generation_configs[0], '') for result in outputs
]
3 changes: 2 additions & 1 deletion swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
from .infer_engine import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .utils import AdapterRequest, InferStreamer, patch_npu_vllm
from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm_memory_leak

try:
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(
) -> None:
if engine_kwargs is None:
engine_kwargs = {}
patch_vllm_memory_leak()
self.use_async_engine = use_async_engine
self.processor = get_model_tokenizer(
model_id_or_path,
Expand Down
1 change: 0 additions & 1 deletion swift/llm/infer/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def _register_rl_rollout_app(self):
self.app.post('/infer/', response_model=None)(self.infer)

def __init__(self, args: Union[List[str], DeployArguments, None] = None):
os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '1')
super().__init__(args)
safe_set_start_method()
self.app = FastAPI(lifespan=self.lifespan)
Expand Down
8 changes: 2 additions & 6 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,6 @@ def split_llm(name):
def prepare_vllm(self, model):
from swift.tuners import Swift
from swift.llm.infer.infer_engine import GRPOVllmEngine
if self.vllm_tensor_parallel_size > 1:
vllm_kwargs = {'distributed_executor_backend': 'external_launcher'}
else:
vllm_kwargs = {}

max_num_seqs = (
self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size
* self.args.gradient_accumulation_steps)
Expand All @@ -436,7 +431,8 @@ def prepare_vllm(self, model):
max_model_len=self.args.vllm_max_model_len,
seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
template=self.template,
**vllm_kwargs)
distributed_executor_backend='external_launcher',
)
return engine

@contextmanager
Expand Down
Loading