Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions llama_stack/providers/remote/inference/tgi/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import AsyncGenerator

from huggingface_hub import AsyncInferenceClient, HfApi
from pydantic import SecretStr

from llama_stack.apis.common.content_types import (
InterleavedContent,
Expand All @@ -33,6 +34,7 @@
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.apis.models.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import ModelsProtocolPrivate
Expand All @@ -41,16 +43,15 @@
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
Expand All @@ -73,26 +74,49 @@ def build_hf_repo_model_entries():


class _HfAdapter(
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
ModelsProtocolPrivate,
):
client: AsyncInferenceClient
url: str
api_key: SecretStr

hf_client: AsyncInferenceClient
max_tokens: int
model_id: str

overwrite_completion_id = True # TGI always returns id=""

def __init__(self) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
self.huggingface_repo_to_llama_model_id = {
model.huggingface_repo: model.descriptor() for model in all_registered_models() if model.huggingface_repo
}

def get_api_key(self):
return self.api_key.get_secret_value()

def get_base_url(self):
return self.url

async def shutdown(self) -> None:
pass

async def list_models(self) -> list[Model] | None:
models = []
async for model in self.client.models.list():
models.append(
Model(
identifier=model.id,
provider_resource_id=model.id,
provider_id=self.__provider_id__,
metadata={},
model_type=ModelType.llm,
)
)
return models

async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
if model.provider_resource_id != self.model_id:
raise ValueError(
f"Model {model.provider_resource_id} does not match the model {self.model_id} served by TGI."
Expand Down Expand Up @@ -176,7 +200,7 @@ async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator
params = await self._get_params_for_completion(request)

async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
Expand All @@ -194,7 +218,7 @@ async def _generate_and_convert_to_openai_compat():

async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)

choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
Expand Down Expand Up @@ -241,7 +265,7 @@ async def chat_completion(

async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = await self.client.text_generation(**params)
r = await self.hf_client.text_generation(**params)

choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
Expand All @@ -256,7 +280,7 @@ async def _stream_chat_completion(self, request: ChatCompletionRequest) -> Async
params = await self._get_params(request)

async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
s = await self.hf_client.text_generation(**params)
async for chunk in s:
token_result = chunk.token

Expand Down Expand Up @@ -308,18 +332,21 @@ async def initialize(self, config: TGIImplConfig) -> None:
if not config.url:
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
self.url = f"{config.url.rstrip('/')}/v1"
self.api_key = SecretStr("NO_KEY")


class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.client.get_endpoint_info()
self.hf_client = AsyncInferenceClient(model=config.huggingface_repo, token=config.api_token.get_secret_value())
endpoint_info = await self.hf_client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"]
# TODO: how do we set url for this?


class InferenceEndpointAdapter(_HfAdapter):
Expand All @@ -331,6 +358,7 @@ async def initialize(self, config: InferenceEndpointImplConfig) -> None:
endpoint.wait(timeout=60)

# Initialize the adapter
self.client = endpoint.async_client
self.hf_client = endpoint.async_client
self.model_id = endpoint.repository
self.max_tokens = int(endpoint.raw["model"]["image"]["custom"]["env"]["MAX_TOTAL_TOKENS"])
# TODO: how do we set url for this?
33 changes: 30 additions & 3 deletions llama_stack/providers/utils/inference/openai_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
Expand Down Expand Up @@ -43,6 +44,12 @@ class OpenAIMixin(ABC):
The model_store is set in routing_tables/common.py during provider initialization.
"""

# Allow subclasses to control whether to overwrite the 'id' field in OpenAI responses
# is overwritten with a client-side generated id.
#
# This is useful for providers that do not return a unique id in the response.
overwrite_completion_id: bool = False

@abstractmethod
def get_api_key(self) -> str:
"""
Expand Down Expand Up @@ -98,6 +105,23 @@ async def _get_provider_model_id(self, model: str) -> str:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id

async def _maybe_overwrite_id(self, resp: Any, stream: bool | None) -> Any:
if not self.overwrite_completion_id:
return resp

new_id = f"cltsd-{uuid.uuid4()}"
if stream:

async def _gen():
async for chunk in resp:
chunk.id = new_id
yield chunk

return _gen()
else:
resp.id = new_id
return resp

async def openai_completion(
self,
model: str,
Expand Down Expand Up @@ -130,7 +154,7 @@ async def openai_completion(
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")

# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
resp = await self.client.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
Expand All @@ -153,6 +177,8 @@ async def openai_completion(
)
)

return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]

async def openai_chat_completion(
self,
model: str,
Expand Down Expand Up @@ -182,8 +208,7 @@ async def openai_chat_completion(
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
resp = await self.client.chat.completions.create(
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
Expand Down Expand Up @@ -211,6 +236,8 @@ async def openai_chat_completion(
)
)

return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]

async def openai_embeddings(
self,
model: str,
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/inference/test_openai_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
"remote::nvidia",
"remote::runpod",
"remote::sambanova",
"remote::tgi",
"remote::vertexai",
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
Expand Down Expand Up @@ -96,6 +95,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
"remote::vertexai",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
"remote::tgi", # TGI ignores n param silently
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")

Expand All @@ -110,7 +110,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
"remote::cerebras",
"remote::databricks",
"remote::runpod",
"remote::tgi",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")

Expand Down
56 changes: 56 additions & 0 deletions tests/integration/recordings/responses/27463384d1a3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"request": {
"method": "POST",
"url": "http://localhost:8080/v1/v1/chat/completions",
"headers": {},
"body": {
"model": "Qwen/Qwen3-0.6B",
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"stream": false
},
"endpoint": "/v1/chat/completions",
"model": "Qwen/Qwen3-0.6B"
},
"response": {
"body": {
"__type__": "openai.types.chat.chat_completion.ChatCompletion",
"__data__": {
"id": "",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"message": {
"content": "<think>\nOkay, the user just said \"Hello, world!\" so I need to respond in a friendly way. My prompt says to respond in the same style, so I should start with \"Hello, world!\" but maybe add some helpful information. Let me think. Since the user is probably testing or just sharing, a simple \"Hello, world!\" with a question would be best for user interaction. I'll make sure to keep it positive and open-ended.\n</think>\n\nHello, world! \ud83d\ude0a What do you need today?",
"refusal": null,
"role": "assistant",
"annotations": null,
"audio": null,
"function_call": null,
"tool_calls": null
}
}
],
"created": 1757550395,
"model": "Qwen/Qwen3-0.6B",
"object": "chat.completion",
"service_tier": null,
"system_fingerprint": "3.3.5-dev0-sha-1b90c50",
"usage": {
"completion_tokens": 108,
"prompt_tokens": 12,
"total_tokens": 120,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}
Loading
Loading