diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 5d11e3353a..70f9a6c26b 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -1,23 +1,15 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os -from copy import copy, deepcopy -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Optional import torch -from packaging import version -from swift.llm import InferRequest, Template, VllmEngine, get_model_tokenizer -from swift.plugin import Metric -from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig -from .patch import patch_auto_config, patch_auto_tokenizer -from .utils import AdapterRequest, patch_vllm_memory_leak +from swift.llm import Template, VllmEngine try: # After setting the environment variables, import vllm. This way of writing allows lint to pass. os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' os.environ['VLLM_ENGINE_ITERATION_TIMEOUT_S'] = '3600' - import vllm - from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams, EngineArgs, LLM except Exception: raise @@ -56,23 +48,15 @@ def __init__( engine_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, ) -> None: - os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '0') - if engine_kwargs is None: - engine_kwargs = {} - patch_vllm_memory_leak() - self.use_async_engine = use_async_engine - self.processor = get_model_tokenizer( - model_id_or_path, - torch_dtype, - load_model=False, - download_model=True, + assert not use_async_engine # TODO + super().__init__( + model_id_or_path=model_id_or_path, + torch_dtype=torch_dtype, + use_async_engine=use_async_engine, model_type=model_type, use_hf=use_hf, hub_token=hub_token, - revision=revision)[1] - self._post_init(template) - - self._prepare_engine_kwargs( + revision=revision, gpu_memory_utilization=gpu_memory_utilization, tensor_parallel_size=tensor_parallel_size, pipeline_parallel_size=pipeline_parallel_size, @@ -81,78 +65,15 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, enforce_eager=enforce_eager, limit_mm_per_prompt=limit_mm_per_prompt, + device=device, + seed=seed, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, enable_prefix_caching=enable_prefix_caching, - device=device, - seed=seed, - distributed_executor_backend=distributed_executor_backend, enable_sleep_mode=enable_sleep_mode, + distributed_executor_backend=distributed_executor_backend, quantization=quantization, - **engine_kwargs, + engine_kwargs=engine_kwargs, + template=template, ) - self._prepare_engine() - self._load_generation_config() - - def _prepare_engine(self) -> None: - with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config): - engine = LLM(**self.engine_args.__dict__) - self.engine = engine - - @property - def inner_model(self): - return self.engine.llm_engine.model_executor.driver_worker.model_runner.model - - @property - def inner_model_executor(self): - return self.engine.llm_engine.model_executor - - def infer( - self, - infer_requests: List[InferRequest], - request_config: Optional[RequestConfig] = None, - metrics: Optional[List[Metric]] = None, - *, - template: Optional[Template] = None, - use_tqdm: Optional[bool] = None, - adapter_request: Optional[AdapterRequest] = None, - ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: - request_config = deepcopy(request_config or RequestConfig()) - if template is None: - template = self.default_template - template.set_mode('vllm') - batched_inputs, error_list = self._batch_encode( - infer_requests, template=template, strict=getattr(self, 'strict', True)) - self.set_default_max_tokens(request_config, batched_inputs) - - prompts = [] - for inputs in batched_inputs: - llm_inputs = {'prompt_token_ids': inputs['input_ids']} - mm_data = {} - for key in ['images', 'audios', 'videos']: - media_data = inputs.get(key) or [] - if media_data: - if version.parse(vllm.__version__) < version.parse('0.6'): - assert len(media_data) == 1, ( - f'The current version of vllm only supports single {key}. Please upgrade to vllm >= 0.6.0') - mm_data = {key.rstrip('s'): media_data[0]} - else: - mm_data = {key.rstrip('s'): media_data[0] if len(media_data) == 1 else media_data} - if mm_data: - llm_inputs['multi_modal_data'] = mm_data - prompts.append(llm_inputs) - - generation_configs = [] - seed = request_config.seed - assert seed >= 0, 'Seed is needed for GRPOVllmEngine.' - for i, _ in enumerate(prompts): - request_config = copy(request_config) - request_config.seed = seed + i - generation_config = self._prepare_generation_config(request_config) - self._add_stop_words(generation_config, request_config, template.template_meta) - generation_configs.append(generation_config) - outputs = self.engine.generate(prompts, generation_configs, use_tqdm=False) - return [ - self._create_chat_completion_response(result, template, generation_configs[0], '') for result in outputs - ] diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 7c0609be84..8cbaaac9e3 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -19,7 +19,7 @@ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid) from .infer_engine import InferEngine from .patch import patch_auto_config, patch_auto_tokenizer -from .utils import AdapterRequest, InferStreamer, patch_npu_vllm +from .utils import AdapterRequest, InferStreamer, patch_npu_vllm, patch_vllm_memory_leak try: # After setting the environment variables, import vllm. This way of writing allows lint to pass. @@ -70,6 +70,7 @@ def __init__( ) -> None: if engine_kwargs is None: engine_kwargs = {} + patch_vllm_memory_leak() self.use_async_engine = use_async_engine self.processor = get_model_tokenizer( model_id_or_path, diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 601ace0a06..9d44315947 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -86,7 +86,6 @@ def _register_rl_rollout_app(self): self.app.post('/infer/', response_model=None)(self.infer) def __init__(self, args: Union[List[str], DeployArguments, None] = None): - os.environ['VLLM_USE_V1'] = os.environ.get('VLLM_USE_V1', '1') super().__init__(args) safe_set_start_method() self.app = FastAPI(lifespan=self.lifespan) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 03b532a215..56886e8f2c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -411,11 +411,6 @@ def split_llm(name): def prepare_vllm(self, model): from swift.tuners import Swift from swift.llm.infer.infer_engine import GRPOVllmEngine - if self.vllm_tensor_parallel_size > 1: - vllm_kwargs = {'distributed_executor_backend': 'external_launcher'} - else: - vllm_kwargs = {} - max_num_seqs = ( self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.gradient_accumulation_steps) @@ -436,7 +431,8 @@ def prepare_vllm(self, model): max_model_len=self.args.vllm_max_model_len, seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, template=self.template, - **vllm_kwargs) + distributed_executor_backend='external_launcher', + ) return engine @contextmanager