Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for kwargs and default arguments in the python client, and improves how parameter information is displayed in the "view API" page #7732

Merged
merged 46 commits into from Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b7e02df
changes
abidlabs Mar 13, 2024
56b5d78
changes
abidlabs Mar 14, 2024
fe9f6c4
add changeset
gradio-pr-bot Mar 18, 2024
f834f62
improvements to api docs ui
abidlabs Mar 18, 2024
e0c8bc6
Merge branch 'api-params' of github.com:gradio-app/gradio into api-pa…
abidlabs Mar 18, 2024
b84f7b2
add changeset
gradio-pr-bot Mar 18, 2024
fe49d7f
ux design work
abidlabs Mar 18, 2024
c66ad14
Merge branch 'api-params' of github.com:gradio-app/gradio into api-pa…
abidlabs Mar 18, 2024
e3480c3
further styling
abidlabs Mar 18, 2024
4c8162e
feedback
abidlabs Mar 18, 2024
c17c8c7
add changeset
gradio-pr-bot Mar 18, 2024
27963d1
Merge branch 'main' into api-params
abidlabs Mar 18, 2024
337531d
get parameter name
abidlabs Mar 18, 2024
fc08155
simplify
abidlabs Mar 18, 2024
bf2cbc0
fix code snippet
abidlabs Mar 18, 2024
769efc7
construct args
abidlabs Mar 19, 2024
c8fa79e
add changeset
gradio-pr-bot Mar 19, 2024
98f0a5e
construct_args
abidlabs Mar 19, 2024
d570799
Merge branch 'api-params' of github.com:gradio-app/gradio into api-pa…
abidlabs Mar 19, 2024
2dd80d6
utils
abidlabs Mar 19, 2024
7645574
changes
abidlabs Mar 19, 2024
a3712c9
add catches
abidlabs Mar 19, 2024
8f4030d
fixes
abidlabs Mar 19, 2024
461392a
valid
abidlabs Mar 19, 2024
5dfea67
fix tests
abidlabs Mar 19, 2024
c22472f
js lint
abidlabs Mar 19, 2024
12db580
add tests
abidlabs Mar 19, 2024
6cb9972
add changeset
gradio-pr-bot Mar 19, 2024
36ae1da
format
abidlabs Mar 19, 2024
f15f9f4
Merge branch 'api-params' of github.com:gradio-app/gradio into api-pa…
abidlabs Mar 19, 2024
edb0947
client
abidlabs Mar 19, 2024
1c7c8d9
doc
abidlabs Mar 21, 2024
f6d4758
fixes
abidlabs Mar 21, 2024
e54303e
changes
abidlabs Mar 21, 2024
e23908c
Merge branch 'main' into api-params
abidlabs Mar 21, 2024
5d1a49e
api docs fixes
abidlabs Mar 21, 2024
fb1b3b7
Merge branch 'api-params' of github.com:gradio-app/gradio into api-pa…
abidlabs Mar 21, 2024
d81bf76
fix .view_api()
abidlabs Mar 21, 2024
50d11d3
Merge branch 'main' into api-params
abidlabs Mar 21, 2024
ffe1da8
updating guide wip
abidlabs Mar 21, 2024
17102fb
guide
abidlabs Mar 21, 2024
45f2830
updated guides'
abidlabs Mar 21, 2024
95fb015
fix
abidlabs Mar 21, 2024
767f86f
guide fixes
abidlabs Mar 21, 2024
2be1a7d
filepath
abidlabs Mar 21, 2024
2c59e56
address review
abidlabs Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/tiny-bars-greet.md
@@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/atoms": minor
"gradio": minor
"gradio_client": minor
---

feat:Adds support for kwargs and default arguments in the python client, and improves how parameter information is displayed in the "view API" page
45 changes: 37 additions & 8 deletions client/python/gradio_client/client.py
Expand Up @@ -33,6 +33,7 @@

