-
Notifications
You must be signed in to change notification settings - Fork 30.6k
fix: continuous batching in transformers serve
#40479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
64272df
da47ca6
b0b6555
1e4ae68
6e011aa
a27ce93
bc392de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,13 +25,15 @@ | |
import time | ||
from argparse import ArgumentParser, Namespace | ||
from collections.abc import Generator, Iterable | ||
from contextlib import asynccontextmanager | ||
from dataclasses import dataclass, field | ||
from io import BytesIO | ||
from threading import Thread | ||
from typing import Optional, Union | ||
|
||
from huggingface_hub import model_info | ||
from huggingface_hub.constants import HF_HUB_OFFLINE | ||
from tokenizers.decoders import DecodeStream | ||
|
||
import transformers | ||
from transformers.models.auto.modeling_auto import ( | ||
|
@@ -313,16 +315,16 @@ def __init__( | |
self._name_or_path = str(model.name_or_path) | ||
self.processor = processor | ||
self.timeout_seconds = timeout_seconds | ||
self._timer = threading.Timer(self.timeout_seconds, self._delete_model) | ||
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) | ||
self._timer.start() | ||
|
||
def reset_timer(self): | ||
"""Reset the timer for the deletion of the instances.""" | ||
self._timer.cancel() | ||
self._timer = threading.Timer(self.timeout_seconds, self._delete_model) | ||
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached) | ||
self._timer.start() | ||
|
||
def _delete_model(self): | ||
def delete_model(self): | ||
"""Delete the wrapped model and processor and clean up resources.""" | ||
if hasattr(self, "model") and self.model is not None: | ||
del self.model | ||
|
@@ -335,9 +337,12 @@ def _delete_model(self): | |
if torch.cuda.is_available(): | ||
torch.cuda.empty_cache() | ||
|
||
logger.info( | ||
f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity" | ||
) | ||
# XXX: in case we manually delete the model, like on server shutdown | ||
self._timer.cancel() | ||
|
||
def timeout_reached(self): | ||
self.delete_model() | ||
logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity") | ||
|
||
def is_deleted(self): | ||
"""Check if the instances have been deleted.""" | ||
|
@@ -353,6 +358,10 @@ class ServeArguments: | |
`transformers serve --help` | ||
""" | ||
|
||
continuous_batching: bool = field( | ||
default=False, | ||
metadata={"help": "Whether to use continuous batching for chat completions."}, | ||
) | ||
device: str = field( | ||
default="auto", | ||
metadata={ | ||
|
@@ -469,7 +478,7 @@ def __init__(self, args: ServeArguments): | |
|
||
# Store and process input arguments | ||
self.args = args | ||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged" | ||
self.use_continuous_batching = self.args.continuous_batching | ||
self.enable_cors = self.args.enable_cors | ||
|
||
if self.args.default_seed is not None: | ||
|
@@ -569,11 +578,13 @@ def validate_transcription_request(self, request: dict): | |
def build_chat_completion_chunk( | ||
self, | ||
request_id: Optional[str] = "", | ||
content: Optional[str] = None, | ||
content: Optional[int] = None, | ||
model: Optional[str] = None, | ||
role: Optional[str] = None, | ||
finish_reason: Optional[str] = None, | ||
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None, | ||
decode_stream: Optional[DecodeStream] = None, | ||
tokenizer: Optional[PreTrainedTokenizerFast] = None, | ||
) -> str: | ||
""" | ||
Builds a chunk of a streaming OpenAI Chat Completion response. | ||
|
@@ -598,6 +609,8 @@ def build_chat_completion_chunk( | |
Returns: | ||
`str`: The built chunk, a string containing a JSON string with the payload. | ||
""" | ||
if decode_stream is not None and content is not None and tokenizer is not None: | ||
content = decode_stream.step(tokenizer._tokenizer, content) | ||
chunk = ChatCompletionChunk( | ||
id=request_id, | ||
created=int(time.time()), | ||
|
@@ -635,7 +648,29 @@ def build_response_event(self, response: "BaseModel") -> str: | |
return f"data: {response.model_dump_json(exclude_none=True)}\n\n" | ||
|
||
def run(self): | ||
app = FastAPI() | ||
""" | ||
Setup and run the FastAPI server for transformers serve. | ||
|
||
Models will be loaded and unloaded automatically based on usage and a timeout. | ||
|
||
The server will expose the following endpoints: | ||
- POST /v1/chat/completions: Generates chat completions. | ||
- POST /v1/responses: Generates responses. | ||
- POST /v1/audio/transcriptions: Generates transcriptions from audio. | ||
- GET /v1/models: Lists available models for 3rd party tools. | ||
|
||
Requires FastAPI and Uvicorn to be installed. | ||
""" | ||
|
||
@asynccontextmanager | ||
async def lifespan(app: FastAPI): | ||
yield | ||
for model in self.loaded_models.values(): | ||
model.delete_model() | ||
if self.running_continuous_batching_manager is not None: | ||
self.running_continuous_batching_manager.stop(block=True, timeout=5) | ||
McPatate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
app = FastAPI(lifespan=lifespan) | ||
|
||
# Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for | ||
# security purposes, it's disabled by default | ||
|
@@ -774,10 +809,7 @@ def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, | |
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id, | ||
use_cache=False, | ||
num_blocks=1, | ||
block_size=1024, | ||
do_sample=False, | ||
max_batch_tokens=10, | ||
Comment on lines
-777
to
-780
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love that! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx @remi-or 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to help 🤗 I just want to point out that while |
||
scheduler="fifo", | ||
) | ||
|
||
|
@@ -798,34 +830,30 @@ def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None, | |
|
||
def stream_chat_completion(_inputs): | ||
try: | ||
decode_stream = DecodeStream([id.item() for id in _inputs], False) | ||
McPatate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
request_id = self.running_continuous_batching_manager.add_request( | ||
_inputs, request_id=req.get("request_id"), max_new_tokens=generation_config.max_new_tokens | ||
) | ||
|
||
queue_is_flushed = False | ||
|
||
# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit | ||
# they come from the assistant. | ||
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision) | ||
|
||
for result in self.running_continuous_batching_manager: | ||
if result.request_id != request_id: | ||
continue | ||
if req.get("request_id") is not None and not queue_is_flushed: | ||
if result.status == RequestStatus.FINISHED: | ||
continue | ||
else: | ||
queue_is_flushed = True | ||
|
||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None | ||
for result in self.running_continuous_batching_manager.request_id_iter(request_id): | ||
if result.status == RequestStatus.FINISHED: | ||
yield self.build_chat_completion_chunk( | ||
request_id, finish_reason=finish_reason, model=model_id_and_revision | ||
request_id, | ||
finish_reason="stop", | ||
model=model_id_and_revision, | ||
) | ||
break | ||
else: | ||
yield self.build_chat_completion_chunk( | ||
request_id=request_id, content=result.next_token, model=model_id_and_revision | ||
request_id=request_id, | ||
content=result.generated_tokens[-1], | ||
model=model_id_and_revision, | ||
decode_stream=decode_stream, | ||
tokenizer=tokenizer, | ||
) | ||
|
||
except Exception as e: | ||
|
Uh oh!
There was an error while loading. Please reload this page.