Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpenAI to OpenAI adapter for gateway #860

Merged
merged 2 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gateway/src/dstack/gateway/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
from typing import Callable, ParamSpec, TypeVar
import json
from typing import Any, AsyncIterator, Callable, Dict, ParamSpec, TypeVar

import httpx

Expand All @@ -19,3 +20,9 @@ def __del__(self):
asyncio.get_running_loop().create_task(self.aclose())
except Exception:
pass

async def stream_sse(self, url: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
async with self.stream("POST", url, **kwargs) as resp:
async for line in resp.aiter_lines():
if line.startswith("data:"):
yield json.loads(line[len("data:") :].strip("\n"))
2 changes: 2 additions & 0 deletions gateway/src/dstack/gateway/openai/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ChatCompletionsResponse,
)

DEFAULT_TIMEOUT = 60


class ChatCompletionsClient(ABC):
@abstractmethod
Expand Down
29 changes: 29 additions & 0 deletions gateway/src/dstack/gateway/openai/clients/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import AsyncIterator, Optional

from dstack.gateway.common import AsyncClientWrapper
from dstack.gateway.errors import GatewayError
from dstack.gateway.openai.clients import DEFAULT_TIMEOUT, ChatCompletionsClient
from dstack.gateway.openai.schemas import (
ChatCompletionsChunk,
ChatCompletionsRequest,
ChatCompletionsResponse,
)


class OpenAIChatCompletions(ChatCompletionsClient):
def __init__(self, base_url: str, host: Optional[str] = None):
self.client = AsyncClientWrapper(
base_url=base_url.rstrip("/"),
headers={} if host is None else {"Host": host},
timeout=DEFAULT_TIMEOUT,
)

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
resp = await self.client.post("/chat/completions", json=request.model_dump())
if resp.status_code != 200:
raise GatewayError(resp.text)
return ChatCompletionsResponse.model_validate(resp.json())

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
async for data in self.client.stream_sse("/chat/completions", json=request.model_dump()):
yield ChatCompletionsChunk.model_validate(data)
20 changes: 15 additions & 5 deletions gateway/src/dstack/gateway/openai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,32 @@
from pydantic import BaseModel, Field


class ChatModel(BaseModel):
type: Literal["chat"] = "chat"
class BaseChatModel(BaseModel):
type: Literal["chat"]
name: str
format: str


class TGIChatModel(BaseChatModel):
format: Literal["tgi"]
chat_template: str
eos_token: str


ANY_MODEL = Union[ChatModel]
class OpenAIChatModel(BaseChatModel):
format: Literal["openai"]
prefix: str


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Union[ChatModel] # embeddings and etc.


class OpenAIOptions(BaseModel):
model: Annotated[ANY_MODEL, Field(discriminator="type")]
model: AnyModel # TODO(egor-s): add discriminator by type


class ServiceModel(BaseModel):
model: ANY_MODEL
model: AnyModel
domain: str
created: int
1 change: 1 addition & 0 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def post_chat_completions(
return StreamingResponse(
stream_chunks(client.stream(body)),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)
except GatewayError as e:
raise e.http()
Expand Down
6 changes: 3 additions & 3 deletions gateway/src/dstack/gateway/openai/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ChatCompletionsChoice(BaseModel):

class ChatCompletionsChunkChoice(BaseModel):
delta: object
logprobs: Optional[object]
logprobs: object = {}
finish_reason: Optional[FinishReason]
index: int

Expand All @@ -53,7 +53,7 @@ class ChatCompletionsResponse(BaseModel):
choices: List[ChatCompletionsChoice]
created: int
model: str
system_fingerprint: str
system_fingerprint: str = ""
object: Literal["chat.completion"] = "chat.completion"
usage: ChatCompletionsUsage

Expand All @@ -63,7 +63,7 @@ class ChatCompletionsChunk(BaseModel):
choices: List[ChatCompletionsChunkChoice]
created: int
model: str
system_fingerprint: str
system_fingerprint: str = ""
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"


Expand Down
6 changes: 6 additions & 0 deletions gateway/src/dstack/gateway/openai/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack.gateway.errors import GatewayError, NotFoundError
from dstack.gateway.openai.clients import ChatCompletionsClient
from dstack.gateway.openai.clients.openai import OpenAIChatCompletions
from dstack.gateway.openai.clients.tgi import TGIChatCompletions
from dstack.gateway.openai.models import OpenAIOptions, ServiceModel
from dstack.gateway.openai.schemas import Model
Expand Down Expand Up @@ -81,6 +82,11 @@ async def get_chat_client(self, project: str, model_name: str) -> ChatCompletion
chat_template=service.model.chat_template,
eos_token=service.model.eos_token,
)
elif service.model.format == "openai":
return OpenAIChatCompletions(
base_url=f"http://localhost/{service.model.prefix.lstrip('/')}",
host=service.domain,
)
else:
raise GatewayError(f"Unsupported model format: {service.model.format}")

