Skip to content

Commit

Permalink
Stop passing inputs and preprocessing on iterators (#5260)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* add changeset

* Update blocks.py

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
aliabid94 and gradio-pr-bot committed Aug 18, 2023
1 parent 1cefee7 commit a773eaf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
5 changes: 5 additions & 0 deletions .changeset/hip-queens-grin.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Stop passing inputs and preprocessing on iterators
27 changes: 14 additions & 13 deletions gradio/blocks.py
Expand Up @@ -1080,23 +1080,21 @@ async def call_function(
block_fn = self.fns[fn_index]
assert block_fn.fn, f"function with index {fn_index} not defined."
is_generating = False

if block_fn.inputs_as_dict:
processed_input = [dict(zip(block_fn.inputs, processed_input))]

request = requests[0] if isinstance(requests, list) else requests
processed_input, progress_index, _ = special_args(
block_fn.fn, processed_input, request, event_data
)
progress_tracker = (
processed_input[progress_index] if progress_index is not None else None
)

start = time.time()

fn = utils.get_function_with_locals(block_fn.fn, self, event_id)

if iterator is None: # If not a generator function that has already run
if block_fn.inputs_as_dict:
processed_input = [dict(zip(block_fn.inputs, processed_input))]

processed_input, progress_index, _ = special_args(
block_fn.fn, processed_input, request, event_data
)
progress_tracker = (
processed_input[progress_index] if progress_index is not None else None
)

if progress_tracker is not None and progress_index is not None:
progress_tracker, fn = create_tracker(
self, event_id, fn, progress_tracker.track_tqdm
Expand Down Expand Up @@ -1425,8 +1423,11 @@ async def process_api(
data = list(zip(*data))
is_generating, iterator = None, None
else:
inputs = self.preprocess_data(fn_index, inputs, state)
old_iterator = iterators.get(fn_index, None) if iterators else None
if old_iterator:
inputs = []
else:
inputs = self.preprocess_data(fn_index, inputs, state)
was_generating = old_iterator is not None
result = await self.call_function(
fn_index, inputs, old_iterator, request, event_id, event_data
Expand Down

0 comments on commit a773eaf

Please sign in to comment.