Skip to content

Commit

Permalink
Decrease latency: do not run pre and postprocess in threadpool (#7796)
Browse files Browse the repository at this point in the history
* revert

* add changeset

* lint

* explicit call

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Mar 22, 2024
1 parent d831040 commit aad209f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 41 deletions.
5 changes: 5 additions & 0 deletions .changeset/old-dolls-pump.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Decrease latency: do not run pre and postprocess in threadpool
50 changes: 9 additions & 41 deletions gradio/blocks.py
Expand Up @@ -1603,15 +1603,6 @@ def handle_streaming_diffs(

return data

def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None):
output = []
for i in zip(*batch):
args = [fn_index, list(i), state]
if explicit_call is not None:
args.append(explicit_call)
output.append(fn(*args))
return output

async def process_api(
self,
fn_index: int,
Expand Down Expand Up @@ -1662,15 +1653,10 @@ async def process_api(
raise ValueError(
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)
inputs = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.preprocess_data,
inputs,
fn_index,
state,
explicit_call,
limiter=self.limiter,
)
inputs = [
self.preprocess_data(fn_index, list(i), state, explicit_call)
for i in zip(*inputs)
]
result = await self.call_function(
fn_index,
list(zip(*inputs)),
Expand All @@ -1681,14 +1667,9 @@ async def process_api(
in_event_listener,
)
preds = result["prediction"]
data = await anyio.to_thread.run_sync(
self.run_fn_batch,
self.postprocess_data,
preds,
fn_index,
state,
limiter=self.limiter,
)
data = [
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
]
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
data = list(zip(*data))
Expand All @@ -1698,14 +1679,7 @@ async def process_api(
if old_iterator:
inputs = []
else:
inputs = await anyio.to_thread.run_sync(
self.preprocess_data,
fn_index,
inputs,
state,
explicit_call,
limiter=self.limiter,
)
inputs = self.preprocess_data(fn_index, inputs, state, explicit_call)
was_generating = old_iterator is not None
result = await self.call_function(
fn_index,
Expand All @@ -1716,13 +1690,7 @@ async def process_api(
event_data,
in_event_listener,
)
data = await anyio.to_thread.run_sync(
self.postprocess_data,
fn_index, # type: ignore
result["prediction"],
state,
limiter=self.limiter,
)
data = self.postprocess_data(fn_index, result["prediction"], state)
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
is_generating, iterator = result["is_generating"], result["iterator"]
Expand Down

0 comments on commit aad209f

Please sign in to comment.