From aad209f0c0faa0bf3e39d0c8624f972118b32830 Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 22 Mar 2024 12:38:52 -0700 Subject: [PATCH] Decrease latency: do not run pre and postprocess in threadpool (#7796) * revert * add changeset * lint * explicit call --------- Co-authored-by: gradio-pr-bot --- .changeset/old-dolls-pump.md | 5 ++++ gradio/blocks.py | 50 +++++++----------------------------- 2 files changed, 14 insertions(+), 41 deletions(-) create mode 100644 .changeset/old-dolls-pump.md diff --git a/.changeset/old-dolls-pump.md b/.changeset/old-dolls-pump.md new file mode 100644 index 000000000000..aaef531db090 --- /dev/null +++ b/.changeset/old-dolls-pump.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Decrease latency: do not run pre and postprocess in threadpool diff --git a/gradio/blocks.py b/gradio/blocks.py index 1d04be05cd8d..897641fdb5f2 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1603,15 +1603,6 @@ def handle_streaming_diffs( return data - def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None): - output = [] - for i in zip(*batch): - args = [fn_index, list(i), state] - if explicit_call is not None: - args.append(explicit_call) - output.append(fn(*args)) - return output - async def process_api( self, fn_index: int, @@ -1662,15 +1653,10 @@ async def process_api( raise ValueError( f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})" ) - inputs = await anyio.to_thread.run_sync( - self.run_fn_batch, - self.preprocess_data, - inputs, - fn_index, - state, - explicit_call, - limiter=self.limiter, - ) + inputs = [ + self.preprocess_data(fn_index, list(i), state, explicit_call) + for i in zip(*inputs) + ] result = await self.call_function( fn_index, list(zip(*inputs)), @@ -1681,14 +1667,9 @@ async def process_api( in_event_listener, ) preds = result["prediction"] - data = await anyio.to_thread.run_sync( - self.run_fn_batch, - self.postprocess_data, - preds, - fn_index, - state, - limiter=self.limiter, - ) + data = [ + self.postprocess_data(fn_index, list(o), state) for o in zip(*preds) + ] if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) data = list(zip(*data)) @@ -1698,14 +1679,7 @@ async def process_api( if old_iterator: inputs = [] else: - inputs = await anyio.to_thread.run_sync( - self.preprocess_data, - fn_index, - inputs, - state, - explicit_call, - limiter=self.limiter, - ) + inputs = self.preprocess_data(fn_index, inputs, state, explicit_call) was_generating = old_iterator is not None result = await self.call_function( fn_index, @@ -1716,13 +1690,7 @@ async def process_api( event_data, in_event_listener, ) - data = await anyio.to_thread.run_sync( - self.postprocess_data, - fn_index, # type: ignore - result["prediction"], - state, - limiter=self.limiter, - ) + data = self.postprocess_data(fn_index, result["prediction"], state) if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) is_generating, iterator = result["is_generating"], result["iterator"]