Skip to content

Commit

Permalink
Implement OpenAI to OpenAI adapter for gateway (#860)
Browse files Browse the repository at this point in the history
* Implement openai-to-openai adapter

* Disable SSE buffering, adjust OpenAI models for vLLM responses
  • Loading branch information
Egor-S committed Jan 31, 2024
1 parent db9a5ae commit 4769ee6
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 43 deletions.
9 changes: 8 additions & 1 deletion gateway/src/dstack/gateway/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
from typing import Callable, ParamSpec, TypeVar
import json
from typing import Any, AsyncIterator, Callable, Dict, ParamSpec, TypeVar

import httpx

Expand All @@ -19,3 +20,9 @@ def __del__(self):
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass

async def stream_sse(self, url: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
async with self.stream("POST", url, **kwargs) as resp:
async for line in resp.aiter_lines():
if line.startswith("data:"):
yield json.loads(line[len("data:") :].strip("\n"))
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/openai/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ChatCompletionsResponse,
)

DEFAULT_TIMEOUT = 60


class ChatCompletionsClient(ABC):
@abstractmethod
Expand Down
29 changes: 29 additions & 0 deletions gateway/src/dstack/gateway/openai/clients/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import AsyncIterator, Optional

from dstack.gateway.common import AsyncClientWrapper
from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.clients import DEFAULT_TIMEOUT, ChatCompletionsClient
from dstack.gateway.openai.schemas import (
ChatCompletionsChunk,
ChatCompletionsRequest,
ChatCompletionsResponse,
)


class OpenAIChatCompletions(ChatCompletionsClient):
def __init__(self, base_url: str, host: Optional[str] = None):
self.client = AsyncClientWrapper(
base_url=base_url.rstrip("/"),
headers={} if host is None else {"Host": host},
timeout=DEFAULT_TIMEOUT,
)

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
resp = await self.client.post("/chat/completions", json=request.model_dump())
if resp.status_code != 200:
raise GatewayError(resp.text)
return ChatCompletionsResponse.model_validate(resp.json())

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
async for data in self.client.stream_sse("/chat/completions", json=request.model_dump()):
yield ChatCompletionsChunk.model_validate(data)
20 changes: 15 additions & 5 deletions gateway/src/dstack/gateway/openai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
from pydantic import BaseModel, Field


class ChatModel(BaseModel):
type: Literal["chat"] = "chat"
class BaseChatModel(BaseModel):
type: Literal["chat"]
name: str
format: str


class TGIChatModel(BaseChatModel):
format: Literal["tgi"]
chat_template: str
eos_token: str


ANY_MODEL = Union[ChatModel]
class OpenAIChatModel(BaseChatModel):
format: Literal["openai"]
prefix: str


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Union[ChatModel] # embeddings and etc.


class OpenAIOptions(BaseModel):
model: Annotated[ANY_MODEL, Field(discriminator="type")]
model: AnyModel # TODO(egor-s): add discriminator by type


class ServiceModel(BaseModel):
model: ANY_MODEL
model: AnyModel
domain: str
created: int
1 change: 1 addition & 0 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def post_chat_completions(
return StreamingResponse(
stream_chunks(client.stream(body)),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)
except GatewayError as e:
raise e.http()
Expand Down
6 changes: 3 additions & 3 deletions gateway/src/dstack/gateway/openai/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ChatCompletionsChoice(BaseModel):

class ChatCompletionsChunkChoice(BaseModel):
delta: object
logprobs: Optional[object]
logprobs: object = {}
finish_reason: Optional[FinishReason]
index: int

Expand All @@ -53,7 +53,7 @@ class ChatCompletionsResponse(BaseModel):
choices: List[ChatCompletionsChoice]
created: int
model: str
system_fingerprint: str
system_fingerprint: str = ""
object: Literal["chat.completion"] = "chat.completion"
usage: ChatCompletionsUsage

Expand All @@ -63,7 +63,7 @@ class ChatCompletionsChunk(BaseModel):
choices: List[ChatCompletionsChunkChoice]
created: int
model: str
system_fingerprint: str
system_fingerprint: str = ""
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"


Expand Down
6 changes: 6 additions & 0 deletions gateway/src/dstack/gateway/openai/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack.gateway.errors import GatewayError, NotFoundError
from dstack.gateway.openai.clients import ChatCompletionsClient
from dstack.gateway.openai.clients.openai import OpenAIChatCompletions
from dstack.gateway.openai.clients.tgi import TGIChatCompletions
from dstack.gateway.openai.models import OpenAIOptions, ServiceModel
from dstack.gateway.openai.schemas import Model
Expand Down Expand Up @@ -81,6 +82,11 @@ async def get_chat_client(self, project: str, model_name: str) -> ChatCompletion
chat_template=service.model.chat_template,
eos_token=service.model.eos_token,
)
elif service.model.format == "openai":
return OpenAIChatCompletions(
base_url=f"http://localhost/{service.model.prefix.lstrip('/')}",
host=service.domain,
)
else:
raise GatewayError(f"Unsupported model format: {service.model.format}")

Expand Down
23 changes: 2 additions & 21 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import ForbidExtra
from dstack._internal.core.models.gateways import AnyModel
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import ResourcesSpec
Expand Down Expand Up @@ -78,26 +79,6 @@ class Artifact(ForbidExtra):
] = False


