Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 11 additions & 186 deletions llama_stack/providers/remote/inference/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import AsyncGenerator
from typing import Any

import httpx
Expand Down Expand Up @@ -38,13 +38,6 @@
LogProbConfig,
Message,
ModelStore,
OpenAIChatCompletion,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
TextTruncation,
Expand All @@ -71,11 +64,11 @@
convert_message_to_openai_dict,
convert_tool_call,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
content_has_media,
Expand Down Expand Up @@ -288,15 +281,14 @@ async def _process_vllm_chat_completion_stream_response(
yield c


class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None

def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.config = config
self.client = None

async def initialize(self) -> None:
if not self.config.url:
Expand All @@ -308,8 +300,6 @@ async def should_refresh_models(self) -> bool:
return self.config.refresh_models

async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
Expand Down Expand Up @@ -340,8 +330,7 @@ async def health(self) -> HealthResponse:
HealthResponse: A dictionary containing the health status.
"""
try:
client = self._create_client() if self.client is None else self.client
_ = [m async for m in client.models.list()] # Ensure the client is initialized
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
Expand All @@ -351,19 +340,14 @@ async def _get_model(self, model_id: str) -> Model:
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)

def _lazy_initialize_client(self):
if self.client is not None:
return
def get_api_key(self):
return self.config.api_token

log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def get_base_url(self):
return self.config.url

def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
def get_extra_client_params(self):
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}

async def completion(
self,
Expand All @@ -374,7 +358,6 @@ async def completion(
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
Expand Down Expand Up @@ -406,7 +389,6 @@ async def chat_completion(
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
Expand Down Expand Up @@ -479,16 +461,12 @@ async def _stream_completion(
yield chunk

async def register_model(self, model: Model) -> Model:
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
try:
res = await client.models.list()
res = await self.client.models.list()
except APIConnectionError as e:
raise ValueError(
f"Failed to connect to vLLM at {self.config.url}. Please check if vLLM is running and accessible at that URL."
Expand Down Expand Up @@ -543,8 +521,6 @@ async def embeddings(
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)

kwargs = {}
Expand All @@ -560,154 +536,3 @@ async def embeddings(

embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model_obj = await self._get_model(model)
assert model_obj.model_type == ModelType.embedding

# Convert input to list if it's a string
input_list = [input] if isinstance(input, str) else input

# Call vLLM embeddings endpoint with encoding_format
response = await self.client.embeddings.create(
model=model_obj.provider_resource_id,
input=input_list,
dimensions=dimensions,
encoding_format=encoding_format,
)

# Convert response to OpenAI format
data = [
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
for i, embedding_data in enumerate(response.data)
]

# Not returning actual token usage since vLLM doesn't provide it
usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1)

return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
usage=usage,
)

async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)

extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice

params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
extra_body=extra_body,
)
return await self.client.completions.create(**params) # type: ignore

async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self.client.chat.completions.create(**params) # type: ignore
28 changes: 23 additions & 5 deletions llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def get_base_url(self) -> str:
"""
pass

def get_extra_client_params(self) -> dict[str, Any]:
"""
Get any extra parameters to pass to the AsyncOpenAI client.

Child classes can override this method to provide additional parameters
such as timeout settings, proxies, etc.

:return: A dictionary of extra parameters
"""
return {}

@property
def client(self) -> AsyncOpenAI:
"""
Expand All @@ -78,6 +89,7 @@ def client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
**self.get_extra_client_params(),
)

async def _get_provider_model_id(self, model: str) -> str:
Expand Down Expand Up @@ -124,10 +136,15 @@ async def openai_completion(
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# Handle parameters that are not supported by OpenAI API, but may be by the provider
# prompt_logprobs is supported by vLLM
# guided_choice is supported by vLLM
# TODO: test coverage
extra_body: dict[str, Any] = {}
if prompt_logprobs is not None and prompt_logprobs >= 0:
extra_body["prompt_logprobs"] = prompt_logprobs
if guided_choice:
extra_body["guided_choice"] = guided_choice

# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
Expand All @@ -150,7 +167,8 @@ async def openai_completion(
top_p=top_p,
user=user,
suffix=suffix,
)
),
extra_body=extra_body,
)

async def openai_chat_completion(
Expand Down
Loading
Loading