Skip to content

Commit

Permalink
Adds an "API Recorder" to the view API page, some internal methods ha…
Browse files Browse the repository at this point in the history
…ve been made async (#7850)

* changes

* code recorder

* api docs

* add changeset

* changes

* change

* add changeset

* add changeset

* changes

* fixes

* format

* changes

* fix

* add changeset

* rename

* api recorder

* docs

* changes for pr

* fix

* async-ify

* add changeset

* fix typing

* fixes

* fix test/test_blocks

* fix more tests

* fix test

* add changeset

* fix more tests

* add async version

* make 2 separate functions

* format

* fix

* format

* fix tests

* move resource

* changes

* fix cache

* lite fixes

* make lazy cache async

* fix

* fix

* flaky

* replace null with None

* fixes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Apr 3, 2024
1 parent efd9524 commit 2bae1cf
Show file tree
Hide file tree
Showing 18 changed files with 543 additions and 84 deletions.
7 changes: 7 additions & 0 deletions .changeset/red-roses-hide.md
@@ -0,0 +1,7 @@
---
"@gradio/app": patch
"gradio": patch
"gradio_client": patch
---

feat:Adds an "API Recorder" to the view API page, some internal methods have been made async
29 changes: 28 additions & 1 deletion client/python/gradio_client/utils.py
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, Optional, TypedDict

import fsspec.asyn
import httpx
Expand Down Expand Up @@ -971,6 +971,9 @@ def get_desc(v):


def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any:
"""
Traverse a JSON object and apply a function to each element that satisfies the is_root condition.
"""
if is_root(json_obj):
return func(json_obj)
elif isinstance(json_obj, dict):
Expand All @@ -987,6 +990,30 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any
return json_obj


async def async_traverse(
json_obj: Any,
func: Callable[..., Coroutine[Any, Any, Any]],
is_root: Callable[..., bool],
) -> Any:
"""
Traverse a JSON object and apply a async function to each element that satisfies the is_root condition.
"""
if is_root(json_obj):
return await func(json_obj)
elif isinstance(json_obj, dict):
new_obj = {}
for key, value in json_obj.items():
new_obj[key] = await async_traverse(value, func, is_root)
return new_obj
elif isinstance(json_obj, (list, tuple)):
new_obj = []
for item in json_obj:
new_obj.append(await async_traverse(item, func, is_root))
return new_obj
else:
return json_obj


def value_is_file(api_info: dict) -> bool:
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
return any(file_data_format in info for file_data_format in FILE_DATA_FORMATS)
Expand Down
1 change: 1 addition & 0 deletions client/python/test/test_client.py
Expand Up @@ -807,6 +807,7 @@ def __call__(self, *args, **kwargs):


class TestAPIInfo:
@pytest.mark.flaky
@pytest.mark.parametrize("trailing_char", ["/", ""])
def test_test_endpoint_src(self, trailing_char):
src = "https://gradio-calculator.hf.space" + trailing_char
Expand Down
73 changes: 59 additions & 14 deletions gradio/blocks.py
Expand Up @@ -247,15 +247,55 @@ def recover_kwargs(
kwargs[parameter.name] = props[parameter.name]
return kwargs

async def async_move_resource_to_block_cache(
self, url_or_file_path: str | Path | None
) -> str | None:
"""Moves a file or downloads a file from a url to a block's cache directory, adds
to to the block's temp_files, and returns the path to the file in cache. This
ensures that the file is accessible to the Block and can be served to users.
This async version of the function is used when this is being called within
a FastAPI route, as this is not blocking.
"""
if url_or_file_path is None:
return None
if isinstance(url_or_file_path, Path):
url_or_file_path = str(url_or_file_path)

if client_utils.is_http_url_like(url_or_file_path):
temp_file_path = await processing_utils.async_save_url_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)

self.temp_files.add(temp_file_path)
else:
url_or_file_path = str(utils.abspath(url_or_file_path))
if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE):
try:
temp_file_path = processing_utils.save_file_to_cache(
url_or_file_path, cache_dir=self.GRADIO_CACHE
)
except FileNotFoundError:
# This can happen if when using gr.load() and the file is on a remote Space
# but the file is not the `value` of the component. For example, if the file
# is the `avatar_image` of the `Chatbot` component. In this case, we skip
# copying the file to the cache and just use the remote file path.
return url_or_file_path
else:
temp_file_path = url_or_file_path
self.temp_files.add(temp_file_path)

return temp_file_path

def move_resource_to_block_cache(
self, url_or_file_path: str | Path | None
) -> str | None:
"""Moves a file or downloads a file from a url to a block's cache directory, adds
to to the block's temp_files, and returns the path to the file in cache. This
ensures that the file is accessible to the Block and can be served to users.
Note: this method is not used in any core Gradio components, but is kept here
for backwards compatibility with custom components created with gradio<=4.20.0.
This sync version of the function is used when this is being called outside of
a FastAPI route, e.g. when examples are being cached.
"""
if url_or_file_path is None:
return None
Expand Down Expand Up @@ -311,7 +351,9 @@ def serve_static_file(
else:
data = {"path": url_or_file_path}
try:
return processing_utils.move_files_to_cache(data, self)
return client_utils.synchronize_async(
processing_utils.async_move_files_to_cache, data, self
)
except AttributeError: # Can be raised if this function is called before the Block is fully initialized.
return data

Expand Down Expand Up @@ -1418,7 +1460,7 @@ def validate_inputs(self, fn_index: int, inputs: list[Any]):
[{received}]"""
)

def preprocess_data(
async def preprocess_data(
self,
fn_index: int,
inputs: list[Any],
Expand Down Expand Up @@ -1449,7 +1491,7 @@ def preprocess_data(
else:
if input_id in state:
block = state[input_id]
inputs_cached = processing_utils.move_files_to_cache(
inputs_cached = await processing_utils.async_move_files_to_cache(
inputs[i],
block,
check_in_upload_folder=not explicit_call,
Expand Down Expand Up @@ -1500,7 +1542,7 @@ def validate_outputs(self, fn_index: int, predictions: Any | list[Any]):
[{received}]"""
)

