Skip to content

Commit

Permalink
Fixes gr.load() so it works properly with Images and Examples (#6322)
Browse files Browse the repository at this point in the history
* changes

* image

* fixes

* examples

* add changeset

* changes

* revert

* add changeset

* fix

* fix test

* modify workflow

* fix

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Hannah <hannahblair@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 9, 2023
1 parent 506ab9e commit 6204cca
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 87 deletions.
6 changes: 6 additions & 0 deletions .changeset/fresh-months-bake.md
@@ -0,0 +1,6 @@
---
"@gradio/image": patch
"gradio": patch
---

fix:Fixes `gr.load()` so it works properly with Images and Examples
16 changes: 8 additions & 8 deletions .github/workflows/backend.yml
Expand Up @@ -4,7 +4,7 @@ on:
push:
branches:
- "main"
pull_request:
pull_request:

concurrency:
group: backend-${{ github.ref }}-${{ github.event_name == 'push' || github.event.inputs.fire != null }}
Expand Down Expand Up @@ -107,8 +107,8 @@ jobs:
if: runner.os == 'Linux'
run: |
. venv/bin/activate
python -m pip install -e client/python
python -m pip install -e .
python -m pip install client/python
python -m pip install ".[oauth]"
- name: Lint (Linux)
if: runner.os == 'Linux'
run: |
Expand All @@ -128,8 +128,8 @@ jobs:
if: runner.os == 'Windows'
run: |
venv\Scripts\activate
python -m pip install -e client/python
python -m pip install -e .
python -m pip install client/python
python -m pip install ".[oauth]"
- name: Tests (Windows)
if: runner.os == 'Windows'
run: |
Expand Down Expand Up @@ -193,7 +193,7 @@ jobs:
if: runner.os == 'Linux'
run: |
. venv/bin/activate
python -m pip install -e client/python
python -m pip install client/python
python -m pip install ".[oauth]"
- name: Install Test Dependencies (Linux)
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Linux'
Expand Down Expand Up @@ -221,13 +221,13 @@ jobs:
if: runner.os == 'Windows'
run: |
venv\Scripts\activate
python -m pip install -e client/python
python -m pip install client/python
python -m pip install ".[oauth]"
- name: Install Test Dependencies (Windows)
if: steps.cache.outputs.cache-hit != 'true' && runner.os == 'Windows'
run: |
venv\Scripts\activate
python -m pip install -e . -r test/requirements.txt
python -m pip install . -r test/requirements.txt
- name: Run tests (Windows)
if: runner.os == 'Windows'
run: |
Expand Down
97 changes: 44 additions & 53 deletions client/python/test/test_client.py
Expand Up @@ -810,63 +810,54 @@ def test_private_space(self):
"unnamed_endpoints": {},
}

