Skip to content

Commit

Permalink
Support call method (#5751)
Browse files Browse the repository at this point in the history
* Support call method

* add changeset

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
freddyaboulton and gradio-pr-bot committed Sep 29, 2023
1 parent 199208e commit 95e35cd
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 93 deletions.
6 changes: 6 additions & 0 deletions .changeset/true-bugs-shine.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:Support call method
29 changes: 7 additions & 22 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,14 +843,12 @@ def _get_component_type(self, component_id: int):
component["type"] == "state",
)

def value_is_file(self, component: dict) -> bool:
@staticmethod
def value_is_file(component: dict) -> bool:
# Hacky for now
if "api_info" not in component:
return False
api_info = utils._json_schema_to_python_type(
component["api_info"], component["api_info"].get("$defs")
)
return utils.FILE_DATA in api_info
return utils.value_is_file(component["api_info"])

def __repr__(self):
return f"Endpoint src: {self.client.src}, api_name: {self.api_name}, fn_index: {self.fn_index}"
Expand Down Expand Up @@ -1008,9 +1006,7 @@ def get_file(d):
new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
d = utils.traverse(
d, get_file, lambda s: isinstance(s, str) and Path(s).exists()
)
d = utils.traverse(d, get_file, utils.is_filepath)
new_data.append(d)
return file_list, new_data

Expand All @@ -1033,8 +1029,8 @@ def serialize(self, *data) -> tuple:
data = self._add_uploaded_files_to_data(data, uploaded_files)
data = utils.traverse(
data,
lambda s: {"name": s, "is_file": True},
lambda s: isinstance(s, str) and utils.is_http_url_like(s),
lambda s: {"name": s, "is_file": True, "data": None},
utils.is_url,
)
o = tuple(data)
return o
Expand Down Expand Up @@ -1075,18 +1071,7 @@ def _download_file(
def deserialize(self, *data) -> tuple:
data_ = list(data)

def is_file(d):
return (
isinstance(d, dict)
and "name" in d
and "is_file" in d
and "data" in d
and "size" in d
and "orig_name" in d
and "mime_type" in d
)

data_: list[Any] = utils.traverse(data_, self.download_file, is_file)
data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
return tuple(data_)

def process_predictions(self, *predictions):
Expand Down
27 changes: 26 additions & 1 deletion client/python/gradio_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
for key, value in json_obj.items():
new_obj[key] = traverse(value, func, is_root)
return new_obj
elif isinstance(json_obj, list):
elif isinstance(json_obj, (list, tuple)):
new_obj = []
for item in json_obj:
new_obj.append(traverse(item, func, is_root))
Expand All @@ -626,6 +626,31 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:
return json_obj


def value_is_file(api_info: dict) -> bool:
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
return FILE_DATA in info


def is_filepath(s):
return isinstance(s, str) and Path(s).exists()


def is_url(s):
return isinstance(s, str) and is_http_url_like(s)


def is_file_obj(d):
return (
isinstance(d, dict)
and "name" in d
and "is_file" in d
and "data" in d
and "size" in d
and "orig_name" in d
and "mime_type" in d
)


SKIP_COMPONENTS = {
"state",
"row",
Expand Down
26 changes: 18 additions & 8 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
wasm_utils,
)
from gradio.context import Context
from gradio.data_classes import FileData
from gradio.deprecation import check_deprecated_parameters, warn_deprecation
from gradio.events import EventData, EventListener
from gradio.exceptions import (
Expand Down Expand Up @@ -952,8 +953,8 @@ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
if batch:
outputs = [out[0] for out in outputs]

processed_outputs = self.deserialize_data(fn_index, outputs)
processed_outputs = utils.resolve_singleton(processed_outputs)
outputs = self.deserialize_data(fn_index, outputs)
processed_outputs = utils.resolve_singleton(outputs)

return processed_outputs

Expand Down Expand Up @@ -1049,6 +1050,9 @@ def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
dependency = self.dependencies[fn_index]
processed_input = []

def format_file(s):
return FileData(name=s, is_file=True).model_dump()

for i, input_id in enumerate(dependency["inputs"]):
try:
block = self.blocks[input_id]
Expand All @@ -1059,7 +1063,15 @@ def serialize_data(self, fn_index: int, inputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {input_id} not a valid input component."
serialized_input = block.serialize(inputs[i]) # type: ignore
api_info = block.api_info()
if client_utils.value_is_file(api_info):
serialized_input = client_utils.traverse(
inputs[i],
format_file,
lambda s: client_utils.is_filepath(s) or client_utils.is_url(s),
)
else:
serialized_input = inputs[i]
processed_input.append(serialized_input)

return processed_input
Expand All @@ -1078,11 +1090,9 @@ def deserialize_data(self, fn_index: int, outputs: list[Any]) -> list[Any]:
assert isinstance(
block, components.Component
), f"{block.__class__} Component with id {output_id} not a valid output component."
deserialized = block.deserialize( # type: ignore
outputs[o],
save_dir=block.DEFAULT_TEMP_DIR, # type: ignore
root_url=block.root_url,
hf_token=Context.hf_token,

deserialized = client_utils.traverse(
outputs[o], lambda s: s["name"], client_utils.is_file_obj
)
predictions.append(deserialized)

Expand Down
3 changes: 0 additions & 3 deletions test/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,6 @@ def run(min, num):


class TestCallFunction:
@pytest.mark.xfail
@pytest.mark.asyncio
async def test_call_regular_function(self):
with gr.Blocks() as demo:
Expand All @@ -901,7 +900,6 @@ async def test_call_regular_function(self):
output = await demo.call_function(0, ["Abubakar"])
assert output["prediction"] == "Hello, Abubakar"

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_call_multiple_functions(self):
with gr.Blocks() as demo:
Expand Down Expand Up @@ -1043,7 +1041,6 @@ def trim(words, lens):
demo.queue()
demo.launch(prevent_thread_lock=True)

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_call_regular_function(self):
def batch_fn(x):
Expand Down

0 comments on commit 95e35cd

Please sign in to comment.