From 4769ee673c851b72b0839c6e6f7f4f8cb5462d7a Mon Sep 17 00:00:00 2001 From: Egor Sklyarov Date: Wed, 31 Jan 2024 12:42:11 +0100 Subject: [PATCH] Implement OpenAI to OpenAI adapter for gateway (#860) * Implement openai-to-openai adapter * Disable SSE buffering, adjust OpenAI models for vLLM responses --- gateway/src/dstack/gateway/common.py | 9 +++- .../dstack/gateway/openai/clients/__init__.py | 2 + .../dstack/gateway/openai/clients/openai.py | 29 ++++++++++++ gateway/src/dstack/gateway/openai/models.py | 20 ++++++-- gateway/src/dstack/gateway/openai/routes.py | 1 + gateway/src/dstack/gateway/openai/schemas.py | 6 +-- gateway/src/dstack/gateway/openai/store.py | 6 +++ .../_internal/core/models/configurations.py | 23 +-------- src/dstack/_internal/core/models/gateways.py | 47 ++++++++++++++++++- .../jobs/configurators/extensions/openai.py | 22 +++++---- src/dstack/api/__init__.py | 3 +- 11 files changed, 125 insertions(+), 43 deletions(-) create mode 100644 gateway/src/dstack/gateway/openai/clients/openai.py diff --git a/gateway/src/dstack/gateway/common.py b/gateway/src/dstack/gateway/common.py index 3b82fb10f..36b95da72 100644 --- a/gateway/src/dstack/gateway/common.py +++ b/gateway/src/dstack/gateway/common.py @@ -1,6 +1,7 @@ import asyncio import functools -from typing import Callable, ParamSpec, TypeVar +import json +from typing import Any, AsyncIterator, Callable, Dict, ParamSpec, TypeVar import httpx @@ -19,3 +20,9 @@ def __del__(self): asyncio.get_running_loop().create_task(self.aclose()) except Exception: pass + + async def stream_sse(self, url: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: + async with self.stream("POST", url, **kwargs) as resp: + async for line in resp.aiter_lines(): + if line.startswith("data:"): + yield json.loads(line[len("data:") :].strip("\n")) diff --git a/gateway/src/dstack/gateway/openai/clients/__init__.py b/gateway/src/dstack/gateway/openai/clients/__init__.py index 1ce1d4eaa..ed136e31b 100644 --- a/gateway/src/dstack/gateway/openai/clients/__init__.py +++ b/gateway/src/dstack/gateway/openai/clients/__init__.py @@ -7,6 +7,8 @@ ChatCompletionsResponse, ) +DEFAULT_TIMEOUT = 60 + class ChatCompletionsClient(ABC): @abstractmethod diff --git a/gateway/src/dstack/gateway/openai/clients/openai.py b/gateway/src/dstack/gateway/openai/clients/openai.py new file mode 100644 index 000000000..41c5d000b --- /dev/null +++ b/gateway/src/dstack/gateway/openai/clients/openai.py @@ -0,0 +1,29 @@ +from typing import AsyncIterator, Optional + +from dstack.gateway.common import AsyncClientWrapper +from dstack.gateway.errors import GatewayError +from dstack.gateway.openai.clients import DEFAULT_TIMEOUT, ChatCompletionsClient +from dstack.gateway.openai.schemas import ( + ChatCompletionsChunk, + ChatCompletionsRequest, + ChatCompletionsResponse, +) + + +class OpenAIChatCompletions(ChatCompletionsClient): + def __init__(self, base_url: str, host: Optional[str] = None): + self.client = AsyncClientWrapper( + base_url=base_url.rstrip("/"), + headers={} if host is None else {"Host": host}, + timeout=DEFAULT_TIMEOUT, + ) + + async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse: + resp = await self.client.post("/chat/completions", json=request.model_dump()) + if resp.status_code != 200: + raise GatewayError(resp.text) + return ChatCompletionsResponse.model_validate(resp.json()) + + async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]: + async for data in self.client.stream_sse("/chat/completions", json=request.model_dump()): + yield ChatCompletionsChunk.model_validate(data) diff --git a/gateway/src/dstack/gateway/openai/models.py b/gateway/src/dstack/gateway/openai/models.py index 6958e6e5d..b32d22966 100644 --- a/gateway/src/dstack/gateway/openai/models.py +++ b/gateway/src/dstack/gateway/openai/models.py @@ -3,22 +3,32 @@ from pydantic import BaseModel, Field -class ChatModel(BaseModel): - type: Literal["chat"] = "chat" +class BaseChatModel(BaseModel): + type: Literal["chat"] name: str + format: str + + +class TGIChatModel(BaseChatModel): format: Literal["tgi"] chat_template: str eos_token: str -ANY_MODEL = Union[ChatModel] +class OpenAIChatModel(BaseChatModel): + format: Literal["openai"] + prefix: str + + +ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")] +AnyModel = Union[ChatModel] # embeddings and etc. class OpenAIOptions(BaseModel): - model: Annotated[ANY_MODEL, Field(discriminator="type")] + model: AnyModel # TODO(egor-s): add discriminator by type class ServiceModel(BaseModel): - model: ANY_MODEL + model: AnyModel domain: str created: int diff --git a/gateway/src/dstack/gateway/openai/routes.py b/gateway/src/dstack/gateway/openai/routes.py index 4bda12cde..14338de50 100644 --- a/gateway/src/dstack/gateway/openai/routes.py +++ b/gateway/src/dstack/gateway/openai/routes.py @@ -46,6 +46,7 @@ async def post_chat_completions( return StreamingResponse( stream_chunks(client.stream(body)), media_type="text/event-stream", + headers={"X-Accel-Buffering": "no"}, ) except GatewayError as e: raise e.http() diff --git a/gateway/src/dstack/gateway/openai/schemas.py b/gateway/src/dstack/gateway/openai/schemas.py index 0ad17a10e..81786e972 100644 --- a/gateway/src/dstack/gateway/openai/schemas.py +++ b/gateway/src/dstack/gateway/openai/schemas.py @@ -37,7 +37,7 @@ class ChatCompletionsChoice(BaseModel): class ChatCompletionsChunkChoice(BaseModel): delta: object - logprobs: Optional[object] + logprobs: object = {} finish_reason: Optional[FinishReason] index: int @@ -53,7 +53,7 @@ class ChatCompletionsResponse(BaseModel): choices: List[ChatCompletionsChoice] created: int model: str - system_fingerprint: str + system_fingerprint: str = "" object: Literal["chat.completion"] = "chat.completion" usage: ChatCompletionsUsage @@ -63,7 +63,7 @@ class ChatCompletionsChunk(BaseModel): choices: List[ChatCompletionsChunkChoice] created: int model: str - system_fingerprint: str + system_fingerprint: str = "" object: Literal["chat.completion.chunk"] = "chat.completion.chunk" diff --git a/gateway/src/dstack/gateway/openai/store.py b/gateway/src/dstack/gateway/openai/store.py index 8faeefafa..c364be350 100644 --- a/gateway/src/dstack/gateway/openai/store.py +++ b/gateway/src/dstack/gateway/openai/store.py @@ -7,6 +7,7 @@ from dstack.gateway.errors import GatewayError, NotFoundError from dstack.gateway.openai.clients import ChatCompletionsClient +from dstack.gateway.openai.clients.openai import OpenAIChatCompletions from dstack.gateway.openai.clients.tgi import TGIChatCompletions from dstack.gateway.openai.models import OpenAIOptions, ServiceModel from dstack.gateway.openai.schemas import Model @@ -81,6 +82,11 @@ async def get_chat_client(self, project: str, model_name: str) -> ChatCompletion chat_template=service.model.chat_template, eos_token=service.model.eos_token, ) + elif service.model.format == "openai": + return OpenAIChatCompletions( + base_url=f"http://localhost/{service.model.prefix.lstrip('/')}", + host=service.domain, + ) else: raise GatewayError(f"Unsupported model format: {service.model.format}") diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9fdcaa63e..a95105503 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -7,6 +7,7 @@ from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.common import ForbidExtra +from dstack._internal.core.models.gateways import AnyModel from dstack._internal.core.models.repos.base import Repo from dstack._internal.core.models.repos.virtual import VirtualRepo from dstack._internal.core.models.resources import ResourcesSpec @@ -78,26 +79,6 @@ class Artifact(ForbidExtra): ] = False -class ModelInfo(ForbidExtra): - """ - Mapping of the model for the OpenAI-compatible endpoint. - - Attributes: - type (str): The type of the model, e.g. "chat" - name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. - format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference. - chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used. - eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used. - """ - - type: Annotated[Literal["chat"], Field(description="The type of the model")] - name: Annotated[str, Field(description="The name of the model")] - format: Annotated[Literal["tgi"], Field(description="The serving format")] - - chat_template: Optional[str] = None # TODO(egor-s): use discriminator and root model - eos_token: Optional[str] = None - - class BaseConfiguration(ForbidExtra): type: Literal["none"] image: Annotated[Optional[str], Field(description="The name of the Docker image to run")] @@ -206,7 +187,7 @@ class ServiceConfiguration(BaseConfiguration): Field(description="The port, that application listens to or the mapping"), ] model: Annotated[ - Optional[ModelInfo], + Optional[AnyModel], Field(description="Mapping of the model for the OpenAI-compatible endpoint"), ] = None auth: Annotated[bool, Field(description="Enable the authorization")] = True diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 9b7da4089..8caf8d36a 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -1,7 +1,8 @@ import datetime -from typing import Optional +from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Literal from dstack._internal.core.models.backends.base import BackendType @@ -15,3 +16,45 @@ class Gateway(BaseModel): default: bool created_at: datetime.datetime backend: BackendType + + +class BaseChatModel(BaseModel): + type: Annotated[Literal["chat"], Field(description="The type of the model")] + name: Annotated[str, Field(description="The name of the model")] + format: Annotated[str, Field(description="The serving format")] + + +class TGIChatModel(BaseChatModel): + """ + Mapping of the model for the OpenAI-compatible endpoint. + + Attributes: + type (str): The type of the model, e.g. "chat" + name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. + format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference. + chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used. + eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used. + """ + + format: Literal["tgi"] + chat_template: Optional[str] = None # will be set before registering the service + eos_token: Optional[str] = None + + +class OpenAIChatModel(BaseChatModel): + """ + Mapping of the model for the OpenAI-compatible endpoint. + + Attributes: + type (str): The type of the model, e.g. "chat" + name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint. + format (str): The format of the model, i.e. "openai". + prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`. + """ + + format: Literal["openai"] + prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1" + + +ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")] +AnyModel = Annotated[Union[ChatModel], Field(discriminator="type")] # embeddings and etc. diff --git a/src/dstack/_internal/server/services/jobs/configurators/extensions/openai.py b/src/dstack/_internal/server/services/jobs/configurators/extensions/openai.py index d71060378..aad991f35 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/extensions/openai.py +++ b/src/dstack/_internal/server/services/jobs/configurators/extensions/openai.py @@ -1,19 +1,23 @@ import requests from dstack._internal.core.errors import ConfigurationError -from dstack._internal.core.models.configurations import ModelInfo +from dstack._internal.core.models.gateways import AnyModel -def complete_model(model_info: ModelInfo) -> dict: +def complete_model(model_info: AnyModel) -> dict: model_info = model_info.copy(deep=True) - # TODO(egor-s): support more types and formats - if model_info.chat_template is None or model_info.eos_token is None: - tokenizer_config = get_tokenizer_config(model_info.name) - if model_info.chat_template is None: - model_info.chat_template = tokenizer_config["chat_template"] # TODO(egor-s): default - if model_info.eos_token is None: - model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default + if model_info.type == "chat" and model_info.format == "tgi": + if model_info.chat_template is None or model_info.eos_token is None: + tokenizer_config = get_tokenizer_config(model_info.name) + if model_info.chat_template is None: + model_info.chat_template = tokenizer_config[ + "chat_template" + ] # TODO(egor-s): default + if model_info.eos_token is None: + model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default + elif model_info.type == "chat" and model_info.format == "openai": + pass # nothing to do return {"model": model_info.dict()} diff --git a/src/dstack/api/__init__.py b/src/dstack/api/__init__.py index 51902f044..08c0f2722 100644 --- a/src/dstack/api/__init__.py +++ b/src/dstack/api/__init__.py @@ -1,12 +1,12 @@ # ruff: noqa: F401 from dstack._internal.core.errors import ClientError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.configurations import ModelInfo as _ModelInfo from dstack._internal.core.models.configurations import RegistryAuth from dstack._internal.core.models.configurations import ( ServiceConfiguration as _ServiceConfiguration, ) from dstack._internal.core.models.configurations import TaskConfiguration as _TaskConfiguration +from dstack._internal.core.models.gateways import OpenAIChatModel, TGIChatModel from dstack._internal.core.models.repos.local import LocalRepo from dstack._internal.core.models.repos.remote import RemoteRepo from dstack._internal.core.models.repos.virtual import VirtualRepo @@ -18,6 +18,5 @@ from dstack.api._public.resources import GPU, Disk, Resources from dstack.api._public.runs import Run, RunStatus -ModelMapping = _ModelInfo Service = _ServiceConfiguration Task = _TaskConfiguration