diff --git a/.changeset/young-ducks-scream.md b/.changeset/young-ducks-scream.md new file mode 100644 index 000000000000..58f7cdc13e8b --- /dev/null +++ b/.changeset/young-ducks-scream.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Benchmark fix test diff --git a/gradio/queueing.py b/gradio/queueing.py index bf6c87d6b189..6e3fbf9f9278 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -2,7 +2,6 @@ import asyncio import copy -import json import os import random import time @@ -461,74 +460,6 @@ def get_status(self) -> EstimationMessage: queue_size=len(self), ) - async def call_prediction(self, events: list[Event], batch: bool) -> dict: - body = events[0].data - if body is None: - raise ValueError("No event data") - username = events[0].username - body.event_id = events[0]._id if not batch else None - try: - body.request = events[0].request - except ValueError: - pass - - if batch: - body.data = list(zip(*[event.data.data for event in events if event.data])) - body.request = events[0].request - body.batched = True - - app = self.server_app - if app is None: - raise Exception("Server app has not been set.") - api_name = "predict" - - fn_index_inferred = route_utils.infer_fn_index( - app=app, api_name=api_name, body=body - ) - - gr_request = route_utils.compile_gr_request( - app=app, - body=body, - fn_index_inferred=fn_index_inferred, - username=username, - request=None, - ) - assert body.request is not None # noqa: S101 - root_path = route_utils.get_root_url( - request=body.request, route_path="/queue/join", root_path=app.root_path - ) - try: - output = await route_utils.call_process_api( - app=app, - body=body, - gr_request=gr_request, - fn_index_inferred=fn_index_inferred, - root_path=root_path, - ) - except Exception as error: - show_error = app.get_blocks().show_error or isinstance(error, Error) - traceback.print_exc() - raise Exception(str(error) if show_error else None) from error - - # To emulate the HTTP response from the predict API, - # convert the output to a JSON response string. - # This is done by FastAPI automatically in the HTTP endpoint handlers, - # but we need to do it manually here. - response_class = app.router.default_response_class - if isinstance(response_class, fastapi.datastructures.DefaultPlaceholder): - actual_response_class = response_class.value - else: - actual_response_class = response_class - http_response = actual_response_class( - output - ) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264 - # Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`. - response_json = json.loads(http_response.body.decode()) - if not isinstance(response_json, dict): - raise ValueError("Unexpected object.") - - return response_json - async def process_events( self, events: list[Event], batch: bool, begin_time: float ) -> None: @@ -548,21 +479,64 @@ async def process_events( awake_events.append(event) if not awake_events: return + + events = awake_events + body = events[0].data + if body is None: + raise ValueError("No event data") + username = events[0].username + body.event_id = events[0]._id if not batch else None + try: + body.request = events[0].request + except ValueError: + pass + + if batch: + body.data = list( + zip(*[event.data.data for event in events if event.data]) + ) + body.request = events[0].request + body.batched = True + + app = self.server_app + if app is None: + raise Exception("Server app has not been set.") + api_name = "predict" + + fn_index_inferred = route_utils.infer_fn_index( + app=app, api_name=api_name, body=body + ) + + gr_request = route_utils.compile_gr_request( + app=app, + body=body, + fn_index_inferred=fn_index_inferred, + username=username, + request=None, + ) + assert body.request is not None # noqa: S101 + root_path = route_utils.get_root_url( + request=body.request, route_path="/queue/join", root_path=app.root_path + ) try: - response = await self.call_prediction(awake_events, batch) + response = await route_utils.call_process_api( + app=app, + body=body, + gr_request=gr_request, + fn_index_inferred=fn_index_inferred, + root_path=root_path, + ) err = None except Exception as e: + show_error = app.get_blocks().show_error or isinstance(e, Error) + traceback.print_exc() response = None err = e for event in awake_events: self.send_message( event, ProcessCompletedMessage( - output={ - "error": None - if len(e.args) and e.args[0] is None - else str(e) - }, + output={"error": str(e) if show_error else None}, success=False, ), ) @@ -584,25 +558,31 @@ async def process_events( if not awake_events: return try: - response = await self.call_prediction(awake_events, batch) - err = None + response = await route_utils.call_process_api( + app=app, + body=body, + gr_request=gr_request, + fn_index_inferred=fn_index_inferred, + root_path=root_path, + ) except Exception as e: + traceback.print_exc() response = None err = e + + if response: + success = True + output = response + else: + success = False + error = err or old_err + show_error = app.get_blocks().show_error or isinstance(error, Error) + output = {"error": str(error) if show_error else None} for event in awake_events: - relevant_response = response or err or old_err self.send_message( - event, - ProcessCompletedMessage( - output={"error": str(relevant_response)} - if isinstance(relevant_response, Exception) - else relevant_response or {}, - success=( - relevant_response is not None - and not isinstance(relevant_response, Exception) - ), - ), + event, ProcessCompletedMessage(output=output, success=success) ) + elif response: output = copy.deepcopy(response) for e, event in enumerate(awake_events): diff --git a/gradio/routes.py b/gradio/routes.py index 77b6c374adce..9bdf23e03e44 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -720,7 +720,7 @@ async def queue_data( session_hash: str, ): def process_msg(message: EventMessage) -> str: - return f"data: {json.dumps(message.model_dump())}\n\n" + return f"data: {orjson.dumps(message.model_dump()).decode('utf-8')}\n\n" return await queue_data_helper(request, session_hash, process_msg)