class ModelInfo(ForbidExtra):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used.
"""

type: Annotated[Literal["chat"], Field(description="The type of the model")]
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[Literal["tgi"], Field(description="The serving format")]

chat_template: Optional[str] = None # TODO(egor-s): use discriminator and root model
eos_token: Optional[str] = None


class BaseConfiguration(ForbidExtra):
type: Literal["none"]
image: Annotated[Optional[str], Field(description="The name of the Docker image to run")]
Expand Down Expand Up @@ -206,7 +187,7 @@ class ServiceConfiguration(BaseConfiguration):
Field(description="The port, that application listens to or the mapping"),
]
model: Annotated[
Optional[ModelInfo],
Optional[AnyModel],
Field(description="Mapping of the model for the OpenAI-compatible endpoint"),
] = None
auth: Annotated[bool, Field(description="Enable the authorization")] = True
Expand Down
47 changes: 45 additions & 2 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
from typing import Optional
from typing import Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.backends.base import BackendType

Expand All @@ -15,3 +16,45 @@ class Gateway(BaseModel):
default: bool
created_at: datetime.datetime
backend: BackendType


class BaseChatModel(BaseModel):
type: Annotated[Literal["chat"], Field(description="The type of the model")]
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[str, Field(description="The serving format")]


class TGIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used.
"""

format: Literal["tgi"]
chat_template: Optional[str] = None # will be set before registering the service
eos_token: Optional[str] = None


class OpenAIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, i.e. "openai".
prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`.
"""

format: Literal["openai"]
prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1"


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Annotated[Union[ChatModel], Field(discriminator="type")] # embeddings and etc.
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import requests

from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.configurations import ModelInfo
from dstack._internal.core.models.gateways import AnyModel


def complete_model(model_info: ModelInfo) -> dict:
def complete_model(model_info: AnyModel) -> dict:
model_info = model_info.copy(deep=True)
# TODO(egor-s): support more types and formats

if model_info.chat_template is None or model_info.eos_token is None:
tokenizer_config = get_tokenizer_config(model_info.name)
if model_info.chat_template is None:
model_info.chat_template = tokenizer_config["chat_template"] # TODO(egor-s): default
if model_info.eos_token is None:
model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default
if model_info.type == "chat" and model_info.format == "tgi":
if model_info.chat_template is None or model_info.eos_token is None:
tokenizer_config = get_tokenizer_config(model_info.name)
if model_info.chat_template is None:
model_info.chat_template = tokenizer_config[
"chat_template"
] # TODO(egor-s): default
if model_info.eos_token is None:
model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default
elif model_info.type == "chat" and model_info.format == "openai":
pass # nothing to do

return {"model": model_info.dict()}

Expand Down
3 changes: 1 addition & 2 deletions src/dstack/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# ruff: noqa: F401
from dstack._internal.core.errors import ClientError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import ModelInfo as _ModelInfo
from dstack._internal.core.models.configurations import RegistryAuth
from dstack._internal.core.models.configurations import (
ServiceConfiguration as _ServiceConfiguration,
)
from dstack._internal.core.models.configurations import TaskConfiguration as _TaskConfiguration
from dstack._internal.core.models.gateways import OpenAIChatModel, TGIChatModel
from dstack._internal.core.models.repos.local import LocalRepo
from dstack._internal.core.models.repos.remote import RemoteRepo
from dstack._internal.core.models.repos.virtual import VirtualRepo
Expand All @@ -18,6 +18,5 @@
from dstack.api._public.resources import GPU, Disk, Resources
from dstack.api._public.runs import Run, RunStatus

ModelMapping = _ModelInfo
Service = _ServiceConfiguration
Task = _TaskConfiguration

0 comments on commit 4769ee6

Please sign in to comment.