diff --git a/.changeset/red-roses-hide.md b/.changeset/red-roses-hide.md new file mode 100644 index 000000000000..2b2a5617dfb2 --- /dev/null +++ b/.changeset/red-roses-hide.md @@ -0,0 +1,7 @@ +--- +"@gradio/app": patch +"gradio": patch +"gradio_client": patch +--- + +feat:Adds an "API Recorder" to the view API page, some internal methods have been made async diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 9edd86af65e8..83eb0e9f2031 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -17,7 +17,7 @@ from enum import Enum from pathlib import Path from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, Optional, TypedDict import fsspec.asyn import httpx @@ -971,6 +971,9 @@ def get_desc(v): def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any: + """ + Traverse a JSON object and apply a function to each element that satisfies the is_root condition. + """ if is_root(json_obj): return func(json_obj) elif isinstance(json_obj, dict): @@ -987,6 +990,30 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any return json_obj +async def async_traverse( + json_obj: Any, + func: Callable[..., Coroutine[Any, Any, Any]], + is_root: Callable[..., bool], +) -> Any: + """ + Traverse a JSON object and apply a async function to each element that satisfies the is_root condition. + """ + if is_root(json_obj): + return await func(json_obj) + elif isinstance(json_obj, dict): + new_obj = {} + for key, value in json_obj.items(): + new_obj[key] = await async_traverse(value, func, is_root) + return new_obj + elif isinstance(json_obj, (list, tuple)): + new_obj = [] + for item in json_obj: + new_obj.append(await async_traverse(item, func, is_root)) + return new_obj + else: + return json_obj + + def value_is_file(api_info: dict) -> bool: info = _json_schema_to_python_type(api_info, api_info.get("$defs")) return any(file_data_format in info for file_data_format in FILE_DATA_FORMATS) diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index c83bdb1445f1..65c49e71c327 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -807,6 +807,7 @@ def __call__(self, *args, **kwargs): class TestAPIInfo: + @pytest.mark.flaky @pytest.mark.parametrize("trailing_char", ["/", ""]) def test_test_endpoint_src(self, trailing_char): src = "https://gradio-calculator.hf.space" + trailing_char diff --git a/gradio/blocks.py b/gradio/blocks.py index 25b1398c1636..c470bfd8636e 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -247,6 +247,46 @@ def recover_kwargs( kwargs[parameter.name] = props[parameter.name] return kwargs + async def async_move_resource_to_block_cache( + self, url_or_file_path: str | Path | None + ) -> str | None: + """Moves a file or downloads a file from a url to a block's cache directory, adds + to to the block's temp_files, and returns the path to the file in cache. This + ensures that the file is accessible to the Block and can be served to users. + + This async version of the function is used when this is being called within + a FastAPI route, as this is not blocking. + """ + if url_or_file_path is None: + return None + if isinstance(url_or_file_path, Path): + url_or_file_path = str(url_or_file_path) + + if client_utils.is_http_url_like(url_or_file_path): + temp_file_path = await processing_utils.async_save_url_to_cache( + url_or_file_path, cache_dir=self.GRADIO_CACHE + ) + + self.temp_files.add(temp_file_path) + else: + url_or_file_path = str(utils.abspath(url_or_file_path)) + if not utils.is_in_or_equal(url_or_file_path, self.GRADIO_CACHE): + try: + temp_file_path = processing_utils.save_file_to_cache( + url_or_file_path, cache_dir=self.GRADIO_CACHE + ) + except FileNotFoundError: + # This can happen if when using gr.load() and the file is on a remote Space + # but the file is not the `value` of the component. For example, if the file + # is the `avatar_image` of the `Chatbot` component. In this case, we skip + # copying the file to the cache and just use the remote file path. + return url_or_file_path + else: + temp_file_path = url_or_file_path + self.temp_files.add(temp_file_path) + + return temp_file_path + def move_resource_to_block_cache( self, url_or_file_path: str | Path | None ) -> str | None: @@ -254,8 +294,8 @@ def move_resource_to_block_cache( to to the block's temp_files, and returns the path to the file in cache. This ensures that the file is accessible to the Block and can be served to users. - Note: this method is not used in any core Gradio components, but is kept here - for backwards compatibility with custom components created with gradio<=4.20.0. + This sync version of the function is used when this is being called outside of + a FastAPI route, e.g. when examples are being cached. """ if url_or_file_path is None: return None @@ -311,7 +351,9 @@ def serve_static_file( else: data = {"path": url_or_file_path} try: - return processing_utils.move_files_to_cache(data, self) + return client_utils.synchronize_async( + processing_utils.async_move_files_to_cache, data, self + ) except AttributeError: # Can be raised if this function is called before the Block is fully initialized. return data @@ -1418,7 +1460,7 @@ def validate_inputs(self, fn_index: int, inputs: list[Any]): [{received}]""" ) - def preprocess_data( + async def preprocess_data( self, fn_index: int, inputs: list[Any], @@ -1449,7 +1491,7 @@ def preprocess_data( else: if input_id in state: block = state[input_id] - inputs_cached = processing_utils.move_files_to_cache( + inputs_cached = await processing_utils.async_move_files_to_cache( inputs[i], block, check_in_upload_folder=not explicit_call, @@ -1500,7 +1542,7 @@ def validate_outputs(self, fn_index: int, predictions: Any | list[Any]): [{received}]""" ) - def postprocess_data( + async def postprocess_data( self, fn_index: int, predictions: list | dict, state: SessionState | None ): state = state or SessionState(self) @@ -1580,7 +1622,7 @@ def postprocess_data( block = state[output_id] prediction_value = block.postprocess(prediction_value) - outputs_cached = processing_utils.move_files_to_cache( + outputs_cached = await processing_utils.async_move_files_to_cache( prediction_value, block, postprocess=True, @@ -1589,7 +1631,7 @@ def postprocess_data( return output - def handle_streaming_outputs( + async def handle_streaming_outputs( self, fn_index: int, data: list, @@ -1613,7 +1655,7 @@ def handle_streaming_outputs( if first_chunk: stream_run[output_id] = [] self.pending_streams[session_hash][run][output_id].append(binary_data) - output_data = processing_utils.move_files_to_cache( + output_data = await processing_utils.async_move_files_to_cache( output_data, block, postprocess=True, @@ -1711,7 +1753,7 @@ async def process_api( f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})" ) inputs = [ - self.preprocess_data(fn_index, list(i), state, explicit_call) + await self.preprocess_data(fn_index, list(i), state, explicit_call) for i in zip(*inputs) ] result = await self.call_function( @@ -1725,7 +1767,8 @@ async def process_api( ) preds = result["prediction"] data = [ - self.postprocess_data(fn_index, list(o), state) for o in zip(*preds) + await self.postprocess_data(fn_index, list(o), state) + for o in zip(*preds) ] if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) @@ -1736,7 +1779,9 @@ async def process_api( if old_iterator: inputs = [] else: - inputs = self.preprocess_data(fn_index, inputs, state, explicit_call) + inputs = await self.preprocess_data( + fn_index, inputs, state, explicit_call + ) was_generating = old_iterator is not None result = await self.call_function( fn_index, @@ -1747,13 +1792,13 @@ async def process_api( event_data, in_event_listener, ) - data = self.postprocess_data(fn_index, result["prediction"], state) + data = await self.postprocess_data(fn_index, result["prediction"], state) if root_path is not None: data = processing_utils.add_root_url(data, root_path, None) is_generating, iterator = result["is_generating"], result["iterator"] if is_generating or was_generating: run = id(old_iterator) if was_generating else id(iterator) - data = self.handle_streaming_outputs( + data = await self.handle_streaming_outputs( fn_index, data, session_hash=session_hash, diff --git a/gradio/helpers.py b/gradio/helpers.py index e4baf807dea1..2bc97be2b741 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -305,7 +305,7 @@ async def load_example(example_id): if not self._defer_caching: self._start_caching() - def _postprocess_output(self, output) -> list: + async def _postprocess_output(self, output) -> list: """ This is a way that we can postprocess the data manually, since we set postprocess=False in the lazy_cache event handler. The reason we did that is because we don't want to postprocess data if we are loading from @@ -318,7 +318,7 @@ def _postprocess_output(self, output) -> list: [output.render() for output in self.outputs] demo.load(self.fn, self.inputs, self.outputs) demo.unrender() - return demo.postprocess_data(0, output, None) + return await demo.postprocess_data(0, output, None) def _get_cached_index_if_cached(self, example_index) -> int | None: if Path(self.cached_indices_file).exists(): @@ -342,7 +342,7 @@ def _start_caching(self): ) break if self.cache_examples == "lazy": - self.lazy_cache() + client_utils.synchronize_async(self.lazy_cache) if self.cache_examples is True: if wasm_utils.IS_WASM: # In the Wasm mode, the `threading` module is not supported, @@ -356,7 +356,7 @@ def _start_caching(self): else: client_utils.synchronize_async(self.cache) - def lazy_cache(self) -> None: + async def lazy_cache(self) -> None: print( f"Will cache examples in '{utils.abspath(self.cached_folder)}' directory at first use. ", end="", @@ -393,7 +393,7 @@ async def async_lazy_cache(self, example_index, *input_values): else: fn = utils.async_fn_to_generator(self.fn) async for output in fn(*input_values): - output = self._postprocess_output(output) + output = await self._postprocess_output(output) yield output[0] if len(self.outputs) == 1 else output self.cache_logger.flag(output) with open(self.cached_indices_file, "a") as f: @@ -411,7 +411,7 @@ def sync_lazy_cache(self, example_index, *input_values): else: fn = utils.sync_fn_to_generator(self.fn) for output in fn(*input_values): - output = self._postprocess_output(output) + output = client_utils.synchronize_async(self._postprocess_output, output) yield output[0] if len(self.outputs) == 1 else output self.cache_logger.flag(output) with open(self.cached_indices_file, "a") as f: diff --git a/gradio/processing_utils.py b/gradio/processing_utils.py index 833759325738..1cf0e90adfac 100644 --- a/gradio/processing_utils.py +++ b/gradio/processing_utils.py @@ -13,6 +13,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +import aiofiles import httpx import numpy as np from gradio_client import utils as client_utils @@ -26,6 +27,8 @@ warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed from pydub import AudioSegment +async_client = httpx.AsyncClient() + log = logging.getLogger(__name__) if TYPE_CHECKING: @@ -214,6 +217,24 @@ def save_url_to_cache(url: str, cache_dir: str) -> str: return full_temp_file_path +async def async_save_url_to_cache(url: str, cache_dir: str) -> str: + """Downloads a file and makes a temporary file path for a copy if does not already + exist. Otherwise returns the path to the existing temp file. Uses async httpx.""" + temp_dir = hash_url(url) + temp_dir = Path(cache_dir) / temp_dir + temp_dir.mkdir(exist_ok=True, parents=True) + name = client_utils.strip_invalid_filename_characters(Path(url).name) + full_temp_file_path = str(abspath(temp_dir / name)) + + if not Path(full_temp_file_path).exists(): + async with async_client.stream("GET", url, follow_redirects=True) as response: + async with aiofiles.open(full_temp_file_path, "wb") as f: + async for chunk in response.aiter_raw(): + await f.write(chunk) + + return full_temp_file_path + + def save_base64_to_cache( base64_encoding: str, cache_dir: str, file_name: str | None = None ) -> str: @@ -319,6 +340,78 @@ def _move_to_cache(d: dict): return client_utils.traverse(data, _move_to_cache, client_utils.is_file_obj) +async def async_move_files_to_cache( + data: Any, + block: Block, + postprocess: bool = False, + check_in_upload_folder=False, + keep_in_cache=False, +) -> dict: + """Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file. + Also handles the case where the file is on an external Gradio app (/proxy=...). + + Runs after .postprocess() and before .preprocess(). + + Args: + data: The input or output data for a component. Can be a dictionary or a dataclass + block: The component whose data is being processed + postprocess: Whether its running from postprocessing + check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not). + keep_in_cache: If True, the file will not be deleted from cache when the server is shut down. + """ + + async def _move_to_cache(d: dict): + payload = FileData(**d) + # If the gradio app developer is returning a URL from + # postprocess, it means the component can display a URL + # without it being served from the gradio server + # This makes it so that the URL is not downloaded and speeds up event processing + if payload.url and postprocess and client_utils.is_http_url_like(payload.url): + payload.path = payload.url + elif utils.is_static_file(payload): + pass + elif not block.proxy_url: + # If the file is on a remote server, do not move it to cache. + if check_in_upload_folder and not client_utils.is_http_url_like( + payload.path + ): + path = os.path.abspath(payload.path) + if not is_in_or_equal(path, get_upload_folder()): + raise ValueError( + f"File {path} is not in the upload folder and cannot be accessed." + ) + if not payload.is_stream: + temp_file_path = await block.async_move_resource_to_block_cache( + payload.path + ) + if temp_file_path is None: + raise ValueError("Did not determine a file path for the resource.") + payload.path = temp_file_path + if keep_in_cache: + block.keep_in_cache.add(payload.path) + + url_prefix = "/stream/" if payload.is_stream else "/file=" + if block.proxy_url: + proxy_url = block.proxy_url.rstrip("/") + url = f"/proxy={proxy_url}{url_prefix}{payload.path}" + elif client_utils.is_http_url_like(payload.path) or payload.path.startswith( + f"{url_prefix}" + ): + url = payload.path + else: + url = f"{url_prefix}{payload.path}" + payload.url = url + + return payload.model_dump() + + if isinstance(data, (GradioRootModel, GradioModel)): + data = data.model_dump() + + return await client_utils.async_traverse( + data, _move_to_cache, client_utils.is_file_obj + ) + + def add_root_url(data: dict | list, root_url: str, previous_root_url: str | None): def _add_root_url(file_dict: dict): if previous_root_url and file_dict["url"].startswith(previous_root_url): diff --git a/guides/08_gradio-clients-and-lite/01_getting-started-with-the-python-client.md b/guides/08_gradio-clients-and-lite/01_getting-started-with-the-python-client.md index f568b09fe0cc..1790e290a266 100644 --- a/guides/08_gradio-clients-and-lite/01_getting-started-with-the-python-client.md +++ b/guides/08_gradio-clients-and-lite/01_getting-started-with-the-python-client.md @@ -104,14 +104,17 @@ Named API endpoints: 1 - [Textbox] output: str ``` +We see that we have 1 API endpoint in this space, and shows us how to use the API endpoint to make a prediction: we should call the `.predict()` method (which we will explore below), providing a parameter `input_audio` of type `str`, which is a `filepath or URL`. -Alternatively, you can click on the "Use via API" link in the footer of the Gradio app, which shows us the same information, along with example usage. +We should also provide the `api_name='/predict'` argument to the `predict()` method. Although this isn't necessary if a Gradio app has only 1 named endpoint, it does allow us to call different endpoints in a single app if they are available. -![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guides/view-api.png) +## The "View API" Page -We see that we have 1 API endpoint in this space, and shows us how to use the API endpoint to make a prediction: we should call the `.predict()` method (which we will explore below), providing a parameter `input_audio` of type `str`, which is a `filepath or URL`. +As an alternative to running the `.view_api()` method, you can click on the "Use via API" link in the footer of the Gradio app, which shows us the same information, along with example usage. -We should also provide the `api_name='/predict'` argument to the `predict()` method. Although this isn't necessary if a Gradio app has only 1 named endpoint, it does allow us to call different endpoints in a single app if they are available. +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guides/view-api.png) + +The View API page also includes an "API Recorder" that lets you interact with the Gradio UI normally and converts your interactions into the corresponding code to run with the Python Client. ## Making a prediction diff --git a/guides/08_gradio-clients-and-lite/02_getting-started-with-the-js-client.md b/guides/08_gradio-clients-and-lite/02_getting-started-with-the-js-client.md index fbd6b52d90a1..b9c2cb14f656 100644 --- a/guides/08_gradio-clients-and-lite/02_getting-started-with-the-js-client.md +++ b/guides/08_gradio-clients-and-lite/02_getting-started-with-the-js-client.md @@ -147,6 +147,15 @@ This shows us that we have 1 API endpoint in this space, and shows us how to use We should also provide the `api_name='/predict'` argument to the `predict()` method. Although this isn't necessary if a Gradio app has only 1 named endpoint, it does allow us to call different endpoints in a single app if they are available. If an app has unnamed API endpoints, these can also be displayed by running `.view_api(all_endpoints=True)`. +## The "View API" Page + +As an alternative to running the `.view_api()` method, you can click on the "Use via API" link in the footer of the Gradio app, which shows us the same information, along with example usage. + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-guides/view-api.png) + +The View API page also includes an "API Recorder" that lets you interact with the Gradio UI normally and converts your interactions into the corresponding code to run with the JS Client. + + ## Making a prediction The simplest way to make a prediction is simply to call the `.predict()` method with the appropriate arguments: diff --git a/js/app/src/Blocks.svelte b/js/app/src/Blocks.svelte index bd9400a55b1e..d32cd218f386 100644 --- a/js/app/src/Blocks.svelte +++ b/js/app/src/Blocks.svelte @@ -8,7 +8,7 @@ import type { ComponentMeta, Dependency, LayoutNode } from "./types"; import type { UpdateTransaction } from "./init"; import { setupi18n } from "./i18n"; - import { ApiDocs } from "./api_docs/"; + import { ApiDocs, ApiRecorder } from "./api_docs/"; import type { ThemeMode, Payload } from "./types"; import { Toast } from "@gradio/statustracker"; import type { ToastMessage } from "@gradio/statustracker"; @@ -68,7 +68,9 @@ let params = new URLSearchParams(window.location.search); let api_docs_visible = params.get("view") === "api" && show_api; + let api_recorder_visible = params.get("view") === "api-recorder" && show_api; function set_api_docs_visible(visible: boolean): void { + api_recorder_visible = false; api_docs_visible = visible; let params = new URLSearchParams(window.location.search); if (visible) { @@ -78,6 +80,7 @@ } history.replaceState(null, "", "?" + params.toString()); } + let api_calls: Payload[] = []; export let render_complete = false; async function handle_update(data: any, fn_index: number): Promise { @@ -251,6 +254,9 @@ } async function make_prediction(payload: Payload): Promise { + if (api_recorder_visible) { + api_calls = [...api_calls, payload]; + } const submission = app .submit( payload.fn_index, @@ -543,6 +549,21 @@ {/if} +{#if api_recorder_visible} + + + +
{ + set_api_docs_visible(true); + api_recorder_visible = false; + }} + > + +
+{/if} + {#if api_docs_visible && $_layout}
@@ -557,13 +578,16 @@
{ + on:close={(event) => { set_api_docs_visible(false); + api_calls = []; + api_recorder_visible = event.detail.api_recorder_visible; }} {dependencies} {root} {app} {space_id} + {api_calls} />
@@ -666,4 +690,11 @@ width: 1150px; } } + + #api-recorder-container { + position: fixed; + left: 10px; + bottom: 10px; + z-index: 1000; + } diff --git a/js/app/src/api_docs/ApiBanner.svelte b/js/app/src/api_docs/ApiBanner.svelte index 62d867b3f24e..9507ea678dc2 100644 --- a/js/app/src/api_docs/ApiBanner.svelte +++ b/js/app/src/api_docs/ApiBanner.svelte @@ -2,6 +2,7 @@ import { createEventDispatcher } from "svelte"; import api_logo from "./img/api-logo.svg"; import Clear from "./img/clear.svelte"; + import Button from "../../../button/shared/Button.svelte"; export let root: string; export let api_count: number; @@ -18,7 +19,15 @@ - {api_count} API endpoint{#if api_count > 1}s{/if} + {api_count} API endpoint{#if api_count > 1}s{/if}
+
diff --git a/js/app/src/api_docs/ApiDocs.svelte b/js/app/src/api_docs/ApiDocs.svelte index 68e8b8e6ea9d..458160910abe 100644 --- a/js/app/src/api_docs/ApiDocs.svelte +++ b/js/app/src/api_docs/ApiDocs.svelte @@ -5,13 +5,15 @@ import { post_data } from "@gradio/client"; import NoApi from "./NoApi.svelte"; import type { client } from "@gradio/client"; - + import type { Payload } from "../types"; import { represent_value } from "./utils"; import ApiBanner from "./ApiBanner.svelte"; + import Button from "../../../button/shared/Button.svelte"; import ParametersSnippet from "./ParametersSnippet.svelte"; import InstallSnippet from "./InstallSnippet.svelte"; import CodeSnippet from "./CodeSnippet.svelte"; + import RecordingSnippet from "./RecordingSnippet.svelte"; import python from "./img/python.svg"; import javascript from "./img/javascript.svg"; @@ -39,6 +41,7 @@ root += "/"; } + export let api_calls: Payload[] = []; let current_language: "python" | "javascript" = "python"; const langs = [ @@ -162,6 +165,8 @@ } } + const dispatch = createEventDispatcher(); + onMount(() => { document.body.style.overflow = "hidden"; if ("parentIFrame" in window) { @@ -178,14 +183,15 @@ +

Use the gradio_client - Python library (docs) or the + Python library or the @gradio/client - Javascript package (docs) to - query the app via API. + Javascript package to query the app + via API.

@@ -193,7 +199,7 @@ {#each langs as [language, img]}
  • (current_language = language)} > @@ -201,25 +207,70 @@
  • {/each}
    - -

    - 1. Install the client if you don't already have it installed. -

    - - - -

    - 2. Find the API endpoint below corresponding to your desired function - in the app. Copy the code snippet, replacing the placeholder values - with your own input data. - {#if space_id}If this is a private Space, you may need to pass your - Hugging Face token as well (read more).{/if} Run the code, that's it! -

    + {#if api_calls.length} +
    +

    + 🪄 Recorded API Calls ({api_calls.length}) +

    +

    + Here is the code snippet to replay the most recently recorded API + calls using the {current_language} + client. +

    + + +

    + Note: the above list may include extra API calls that affect the + UI, but are not necessary for the clients. +

    +
    +

    + API Documentation +

    + {:else} +

    + 1. Install the client if you don't already have it installed. +

    + + + +

    + 2. Find the API endpoint below corresponding to your desired + function in the app. Copy the code snippet, replacing the + placeholder values with your own input data. + {#if space_id}If this is a private Space, you may need to pass your + Hugging Face token as well (read more).{/if} Or + + to automatically generate your API requests. + + +

    + {/if} {#each dependencies as dependency, dependency_index} {#if dependency.show_api} @@ -306,9 +357,10 @@ border: 1px solid var(--border-color-accent); border-radius: var(--radius-sm); background: var(--color-accent-soft); - padding: var(--size-1); + padding: 0px var(--size-1); color: var(--color-accent); font-size: var(--text-md); + text-decoration: none; } .snippets { @@ -378,4 +430,12 @@ padding: 15px 0px; font-size: var(--text-lg); } + + #api-recorder { + border: 1px solid var(--color-accent); + background-color: var(--color-accent-soft); + padding: 0px var(--size-2); + border-radius: var(--size-1); + cursor: pointer; + } diff --git a/js/app/src/api_docs/ApiRecorder.svelte b/js/app/src/api_docs/ApiRecorder.svelte new file mode 100644 index 000000000000..a2778ed2bb62 --- /dev/null +++ b/js/app/src/api_docs/ApiRecorder.svelte @@ -0,0 +1,33 @@ + + +
    + 🟠 Recording API Calls... + ({api_calls.length}) {#if api_calls.length > 0} + | /{dependencies[api_calls[api_calls.length - 1].fn_index].api_name} + {/if} +
    + + diff --git a/js/app/src/api_docs/CodeSnippet.svelte b/js/app/src/api_docs/CodeSnippet.svelte index 6ee63bac0425..4525266b6a4a 100644 --- a/js/app/src/api_docs/CodeSnippet.svelte +++ b/js/app/src/api_docs/CodeSnippet.svelte @@ -59,7 +59,7 @@ result = client.predict {parameter_name ? parameter_name + "=" - : ""}{represent_value( parameter_has_default ? parameter_default : example_input, python_type.type, @@ -86,7 +86,9 @@ const example{component} = await response_{i}.blob(); {/each} const app = await client("{root}"); -const result = await app.predict({#if named}"/{dependency.api_name}"{:else}{dependency_index}{/if}, [{#each endpoint_parameters as { label, type, python_type, component, example_input, serializer }, i}{#if blob_components.includes(component)} @@ -168,9 +170,6 @@ console.log(result.data); color: var(--body-text-color-subdued); } - .example-inputs { - color: var(--color-accent); - } .api-name { color: var(--color-accent); } diff --git a/js/app/src/api_docs/RecordingSnippet.svelte b/js/app/src/api_docs/RecordingSnippet.svelte new file mode 100644 index 000000000000..62024d313bf6 --- /dev/null +++ b/js/app/src/api_docs/RecordingSnippet.svelte @@ -0,0 +1,118 @@ + + +
    + + + + {#if current_language === "python"} +
    + +
    +
    +
    from gradio_client import Client, file
    +
    +client = Client("{root}")
    +{#each api_calls as call}
    +client.predict(
    +{format_api_call(call)}  api_name="/{dependencies[call.fn_index].api_name}"
    +)
    +{/each}
    +
    + {:else if current_language === "javascript"} +
    + +
    +
    +
    import { client } from "@gradio/client";
    +
    +const app = await client("{root}");
    +{#each api_calls as call}
    +{#if dependencies[call.fn_index].backend_fn}client.predict("/{dependencies[call.fn_index].api_name}"", {JSON.stringify(call.data, null, 2)});{/if}
    +						{/each}
    +
    {/if} +
    +
    +
    + + diff --git a/js/app/src/api_docs/index.ts b/js/app/src/api_docs/index.ts index fe09cac462ff..9d60f596e1ef 100644 --- a/js/app/src/api_docs/index.ts +++ b/js/app/src/api_docs/index.ts @@ -1 +1,2 @@ export { default as ApiDocs } from "./ApiDocs.svelte"; +export { default as ApiRecorder } from "./ApiRecorder.svelte"; diff --git a/js/app/src/api_docs/utils.ts b/js/app/src/api_docs/utils.ts index fe809a23d9e2..32e31698dbc9 100644 --- a/js/app/src/api_docs/utils.ts +++ b/js/app/src/api_docs/utils.ts @@ -94,7 +94,10 @@ function replace_file_data_with_file_function(obj: any): any { } function stringify_except_file_function(obj: any): string { - const jsonString = JSON.stringify(obj, (key, value) => { + let jsonString = JSON.stringify(obj, (key, value) => { + if (value === null) { + return "UNQUOTEDNone"; + } if ( typeof value === "string" && value.startsWith("file(") && @@ -105,5 +108,7 @@ function stringify_except_file_function(obj: any): string { return value; }); const regex = /"UNQUOTEDfile\(([^)]*)\)"/g; - return jsonString.replace(regex, (match, p1) => `file(${p1})`); + jsonString = jsonString.replace(regex, (match, p1) => `file(${p1})`); + const regexNone = /"UNQUOTEDNone"/g; + return jsonString.replace(regexNone, "None"); } diff --git a/test/test_blocks.py b/test/test_blocks.py index 742063572d6e..01418d2caff7 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -496,7 +496,8 @@ def test_get_load_events(self, io_components): class TestBlocksPostprocessing: - def test_blocks_do_not_filter_none_values_from_updates(self, io_components): + @pytest.mark.asyncio + async def test_blocks_do_not_filter_none_values_from_updates(self, io_components): io_components = [ c() for c in io_components @@ -522,7 +523,7 @@ def test_blocks_do_not_filter_none_values_from_updates(self, io_components): outputs=io_components, ) - output = demo.postprocess_data( + output = await demo.postprocess_data( 0, [gr.update(value=None) for _ in io_components], state=None ) @@ -536,7 +537,8 @@ def process_and_dump(component): o["value"] == process_and_dump(c) for o, c in zip(output, io_components) ) - def test_blocks_does_not_replace_keyword_literal(self): + @pytest.mark.asyncio + async def test_blocks_does_not_replace_keyword_literal(self): with gr.Blocks() as demo: text = gr.Textbox() btn = gr.Button(value="Reset") @@ -546,10 +548,11 @@ def test_blocks_does_not_replace_keyword_literal(self): outputs=text, ) - output = demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None) + output = await demo.postprocess_data(0, gr.update(value="NO_VALUE"), state=None) assert output[0]["value"] == "NO_VALUE" - def test_blocks_does_not_del_dict_keys_inplace(self): + @pytest.mark.asyncio + async def test_blocks_does_not_del_dict_keys_inplace(self): with gr.Blocks() as demo: im_list = [gr.Image() for i in range(2)] @@ -559,13 +562,16 @@ def change_visibility(value): checkbox = gr.Checkbox(value=True, label="Show image") checkbox.change(change_visibility, inputs=checkbox, outputs=im_list) - output = demo.postprocess_data(0, [gr.update(visible=False)] * 2, state=None) + output = await demo.postprocess_data( + 0, [gr.update(visible=False)] * 2, state=None + ) assert output == [ {"visible": False, "__type__": "update"}, {"visible": False, "__type__": "update"}, ] - def test_blocks_returns_correct_output_dict_single_key(self): + @pytest.mark.asyncio + async def test_blocks_returns_correct_output_dict_single_key(self): with gr.Blocks() as demo: num = gr.Number() num2 = gr.Number() @@ -576,10 +582,10 @@ def update_values(val): update.click(update_values, inputs=[num], outputs=[num2]) - output = demo.postprocess_data(0, {num2: gr.Number(value=42)}, state=None) + output = await demo.postprocess_data(0, {num2: gr.Number(value=42)}, state=None) assert output[0]["value"] == 42 - output = demo.postprocess_data(0, {num2: 23}, state=None) + output = await demo.postprocess_data(0, {num2: 23}, state=None) assert output[0] == 23 @pytest.mark.asyncio @@ -645,7 +651,8 @@ def generic_update(): } assert output["data"][1] == {"__type__": "update", "interactive": True} - def test_error_raised_if_num_outputs_mismatch(self): + @pytest.mark.asyncio + async def test_error_raised_if_num_outputs_mismatch(self): with gr.Blocks() as demo: textbox1 = gr.Textbox() textbox2 = gr.Textbox() @@ -655,9 +662,10 @@ def test_error_raised_if_num_outputs_mismatch(self): ValueError, match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): - demo.postprocess_data(fn_index=0, predictions=["test"], state=None) + await demo.postprocess_data(fn_index=0, predictions=["test"], state=None) - def test_error_raised_if_num_outputs_mismatch_with_function_name(self): + @pytest.mark.asyncio + async def test_error_raised_if_num_outputs_mismatch_with_function_name(self): def infer(x): return x @@ -670,9 +678,10 @@ def infer(x): ValueError, match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): - demo.postprocess_data(fn_index=0, predictions=["test"], state=None) + await demo.postprocess_data(fn_index=0, predictions=["test"], state=None) - def test_error_raised_if_num_outputs_mismatch_single_output(self): + @pytest.mark.asyncio + async def test_error_raised_if_num_outputs_mismatch_single_output(self): with gr.Blocks() as demo: num1 = gr.Number() num2 = gr.Number() @@ -682,9 +691,10 @@ def test_error_raised_if_num_outputs_mismatch_single_output(self): ValueError, match=r"^An event handler didn\'t receive enough output values \(needed: 2, received: 1\)\.\nWanted outputs:", ): - demo.postprocess_data(fn_index=0, predictions=1, state=None) + await demo.postprocess_data(fn_index=0, predictions=1, state=None) - def test_error_raised_if_num_outputs_mismatch_tuple_output(self): + @pytest.mark.asyncio + async def test_error_raised_if_num_outputs_mismatch_tuple_output(self): def infer(a, b): return a, b @@ -698,7 +708,7 @@ def infer(a, b): ValueError, match=r"^An event handler \(infer\) didn\'t receive enough output values \(needed: 3, received: 2\)\.\nWanted outputs:", ): - demo.postprocess_data(fn_index=0, predictions=(1, 2), state=None) + await demo.postprocess_data(fn_index=0, predictions=(1, 2), state=None) class TestStateHolder: @@ -1663,7 +1673,8 @@ def test_fn(x): assert demo.dependencies[0]["api_name"] == "test_fn" -def test_blocks_postprocessing_with_copies_of_component_instance(): +@pytest.mark.asyncio +async def test_blocks_postprocessing_with_copies_of_component_instance(): # Test for: https://github.com/gradio-app/gradio/issues/6608 with gr.Blocks() as demo: chatbot = gr.Chatbot() @@ -1679,7 +1690,7 @@ def clear_func(): ) assert ( - demo.postprocess_data(0, [gr.Chatbot(value=[])] * 3, None) + await demo.postprocess_data(0, [gr.Chatbot(value=[])] * 3, None) == [{"value": [], "__type__": "update"}] * 3 ) diff --git a/test/test_components.py b/test/test_components.py index 257e4044fe9a..a29b2b27de47 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -830,7 +830,8 @@ def test_plot_format_parameter(self): class TestAudio: - def test_component_functions(self, gradio_temp_dir): + @pytest.mark.asyncio + async def test_component_functions(self, gradio_temp_dir): """ Preprocess, postprocess serialize, get_config, deserialize type: filepath, numpy, file @@ -841,7 +842,8 @@ def test_component_functions(self, gradio_temp_dir): assert output1[0] == 8000 assert output1[1].shape == (8046,) - x_wav = processing_utils.move_files_to_cache([x_wav], audio_input)[0] + x_wav = await processing_utils.async_move_files_to_cache([x_wav], audio_input) + x_wav = x_wav[0] audio_input = gr.Audio(type="filepath") output1 = audio_input.preprocess(x_wav) assert Path(output1).name.endswith("audio_sample.wav") @@ -1489,7 +1491,8 @@ def test_postprocessing(self): class TestVideo: - def test_component_functions(self): + @pytest.mark.asyncio + async def test_component_functions(self): """ Preprocess, serialize, deserialize, get_config """ @@ -1498,7 +1501,10 @@ def test_component_functions(self): ) video_input = gr.Video() - x_video = processing_utils.move_files_to_cache([x_video], video_input)[0] + x_video = await processing_utils.async_move_files_to_cache( + [x_video], video_input + ) + x_video = x_video[0] output1 = video_input.preprocess(x_video) assert isinstance(output1, str) @@ -3052,7 +3058,8 @@ def test_component_example_payloads(io_components): else: c: Component = component() data = c.example_payload() - data = processing_utils.move_files_to_cache( + data = client_utils.synchronize_async( + processing_utils.async_move_files_to_cache, data, c, check_in_upload_folder=False,