def postprocess_data(
async def postprocess_data(
self, fn_index: int, predictions: list | dict, state: SessionState | None
):
state = state or SessionState(self)
Expand Down Expand Up @@ -1580,7 +1622,7 @@ def postprocess_data(
block = state[output_id]
prediction_value = block.postprocess(prediction_value)

outputs_cached = processing_utils.move_files_to_cache(
outputs_cached = await processing_utils.async_move_files_to_cache(
prediction_value,
block,
postprocess=True,
Expand All @@ -1589,7 +1631,7 @@ def postprocess_data(

return output

def handle_streaming_outputs(
async def handle_streaming_outputs(
self,
fn_index: int,
data: list,
Expand All @@ -1613,7 +1655,7 @@ def handle_streaming_outputs(
if first_chunk:
stream_run[output_id] = []
self.pending_streams[session_hash][run][output_id].append(binary_data)
output_data = processing_utils.move_files_to_cache(
output_data = await processing_utils.async_move_files_to_cache(
output_data,
block,
postprocess=True,
Expand Down Expand Up @@ -1711,7 +1753,7 @@ async def process_api(
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
)
inputs = [
self.preprocess_data(fn_index, list(i), state, explicit_call)
await self.preprocess_data(fn_index, list(i), state, explicit_call)
for i in zip(*inputs)
]
result = await self.call_function(
Expand All @@ -1725,7 +1767,8 @@ async def process_api(
)
preds = result["prediction"]
data = [
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
await self.postprocess_data(fn_index, list(o), state)
for o in zip(*preds)
]
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
Expand All @@ -1736,7 +1779,9 @@ async def process_api(
if old_iterator:
inputs = []
else:
inputs = self.preprocess_data(fn_index, inputs, state, explicit_call)
inputs = await self.preprocess_data(
fn_index, inputs, state, explicit_call
)
was_generating = old_iterator is not None
result = await self.call_function(
fn_index,
Expand All @@ -1747,13 +1792,13 @@ async def process_api(
event_data,
in_event_listener,
)
data = self.postprocess_data(fn_index, result["prediction"], state)
data = await self.postprocess_data(fn_index, result["prediction"], state)
if root_path is not None:
data = processing_utils.add_root_url(data, root_path, None)
is_generating, iterator = result["is_generating"], result["iterator"]
if is_generating or was_generating:
run = id(old_iterator) if was_generating else id(iterator)
data = self.handle_streaming_outputs(
data = await self.handle_streaming_outputs(
fn_index,
data,
session_hash=session_hash,
Expand Down
12 changes: 6 additions & 6 deletions gradio/helpers.py
Expand Up @@ -305,7 +305,7 @@ async def load_example(example_id):
if not self._defer_caching:
self._start_caching()

def _postprocess_output(self, output) -> list:
async def _postprocess_output(self, output) -> list:
"""
This is a way that we can postprocess the data manually, since we set postprocess=False in the lazy_cache
event handler. The reason we did that is because we don't want to postprocess data if we are loading from
Expand All @@ -318,7 +318,7 @@ def _postprocess_output(self, output) -> list:
[output.render() for output in self.outputs]
demo.load(self.fn, self.inputs, self.outputs)
demo.unrender()
return demo.postprocess_data(0, output, None)
return await demo.postprocess_data(0, output, None)

def _get_cached_index_if_cached(self, example_index) -> int | None:
if Path(self.cached_indices_file).exists():
Expand All @@ -342,7 +342,7 @@ def _start_caching(self):
)
break
if self.cache_examples == "lazy":
self.lazy_cache()
client_utils.synchronize_async(self.lazy_cache)
if self.cache_examples is True:
if wasm_utils.IS_WASM:
# In the Wasm mode, the `threading` module is not supported,
Expand All @@ -356,7 +356,7 @@ def _start_caching(self):
else:
client_utils.synchronize_async(self.cache)

def lazy_cache(self) -> None:
async def lazy_cache(self) -> None:
print(
f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use. ",
end="",
Expand Down Expand Up @@ -393,7 +393,7 @@ async def async_lazy_cache(self, example_index, *input_values):
else:
fn = utils.async_fn_to_generator(self.fn)
async for output in fn(*input_values):
output = self._postprocess_output(output)
output = await self._postprocess_output(output)
yield output[0] if len(self.outputs) == 1 else output
self.cache_logger.flag(output)
with open(self.cached_indices_file, "a") as f:
Expand All @@ -411,7 +411,7 @@ def sync_lazy_cache(self, example_index, *input_values):
else:
fn = utils.sync_fn_to_generator(self.fn)
for output in fn(*input_values):
output = self._postprocess_output(output)
output = client_utils.synchronize_async(self._postprocess_output, output)
yield output[0] if len(self.outputs) == 1 else output
self.cache_logger.flag(output)
with open(self.cached_indices_file, "a") as f:
Expand Down

0 comments on commit 2bae1cf

Please sign in to comment.