Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run before_fn and after_fn for each generator iteration #7029

Merged
merged 9 commits into from Jan 22, 2024
5 changes: 5 additions & 0 deletions .changeset/big-bears-cover.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Run before_fn and after_fn for each generator iteration
40 changes: 29 additions & 11 deletions gradio/utils.py
Expand Up @@ -635,12 +635,19 @@ def function_wrapper(

@functools.wraps(f)
async def asyncgen_wrapper(*args, **kwargs):
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
if before_fn:
before_fn(*before_args)
async for response in f(*args, **kwargs):
iterator = f(*args, **kwargs)
while True:
if before_fn:
before_fn(*before_args)
try:
response = await iterator.__anext__()
except StopAsyncIteration:
if after_fn:
after_fn(*after_args)
break
if after_fn:
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
after_fn(*after_args)
yield response
if after_fn:
after_fn(*after_args)

return asyncgen_wrapper

Expand All @@ -661,11 +668,19 @@ async def async_wrapper(*args, **kwargs):

@functools.wraps(f)
def gen_wrapper(*args, **kwargs):
if before_fn:
before_fn(*before_args)
yield from f(*args, **kwargs)
if after_fn:
after_fn(*after_args)
iterator = f(*args, **kwargs)
while True:
if before_fn:
before_fn(*before_args)
try:
response = next(iterator)
except StopIteration:
if after_fn:
after_fn(*after_args)
break
if after_fn:
after_fn(*after_args)
yield response

return gen_wrapper

Expand Down Expand Up @@ -705,7 +720,10 @@ def after_fn():
LocalContext.request.set(None)

return function_wrapper(
fn, before_fn=before_fn, before_args=(blocks, event_id), after_fn=after_fn
fn,
before_fn=before_fn,
before_args=(blocks, event_id),
after_fn=after_fn,
)


Expand Down