Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in example cache loading event #5636

Merged
merged 6 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cold-lights-trade.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Fix bug in example cache loading event
47 changes: 22 additions & 25 deletions gradio/helpers.py
Expand Up @@ -280,14 +280,13 @@ async def cache(self) -> None:
"""
Caches all of the examples so that their predictions can be shown immediately.
"""
if Context.root_block is None:
raise ValueError("Cannot cache examples if not in a Blocks context")
if Path(self.cached_file).exists():
print(
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
)
else:
if Context.root_block is None:
raise ValueError("Cannot cache examples if not in a Blocks context")

print(f"Caching examples at: '{utils.abspath(self.cached_folder)}'")
cache_logger = CSVLogger()

Expand Down Expand Up @@ -352,30 +351,28 @@ async def get_final_item(*args):
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)

# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)

async def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)
# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)

self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
api_name=self.api_name, # type: ignore
)
async def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

print("Caching complete\n")
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
api_name=self.api_name, # type: ignore
)

async def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.
Expand Down
4 changes: 2 additions & 2 deletions test/conftest.py
Expand Up @@ -39,8 +39,8 @@ def io_components():
@pytest.fixture
def connect():
@contextmanager
def _connect(demo: gr.Blocks, serialize=True):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
def _connect(demo: gr.Blocks, serialize=True, server_port=None):
_, local_url, _ = demo.launch(prevent_thread_lock=True, server_port=server_port)
try:
yield Client(local_url, serialize=serialize)
finally:
Expand Down
70 changes: 70 additions & 0 deletions test/test_helpers.py
Expand Up @@ -118,6 +118,43 @@ def test_some_headers(self):
assert examples.dataset.headers == ["im", ""]


def test_example_caching_relaunch(connect):
def combine(a, b):
return a + " " + b

with gr.Blocks() as demo:
txt = gr.Textbox(label="Input")
txt_2 = gr.Textbox(label="Input 2")
txt_3 = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
gr.Examples(
[["hi", "Adam"], ["hello", "Eve"]],
[txt, txt_2],
txt_3,
combine,
cache_examples=True,
api_name="examples",
)

with connect(demo, server_port=7859) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)

# Let the server shut down
time.sleep(1)

with connect(demo, server_port=7859) as client:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, you launch on the same port to ensure that the server truly has shutdown? I don't think you need to do this.

Even if demo.launch() were to launch a FastAPI app on a new port, it would still read from the same cache since the ids of the components hasn't changed.

So you can remove the hardcoding of the server ports (which will be better in case this port is occupied)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The time.sleep is to ensure shutdown. Hardcoding the same port is for paranoia 😂 Let me fix.

assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)


@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
Expand All @@ -132,6 +169,39 @@ async def test_caching(self):
prediction = await io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"

def test_example_caching_relaunch(self, connect):
def combine(a, b):
return a + " " + b

with gr.Blocks() as demo:
txt = gr.Textbox(label="Input")
txt_2 = gr.Textbox(label="Input 2")
txt_3 = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3])
gr.Examples(
[["hi", "Adam"], ["hello", "Eve"]],
[txt, txt_2],
txt_3,
combine,
cache_examples=True,
api_name="examples",
)

with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)

with connect(demo) as client:
assert client.predict(1, api_name="/examples") == (
"hello",
"Eve",
"hello Eve",
)

@pytest.mark.asyncio
async def test_caching_image(self):
io = gr.Interface(
Expand Down