Expand Down
23 changes: 2 additions & 21 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import ForbidExtra
from dstack._internal.core.models.gateways import AnyModel
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import ResourcesSpec
Expand Down Expand Up @@ -78,26 +79,6 @@ class Artifact(ForbidExtra):
] = False


class ModelInfo(ForbidExtra):
"""
Mapping of the model for the OpenAI-compatible endpoint.

Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used.
"""

type: Annotated[Literal["chat"], Field(description="The type of the model")]
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[Literal["tgi"], Field(description="The serving format")]

chat_template: Optional[str] = None # TODO(egor-s): use discriminator and root model
eos_token: Optional[str] = None


class BaseConfiguration(ForbidExtra):
type: Literal["none"]
image: Annotated[Optional[str], Field(description="The name of the Docker image to run")]
Expand Down Expand Up @@ -206,7 +187,7 @@ class ServiceConfiguration(BaseConfiguration):
Field(description="The port, that application listens to or the mapping"),
]
model: Annotated[
Optional[ModelInfo],
Optional[AnyModel],
Field(description="Mapping of the model for the OpenAI-compatible endpoint"),
] = None
auth: Annotated[bool, Field(description="Enable the authorization")] = True
Expand Down
47 changes: 45 additions & 2 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
from typing import Optional
from typing import Optional, Union

from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.backends.base import BackendType

Expand All @@ -15,3 +16,45 @@ class Gateway(BaseModel):
default: bool
created_at: datetime.datetime
backend: BackendType


class BaseChatModel(BaseModel):
type: Annotated[Literal["chat"], Field(description="The type of the model")]
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[str, Field(description="The serving format")]


class TGIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.

Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default custom end of sentence token from the HuggingFace Hub configuration will be used.
"""

format: Literal["tgi"]
chat_template: Optional[str] = None # will be set before registering the service
eos_token: Optional[str] = None


class OpenAIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.

Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, i.e. "openai".
prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`.
"""

format: Literal["openai"]
prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1"


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Annotated[Union[ChatModel], Field(discriminator="type")] # embeddings and etc.
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import requests

from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.configurations import ModelInfo
from dstack._internal.core.models.gateways import AnyModel


def complete_model(model_info: ModelInfo) -> dict:
def complete_model(model_info: AnyModel) -> dict:
model_info = model_info.copy(deep=True)
# TODO(egor-s): support more types and formats

if model_info.chat_template is None or model_info.eos_token is None:
tokenizer_config = get_tokenizer_config(model_info.name)
if model_info.chat_template is None:
model_info.chat_template = tokenizer_config["chat_template"] # TODO(egor-s): default
if model_info.eos_token is None:
model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default
if model_info.type == "chat" and model_info.format == "tgi":
if model_info.chat_template is None or model_info.eos_token is None:
tokenizer_config = get_tokenizer_config(model_info.name)
if model_info.chat_template is None:
model_info.chat_template = tokenizer_config[
"chat_template"
] # TODO(egor-s): default
if model_info.eos_token is None:
model_info.eos_token = tokenizer_config["eos_token"] # TODO(egor-s): default
elif model_info.type == "chat" and model_info.format == "openai":
pass # nothing to do

return {"model": model_info.dict()}

Expand Down
3 changes: 1 addition & 2 deletions src/dstack/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# ruff: noqa: F401
from dstack._internal.core.errors import ClientError
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.configurations import ModelInfo as _ModelInfo
from dstack._internal.core.models.configurations import RegistryAuth
from dstack._internal.core.models.configurations import (
ServiceConfiguration as _ServiceConfiguration,
)
from dstack._internal.core.models.configurations import TaskConfiguration as _TaskConfiguration
from dstack._internal.core.models.gateways import OpenAIChatModel, TGIChatModel
from dstack._internal.core.models.repos.local import LocalRepo
from dstack._internal.core.models.repos.remote import RemoteRepo
from dstack._internal.core.models.repos.virtual import VirtualRepo
Expand All @@ -18,6 +18,5 @@
from dstack.api._public.resources import GPU, Disk, Resources
from dstack.api._public.runs import Run, RunStatus

ModelMapping = _ModelInfo
Service = _ServiceConfiguration
Task = _TaskConfiguration
Loading