Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/docs/guides/services.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ model:
!!! info "Experimental feature"
OpenAI interface is an experimental feature.
Only TGI chat models are supported at the moment.
Streaming is not supported yet.

Run the configuration. Text Generation Inference requires a GPU with a compute capability above 8.0: e.g., L4 or A100.

Expand Down
41 changes: 39 additions & 2 deletions gateway/src/dstack/gateway/openai/clients/tgi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import json
import uuid
from typing import AsyncIterator, Dict, List, Optional

Expand All @@ -11,6 +12,7 @@
from dstack.gateway.openai.schemas import (
ChatCompletionsChoice,
ChatCompletionsChunk,
ChatCompletionsChunkChoice,
ChatCompletionsRequest,
ChatCompletionsResponse,
ChatCompletionsUsage,
Expand Down Expand Up @@ -83,7 +85,42 @@ async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResp
)

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
raise NotImplementedError()
completion_id = uuid.uuid4().hex
created = int(datetime.datetime.utcnow().timestamp())

payload = self.get_payload(request)
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
async for line in resp.aiter_lines():
if line.startswith("data:"):
data = json.loads(line[len("data:") :].strip("\n"))
if "error" in data:
raise GatewayError(data["error"])
chunk = ChatCompletionsChunk(
id=completion_id,
choices=[],
created=created,
model=request.model,
system_fingerprint="",
)
if data["details"] is not None:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={},
logprobs=None,
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
index=0,
)
]
else:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={"content": data["token"]["text"], "role": "assistant"},
logprobs=None,
finish_reason=None,
index=0,
)
]
yield chunk

def get_payload(self, request: ChatCompletionsRequest) -> Dict:
inputs = self.chat_template.render(
Expand All @@ -110,7 +147,7 @@ def get_payload(self, request: ChatCompletionsRequest) -> Dict:
"best_of": request.n,
# "watermark": False,
"details": True, # to get best_of_sequences
"decoder_input_details": True,
"decoder_input_details": not request.stream,
}
if request.top_p < 1.0:
parameters["top_p"] = request.top_p
Expand Down
6 changes: 3 additions & 3 deletions gateway/src/dstack/gateway/openai/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ async def post_chat_completions(
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(await client.stream(body)),
stream_chunks(client.stream(body)),
media_type="text/event-stream",
)


async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
async for chunk in chunks:
yield f"data: {chunk.model_dump_json()}".encode()
yield "data: [DONE]\n".encode()
yield f"data:{chunk.model_dump_json()}\n\n".encode()
yield "data: [DONE]\n\n".encode()
9 changes: 8 additions & 1 deletion gateway/src/dstack/gateway/openai/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class ChatCompletionsChoice(BaseModel):
message: ChatMessage


class ChatCompletionsChunkChoice(BaseModel):
delta: object
logprobs: Optional[object]
finish_reason: Optional[FinishReason]
index: int


class ChatCompletionsUsage(BaseModel):
completion_tokens: int
prompt_tokens: int
Expand All @@ -53,7 +60,7 @@ class ChatCompletionsResponse(BaseModel):

class ChatCompletionsChunk(BaseModel):
id: str
choices: List[ChatCompletionsChoice]
choices: List[ChatCompletionsChunkChoice]
created: int
model: str
system_fingerprint: str
Expand Down
1 change: 1 addition & 0 deletions gateway/src/dstack/gateway/systemd/resources/update.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ else
version="blue"
fi

"$root/$version/bin/pip" uninstall -y dstack-gateway
"$root/$version/bin/pip" install "$1"
sudo "$root/$version/bin/python" -m dstack.gateway.systemd install
echo "$version" > "$root/version"
Expand Down