Skip to content

Commit

Permalink
Name Endpoints if api_name is None (#5782)
Browse files Browse the repository at this point in the history
* Implementation and test

* add changeset

* fix lint

* Fix nits

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Oct 5, 2023
1 parent 7d29896 commit 370a33a
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 21 deletions.
5 changes: 5 additions & 0 deletions .changeset/tame-chairs-tan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Name Endpoints if api_name is None
25 changes: 24 additions & 1 deletion gradio/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import string
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Literal

Expand Down Expand Up @@ -168,7 +169,7 @@ def event_trigger(
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
inputs: List of gradio.components to use as inputs. If the function takes no inputs, this should be an empty list.
outputs: List of gradio.components to use as outputs. If the function returns no outputs, this should be an empty list.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be exposed in the api docs as an unnamed endpoint, although this behavior will be changed in Gradio 4.0. If set to a string, the endpoint will be exposed in the api docs with the given name.
api_name: Defines how the endpoint appears in the API docs. Can be a string, None, or False. If False, the endpoint will not be exposed in the api docs. If set to None, the endpoint will be given the name of the python function fn. If no fn is passed in, it will be given the name 'unnamed'. If set to a string, the endpoint will be exposed in the api docs with the given name.
status_tracker: Deprecated and has no effect.
scroll_to_output: If True, will scroll to output component on completion
show_progress: If True, will show progress animation while pending
Expand Down Expand Up @@ -224,6 +225,28 @@ def inner(*args, **kwargs):
if isinstance(show_progress, bool):
show_progress = "full" if show_progress else "hidden"

if api_name is None:
if fn is not None:
if not hasattr(fn, "__name__"):
if hasattr(fn, "__class__") and hasattr(
fn.__class__, "__name__"
):
name = fn.__class__.__name__
else:
name = "unnamed"
else:
name = fn.__name__
api_name = "".join(
[
s
for s in name
if s not in set(string.punctuation) - {"-", "_"}
]
)
else:
# Don't document _js only events
api_name = False

dep, dep_index = block.set_event_trigger(
_event_name,
fn,
Expand Down
14 changes: 8 additions & 6 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,12 @@ def cleanup():
inputs=None,
outputs=[submit_btn, stop_btn],
queue=False,
api_name=False,
).then(
self.fn,
self.input_components,
self.output_components,
api_name=self.api_name if i == 0 else None,
api_name=self.api_name if i == 0 else False,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
Expand All @@ -647,6 +648,7 @@ def cleanup():
inputs=None,
outputs=extra_output, # type: ignore
queue=False,
api_name=False,
)

stop_btn.click(
Expand All @@ -655,6 +657,7 @@ def cleanup():
outputs=[submit_btn, stop_btn],
cancels=predict_events,
queue=False,
api_name=False,
)
else:
for i, trigger in enumerate(triggers):
Expand All @@ -663,7 +666,7 @@ def cleanup():
fn,
self.input_components,
self.output_components,
api_name=self.api_name if i == 0 else None,
api_name=self.api_name if i == 0 else False,
scroll_to_output=True,
preprocess=not (self.api_mode),
postprocess=not (self.api_mode),
Expand Down Expand Up @@ -740,19 +743,18 @@ def attach_flagging_events(
None,
flag_btn,
queue=False,
api_name=False,
)
flag_btn.click(
flag_method,
inputs=flag_components,
outputs=flag_btn,
preprocess=False,
queue=False,
api_name=False,
)
clear_btn.click(
flag_method.reset,
None,
flag_btn,
queue=False,
flag_method.reset, None, flag_btn, queue=False, api_name=False
)

def render_examples(self):
Expand Down
28 changes: 14 additions & 14 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def test_raise_error_if_event_queued_but_queue_not_enabled(self):
lambda x: f"Hello, {x}", inputs=input_, outputs=output, queue=True
)

with pytest.raises(ValueError, match="The queue is enabled for event 0"):
with pytest.raises(ValueError, match="The queue is enabled for event lambda"):
demo.launch(prevent_thread_lock=True)

demo.close()
Expand Down Expand Up @@ -463,8 +463,8 @@ def create_images(n_images):
outputs=gallery,
)
with connect(demo) as client:
client.predict(3)
_ = client.predict(3)
client.predict(3, api_name="/predict")
_ = client.predict(3, api_name="/predict")
# only three files created and in temp directory
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 3

Expand All @@ -478,9 +478,9 @@ def test_no_empty_image_files(self, gradio_temp_dir, connect):
outputs=gr.Image(),
)
with connect(demo) as client:
_ = client.predict(image)
_ = client.predict(image)
_ = client.predict(image)
_ = client.predict(image, api_name="/predict")
_ = client.predict(image, api_name="/predict")
_ = client.predict(image, api_name="/predict")
# Upload creates a file. image preprocessing creates another one.
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

