Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import tempfile
import threading
import time
import uuid
from argparse import ArgumentParser, Namespace
from collections.abc import AsyncGenerator, Generator, Iterable
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -132,7 +133,6 @@ class TransformersCompletionCreateParamsStreaming(CompletionCreateParamsStreamin
"""

generation_config: str
request_id: str

class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total=False):
"""
Expand Down Expand Up @@ -211,6 +211,8 @@ class TransformersTranscriptionCreateParams(TranscriptionCreateParamsBase, total
}
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())

X_REQUEST_ID = "x-request-id"


class Modality(enum.Enum):
LLM = "LLM"
Expand Down Expand Up @@ -688,14 +690,16 @@ async def lifespan(app: FastAPI):
"CORS allow origin is set to `*`. This is not recommended for production environments."
)

from fastapi import Request

@app.post("/v1/chat/completions")
def chat_completion(request: dict):
self.validate_chat_completion_request(request=request)
def chat_completion(request: Request, body: dict):
self.validate_chat_completion_request(request=body)

if self.use_continuous_batching:
output = self.continuous_batching_chat_completion(request)
output = self.continuous_batching_chat_completion(body, request.state.request_id)
else:
output = self.generate_chat_completion(request)
output = self.generate_chat_completion(body)
return StreamingResponse(output, media_type="text/event-stream")

@app.post("/v1/responses")
Expand All @@ -705,8 +709,6 @@ def responses(request: dict):
output = self.generate_response(request)
return StreamingResponse(output, media_type="text/event-stream")

from fastapi import Request

@app.post("/v1/audio/transcriptions")
async def audio_transcriptions(request: Request):
# Parses the multipart/form-data request into the request format used by other endpoints
Expand Down Expand Up @@ -734,6 +736,14 @@ def get_all_models():
def healthcheck():
return JSONResponse({"status": "ok"})

@app.middleware("http")
async def get_or_set_request_id(request: Request, call_next):
request_id = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers[X_REQUEST_ID] = request_id
return response

uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)

@functools.cache
Expand Down Expand Up @@ -782,7 +792,7 @@ def get_gen_models(self) -> list[dict[str, any]]:
for model in model_infos
]

def continuous_batching_chat_completion(self, req: dict) -> AsyncGenerator[str, None]:
def continuous_batching_chat_completion(self, req: dict, request_id: str) -> AsyncGenerator[str, None]:
"""
Generates an OpenAI Chat Completion using continuous batching.

Expand Down Expand Up @@ -858,22 +868,21 @@ def stream_chat_completion(request_id, decode_stream):
self.running_continuous_batching_manager.cancel_request(request_id)
yield f'data: {{"error": "{str(e)}"}}'

async def cancellation_wrapper(_inputs):
request_id = None
async def cancellation_wrapper(_inputs, request_id):
try:
decode_stream = DecodeStream(_inputs.tolist(), False)
# XXX: using returned request_id as safety in case it is None
request_id = self.running_continuous_batching_manager.add_request(
_inputs, request_id=req.get("request_id"), max_new_tokens=generation_config.max_new_tokens
_inputs, request_id=request_id, max_new_tokens=generation_config.max_new_tokens
)
for chunk in stream_chat_completion(request_id, decode_stream):
yield chunk
await asyncio.sleep(0) # Yield control to the event loop to check for cancellations
except asyncio.CancelledError:
if request_id is not None:
self.running_continuous_batching_manager.cancel_request(request_id)
logger.warning(f"Request {request_id} was cancelled.")
self.running_continuous_batching_manager.cancel_request(request_id)
logger.warning(f"Request {request_id} was cancelled.")

return cancellation_wrapper(inputs[0])
return cancellation_wrapper(inputs[0], request_id)

@staticmethod
def get_model_modality(model: "PreTrainedModel") -> Modality:
Expand Down
40 changes: 26 additions & 14 deletions tests/commands/test_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,30 +498,45 @@ def _get_scheduler(serve_command):
cbm = getattr(serve_command, "running_continuous_batching_manager", None)
assert cbm is not None, "ServeCommand has no running_continuous_batching_manager"
bp = getattr(cbm, "batch_processor", None)
assert bp is not None, "CBM has no batch_processor"
assert bp is not None, "running_continuous_batching_manager has no batch_processor"
sched = getattr(bp, "scheduler", None)
assert sched is not None, "batch_processor has no scheduler"
return sched


def _call_healthcheck(base_url: str):
response = None
retries = 10
while retries > 0:
try:
response = requests.get(f"{base_url}/health")
break
except requests.exceptions.ConnectionError:
time.sleep(0.1)
retries -= 1
return response


def _open_stream_and_cancel(base_url: str, request_id: str):
with requests.Session() as s:
with s.post(
f"{base_url}/v1/chat/completions",
headers={"X-Request-ID": request_id},
json={
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"stream": True,
"messages": [{"role": "user", "content": "Count slowly so I can cancel you."}],
"request_id": request_id,
},
stream=True,
timeout=30,
) as resp:
assert resp.status_code == 200

for _ in resp.iter_content(chunk_size=None):
resp.close()
break
wait_for_n_chunks = 3
for i, _ in enumerate(resp.iter_content(chunk_size=None)):
if i >= wait_for_n_chunks:
resp.close()
break


@slow # server startup time is slow on our push CI
Expand Down Expand Up @@ -598,6 +613,11 @@ def test_request_cancellation(self):
base_url = f"http://127.0.0.1:{self.port}"
request_id = "test-cancel"

# Ensure the server is up before sending a request
response = _call_healthcheck(base_url)
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
self.assertEqual(response.status_code, 200)

_open_stream_and_cancel(base_url, request_id)

scheduler = _get_scheduler(self.serve_command)
Expand Down Expand Up @@ -724,15 +744,7 @@ def setUpClass(cls):

def test_healthcheck(self):
"""Tests that the healthcheck endpoint works."""
response = None
retries = 10
while retries > 0:
try:
response = requests.get(f"http://localhost:{self.port}/health")
break
except requests.exceptions.ConnectionError:
time.sleep(0.1)
retries -= 1
response = _call_healthcheck(f"http://localhost:{self.port}")
self.assertIsNotNone(response, "Failed to connect to the server health endpoint.")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"status": "ok"})