Skip to content

Commit

Permalink
fix get_continuous_fn bug when having every (#4434)
Browse files Browse the repository at this point in the history
* fix the bug for wrap continuous func with parameter every while origin func return generator

* Update utils.py

* Update CHANGELOG.md

* Update test_utils.py

* Update CHANGELOG.md

* formatting

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
dkjshk and abidlabs committed Jun 6, 2023
1 parent 9c45ace commit 92a70dd
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ No changes to highlight.

## Bug Fixes:

- Fix bug for get_continuous_fn by [@dkjshk](https://github.com/dkjshk) in [PR 4434](https://github.com/gradio-app/gradio/pull/4434)
- Fix z-index of status component by [@hannahblair](https://github.com/hannahblair) in [PR 4429](https://github.com/gradio-app/gradio/pull/4429)
- Allow gradio to work offline, by [@aliabid94](https://github.com/aliabid94) in [PR 4398](https://github.com/gradio-app/gradio/pull/4398).
- Fixed `validate_url` to check for 403 errors and use a GET request in place of a HEAD by [@alvindaiyan](https://github.com/alvindaiyan) in [PR 4388](https://github.com/gradio-app/gradio/pull/4388).
Expand Down
6 changes: 5 additions & 1 deletion gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from io import BytesIO
from numbers import Number
from pathlib import Path
from types import GeneratorType
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -636,7 +637,10 @@ def get_continuous_fn(fn: Callable, every: float) -> Callable:
def continuous_fn(*args):
while True:
output = fn(*args)
yield output
if isinstance(output, GeneratorType):
yield from output
else:
yield output
time.sleep(every)

return continuous_fn
Expand Down
32 changes: 32 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
colab_check,
delete_none,
format_ner_list,
get_continuous_fn,
get_type_hints,
ipython_check,
is_in_or_equal,
Expand Down Expand Up @@ -613,6 +614,37 @@ def f(s: str, evt: EventData):
check_function_inputs_match(x, [None], False)


class TestGetContinuousFn:
def test_get_continuous_fn(self):
def int_return(x): # for origin condition
return x + 1

def int_yield(x): # new condition
for _i in range(2):
yield x
x += 1

def list_yield(x): # new condition
for _i in range(2):
yield x
x += [1]

gen_int_return = get_continuous_fn(fn=int_return, every=0.01)
gen_int_yield = get_continuous_fn(fn=int_yield, every=0.01)
gen_list_yield = get_continuous_fn(fn=list_yield, every=0.01)
gener_int_return = gen_int_return(1)
gener_int = gen_int_yield(1) # Primitive
gener_list = gen_list_yield([1]) # Reference
assert next(gener_int_return) == 2
assert next(gener_int_return) == 2
assert next(gener_int) == 1
assert next(gener_int) == 2
assert next(gener_int) == 1
assert [1] == next(gener_list)
assert [1, 1] == next(gener_list)
assert [1, 1, 1] == next(gener_list)


def test_tex2svg_preserves_matplotlib_backend():
import matplotlib

Expand Down

0 comments on commit 92a70dd

Please sign in to comment.