Skip to content

Commit

Permalink
Fixes issue 5781: Enables specifying a caching directory for Examples (
Browse files Browse the repository at this point in the history
…#6803)

* issue 5781 first commit

* second commit

* unnecessary str removed

* backend formatted

* Update gradio/helpers.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update guides/02_building-interfaces/03_more-on-examples.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* tests added

* add changeset

* format

---------

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 19, 2023
1 parent 50496f9 commit 77c9003
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 108 deletions.
5 changes: 5 additions & 0 deletions .changeset/spicy-wings-thank.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Fixes issue 5781: Enables specifying a caching directory for Examples
3 changes: 1 addition & 2 deletions gradio/helpers.py
Expand Up @@ -33,7 +33,6 @@
if TYPE_CHECKING: # Only import for type checking (to avoid circular imports).
from gradio.components import Component

CACHED_FOLDER = "gradio_cached_examples"
LOG_FILE = "log.csv"

set_documentation_group("helpers")
Expand Down Expand Up @@ -248,7 +247,7 @@ def __init__(
elem_id=elem_id,
)

self.cached_folder = Path(CACHED_FOLDER) / str(self.dataset._id)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cache_examples = cache_examples
self.run_on_click = run_on_click
Expand Down
3 changes: 1 addition & 2 deletions gradio/routes.py
Expand Up @@ -54,7 +54,6 @@
from gradio.context import Context
from gradio.data_classes import ComponentServerBody, PredictBody, ResetBody
from gradio.exceptions import Error
from gradio.helpers import CACHED_FOLDER
from gradio.oauth import attach_oauth
from gradio.queueing import Estimation
from gradio.route_utils import ( # noqa: F401
Expand Down Expand Up @@ -455,7 +454,7 @@ async def file(path_or_url: str, request: fastapi.Request):
)
was_uploaded = utils.is_in_or_equal(abs_path, app.uploaded_file_dir)
is_cached_example = utils.is_in_or_equal(
abs_path, utils.abspath(CACHED_FOLDER)
abs_path, utils.abspath(utils.get_cache_folder())
)

if not (
Expand Down
4 changes: 4 additions & 0 deletions gradio/utils.py
Expand Up @@ -1016,3 +1016,7 @@ def __setitem__(self, key: K, value: V) -> None:
elif len(self) >= self.max_size:
self.popitem(last=False)
super().__setitem__(key, value)


def get_cache_folder() -> Path:
return Path(os.environ.get("GRADIO_EXAMPLES_CACHE", "gradio_cached_examples"))
2 changes: 1 addition & 1 deletion guides/02_building-interfaces/03_more-on-examples.md
Expand Up @@ -34,7 +34,7 @@ Sometimes your app has many input components, but you would only like to provide
## Caching examples

You may wish to provide some cached examples of your model for users to quickly try out, in case your model takes a while to run normally.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples`.
If `cache_examples=True`, the `Interface` will run all of your examples through your app and save the outputs when you call the `launch()` method. This data will be saved in a directory called `gradio_cached_examples` in your working directory by default. You can also set this directory with the `GRADIO_EXAMPLES_CACHE` environment variable, which can be either an absolute path or a relative path to your working directory.

Whenever a user clicks on an example, the output will automatically be populated in the app now, using data from this cached directory instead of actually running the function. This is useful so users can quickly try out your model without adding any load!

Expand Down
131 changes: 72 additions & 59 deletions test/test_chat_interface.py
@@ -1,10 +1,11 @@
import tempfile
from concurrent.futures import wait
from pathlib import Path
from unittest.mock import patch

import pytest

import gradio as gr
from gradio import helpers


def invalid_fn(message):
Expand Down Expand Up @@ -79,44 +80,52 @@ def test_events_attached(self):
)

def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
assert prediction_hi[0].root[0] == ("hi", "hi hi")

def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("abubakar", "hi, abubakar")
assert prediction_hi[0].root[0] == ("tom", "hi, tom")

def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")

def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "hello")
assert prediction_hi[0].root[0] == ("hi", "hi")

def test_default_accordion_params(self):
chatbot = gr.ChatInterface(
Expand Down Expand Up @@ -146,34 +155,38 @@ def test_setting_accordion_params(self, monkeypatch):
assert accordion.get_config().get("label") == "MOAR"

def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")

def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
with gr.Blocks():
with gr.Accordion("Inputs"):
text = gr.Textbox()
slider = gr.Slider()
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=[text, slider],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0].root[0] == ("hello", "robot hello")
assert prediction_hi[0].root[0] == ("hi", "ro")


class TestAPI:
Expand Down
7 changes: 5 additions & 2 deletions test/test_external.py
@@ -1,4 +1,5 @@
import os
import tempfile
import textwrap
import warnings
from pathlib import Path
Expand Down Expand Up @@ -356,7 +357,7 @@ def test_private_space_v4_sse_v1(self):
class TestLoadInterfaceWithExamples:
def test_interface_load_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch("gradio.utils.get_cache_folder", return_value=tmp_path):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
Expand All @@ -365,7 +366,9 @@ def test_interface_load_examples(self, tmp_path):

def test_interface_load_cache_examples(self, tmp_path):
test_file_dir = Path(Path(__file__).parent, "test_files")
with patch("gradio.helpers.CACHED_FOLDER", tmp_path):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
gr.load(
name="models/google/vit-base-patch16-224",
examples=[Path(test_file_dir, "cheetah1.jpg")],
Expand Down

0 comments on commit 77c9003

Please sign in to comment.