diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9c9025a96..3a1b3e86e 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -15,10 +15,8 @@ from dataclasses import dataclass, field import torch -import torch.distributed.checkpoint as dcp import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh -from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs @@ -37,6 +35,8 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase @@ -50,7 +50,6 @@ ) from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh -from forge.data.sharding import VLLMSharding from forge.data_models.completion import Completion from forge.data_models.prompt import to_prompt from forge.env import TORCHSTORE_USE_RDMA @@ -65,17 +64,38 @@ @dataclass class Policy(PolicyInterface): + """Instance of a vLLM-based Policy. + + This class manually recreates a vLLM engine that mirrors the design of AsyncLLMEngine in v1. The + main difference is that all communications are controlled here via Monarch's proc meshes. + + Args: + engine_args (EngineArgs): The engine arguments to use for the vLLM engine. + sampling_params (SamplingParams): The sampling parameters to use for the vLLM engine. + available_devices (str): The available devices to use for the vLLM engine. + use_dcp (bool): Whether to use DCP for NFS-based weight sync. + + Example: + >>> policy = await Policy.options(procs=1, num_replicas=1, with_gpus=True).as_service( + ... engine_args=EngineArgs(...), + ... sampling_params=SamplingParams(...), + ... ) + >>> await policy.generate("Tell me a joke") + Completion(prompt="Tell me a joke", text="A: Why did the chicken cross the road? B: To get to the other side.", + token_ids=[...], logprobs=[...]) + >>> await policy.shutdown() + """ + engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) available_devices: str | None = None use_dcp: bool = ( TORCHSTORE_USE_RDMA.get_value() == 0 ) # torchstore currently only accepts 0 or 1 - # Gets set up by setup + # Remaining variables are initialized in self.setup() lora_request: LoRARequest | None = None tokenization_kwargs: dict = field(default_factory=dict) - policy_worker: "PolicyWorker" = None - policy_version: int | None = None + policy_worker: PolicyWorker | None = None def __post_init__(self): super().__init__() @@ -83,6 +103,7 @@ def __post_init__(self): self._policy_proc: ProcMesh | None = None self._worker_procs: ProcMesh | None = None self.running = False + self.policy_version: int = 0 if isinstance(self.engine_args, Mapping): self.engine_args = EngineArgs(**self.engine_args) @@ -99,11 +120,20 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] engine_args: EngineArgs | Mapping = EngineArgs(), sampling_params: SamplingParams | Mapping = SamplingParams(), available_devices: str | None = None, - use_dcp: bool = True, + use_dcp: bool = ( + TORCHSTORE_USE_RDMA.get_value() == 0 + ), # torchstore currently only accepts 0 or 1 **kwargs, ) -> "Policy": - # Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES - # automatically. + """Launch the policy with its workers. + + We overwrite the default Service launch method in order to setup Actors (PolicyWorker) within this "coordinating" Actor. + We first create a proc_mesh for the workers, then a proc_mesh for the policy, and then we spawn the workers + and the policy in setup. + + The args here generally should match those in the `__init__` method of the Policy class. + """ + # Note: get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES process_config: ProcessConfig = ProcessConfig( procs=cls.procs, hosts=cls.hosts, @@ -155,82 +185,66 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] await policy.setup.call() return policy - @classmethod - async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], actor: "Policy" - ): - assert ( - actor._policy_proc is not None - ), "Tried to shutdown a policy that was not initialized correctly" - assert ( - actor._worker_procs is not None - ), "Tried to shutdown a policy that was not initialized correctly" - - # TODO - may want to expand stop to gracefully respond to - # ongoing requests. - await actor.stop.call() - await stop_proc_mesh(actor._worker_procs) - await stop_proc_mesh(actor._policy_proc) - @endpoint async def setup(self): - # Set up policy_worker - assert self.policy_worker is not None, "Policy worker should not be None" + """Mirrors the __init__ of vLLM's LLMEngine.""" + if self.policy_worker is None: + raise RuntimeError( + "Policy worker should not be None. Usually it would be attached to Policy in the ``launch`` method." + ) await self.policy_worker.setup.call() self.request_id = 0 - self.policy_version = 0 self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {} # TODO: Investigate whether this can be combined with `policy.running` - # Whether this policy is accepting requests. self.accepting_requests = True - # Guard for accepting_requests - self.request_lock = asyncio.Condition() - # Guard for updating requests - self.update_lock = asyncio.Condition() - self.vllm_config: VllmConfig = self.engine_args.create_engine_config( + self.request_lock = asyncio.Condition() # Guard for accepting_requests + self.update_lock = asyncio.Condition() # Guard for updating requests + + vllm_config: VllmConfig = self.engine_args.create_engine_config( UsageContext.LLM_CLASS ) + self.max_model_len = vllm_config.model_config.max_model_len # Setup processors # TODO: move all processing to the Environment # TODO: add support for `log_stats` and `mm_registry` tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config, ) self.processor = Processor( - vllm_config=self.vllm_config, tokenizer=tokenizer, mm_registry=None + vllm_config=vllm_config, tokenizer=tokenizer, mm_registry=None ) self.output_processor = OutputProcessor(tokenizer, log_stats=None) - # Setup scheduler - # TODO: Add support for `log_stats` + # Configure KV caches kv_cache_configs = await self.policy_worker.setup_kv_cache.call() _, kv_cache_config = next(kv_cache_configs.items()) - self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - self.vllm_config.cache_config.num_cpu_blocks = 0 + vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + vllm_config.cache_config.num_cpu_blocks = 0 - structured_output_manager = StructuredOutputManager(self.vllm_config) + # Setup scheduler + # TODO: Add support for `log_stats` + structured_output_manager = StructuredOutputManager(vllm_config) self.scheduler = Scheduler( - vllm_config=self.vllm_config, + vllm_config=vllm_config, kv_cache_config=kv_cache_config, structured_output_manager=structured_output_manager, include_finished_set=False, log_stats=None, ) - self.start_processing() + self._start_processing() - def start_processing(self): - """Start the replica's processing loop if not already running.""" + def _start_processing(self): if self._run_task is None or self._run_task.done(): self._run_task = asyncio.create_task(self.run()) @endpoint - async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: + async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: """Generate a response for the given prompt Args: @@ -238,41 +252,33 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: priority (int, optional): The priority of the request. Defaults to 0. Returns: - RequestOutput: vLLM class with the generated response. + list[Completion]: n completions from vLLM based on your prompt. """ t = Tracer("policy_perf/generate", timer="gpu") t.start() - record_metric("policy/generate/count_requests", 1, Reduce.SUM) self.request_id += 1 % sys.maxsize - request_id = str(self.request_id) # implement from a counter - - # Wraps prompt into a dict - prompt_dict: dict[str, str] = convert_input(prompt=prompt) + request_id = str(self.request_id) - # truncate prmpt tokenization_kwargs = self.tokenization_kwargs or {} # TODO: add truncation support https://github.com/vllm-project/vllm/issues/4507 truncate_prompt_tokens = self.sampling_params.truncate_prompt_tokens _validate_truncation_size( - self.vllm_config.model_config.max_model_len, + self.max_model_len, truncate_prompt_tokens, tokenization_kwargs, ) - t.step("prompt_truncation") - - # process and tokenize prompt prompt_str, request = self.processor.process_inputs( request_id=request_id, - prompt=prompt_dict, + prompt={"prompt": prompt}, params=self.sampling_params, arrival_time=None, lora_request=self.lora_request, tokenization_kwargs=tokenization_kwargs, trace_headers=None, priority=priority, - data_parallel_rank=None, + data_parallel_rank=None, # We do not support DP ) t.step("process_inputs") @@ -282,15 +288,12 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: async with self.request_lock: await self.request_lock.wait_for(lambda: self.accepting_requests) - # Explicitly keeping the redundant logic to make it easier to pick up - # vllm changes - # TODO: Clean up before release + # Explicitly keeping the redundant logic to make it easier to pick up vLLM changes if (num_samples := self.sampling_params.n) == 1: self.output_processor.add_request(request, prompt_str, None, 0) - request, _ = self.preprocess_add_request(request) + request, _ = self._preprocess_add_request(request) request_fut = asyncio.Future() self.requests[request_id] = (None, request_fut) - self.scheduler.add_request(request) else: parent_req = ParentRequest(request_id, self.sampling_params) @@ -304,8 +307,7 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: self.output_processor.add_request( child_request, prompt_str, parent_req, idx ) - child_request, _ = self.preprocess_add_request(child_request) - + child_request, _ = self._preprocess_add_request(child_request) self.scheduler.add_request(child_request) request_fut = asyncio.Future() self.requests[request_id] = (parent_req, request_fut) @@ -313,6 +315,7 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: completions = await request_fut t.step("generate") + # Log some metrics record_metric( "policy/generate/count_sequences_completed", len(completions), @@ -332,35 +335,34 @@ async def generate(self, prompt: str, priority: int = 0) -> list[Completion]: num_generated_tokens, Reduce.MEAN, ) - t.stop() - return completions - # Abstracted to match vllm - # https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 - def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: + def _preprocess_add_request( + self, request: EngineCoreRequest + ) -> tuple[Request, int]: + """ (forge/issues/332) Will require attention when we bump vllm versions + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") - request: Request = Request.from_engine_core_request(request) - if request.use_structured_output: + req = Request.from_engine_core_request(request) + if req.use_structured_output: self.scheduler.structured_output_manager.grammar_init(request) + return req, 0 - return request, 0 # Unused Arg: Current Wave - - async def run(self): - # TODO: add support for `iteration_stats` + async def run(self) -> None: + """Schedule, execute, and make output. + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L276 + """ # TODO: move postprocessing out of loop to not block self.running = True while self.running: - scheduler_output = self.scheduler.schedule() - worker_outputs = await self.policy_worker.execute_model.call( scheduler_output ) - # the results of `execute_model` is gathered on the driver rank (rank 0) + # The results of `execute_model` are gathered on the driver rank (rank 0) _, worker_output = next(worker_outputs.items()) outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() @@ -369,9 +371,8 @@ async def run(self): processed_outputs = self.output_processor.process_outputs( outputs.outputs, engine_core_timestamp=outputs.timestamp, - iteration_stats=None, + iteration_stats=None, # TODO: add support for `iteration_stats` ) - for request_output in processed_outputs.request_outputs: if request_output.finished: completions = self._to_completions(request_output) @@ -384,13 +385,24 @@ async def run(self): self.request_lock.notify_all() @endpoint - async def update_weights(self, policy_version: int): + async def update_weights(self, policy_version: int) -> None: + """Update weights on base model from a policy version to be found in a torchstore volume. + + Args: + policy_version (int): Policy version from which to update. This will correspond to a key in a + torchstore volume. + + Example: + >>> trainer.train_step(...) + >>> version += 1 + >>> await trainer.push_weights() + >>> policy.update_weights(version) + """ # Serialize updates (only one update at a time) async with self.update_lock: # Grab the lock to stop accepting requests and wait on pending requests async with self.request_lock: self.accepting_requests = False - curr_requests = [fut for _, fut in self.requests.values()] if curr_requests: # Record pending requests metrics @@ -415,7 +427,8 @@ async def update_weights(self, policy_version: int): record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM) logger.debug(f"Starting weight update on {self.__class__.__name__}") - await self.policy_worker.update.call(version=policy_version) + # Call update_weights on every policy_worker + await self.policy_worker.update_weights.call(policy_version) self.policy_version = policy_version # After updating the weights, we need to reset the KV cache @@ -441,20 +454,8 @@ async def get_version(self) -> int: async def stop(self): self.running = False - @endpoint - async def _test_save_model_params(self): - """Save model parameters before weight update, used for tesing purposes only.""" - logger.info("[Policy] save model parameters for testing.") - await self.policy_worker._test_save_model_params.call() - - @endpoint - async def _test_validate_model_params(self, validate_fn): - """Validate updated model params using validate_fn.""" - logger.info("[Policy] start validating model parameters.") - return await self.policy_worker._test_validate_model_params.call(validate_fn) - def _to_completions(self, request_output: RequestOutput) -> list[Completion]: - """Convert a RequestOutput to a list of Completion objects.""" + """Convert a vLLM RequestOutput to a list of Completion objects.""" completions = [] original_prompt = request_output.prompt prompt_token_ids = request_output.prompt_token_ids @@ -473,27 +474,57 @@ def _to_completions(self, request_output: RequestOutput) -> list[Completion]: metadata={"num_cached_tokens": request_output.num_cached_tokens}, ) ) - return completions - def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None: - """ - Extract log probabilities from a sample, if available. - """ - if one_sample.logprobs is not None: + def _extract_logprobs(self, sample: CompletionOutput) -> torch.Tensor | None: + if sample.logprobs is not None: return torch.tensor( [ top_k_dict[token].logprob - for token, top_k_dict in zip( - one_sample.token_ids, one_sample.logprobs - ) + for token, top_k_dict in zip(sample.token_ids, sample.logprobs) ] ) return None + @classmethod + async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] + cls: type["Policy"], actor: "Policy" + ): + assert ( + actor._policy_proc is not None + ), "Tried to shutdown a policy that was not initialized correctly" + assert ( + actor._worker_procs is not None + ), "Tried to shutdown a policy that was not initialized correctly" + + # TODO - may want to expand stop to gracefully respond to + # ongoing requests. + await actor.stop.call() + await stop_proc_mesh(actor._worker_procs) + await stop_proc_mesh(actor._policy_proc) + + @endpoint + async def _test_save_model_params(self): + """Save model parameters before weight update, used for tesing purposes only.""" + logger.info("[Policy] save model parameters for testing.") + await self.policy_worker._test_save_model_params.call() + + @endpoint + async def _test_validate_model_params(self, validate_fn): + """Validate updated model params using validate_fn.""" + logger.info("[Policy] start validating model parameters.") + return await self.policy_worker._test_validate_model_params.call(validate_fn) + @dataclass class PolicyWorker(ForgeActor): + """Mirrors a vLLM GPUWorker + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/worker/gpu_worker.py + + In general, this class should not be instantiated or called directly. Rather, the Policy controls + the creation and invocation of all PolicyWorkers. + """ + vllm_config: VllmConfig state_dict_key: str = "model_state_dict" # TODO: remove this later since no plumbing exists to change this value. @@ -510,68 +541,68 @@ def __post_init__(self): async def setup(self): self.rank = current_rank().rank os.environ["RANK"] = str(self.rank) - self.worker = self.setup_worker() + parallel_config = self.vllm_config.parallel_config + set_multiprocessing_worker_envs(parallel_config) + ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT") + distributed_init_method = get_distributed_init_method(ip, port) + all_kwargs = [{}] * parallel_config.world_size + local_rank = self.rank % torch.accelerator.device_count() + is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0 + all_kwargs[self.rank] = { + "vllm_config": self.vllm_config, + "local_rank": local_rank, + "rank": self.rank, + "distributed_init_method": distributed_init_method, + "is_driver_worker": is_driver_worker, + } + self.worker = WorkerWrapperBase(self.vllm_config, self.rank) + self.worker.init_worker(all_kwargs) + self.worker.init_device() + self.worker.load_model() @endpoint - async def execute_model(self, schedule: SchedulerOutput): - return self.worker.execute_model(schedule) + async def setup_kv_cache(self) -> KVCacheConfig: + """https://github.com/vllm-project/vllm/blob/5c7fe25491825b95936c011a43337c7d4fb7e472/vllm/v1/engine/core.py#L199""" + kv_cache_spec = self.worker.get_kv_cache_spec() + if kv_cache_spec is not None: + available_gpu_memory = self.worker.determine_available_memory() + else: + # Attention free models don't need memory for kv cache + available_gpu_memory = 0 - async def _load_tensor_parallel_state_dict( - self, current_state_dict: dict, version: int - ): - """ - Load full state dict from torchstore into tensor parallel model with deterministic sharding. - """ - sharding = VLLMSharding( - self.vllm_config.parallel_config.tensor_parallel_size, self.rank + # Get the kv cache tensor size + kv_cache_config = get_kv_cache_config( + self.vllm_config, kv_cache_spec, available_gpu_memory ) + # TODO: unify configs across TorchStore + # unify_kv_cache_configs(kv_cache_configs) + self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks + self.vllm_config.cache_config.num_cpu_blocks = 0 - checkpoint_id = f"{self.state_dict_key}{DELIM}{version}" - dcp_metadata = None - if self.use_dcp: - dcp_metadata = await ts.get(checkpoint_id) - - for param_name in current_state_dict.keys(): - current_tensor = current_state_dict[param_name] + # Initialize kv cache and warmup the execution: + # from multiproc_executor.py:MultiprocExecutor.initialize_from_config + kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size + kv_cache_configs[self.rank] = kv_cache_config + self.worker.initialize_from_config(kv_cache_configs) + self.worker.compile_or_warm_up_model() + self.worker.initialize_cache(kv_cache_config.num_blocks, 0) + return kv_cache_config - # Load the full tensor from torchstore - # TODO: only get the part of the tensor that is needed - if self.use_dcp: - tensor_meta = dcp_metadata.state_dict_metadata[param_name] - stored_tensor = torch.empty( - size=tensor_meta.size, dtype=tensor_meta.properties.dtype - ) - dcp.load( - checkpoint_id=checkpoint_id, state_dict={param_name: stored_tensor} - ) - else: - stored_tensor = await ts.get(f"{checkpoint_id}{DELIM}{param_name}") - sharding.load_from_source_to_target( - param_name, - stored_tensor, - current_tensor, - ) + @endpoint + async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput: + return self.worker.execute_model(schedule) @endpoint - async def update(self, version: int): - """Update model weights by reading state dict from torchstore""" - logger.info( - f"[PolicyWorker::update] start updating weights to version {version}" - ) + async def update_weights(self, policy_version: int) -> None: model = self.worker.model_runner.model - prefix = get_param_prefix(version) - logger.debug(f"{prefix=}") + prefix = get_param_prefix(policy_version) matching_keys = await ts.keys(prefix) - logger.debug(f"{matching_keys=}") - dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(policy_version) loaded_weights = set() - t = Tracer("policy_worker_perf/update", timer="gpu") + t = Tracer("policy_worker_perf/update_weights", timer="gpu") t.start() # Entire state dict is stored in a single DCP handle if dcp_whole_state_dict_key in matching_keys: - logger.info( - f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" - ) dcp_handle = await ts.get(dcp_whole_state_dict_key) hf_param_names = dcp_handle.param_names for name in hf_param_names: @@ -584,43 +615,12 @@ async def update(self, version: int): # We can't pass a generator since vllm load_weights is not async. # Instead, we just call load_weights with one parameter at a time. for name in hf_param_names: - param_key = get_param_key(version, name) + param_key = get_param_key(policy_version, name) param = await ts.get(param_key) loaded = model.load_weights([(name, param)]) del param loaded_weights.update(loaded) t.stop() - logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") - - @endpoint - async def setup_kv_cache(self): - """Based on vllm/v1/engine/core.py:EngineCore._initialize_kv_caches - TODO: test that fails if vllm method updates - """ - kv_cache_spec = self.worker.get_kv_cache_spec() - if kv_cache_spec is not None: - available_gpu_memory = self.worker.determine_available_memory() - else: - # Attention free models don't need memory for kv cache - available_gpu_memory = 0 - - # Get the kv cache tensor size - kv_cache_config = get_kv_cache_config( - self.vllm_config, kv_cache_spec, available_gpu_memory - ) - # TODO: unify configs across TorchStore - # unify_kv_cache_configs(kv_cache_configs) - self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks - self.vllm_config.cache_config.num_cpu_blocks = 0 - - # Initialize kv cache and warmup the execution: - # from multiproc_executor.py:MultiprocExecutor.initialize_from_config - kv_cache_configs = [None] * self.vllm_config.parallel_config.world_size - kv_cache_configs[self.rank] = kv_cache_config - self.worker.initialize_from_config(kv_cache_configs) - self.worker.compile_or_warm_up_model() - self.worker.initialize_cache(kv_cache_config.num_blocks, 0) - return kv_cache_config @endpoint async def _test_save_model_params(self): @@ -640,32 +640,3 @@ async def _test_validate_model_params(self, validate_fn): return validate_fn( self._test_prev_params, self.worker.model_runner.model, logger ) - - def setup_worker(self): - """Build and Instantiate vLLM worker""" - parallel_config = self.vllm_config.parallel_config - set_multiprocessing_worker_envs(parallel_config) - ip, port = os.getenv("MASTER_ADDR"), os.getenv("MASTER_PORT") - distributed_init_method = get_distributed_init_method(ip, port) - all_kwargs = [{}] * parallel_config.world_size - local_rank = self.rank % torch.accelerator.device_count() - is_driver_worker = self.rank % parallel_config.tensor_parallel_size == 0 - all_kwargs[self.rank] = { - "vllm_config": self.vllm_config, - "local_rank": local_rank, - "rank": self.rank, - "distributed_init_method": distributed_init_method, - "is_driver_worker": is_driver_worker, - } - worker = WorkerWrapperBase(self.vllm_config, self.rank) - worker.init_worker(all_kwargs) - worker.init_device() - worker.load_model() - return worker - - -def convert_input(prompt=None, prompt_token_ids=None) -> dict: - assert (prompt is None) ^ (prompt_token_ids is None) - if prompt is not None: - return {"prompt": prompt} - return {"prompt_token_ids": prompt_token_ids}