@pytest.mark.flaky
def test_fetch_fixed_version_space(self, calculator_demo):
with connect(calculator_demo) as client:
assert client.view_api(return_format="dict") == {
"named_endpoints": {
"/predict": {
"parameters": [
{
"label": "num1",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"component": "Number",
"example_input": 3,
},
{
"label": "operation",
"type": {
"enum": ["add", "subtract", "multiply", "divide"],
"title": "Radio",
"type": "string",
},
"python_type": {
"type": "Literal[add, subtract, multiply, divide]",
"description": "",
},
"component": "Radio",
"example_input": "add",
},
{
"label": "num2",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"component": "Number",
"example_input": 3,
},
],
"returns": [
{
"label": "output",
"type": {"type": "number"},
"python_type": {
"type": "float",
"description": "",
},
"component": "Number",
}
],
api_info = client.view_api(return_format="dict")
assert isinstance(api_info, dict)
assert api_info["named_endpoints"]["/predict"] == {
"parameters": [
{
"label": "num1",
"type": {"type": "number"},
"python_type": {"type": "float", "description": ""},
"component": "Number",
"example_input": 3,
},
{
"label": "operation",
"type": {
"enum": ["add", "subtract", "multiply", "divide"],
"title": "Radio",
"type": "string",
},
"python_type": {
"type": "Literal[add, subtract, multiply, divide]",
"description": "",
},
"component": "Radio",
"example_input": "add",
},
{
"label": "num2",
"type": {"type": "number"},
"python_type": {"type": "float", "description": ""},
"component": "Number",
"example_input": 3,
},
],
"returns": [
{
"label": "output",
"type": {"type": "number"},
"python_type": {"type": "float", "description": ""},
"component": "Number",
}
},
"unnamed_endpoints": {},
],
}
assert (
"/load_example" in api_info["named_endpoints"]
) # The exact api configuration includes Block IDs and thus is not deterministic
assert api_info["unnamed_endpoints"] == {}

def test_unnamed_endpoints_use_fn_index(self, count_generator_demo):
with connect(count_generator_demo) as client:
Expand Down
7 changes: 3 additions & 4 deletions gradio/blocks.py
Expand Up @@ -103,7 +103,6 @@ def __init__(
render: bool = True,
visible: bool = True,
proxy_url: str | None = None,
_skip_init_processing: bool = False,
):
self._id = Context.id
Context.id += 1
Expand All @@ -114,7 +113,6 @@ def __init__(
)
self.proxy_url = proxy_url
self.share_token = secrets.token_urlsafe(32)
self._skip_init_processing = _skip_init_processing
self.parent: BlockContext | None = None
self.is_rendered: bool = False
self._constructor_args: dict
Expand Down Expand Up @@ -631,9 +629,12 @@ def get_block_instance(id: int) -> Block:
# URL of C, not B. The else clause below handles this case.
if block_config["props"].get("proxy_url") is None:
block_config["props"]["proxy_url"] = f"{proxy_url}/"
postprocessed_value = block_config["props"].pop("value", None)

constructor_args = cls.recover_kwargs(block_config["props"])
block = cls(**constructor_args)
if postprocessed_value is not None:
block.value = postprocessed_value # type: ignore

block_proxy_url = block_config["props"]["proxy_url"]
block.proxy_url = block_proxy_url
Expand All @@ -642,8 +643,6 @@ def get_block_instance(id: int) -> Block:
_selectable := block_config["props"].pop("_selectable", None)
) is not None:
block._selectable = _selectable # type: ignore
# Any component has already processed its initial value, so we skip that step here
block._skip_init_processing = True

return block

Expand Down
7 changes: 1 addition & 6 deletions gradio/components/base.py
Expand Up @@ -206,11 +206,7 @@ def __init__(
self.load_event: None | dict[str, Any] = None
self.load_event_to_attach: None | tuple[Callable, float | None] = None
load_fn, initial_value = self.get_load_fn_and_initial_value(value)
initial_value = (
initial_value
if self._skip_init_processing
else self.postprocess(initial_value)
)
initial_value = self.postprocess(initial_value)
self.value = move_files_to_cache(initial_value, self, postprocess=True) # type: ignore

if callable(load_fn):
Expand All @@ -227,7 +223,6 @@ def get_config(self):
config["info"] = self.info
if len(self.server_fns):
config["server_fns"] = [fn.__name__ for fn in self.server_fns]
config.pop("_skip_init_processing", None)
config.pop("render", None)
return config

Expand Down
7 changes: 3 additions & 4 deletions gradio/components/dataset.py
Expand Up @@ -103,9 +103,8 @@ def __init__(
self.headers = [c.label or "" for c in self._components]
self.samples_per_page = samples_per_page

@property
def skip_api(self):
return True
def api_info(self) -> dict[str, str]:
return {"type": "integer", "description": "index of selected example"}

def get_config(self):
config = super().get_config()
Expand Down Expand Up @@ -134,4 +133,4 @@ def postprocess(self, samples: list[list]) -> dict:
}

def example_inputs(self) -> Any:
return None
return 0
2 changes: 1 addition & 1 deletion gradio/helpers.py
Expand Up @@ -52,7 +52,7 @@ def create_examples(
run_on_click: bool = False,
preprocess: bool = True,
postprocess: bool = True,
api_name: str | None | Literal[False] = False,
api_name: str | None | Literal[False] = None,
batch: bool = False,
):
"""Top-level synchronous function that creates Examples. Provided for backwards compatibility, i.e. so that gr.Examples(...) can be used to create the Examples component."""
Expand Down
4 changes: 3 additions & 1 deletion gradio/image_utils.py
Expand Up @@ -47,6 +47,8 @@ def save_image(y: np.ndarray | _Image.Image | str | Path, cache_dir: str):
elif isinstance(y, str):
path = y
else:
raise ValueError("Cannot process this value as an Image")
raise ValueError(
"Cannot process this value as an Image, it is of type: " + str(type(y))
)

return path
13 changes: 6 additions & 7 deletions js/image/Index.svelte
Expand Up @@ -17,15 +17,18 @@
import { StatusTracker } from "@gradio/statustracker";
import type { FileData } from "@gradio/client";
import type { LoadingStatus } from "@gradio/statustracker";
import { normalise_file } from "@gradio/client";
export let elem_id = "";
export let elem_classes: string[] = [];
export let visible = true;
export let value: null | FileData = null;
$: _value = normalise_file(value, root, proxy_url);
export let label: string;
export let show_label: boolean;
export let show_download_button: boolean;
export let root: string;
export let proxy_url: null | string;
export let height: number | undefined;
export let width: number | undefined;
Expand Down Expand Up @@ -60,9 +63,6 @@
$: value?.url && gradio.dispatch("change");
let dragging: boolean;
$: value = !value ? null : value;
let active_tool: null | "webcam" = null;
</script>

Expand Down Expand Up @@ -90,8 +90,7 @@
on:select={({ detail }) => gradio.dispatch("select", detail)}
on:share={({ detail }) => gradio.dispatch("share", detail)}
on:error={({ detail }) => gradio.dispatch("error", detail)}
{root}
{value}
value={_value}
{label}
{show_label}
{show_download_button}
Expand All @@ -103,7 +102,7 @@
{:else}
<Block
{visible}
variant={value === null ? "dashed" : "solid"}
variant={_value === null ? "dashed" : "solid"}
border_mode={dragging ? "focus" : "base"}
padding={false}
{elem_id}
Expand All @@ -123,7 +122,7 @@

<ImageUploader
bind:active_tool
bind:value
bind:value={_value}
selectable={_selectable}
{root}
{sources}
Expand Down
3 changes: 0 additions & 3 deletions js/image/shared/ImagePreview.svelte
Expand Up @@ -16,16 +16,13 @@
export let show_download_button = true;
export let selectable = false;
export let show_share_button = false;
export let root: string;
export let i18n: I18nFormatter;
const dispatch = createEventDispatcher<{
change: string;
select: SelectData;
}>();
$: value = normalise_file(value, root, null);
const handle_click = (evt: MouseEvent): void => {
let coordinates = get_coordinates_of_clicked_image(evt);
if (coordinates) {
Expand Down

0 comments on commit 6204cca

Please sign in to comment.