from gradio_client import utils
from gradio_client.compatibility import EndpointV3Compatibility
from gradio_client.data_classes import ParameterInfo
from gradio_client.documentation import document
from gradio_client.exceptions import AuthenticationError
from gradio_client.utils import (
Expand Down Expand Up @@ -400,6 +401,7 @@ def predict(
*args,
api_name: str | None = None,
fn_index: int | None = None,
**kwargs,
) -> Any:
"""
Calls the Gradio API and returns the result (this is a blocking call).
Expand All @@ -421,7 +423,9 @@ def predict(
raise ValueError(
"Cannot call predict on this function as it may run forever. Use submit instead."
)
return self.submit(*args, api_name=api_name, fn_index=fn_index).result()
return self.submit(
*args, api_name=api_name, fn_index=fn_index, **kwargs
).result()

def new_helper(self, fn_index: int) -> Communicator:
return Communicator(
Expand All @@ -437,6 +441,7 @@ def submit(
api_name: str | None = None,
fn_index: int | None = None,
result_callbacks: Callable | list[Callable] | None = None,
**kwargs,
) -> Job:
"""
Creates and returns a Job object which calls the Gradio API in a background thread. The job can be used to retrieve the status and result of the remote API call.
Expand All @@ -458,9 +463,13 @@ def submit(
>> 9.0
"""
inferred_fn_index = self._infer_fn_index(api_name, fn_index)
endpoint = self.endpoints[inferred_fn_index]

if isinstance(endpoint, Endpoint):
args = utils.construct_args(endpoint.parameters_info, args, kwargs)

helper = None
if self.endpoints[inferred_fn_index].protocol in (
if endpoint.protocol in (
"ws",
"sse",
"sse_v1",
Expand All @@ -469,7 +478,7 @@ def submit(
"sse_v3",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
end_to_end_fn = endpoint.make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)

job = Job(
Expand Down Expand Up @@ -637,9 +646,12 @@ def reset_session(self) -> None:
def _render_endpoints_info(
self,
name_or_index: str | int,
endpoints_info: dict[str, list[dict[str, Any]]],
endpoints_info: dict[str, list[ParameterInfo]],
) -> str:
parameter_names = [p["label"] for p in endpoints_info["parameters"]]
parameter_info = endpoints_info["parameters"]
parameter_names = [
p.get("parameter_name") or p["label"] for p in parameter_info
]
parameter_names = [utils.sanitize_parameter_names(p) for p in parameter_names]
rendered_parameters = ", ".join(parameter_names)
if rendered_parameters:
Expand All @@ -659,15 +671,23 @@ def _render_endpoints_info(

human_info = f"\n - predict({rendered_parameters}{final_param}) -> {rendered_return_values}\n"
human_info += " Parameters:\n"
if endpoints_info["parameters"]:
for info in endpoints_info["parameters"]:
if parameter_info:
for info in parameter_info:
desc = (
f" ({info['python_type']['description']})"
if info["python_type"].get("description")
else ""
)
default_value = info.get("parameter_default")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to traverse the value here to only pull out the actual filepath or url

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the view API page takes care of this automatically:

image

but the .view_api() function does not

image
import gradio as gr

demo = gr.Interface(
    lambda x:x,
    gr.Audio("test.mp3"),
    "audio"
)

_, url, _ = demo.launch()

from gradio_client import Client

client = Client(url)

client.view_api()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix in the .view_api()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea sorry that's what I meant!

default_info = (
"(required)"
if not info.get("parameter_has_default", False)
else f"(not required, defaults to {default_value})"
)
type_ = info["python_type"]["type"]
human_info += f" - [{info['component']}] {utils.sanitize_parameter_names(info['label'])}: {type_}{desc} \n"
if info.get("parameter_has_default", False) and default_value is None:
type_ += " | None"
human_info += f" - [{info['component']}] {utils.sanitize_parameter_names(info.get('parameter_name') or info['label'])}: {type_} {default_info} {desc} \n"
else:
human_info += " - None\n"
human_info += " Returns:\n"
Expand Down Expand Up @@ -982,6 +1002,8 @@ def __init__(
self.output_component_types = [
self._get_component_type(id_) for id_ in dependency["outputs"]
]
self.parameters_info = self._get_parameters_info()

self.root_url = client.src + "/" if not client.src.endswith("/") else client.src
self.is_continuous = dependency.get("types", {}).get("continuous", False)

Expand All @@ -1001,6 +1023,13 @@ def _get_component_type(self, component_id: int):
component["type"] == "state",
)

def _get_parameters_info(self) -> list[ParameterInfo] | None:
if not self.client._info:
self._info = self.client._get_api_info()
if self.api_name in self._info["named_endpoints"]:
return self._info["named_endpoints"][self.api_name]["parameters"]
return None

@staticmethod
def value_is_file(component: dict) -> bool:
# This is still hacky as it does not tell us which part of the payload is a file.
Expand Down
13 changes: 12 additions & 1 deletion client/python/gradio_client/data_classes.py
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TypedDict
from typing import Any, TypedDict

from typing_extensions import NotRequired

Expand All @@ -15,3 +15,14 @@ class FileData(TypedDict):
orig_name: NotRequired[str] # original filename
mime_type: NotRequired[str]
is_stream: NotRequired[bool]


class ParameterInfo(TypedDict):
label: str
parameter_name: NotRequired[str]
parameter_has_default: NotRequired[bool]
parameter_default: NotRequired[Any]
type: dict
python_type: dict
component: str
example_input: Any
56 changes: 55 additions & 1 deletion client/python/gradio_client/utils.py
Expand Up @@ -17,14 +17,17 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Literal, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict

import fsspec.asyn
import httpx
import huggingface_hub
from huggingface_hub import SpaceStage
from websockets.legacy.protocol import WebSocketCommonProtocol

if TYPE_CHECKING:
from gradio_client.data_classes import ParameterInfo

API_URL = "api/predict/"
SSE_URL_V0 = "queue/join"
SSE_DATA_URL_V0 = "queue/data"
Expand Down Expand Up @@ -1063,3 +1066,54 @@ def file(filepath_or_url: str | Path):
raise ValueError(
f"File {s} does not exist on local filesystem and is not a valid URL."
)


def construct_args(
parameters_info: list[ParameterInfo] | None, args: tuple, kwargs: dict
) -> list:
class _Keywords(Enum):
NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a parameter for an argument

_args = list(args)
if parameters_info is None:
if kwargs:
raise ValueError(
"This endpoint does not support key-word arguments Please click on 'view API' in the footer of the Gradio app to see usage."
)
return _args
num_args = len(args)
_args = _args + [_Keywords.NO_VALUE] * (len(parameters_info) - num_args)

kwarg_arg_mapping = {}
kwarg_names = []
for index, param_info in enumerate(parameters_info):
if "parameter_name" in param_info:
kwarg_arg_mapping[param_info["parameter_name"]] = index
kwarg_names.append(param_info["parameter_name"])
else:
kwarg_names.append("argument {index}")
if (
param_info.get("parameter_has_default", False)
and _args[index] == _Keywords.NO_VALUE
):
_args[index] = param_info.get("parameter_default")

for key, value in kwargs.items():
if key in kwarg_arg_mapping:
if kwarg_arg_mapping[key] < num_args:
raise ValueError(
f"Parameter `{key}` is already set as a positional argument. Please click on 'view API' in the footer of the Gradio app to see usage."
)
else:
_args[kwarg_arg_mapping[key]] = value
else:
raise ValueError(
f"Parameter `{key}` is not a valid key-word argument. Please click on 'view API' in the footer of the Gradio app to see usage."
)

if _Keywords.NO_VALUE in _args:
raise ValueError(
f"No value provided for required argument: {kwarg_names[_args.index(_Keywords.NO_VALUE)]}"
)

return _args
32 changes: 32 additions & 0 deletions client/python/test/conftest.py
Expand Up @@ -41,6 +41,38 @@ def calculator(num1, operation, num2):
return demo


@pytest.fixture
def calculator_demo_with_defaults():
def calculator(num1, operation=None, num2=100):
if operation is None or operation == "add":
return num1 + num2
elif operation == "subtract":
return num1 - num2
elif operation == "multiply":
return num1 * num2
elif operation == "divide":
if num2 == 0:
raise gr.Error("Cannot divide by zero!")
return num1 / num2

demo = gr.Interface(
calculator,
[
gr.Number(value=10),
gr.Radio(["add", "subtract", "multiply", "divide"]),
gr.Number(),
],
"number",
examples=[
[5, "add", 3],
[4, "divide", 2],
[-4, "multiply", 2.5],
[0, "subtract", 1.2],
],
)
return demo


@pytest.fixture
def state_demo():
demo = gr.Interface(
Expand Down
57 changes: 45 additions & 12 deletions client/python/test/test_client.py
Expand Up @@ -601,6 +601,33 @@ def return_bad():
assert pred[0] == data[0]


class TestClientPredictionsWithKwargs:
def test_no_default_params(self, calculator_demo):
with connect(calculator_demo) as client:
result = client.predict(
num1=3, operation="add", num2=3, api_name="/predict"
)
assert result == 6

result = client.predict(33, operation="add", num2=3, api_name="/predict")
assert result == 36

def test_default_params(self, calculator_demo_with_defaults):
with connect(calculator_demo_with_defaults) as client:
result = client.predict(num2=10, api_name="/predict")
assert result == 20

result = client.predict(num2=33, operation="multiply", api_name="/predict")
assert result == 330

def test_missing_params(self, calculator_demo):
with connect(calculator_demo) as client:
with pytest.raises(
ValueError, match="No value provided for required argument: num2"
):
client.predict(num1=3, operation="add", api_name="/predict")


class TestStatusUpdates:
@patch("gradio_client.client.Endpoint.make_end_to_end_fn")
def test_messages_passed_correctly(self, mock_make_end_to_end_fn, calculator_demo):
Expand Down Expand Up @@ -952,13 +979,19 @@ def test_api_info_of_local_demo(self, calculator_demo):
"parameters": [
{
"label": "num1",
"parameter_name": "num1",
"parameter_has_default": False,
"parameter_default": None,
"type": {"type": "number"},
"python_type": {"type": "float", "description": ""},
"component": "Number",
"example_input": 3,
},
{
"label": "operation",
"parameter_name": "operation",
"parameter_has_default": False,
"parameter_default": None,
"type": {
"enum": ["add", "subtract", "multiply", "divide"],
"title": "Radio",
Expand All @@ -973,6 +1006,9 @@ def test_api_info_of_local_demo(self, calculator_demo):
},
{
"label": "num2",
"parameter_name": "num2",
"parameter_has_default": False,
"parameter_default": None,
"type": {"type": "number"},
"python_type": {"type": "float", "description": ""},
"component": "Number",
Expand Down Expand Up @@ -1046,6 +1082,9 @@ def test_layout_components_in_output(self, hello_world_with_group):
"parameters": [
{
"label": "name",
"parameter_name": "name",
"parameter_has_default": True,
"parameter_default": "",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
Expand Down Expand Up @@ -1077,6 +1116,9 @@ def test_layout_and_state_components_in_output(
"parameters": [
{
"label": "name",
"parameter_name": "name",
"parameter_has_default": True,
"parameter_default": "",
"type": {"type": "string"},
"python_type": {"type": "str", "description": ""},
"component": "Textbox",
Expand All @@ -1093,10 +1135,7 @@ def test_layout_and_state_components_in_output(
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"python_type": {"type": "float", "description": ""},
"component": "Number",
},
],
Expand All @@ -1107,10 +1146,7 @@ def test_layout_and_state_components_in_output(
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"python_type": {"type": "float", "description": ""},
"component": "Number",
}
],
Expand All @@ -1121,10 +1157,7 @@ def test_layout_and_state_components_in_output(
{
"label": "count",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"python_type": {"type": "float", "description": ""},
"component": "Number",
}
],
Expand Down