Skip to content

Commit

Permalink
Benchmark fix test (#7810)
Browse files Browse the repository at this point in the history
* changes

* changes

* changes

* chnages

* changes

* add changeset

* changes

* changes

* add changeset

* chnages

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 27, 2024
1 parent 8d7b3ca commit 425fd1c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 89 deletions.
5 changes: 5 additions & 0 deletions .changeset/young-ducks-scream.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Benchmark fix test
156 changes: 68 additions & 88 deletions gradio/queueing.py
Expand Up @@ -2,7 +2,6 @@

import asyncio
import copy
import json
import os
import random
import time
Expand Down Expand Up @@ -461,74 +460,6 @@ def get_status(self) -> EstimationMessage:
queue_size=len(self),
)

async def call_prediction(self, events: list[Event], batch: bool) -> dict:
body = events[0].data
if body is None:
raise ValueError("No event data")
username = events[0].username
body.event_id = events[0]._id if not batch else None
try:
body.request = events[0].request
except ValueError:
pass

if batch:
body.data = list(zip(*[event.data.data for event in events if event.data]))
body.request = events[0].request
body.batched = True

app = self.server_app
if app is None:
raise Exception("Server app has not been set.")
api_name = "predict"

fn_index_inferred = route_utils.infer_fn_index(
app=app, api_name=api_name, body=body
)

gr_request = route_utils.compile_gr_request(
app=app,
body=body,
fn_index_inferred=fn_index_inferred,
username=username,
request=None,
)
assert body.request is not None # noqa: S101
root_path = route_utils.get_root_url(
request=body.request, route_path="/queue/join", root_path=app.root_path
)
try:
output = await route_utils.call_process_api(
app=app,
body=body,
gr_request=gr_request,
fn_index_inferred=fn_index_inferred,
root_path=root_path,
)
except Exception as error:
show_error = app.get_blocks().show_error or isinstance(error, Error)
traceback.print_exc()
raise Exception(str(error) if show_error else None) from error

# To emulate the HTTP response from the predict API,
# convert the output to a JSON response string.
# This is done by FastAPI automatically in the HTTP endpoint handlers,
# but we need to do it manually here.
response_class = app.router.default_response_class
if isinstance(response_class, fastapi.datastructures.DefaultPlaceholder):
actual_response_class = response_class.value
else:
actual_response_class = response_class
http_response = actual_response_class(
output
) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264
# Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`.
response_json = json.loads(http_response.body.decode())
if not isinstance(response_json, dict):
raise ValueError("Unexpected object.")

return response_json

async def process_events(
self, events: list[Event], batch: bool, begin_time: float
) -> None:
Expand All @@ -548,21 +479,64 @@ async def process_events(
awake_events.append(event)
if not awake_events:
return

events = awake_events
body = events[0].data
if body is None:
raise ValueError("No event data")
username = events[0].username
body.event_id = events[0]._id if not batch else None
try:
body.request = events[0].request
except ValueError:
pass

if batch:
body.data = list(
zip(*[event.data.data for event in events if event.data])
)
body.request = events[0].request
body.batched = True

app = self.server_app
if app is None:
raise Exception("Server app has not been set.")
api_name = "predict"

fn_index_inferred = route_utils.infer_fn_index(
app=app, api_name=api_name, body=body
)

gr_request = route_utils.compile_gr_request(
app=app,
body=body,
fn_index_inferred=fn_index_inferred,
username=username,
request=None,
)
assert body.request is not None # noqa: S101
root_path = route_utils.get_root_url(
request=body.request, route_path="/queue/join", root_path=app.root_path
)
try:
response = await self.call_prediction(awake_events, batch)
response = await route_utils.call_process_api(
app=app,
body=body,
gr_request=gr_request,
fn_index_inferred=fn_index_inferred,
root_path=root_path,
)
err = None
except Exception as e:
show_error = app.get_blocks().show_error or isinstance(e, Error)
traceback.print_exc()
response = None
err = e
for event in awake_events:
self.send_message(
event,
ProcessCompletedMessage(
output={
"error": None
if len(e.args) and e.args[0] is None
else str(e)
},
output={"error": str(e) if show_error else None},
success=False,
),
)
Expand All @@ -584,25 +558,31 @@ async def process_events(
if not awake_events:
return
try:
response = await self.call_prediction(awake_events, batch)
err = None
response = await route_utils.call_process_api(
app=app,
body=body,
gr_request=gr_request,
fn_index_inferred=fn_index_inferred,
root_path=root_path,
)
except Exception as e:
traceback.print_exc()
response = None
err = e

if response:
success = True
output = response
else:
success = False
error = err or old_err
show_error = app.get_blocks().show_error or isinstance(error, Error)
output = {"error": str(error) if show_error else None}
for event in awake_events:
relevant_response = response or err or old_err
self.send_message(
event,
ProcessCompletedMessage(
output={"error": str(relevant_response)}
if isinstance(relevant_response, Exception)
else relevant_response or {},
success=(
relevant_response is not None
and not isinstance(relevant_response, Exception)
),
),
event, ProcessCompletedMessage(output=output, success=success)
)

elif response:
output = copy.deepcopy(response)
for e, event in enumerate(awake_events):
Expand Down
2 changes: 1 addition & 1 deletion gradio/routes.py
Expand Up @@ -720,7 +720,7 @@ async def queue_data(
session_hash: str,
):
def process_msg(message: EventMessage) -> str:
return f"data: {json.dumps(message.model_dump())}\n\n"
return f"data: {orjson.dumps(message.model_dump()).decode('utf-8')}\n\n"

return await queue_data_helper(request, session_hash, process_msg)

Expand Down

0 comments on commit 425fd1c

Please sign in to comment.