Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 54 additions & 26 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import time
from argparse import ArgumentParser, Namespace
from collections.abc import Generator, Iterable
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from io import BytesIO
from threading import Thread
from typing import Optional, Union

from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from tokenizers.decoders import DecodeStream

import transformers
from transformers.models.auto.modeling_auto import (
Expand Down Expand Up @@ -313,16 +315,16 @@ def __init__(
self._name_or_path = str(model.name_or_path)
self.processor = processor
self.timeout_seconds = timeout_seconds
self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
self._timer.start()

def reset_timer(self):
"""Reset the timer for the deletion of the instances."""
self._timer.cancel()
self._timer = threading.Timer(self.timeout_seconds, self._delete_model)
self._timer = threading.Timer(self.timeout_seconds, self.timeout_reached)
self._timer.start()

def _delete_model(self):
def delete_model(self):
"""Delete the wrapped model and processor and clean up resources."""
if hasattr(self, "model") and self.model is not None:
del self.model
Expand All @@ -335,9 +337,12 @@ def _delete_model(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()

logger.info(
f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity"
)
# XXX: in case we manually delete the model, like on server shutdown
self._timer.cancel()

def timeout_reached(self):
self.delete_model()
logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds} seconds of inactivity")

def is_deleted(self):
"""Check if the instances have been deleted."""
Expand All @@ -353,6 +358,10 @@ class ServeArguments:
`transformers serve --help`
"""

continuous_batching: bool = field(
default=False,
metadata={"help": "Whether to use continuous batching for chat completions."},
)
device: str = field(
default="auto",
metadata={
Expand Down Expand Up @@ -469,7 +478,7 @@ def __init__(self, args: ServeArguments):

# Store and process input arguments
self.args = args
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
self.use_continuous_batching = self.args.continuous_batching
self.enable_cors = self.args.enable_cors

if self.args.default_seed is not None:
Expand Down Expand Up @@ -569,11 +578,13 @@ def validate_transcription_request(self, request: dict):
def build_chat_completion_chunk(
self,
request_id: Optional[str] = "",
content: Optional[str] = None,
content: Optional[int] = None,
model: Optional[str] = None,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None,
decode_stream: Optional[DecodeStream] = None,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
) -> str:
"""
Builds a chunk of a streaming OpenAI Chat Completion response.
Expand All @@ -598,6 +609,8 @@ def build_chat_completion_chunk(
Returns:
`str`: The built chunk, a string containing a JSON string with the payload.
"""
if decode_stream is not None and content is not None and tokenizer is not None:
content = decode_stream.step(tokenizer._tokenizer, content)
chunk = ChatCompletionChunk(
id=request_id,
created=int(time.time()),
Expand Down Expand Up @@ -635,7 +648,29 @@ def build_response_event(self, response: "BaseModel") -> str:
return f"data: {response.model_dump_json(exclude_none=True)}\n\n"

def run(self):
app = FastAPI()
"""
Setup and run the FastAPI server for transformers serve.

Models will be loaded and unloaded automatically based on usage and a timeout.

The server will expose the following endpoints:
- POST /v1/chat/completions: Generates chat completions.
- POST /v1/responses: Generates responses.
- POST /v1/audio/transcriptions: Generates transcriptions from audio.
- GET /v1/models: Lists available models for 3rd party tools.

Requires FastAPI and Uvicorn to be installed.
"""

@asynccontextmanager
async def lifespan(app: FastAPI):
yield
for model in self.loaded_models.values():
model.delete_model()
if self.running_continuous_batching_manager is not None:
self.running_continuous_batching_manager.stop(block=True, timeout=5)

app = FastAPI(lifespan=lifespan)

# Some apps that make requests from external domains (e.g. Cursor) require CORS to be enabled. However, for
# security purposes, it's disabled by default
Expand Down Expand Up @@ -774,10 +809,7 @@ def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=1,
block_size=1024,
do_sample=False,
max_batch_tokens=10,
Comment on lines -777 to -780
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love that!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx @remi-or 😄

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to help 🤗 I just want to point out that while num_blocks and max_batch_tokens can be inferred from available GPU memory, if block_size is not given it simply defaults to 32, which is quite far from the previous 1024 here. Might not be important though!

scheduler="fifo",
)

Expand All @@ -798,34 +830,30 @@ def continuous_batching_chat_completion(self, req: dict) -> Generator[str, None,

def stream_chat_completion(_inputs):
try:
decode_stream = DecodeStream([id.item() for id in _inputs], False)
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.get("request_id"), max_new_tokens=generation_config.max_new_tokens
)

queue_is_flushed = False

# Emit the assistant role to start the stream. Other chunks won't have a role, as it is implicit
# they come from the assistant.
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)

for result in self.running_continuous_batching_manager:
if result.request_id != request_id:
continue
if req.get("request_id") is not None and not queue_is_flushed:
if result.status == RequestStatus.FINISHED:
continue
else:
queue_is_flushed = True

finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
for result in self.running_continuous_batching_manager.request_id_iter(request_id):
if result.status == RequestStatus.FINISHED:
yield self.build_chat_completion_chunk(
request_id, finish_reason=finish_reason, model=model_id_and_revision
request_id,
finish_reason="stop",
model=model_id_and_revision,
)
break
else:
yield self.build_chat_completion_chunk(
request_id=request_id, content=result.next_token, model=model_id_and_revision
request_id=request_id,
content=result.generated_tokens[-1],
model=model_id_and_revision,
decode_stream=decode_stream,
tokenizer=tokenizer,
)

except Exception as e:
Expand Down
5 changes: 0 additions & 5 deletions src/transformers/generation/continuous_batching/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class GenerationOutput:
error (Optional[str]): Any error message associated with the request. When None, the request was successful.
status (RequestStatus): The status of the request.
created_time (float): The time the request was created.
next_token (Optional[int]): The next token to be generated.
"""

request_id: str
Expand All @@ -85,7 +84,6 @@ class GenerationOutput:
error: Optional[str] = None
status: RequestStatus = RequestStatus.PENDING
created_time: float = field(default_factory=time.time)
next_token: Optional[int] = field(default_factory=int)


@dataclass
Expand All @@ -106,7 +104,6 @@ class RequestState:
eos_token_id (int): The ID of the end-of-sequence token.
created_time (float): The time the request was created.
error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
next_token (Optional[str]): The next token to be generated.
"""

# Required fields
Expand All @@ -122,7 +119,6 @@ class RequestState:
eos_token_id: int = -1 # ID of the end-of-sequence token
created_time: float = field(default_factory=time.time) # Time the request was created
error: Optional[str] = None # Error message if the request failed
next_token: Optional[str] = None # Next token to be generated
lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)

@property
Expand Down Expand Up @@ -206,5 +202,4 @@ def to_generation_output(self):
generated_tokens=self.static_outputs,
logprobs=[],
error=self.error,
next_token=self.next_token,
)
47 changes: 25 additions & 22 deletions src/transformers/generation/continuous_batching/continuous_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
from typing import Optional

import torch
from tokenizers.decoders import DecodeStream
from torch import nn
from tqdm import tqdm

from ...configuration_utils import PretrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils.logging import logging
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
from .cache import PagedAttentionCache
Expand Down Expand Up @@ -102,9 +100,6 @@ def __init__(

self.setup_static_tensors()

self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)

def return_attention_mask(self) -> bool:
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call

Expand Down Expand Up @@ -227,18 +222,18 @@ def _handle_request_error(self, error, state: RequestState):
self.output_queue.put(state.to_generation_output())

@traced
def prepare_next_batch(self):
def prepare_next_batch(self) -> bool:
"""Prepare tensors and metadata for the next model forward pass."""
# Get new requests from the queue
self._get_new_requests()
if not self.scheduler.has_pending_requests():
return None
return False

self.metrics.record_queue_metrics(len(self.scheduler.active_requests), len(self.scheduler.waiting_requests))

self.requests_in_batch = self.scheduler.schedule_batch(self.max_batch_tokens)
if not self.requests_in_batch:
return None
return False

# Get the request objects for this batch
self.reset_static_tensors()
Expand Down Expand Up @@ -291,6 +286,8 @@ def prepare_next_batch(self):

self.metrics.record_kv_cache_memory_metrics(self.cache)

return True

@traced
def _build_tensors(
self,
Expand Down Expand Up @@ -357,7 +354,6 @@ def _sync(self):
def _maybe_send_output(self, state: RequestState, token: int):
"""Send output to the queue based on streaming mode and request state."""
if self.streaming:
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
self.output_queue.put(state.to_generation_output())
elif state.status == RequestStatus.FINISHED:
self.output_queue.put(state.to_generation_output())
Expand Down Expand Up @@ -463,7 +459,6 @@ def __init__(
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
self.decode_stream = DecodeStream(skip_special_tokens=True)
self.slice_inputs = slice_inputs

@traced
Expand Down Expand Up @@ -534,6 +529,7 @@ def add_request(

max_new_tokens = self.generation_config.max_new_tokens if max_new_tokens is None else max_new_tokens

# NOTE: do we want to handle a case when the user wants token ids returned instead of decoded text?
state = RequestState(
request_id=request_id,
prompt_ids=list(input_ids),
Expand All @@ -548,35 +544,41 @@ def add_request(
return request_id

def add_requests(self, inputs: list[list[int]], **kwargs):
for i, input_ids in enumerate(inputs):
# Assign a predictable request ID for ordering results later
req_id = f"batch_req_{i}"
self.add_request(input_ids, request_id=req_id, **kwargs)
for input_ids in inputs:
self.add_request(input_ids, **kwargs)

def get_result(self, timeout=None) -> Optional[GenerationOutput]:
def get_result(self, request_id=None, timeout=None) -> Optional[GenerationOutput]:
"""Retrieve one result from the output queue.

Args:
timeout: Maximum time to wait for a result

Returns:
Optional[Dict]: The result data or None if timeout
Optional[GenerationOutput]: The result data or None if timeout
"""
if self._generation_thread is None and self.output_queue.empty():
return None
try:
result = self.output_queue.get(block=True, timeout=timeout)
if request_id is not None and result.request_id != request_id:
self.output_queue.put(result)
return None
logger.debug(f"Retrieved result for request {result.request_id}")
return result
except queue.Empty:
return None

def __iter__(self):
"""Iterate over results as they become available."""
while (
self._generation_thread is not None and self._generation_thread.is_alive() or not self.output_queue.empty()
):
result = self.get_result(timeout=0.1) # allow the model to run for 10 seconds
while self._generation_thread is not None and self._generation_thread.is_alive():
result = self.get_result(timeout=0.1)
if result is not None:
yield result

def request_id_iter(self, request_id):
"""Iterate over results matching a specific request id as they become available."""
while self._generation_thread is not None and self._generation_thread.is_alive():
result = self.get_result(request_id=request_id, timeout=0.1)
if result is not None:
yield result

Expand Down Expand Up @@ -637,6 +639,7 @@ def _run_generation_loop(self):
self.generation_config,
self.model.device,
self.model.dtype,
# FIXME: this is unused, why was it added?
num_requests=len(self.input_queue.queue),
tp_size=getattr(self.model, "_tp_size", None), # Use model's actual TP setting
)
Expand Down Expand Up @@ -681,7 +684,8 @@ def _run_generation_loop(self):
def _inner_generation_loop(self, batch_processor: ContinuousBatchProcessor):
if torch.cuda.is_available():
torch.cuda.synchronize()
batch_processor.prepare_next_batch()
if not batch_processor.prepare_next_batch():
return
device, total, reserved, allocated = get_device_and_memory_breakdown()
logger.debug(f"[Memory] Device: {device}, Total: {total}, Reserved: {reserved}, Allocated: {allocated}")
if torch.cuda.is_available() and self.use_cuda_graph:
Expand Down Expand Up @@ -829,7 +833,6 @@ def generate_batch(
results[req_id] = result
finished_count += 1
pbar.update(1)
logger.debug(manager.batch_processor.tokenizer.decode(result.generated_tokens))
else:
if not manager.is_running():
logger.error("Generation thread terminated unexpectedly.")
Expand Down
Loading