From 6204ccac5967763e0ebde550d04d12584243a120 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 9 Nov 2023 08:34:32 -0800 Subject: [PATCH] Fixes `gr.load()` so it works properly with Images and Examples (#6322) * changes * image * fixes * examples * add changeset * changes * revert * add changeset * fix * fix test * modify workflow * fix --------- Co-authored-by: gradio-pr-bot Co-authored-by: Hannah --- .changeset/fresh-months-bake.md | 6 ++ .github/workflows/backend.yml | 16 ++--- client/python/test/test_client.py | 97 +++++++++++++---------------- gradio/blocks.py | 7 +-- gradio/components/base.py | 7 +-- gradio/components/dataset.py | 7 +-- gradio/helpers.py | 2 +- gradio/image_utils.py | 4 +- js/image/Index.svelte | 13 ++-- js/image/shared/ImagePreview.svelte | 3 - 10 files changed, 75 insertions(+), 87 deletions(-) create mode 100644 .changeset/fresh-months-bake.md diff --git a/.changeset/fresh-months-bake.md b/.changeset/fresh-months-bake.md new file mode 100644 index 000000000000..6778451b4923 --- /dev/null +++ b/.changeset/fresh-months-bake.md @@ -0,0 +1,6 @@ +--- +"@gradio/image": patch +"gradio": patch +--- + +fix:Fixes `gr.load()` so it works properly with Images and Examples diff --git a/.github/workflows/backend.yml b/.github/workflows/backend.yml index 2ef85a64160f..4090d4106c2c 100644 --- a/.github/workflows/backend.yml +++ b/.github/workflows/backend.yml @@ -4,7 +4,7 @@ on: push: branches: - "main" - pull_request: + pull_request: concurrency: group: backend-${{ github.ref }}-${{ github.event_name == 'push' || github.event.inputs.fire != null }} @@ -107,8 +107,8 @@ jobs: if: runner.os == 'Linux' run: | . venv/bin/activate - python -m pip install -e client/python - python -m pip install -e . + python -m pip install client/python + python -m pip install ".[oauth]" - name: Lint (Linux) if: runner.os == 'Linux' run: | @@ -128,8 +128,8 @@ jobs: if: runner.os == 'Windows' run: | venv\Scripts\activate - python -m pip install -e client/python - python -m pip install -e . + python -m pip install client/python + python -m pip install ".[oauth]" - name: Tests (Windows) if: runner.os == 'Windows' run: | @@ -193,7 +193,7 @@ jobs: if: runner.os == 'Linux' run: | . venv/bin/activate - python -m pip install -e client/python + python -m pip install client/python python -m pip install ".[oauth]" - name: Install Test Dependencies (Linux) if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux' @@ -221,13 +221,13 @@ jobs: if: runner.os == 'Windows' run: | venv\Scripts\activate - python -m pip install -e client/python + python -m pip install client/python python -m pip install ".[oauth]" - name: Install Test Dependencies (Windows) if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Windows' run: | venv\Scripts\activate - python -m pip install -e . -r test/requirements.txt + python -m pip install . -r test/requirements.txt - name: Run tests (Windows) if: runner.os == 'Windows' run: | diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index da5e7aeea0a4..894a6c654881 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -810,63 +810,54 @@ def test_private_space(self): "unnamed_endpoints": {}, } - @pytest.mark.flaky def test_fetch_fixed_version_space(self, calculator_demo): with connect(calculator_demo) as client: - assert client.view_api(return_format="dict") == { - "named_endpoints": { - "/predict": { - "parameters": [ - { - "label": "num1", - "type": {"type": "number"}, - "python_type": { - "type": "float", - "description": "", - }, - "component": "Number", - "example_input": 3, - }, - { - "label": "operation", - "type": { - "enum": ["add", "subtract", "multiply", "divide"], - "title": "Radio", - "type": "string", - }, - "python_type": { - "type": "Literal[add, subtract, multiply, divide]", - "description": "", - }, - "component": "Radio", - "example_input": "add", - }, - { - "label": "num2", - "type": {"type": "number"}, - "python_type": { - "type": "float", - "description": "", - }, - "component": "Number", - "example_input": 3, - }, - ], - "returns": [ - { - "label": "output", - "type": {"type": "number"}, - "python_type": { - "type": "float", - "description": "", - }, - "component": "Number", - } - ], + api_info = client.view_api(return_format="dict") + assert isinstance(api_info, dict) + assert api_info["named_endpoints"]["/predict"] == { + "parameters": [ + { + "label": "num1", + "type": {"type": "number"}, + "python_type": {"type": "float", "description": ""}, + "component": "Number", + "example_input": 3, + }, + { + "label": "operation", + "type": { + "enum": ["add", "subtract", "multiply", "divide"], + "title": "Radio", + "type": "string", + }, + "python_type": { + "type": "Literal[add, subtract, multiply, divide]", + "description": "", + }, + "component": "Radio", + "example_input": "add", + }, + { + "label": "num2", + "type": {"type": "number"}, + "python_type": {"type": "float", "description": ""}, + "component": "Number", + "example_input": 3, + }, + ], + "returns": [ + { + "label": "output", + "type": {"type": "number"}, + "python_type": {"type": "float", "description": ""}, + "component": "Number", } - }, - "unnamed_endpoints": {}, + ], } + assert ( + "/load_example" in api_info["named_endpoints"] + ) # The exact api configuration includes Block IDs and thus is not deterministic + assert api_info["unnamed_endpoints"] == {} def test_unnamed_endpoints_use_fn_index(self, count_generator_demo): with connect(count_generator_demo) as client: diff --git a/gradio/blocks.py b/gradio/blocks.py index 9a2c3efde92b..ecfce3b8019b 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -103,7 +103,6 @@ def __init__( render: bool = True, visible: bool = True, proxy_url: str | None = None, - _skip_init_processing: bool = False, ): self._id = Context.id Context.id += 1 @@ -114,7 +113,6 @@ def __init__( ) self.proxy_url = proxy_url self.share_token = secrets.token_urlsafe(32) - self._skip_init_processing = _skip_init_processing self.parent: BlockContext | None = None self.is_rendered: bool = False self._constructor_args: dict @@ -631,9 +629,12 @@ def get_block_instance(id: int) -> Block: # URL of C, not B. The else clause below handles this case. if block_config["props"].get("proxy_url") is None: block_config["props"]["proxy_url"] = f"{proxy_url}/" + postprocessed_value = block_config["props"].pop("value", None) constructor_args = cls.recover_kwargs(block_config["props"]) block = cls(**constructor_args) + if postprocessed_value is not None: + block.value = postprocessed_value # type: ignore block_proxy_url = block_config["props"]["proxy_url"] block.proxy_url = block_proxy_url @@ -642,8 +643,6 @@ def get_block_instance(id: int) -> Block: _selectable := block_config["props"].pop("_selectable", None) ) is not None: block._selectable = _selectable # type: ignore - # Any component has already processed its initial value, so we skip that step here - block._skip_init_processing = True return block diff --git a/gradio/components/base.py b/gradio/components/base.py index 00c3a83ee647..e3ee2f7c40c7 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -206,11 +206,7 @@ def __init__( self.load_event: None | dict[str, Any] = None self.load_event_to_attach: None | tuple[Callable, float | None] = None load_fn, initial_value = self.get_load_fn_and_initial_value(value) - initial_value = ( - initial_value - if self._skip_init_processing - else self.postprocess(initial_value) - ) + initial_value = self.postprocess(initial_value) self.value = move_files_to_cache(initial_value, self, postprocess=True) # type: ignore if callable(load_fn): @@ -227,7 +223,6 @@ def get_config(self): config["info"] = self.info if len(self.server_fns): config["server_fns"] = [fn.__name__ for fn in self.server_fns] - config.pop("_skip_init_processing", None) config.pop("render", None) return config diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 53a6c4341332..01a671a0acf2 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -103,9 +103,8 @@ def __init__( self.headers = [c.label or "" for c in self._components] self.samples_per_page = samples_per_page - @property - def skip_api(self): - return True + def api_info(self) -> dict[str, str]: + return {"type": "integer", "description": "index of selected example"} def get_config(self): config = super().get_config() @@ -134,4 +133,4 @@ def postprocess(self, samples: list[list]) -> dict: } def example_inputs(self) -> Any: - return None + return 0 diff --git a/gradio/helpers.py b/gradio/helpers.py index 512333941700..d70c8dda18b0 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -52,7 +52,7 @@ def create_examples( run_on_click: bool = False, preprocess: bool = True, postprocess: bool = True, - api_name: str | None | Literal[False] = False, + api_name: str | None | Literal[False] = None, batch: bool = False, ): """Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component.""" diff --git a/gradio/image_utils.py b/gradio/image_utils.py index 6806ec9d8cdb..154a72e06555 100644 --- a/gradio/image_utils.py +++ b/gradio/image_utils.py @@ -47,6 +47,8 @@ def save_image(y: np.ndarray | _Image.Image | str | Path, cache_dir: str): elif isinstance(y, str): path = y else: - raise ValueError("Cannot process this value as an Image") + raise ValueError( + "Cannot process this value as an Image, it is of type: " + str(type(y)) + ) return path diff --git a/js/image/Index.svelte b/js/image/Index.svelte index a1b0c89b9c94..46da2a4f0def 100644 --- a/js/image/Index.svelte +++ b/js/image/Index.svelte @@ -17,15 +17,18 @@ import { StatusTracker } from "@gradio/statustracker"; import type { FileData } from "@gradio/client"; import type { LoadingStatus } from "@gradio/statustracker"; + import { normalise_file } from "@gradio/client"; export let elem_id = ""; export let elem_classes: string[] = []; export let visible = true; export let value: null | FileData = null; + $: _value = normalise_file(value, root, proxy_url); export let label: string; export let show_label: boolean; export let show_download_button: boolean; export let root: string; + export let proxy_url: null | string; export let height: number | undefined; export let width: number | undefined; @@ -60,9 +63,6 @@ $: value?.url && gradio.dispatch("change"); let dragging: boolean; - - $: value = !value ? null : value; - let active_tool: null | "webcam" = null; @@ -90,8 +90,7 @@ on:select={({ detail }) => gradio.dispatch("select", detail)} on:share={({ detail }) => gradio.dispatch("share", detail)} on:error={({ detail }) => gradio.dispatch("error", detail)} - {root} - {value} + value={_value} {label} {show_label} {show_download_button} @@ -103,7 +102,7 @@ {:else} (); - $: value = normalise_file(value, root, null); - const handle_click = (evt: MouseEvent): void => { let coordinates = get_coordinates_of_clicked_image(evt); if (coordinates) {