Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow caching generators and async generators #4927

Merged
merged 8 commits into from Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -2,6 +2,7 @@

## New Features:
- Chatbot messages now show hyperlinks to download files uploaded to `gr.Chatbot()` by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 4848](https://github.com/gradio-app/gradio/pull/4848)
- Cached examples now work with generators and async generators by [@abidlabs](https://github.com/abidlabs) in [PR 4927](https://github.com/gradio-app/gradio/pull/4927)

## Bug Fixes:

Expand Down
29 changes: 26 additions & 3 deletions gradio/helpers.py
Expand Up @@ -110,7 +110,7 @@ def __init__(
inputs: the component or list of components corresponding to the examples
outputs: optionally, provide the component or list of components corresponding to the output of the examples. Required if `cache` is True.
fn: optionally, provide the function to run to generate the outputs corresponding to the examples. Required if `cache` is True.
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` need to be provided
cache_examples: if True, caches examples for fast runtime. If True, then `fn` and `outputs` must be provided. If `fn` is a generator function, then the last yielded value will be used as the output.
examples_per_page: how many examples to show per page.
label: the label to use for the examples component (by default, "Examples")
elem_id: an optional string that is assigned as the id of this component in the HTML DOM.
Expand Down Expand Up @@ -289,7 +289,7 @@ async def cache(self) -> None:
"""
if Path(self.cached_file).exists():
print(
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache."
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
)
else:
if Context.root_block is None:
Expand All @@ -298,10 +298,31 @@ async def cache(self) -> None:
print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
cache_logger = CSVLogger()

if inspect.isgeneratorfunction(self.fn):

def get_final_item(args): # type: ignore
x = None
for x in self.fn(args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
elif inspect.isasyncgenfunction(self.fn):

async def get_final_item(args):
x = None
async for x in self.fn(args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
else:
fn = self.fn

# create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event",
fn=self.fn,
fn=fn,
inputs=self.inputs_with_examples, # type: ignore
outputs=self.outputs, # type: ignore
preprocess=self.preprocess and not self._api_mode,
Expand All @@ -312,6 +333,7 @@ async def cache(self) -> None:
assert self.outputs is not None
cache_logger.setup(self.outputs, self.cached_folder)
for example_id, _ in enumerate(self.examples):
print(f"Caching example {example_id + 1}/{len(self.examples)}")
processed_input = self.processed_examples[example_id]
if self.batch:
processed_input = [[value] for value in processed_input]
Expand All @@ -329,6 +351,7 @@ async def cache(self) -> None:
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)
print("Caching complete\n")

async def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.
Expand Down
2 changes: 1 addition & 1 deletion gradio/interface.py
Expand Up @@ -154,7 +154,7 @@ def __init__(
inputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of input components should match the number of parameters in fn. If set to None, then only the output components will be displayed.
outputs: a single Gradio component, or list of Gradio components. Components can either be passed as instantiated objects, or referred to by their string shortcuts. The number of output components should match the number of values returned by fn. If set to None, then only the input components will be displayed.
examples: sample inputs for the function; if provided, appear below the UI components and can be clicked to populate the interface. Should be nested list, in which the outer list consists of samples and each inner list consists of an input corresponding to each input component. A string path to a directory of examples can also be provided, but it should be within the directory with the python file running the gradio app. If there are multiple input components and a directory is provided, a log.csv file must be present in the directory to link corresponding inputs.
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
cache_examples: If True, caches examples in the server for fast runtime in examples. If `fn` is a generator function, then the last yielded value will be used as the output. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
examples_per_page: If examples are provided, how many to display per page.
live: whether the interface should automatically rerun if any of the inputs change.
interpretation: function that provides interpretation explaining prediction output. Pass "default" to use simple built-in interpreter, "shap" to use a built-in shapley-based interpreter, or your own custom interpretation function. For more information on the different interpretation methods, see the Advanced Interface Features guide.
Expand Down
32 changes: 32 additions & 0 deletions test/test_helpers.py
Expand Up @@ -204,6 +204,38 @@ async def test_caching_with_dict(self):
{"label": "lion"},
]

@pytest.mark.asyncio
async def test_caching_with_generators(self):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]

io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"

@pytest.mark.asyncio
async def test_caching_with_async_generators(self):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]

io = gr.Interface(
test_generator,
"textbox",
"textbox",
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"

def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
def foo(a, b):
return a + b
Expand Down