Expand All @@ -489,8 +489,8 @@ def test_file_component_uploads(self, component, connect, gradio_temp_dir):
code_file = str(pathlib.Path(__file__))
demo = gr.Interface(lambda x: x.name, component(), gr.File())
with connect(demo) as client:
_ = client.predict(code_file)
_ = client.predict(code_file)
_ = client.predict(code_file, api_name="/predict")
_ = client.predict(code_file, api_name="/predict")
# the upload route hashees the files so we get 1 from there
# We create two tempfiles (empty) because API says we return
# preprocess/postprocess will create the same file as the upload route
Expand All @@ -502,8 +502,8 @@ def test_no_empty_video_files(self, gradio_temp_dir, connect):
video = str(file_dir / "video_sample.mp4")
demo = gr.Interface(lambda x: x, gr.Video(type="file"), gr.Video())
with connect(demo) as client:
_ = client.predict({"video": video})
_ = client.predict({"video": video})
_ = client.predict({"video": video}, api_name="/predict")
_ = client.predict({"video": video}, api_name="/predict")
# Upload route and postprocessing return the same file
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 1

Expand All @@ -517,8 +517,8 @@ def reverse_audio(audio):

demo = gr.Interface(fn=reverse_audio, inputs=gr.Audio(), outputs=gr.Audio())
with connect(demo) as client:
_ = client.predict(audio)
_ = client.predict(audio)
_ = client.predict(audio, api_name="/predict")
_ = client.predict(audio, api_name="/predict")
# One for upload and one for reversal
assert len([f for f in gradio_temp_dir.glob("**/*") if f.is_file()]) == 2

Expand Down Expand Up @@ -1495,12 +1495,12 @@ def test_many_endpoints(self):
t5 = gr.Textbox()
t1.change(lambda x: x, t1, t2, api_name="change1")
t2.change(lambda x: x, t2, t3, api_name="change2")
t3.change(lambda x: x, t3, t4)
t3.change(lambda x: x, t3, t4, api_name=False)
t4.change(lambda x: x, t4, t5, api_name=False)

api_info = demo.get_api_info()
assert len(api_info["named_endpoints"]) == 2
assert len(api_info["unnamed_endpoints"]) == 1
assert len(api_info["unnamed_endpoints"]) == 0

def test_no_endpoints(self):
with gr.Blocks() as demo:
Expand Down
96 changes: 96 additions & 0 deletions test/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Contains tests for networking.py and app.py"""
import functools
import json
import os
import tempfile
Expand All @@ -23,6 +24,7 @@
close_all,
routes,
)
from gradio.route_utils import FnIndexInferError


@pytest.fixture()
Expand Down Expand Up @@ -698,3 +700,97 @@ def test_file_route_does_not_allow_dot_paths(tmp_path):
assert client.get("/file=.env").status_code == 403
assert client.get("/file=subdir/.env").status_code == 403
assert client.get("/file=.versioncontrol/settings").status_code == 403


def test_api_name_set_for_all_events(connect):
with gr.Blocks() as demo:
i = Textbox()
o = Textbox()
btn = Button()
btn1 = Button()
btn2 = Button()
btn3 = Button()
btn4 = Button()
btn5 = Button()
btn6 = Button()
btn7 = Button()
btn8 = Button()

def greet(i):
return "Hello " + i

def goodbye(i):
return "Goodbye " + i

def greet_me(i):
return "Hello"

def say_goodbye(i):
return "Goodbye"

say_goodbye.__name__ = "Say_$$_goodbye"

# Otherwise changed by ruff
foo = lambda s: s # noqa

def foo2(s):
return s + " foo"

foo2.__name__ = "foo-2"

class Callable:
def __call__(self, a) -> str:
return "From __call__"

def from_partial(a, b):
return b + a

part = functools.partial(from_partial, b="From partial: ")

btn.click(greet, i, o)
btn1.click(goodbye, i, o)
btn2.click(greet_me, i, o)
btn3.click(say_goodbye, i, o)
btn4.click(None, i, o)
btn5.click(foo, i, o)
btn6.click(foo2, i, o)
btn7.click(Callable(), i, o)
btn8.click(part, i, o)

with closing(demo) as io:
app, _, _ = io.launch(prevent_thread_lock=True)
client = TestClient(app)
assert client.post(
"/api/greet", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Hello freddy"]
assert client.post(
"/api/goodbye", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Goodbye freddy"]
assert client.post(
"/api/greet_me", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Hello"]
assert client.post(
"/api/Say__goodbye", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["Goodbye"]
assert client.post(
"/api/lambda", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["freddy"]
assert client.post(
"/api/foo-2", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["freddy foo"]
assert client.post(
"/api/Callable", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["From __call__"]
assert client.post(
"/api/partial", json={"data": ["freddy"], "session_hash": "foo"}
).json()["data"] == ["From partial: freddy"]
with pytest.raises(FnIndexInferError):
client.post(
"/api/Say_goodbye", json={"data": ["freddy"], "session_hash": "foo"}
)

with connect(demo) as client:
assert client.predict("freddy", api_name="/greet") == "Hello freddy"
assert client.predict("freddy", api_name="/goodbye") == "Goodbye freddy"
assert client.predict("freddy", api_name="/greet_me") == "Hello"
assert client.predict("freddy", api_name="/Say__goodbye") == "Goodbye"

0 comments on commit 370a33a

Please sign in to comment.