Skip to content

Commit

Permalink
Convert async methods in the Examples class into normal sync methods (#…
Browse files Browse the repository at this point in the history
…5822)

* Convert async methods in the Examples class into normal sync methods

* add changeset

* Fix test/test_chat_interface.py

* Fix test/test_helpers.py

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Oct 6, 2023
1 parent 1aa1862 commit 7b63db2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 78 deletions.
5 changes: 5 additions & 0 deletions .changeset/evil-berries-teach.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Convert async methods in the Examples class into normal sync methods
6 changes: 3 additions & 3 deletions gradio/helpers.py
Expand Up @@ -369,10 +369,10 @@ async def get_final_item(*args):
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)

async def load_example(example_id):
def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
] + self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.load_input_event = self.dataset.click(
Expand All @@ -385,7 +385,7 @@ async def load_example(example_id):
api_name=self.api_name, # type: ignore
)

async def load_from_cache(self, example_id: int) -> list[Any]:
def load_from_cache(self, example_id: int) -> list[Any]:
"""Loads a particular cached example for the interface.
Parameters:
example_id: The id of the example to process (zero-indexed).
Expand Down
44 changes: 18 additions & 26 deletions test/test_chat_interface.py
Expand Up @@ -72,68 +72,60 @@ def test_events_attached(self):
None,
)

@pytest.mark.asyncio
async def test_example_caching(self, monkeypatch):
def test_example_caching(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
double, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello hello"]
assert prediction_hi[0][0] == ["hi", "hi hi"]

@pytest.mark.asyncio
async def test_example_caching_async(self, monkeypatch):
def test_example_caching_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_greet, examples=["abubakar", "tom"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["abubakar", "hi, abubakar"]
assert prediction_hi[0][0] == ["tom", "hi, tom"]

@pytest.mark.asyncio
async def test_example_caching_with_streaming(self, monkeypatch):
def test_example_caching_with_streaming(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]

@pytest.mark.asyncio
async def test_example_caching_with_streaming_async(self, monkeypatch):
def test_example_caching_with_streaming_async(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
async_stream, examples=["hello", "hi"], cache_examples=True
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "hello"]
assert prediction_hi[0][0] == ["hi", "hi"]

@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs(self, monkeypatch):
def test_example_caching_with_additional_inputs(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
chatbot = gr.ChatInterface(
echo_system_prompt_plus_message,
additional_inputs=["textbox", "slider"],
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]

@pytest.mark.asyncio
async def test_example_caching_with_additional_inputs_already_rendered(
self, monkeypatch
):
def test_example_caching_with_additional_inputs_already_rendered(self, monkeypatch):
monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp())
with gr.Blocks():
with gr.Accordion("Inputs"):
Expand All @@ -145,8 +137,8 @@ async def test_example_caching_with_additional_inputs_already_rendered(
examples=[["hello", "robot", 100], ["hi", "robot", 2]],
cache_examples=True,
)
prediction_hello = await chatbot.examples_handler.load_from_cache(0)
prediction_hi = await chatbot.examples_handler.load_from_cache(1)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hello[0][0] == ["hello", "robot hello"]
assert prediction_hi[0][0] == ["hi", "ro"]

Expand Down
82 changes: 33 additions & 49 deletions test/test_helpers.py
Expand Up @@ -58,8 +58,7 @@ def test_examples_per_page(self):
examples = gr.Examples(["hello", "hi"], gr.Textbox(), examples_per_page=2)
assert examples.dataset.get_config()["samples_per_page"] == 2

@pytest.mark.asyncio
async def test_no_preprocessing(self):
def test_no_preprocessing(self):
with gr.Blocks():
image = gr.Image()
textbox = gr.Textbox()
Expand All @@ -73,11 +72,10 @@ async def test_no_preprocessing(self):
preprocess=False,
)

prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
assert prediction == [media_data.BASE64_IMAGE]

@pytest.mark.asyncio
async def test_no_postprocessing(self):
def test_no_postprocessing(self):
def im(x):
return [media_data.BASE64_IMAGE]

Expand All @@ -94,7 +92,7 @@ def im(x):
postprocess=False,
)

prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
file = prediction[0][0][0]["name"]
assert utils.encode_url_or_file_to_base64(file) == media_data.BASE64_IMAGE

Expand Down Expand Up @@ -158,16 +156,15 @@ def combine(a, b):

@patch("gradio.helpers.CACHED_FOLDER", tempfile.mkdtemp())
class TestProcessExamples:
@pytest.mark.asyncio
async def test_caching(self):
def test_caching(self):
io = gr.Interface(
lambda x: f"Hello {x}",
"text",
"text",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == "Hello Dunya"

def test_example_caching_relaunch(self, connect):
Expand Down Expand Up @@ -203,66 +200,61 @@ def combine(a, b):
"hello Eve",
)

@pytest.mark.asyncio
async def test_caching_image(self):
def test_caching_image(self):
io = gr.Interface(
lambda x: x,
"image",
"image",
examples=[["test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0].startswith("")

@pytest.mark.asyncio
async def test_caching_audio(self):
def test_caching_audio(self):
io = gr.Interface(
lambda x: x,
"audio",
"audio",
examples=[["test/test_files/audio_sample.wav"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
file = prediction[0]["name"]
assert utils.encode_url_or_file_to_base64(file).startswith(
"data:audio/wav;base64,UklGRgA/"
)

@pytest.mark.asyncio
async def test_caching_with_update(self):
def test_caching_with_update(self):
io = gr.Interface(
lambda x: gr.update(visible=False),
"text",
"image",
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"visible": False,
"__type__": "update",
}

@pytest.mark.asyncio
async def test_caching_with_mix_update(self):
def test_caching_with_mix_update(self):
io = gr.Interface(
lambda x: [gr.update(lines=4, value="hello"), "test/test_files/bus.png"],
"text",
["text", "image"],
examples=[["World"], ["Dunya"], ["Monde"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert prediction[0] == {
"lines": 4,
"value": "hello",
"__type__": "update",
}

@pytest.mark.asyncio
async def test_caching_with_dict(self):
def test_caching_with_dict(self):
text = gr.Textbox()
out = gr.Label()

Expand All @@ -273,14 +265,13 @@ async def test_caching_with_dict(self):
examples=["abc"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == [
{"lines": 4, "__type__": "update", "mode": "static"},
{"label": "lion"},
]

@pytest.mark.asyncio
async def test_caching_with_generators(self):
def test_caching_with_generators(self):
def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
Expand All @@ -292,11 +283,10 @@ def test_generator(x):
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"

@pytest.mark.asyncio
async def test_caching_with_generators_and_streamed_output(self):
def test_caching_with_generators_and_streamed_output(self):
file_dir = Path(Path(__file__).parent, "test_files")
audio = str(file_dir / "audio_sample.wav")

Expand All @@ -311,15 +301,14 @@ def test_generator(x):
examples=[3],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
len_input_audio = len(AudioSegment.from_wav(audio))
len_output_audio = len(AudioSegment.from_wav(prediction[0]["name"]))
length_ratio = len_output_audio / len_input_audio
assert round(length_ratio, 1) == 3.0 # might not be exactly 3x
assert float(prediction[1]) == 10.0

@pytest.mark.asyncio
async def test_caching_with_async_generators(self):
def test_caching_with_async_generators(self):
async def test_generator(x):
for y in range(len(x)):
yield "Your output: " + x[: y + 1]
Expand All @@ -331,7 +320,7 @@ async def test_generator(x):
examples=["abcdef"],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction[0] == "Your output: abcdef"

def test_raise_helpful_error_message_if_providing_partial_examples(self, tmp_path):
Expand Down Expand Up @@ -391,8 +380,7 @@ def many_missing(a, b, c):
cache_examples=True,
)

@pytest.mark.asyncio
async def test_caching_with_batch(self):
def test_caching_with_batch(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return [trimmed_words]
Expand All @@ -406,11 +394,10 @@ def trim_words(words, lens):
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel"]

@pytest.mark.asyncio
async def test_caching_with_batch_multiple_outputs(self):
def test_caching_with_batch_multiple_outputs(self):
def trim_words(words, lens):
trimmed_words = [word[:length] for word, length in zip(words, lens)]
return trimmed_words, lens
Expand All @@ -424,11 +411,10 @@ def trim_words(words, lens):
examples=[["hello", 3], ["hi", 4]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert prediction == ["hel", "3"]

@pytest.mark.asyncio
async def test_caching_with_non_io_component(self):
def test_caching_with_non_io_component(self):
def predict(name):
return name, gr.update(visible=True)

Expand All @@ -445,7 +431,7 @@ def predict(name):
cache_examples=True,
)

prediction = await examples.load_from_cache(0)
prediction = examples.load_from_cache(0)
assert prediction == ["John", {"visible": True, "__type__": "update"}]

def test_end_to_end(self):
Expand Down Expand Up @@ -500,8 +486,7 @@ def concatenate(str1, str2):
assert response.json()["data"] == ["Michael", "Jordan", "Michael Jordan"]


@pytest.mark.asyncio
async def test_multiple_file_flagging(tmp_path):
def test_multiple_file_flagging(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
io = gr.Interface(
fn=lambda *x: list(x),
Expand All @@ -513,14 +498,13 @@ async def test_multiple_file_flagging(tmp_path):
examples=[["test/test_files/cheetah1.jpg", "test/test_files/bus.png"]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)

assert len(prediction[0]) == 2
assert all(isinstance(d, dict) for d in prediction[0])


@pytest.mark.asyncio
async def test_examples_keep_all_suffixes(tmp_path):
def test_examples_keep_all_suffixes(tmp_path):
with patch("gradio.helpers.CACHED_FOLDER", str(tmp_path)):
file_1 = tmp_path / "foo.bar.txt"
file_1.write_text("file 1")
Expand All @@ -535,10 +519,10 @@ async def test_examples_keep_all_suffixes(tmp_path):
examples=[[str(file_1)], [str(file_2)]],
cache_examples=True,
)
prediction = await io.examples_handler.load_from_cache(0)
prediction = io.examples_handler.load_from_cache(0)
assert Path(prediction[0]["name"]).read_text() == "file 1"
assert prediction[0]["orig_name"] == "foo.bar.txt"
prediction = await io.examples_handler.load_from_cache(1)
prediction = io.examples_handler.load_from_cache(1)
assert Path(prediction[0]["name"]).read_text() == "file 2"
assert prediction[0]["orig_name"] == "foo.bar.txt"

Expand Down

0 comments on commit 7b63db2

Please sign in to comment.