Skip to content

Commit

Permalink
pass props to example components and to example outputs (#6014)
Browse files Browse the repository at this point in the history
* pass props to example components and to example outputs

* add changeset

* make util less egenric/ more useful

* fix demo

* fix demo

* fix

* fix test

* Fix test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: freddyaboulton <alfonsoboulton@gmail.com>
  • Loading branch information
3 people committed Oct 19, 2023
1 parent 9cf40f7 commit cad537a
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 22 deletions.
6 changes: 6 additions & 0 deletions .changeset/green-olives-shake.md
@@ -0,0 +1,6 @@
---
"@gradio/dataset": minor
"gradio": minor
---

feat:pass props to example components and to example outputs
2 changes: 1 addition & 1 deletion demo/examples_component/run.ipynb
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: examples_component"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/lion.jpg\n", "!wget -q -O images/lion.webp https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/lion.webp\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/logo.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "def flip(i):\n", " return i.rotate(180)\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " img_i = gr.Image(label=\"Input Image\", type=\"pil\")\n", " with gr.Column():\n", " img_o = gr.Image(label=\"Output Image\")\n", " with gr.Row():\n", " btn = gr.Button(value=\"Flip Image\")\n", " btn.click(flip, inputs=[img_i], outputs=[img_o])\n", "\n", " gr.Examples(\n", " [ \n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " img_i,\n", " img_o,\n", " flip\n", " )\n", " \n", "demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: examples_component"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "os.mkdir('images')\n", "!wget -q -O images/cheetah1.jpg https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/cheetah1.jpg\n", "!wget -q -O images/lion.jpg https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/lion.jpg\n", "!wget -q -O images/lion.webp https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/lion.webp\n", "!wget -q -O images/logo.png https://github.com/gradio-app/gradio/raw/main/demo/examples_component/images/logo.png"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import os\n", "\n", "\n", "def flip(i):\n", " return i.rotate(180)\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " img_i = gr.Image(label=\"Input Image\", type=\"pil\")\n", " with gr.Column():\n", " img_o = gr.Image(label=\"Output Image\")\n", " with gr.Row():\n", " btn = gr.Button(value=\"Flip Image\")\n", " btn.click(flip, inputs=[img_i], outputs=[img_o])\n", "\n", " gr.Examples(\n", " [\n", " os.path.join(os.path.abspath(''), \"images/cheetah1.jpg\"),\n", " os.path.join(os.path.abspath(''), \"images/lion.jpg\"),\n", " ],\n", " img_i,\n", " img_o,\n", " flip,\n", " )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
10 changes: 6 additions & 4 deletions demo/examples_component/run.py
@@ -1,9 +1,11 @@
import gradio as gr
import os


def flip(i):
return i.rotate(180)


with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
Expand All @@ -15,13 +17,13 @@ def flip(i):
btn.click(flip, inputs=[img_i], outputs=[img_o])

gr.Examples(
[
[
os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg"),
os.path.join(os.path.dirname(__file__), "images/lion.jpg"),
],
img_i,
img_o,
flip
flip,
)
demo.launch()

demo.launch()
7 changes: 2 additions & 5 deletions gradio/blocks.py
Expand Up @@ -588,11 +588,8 @@ def get_block_instance(id: int) -> Block:
else:
raise ValueError(f"Cannot find block with id {id}")
cls = component_or_layout_class(block_config["type"])
block_config["props"].pop("type", None)
block_config["props"].pop("name", None)
block_config["props"].pop("selectable", None)
block_config["props"].pop("streamable", None)
block_config["props"].pop("server_fns", None)

block_config["props"] = utils.recover_kwargs(block_config["props"])

# If a Gradio app B is loaded into a Gradio app A, and B itself loads a
# Gradio app C, then the root_urls of the components in A need to be the
Expand Down
22 changes: 18 additions & 4 deletions gradio/components/dataset.py
Expand Up @@ -6,6 +6,7 @@

from gradio_client.documentation import document, set_documentation_group

import gradio.utils as utils
from gradio.components.base import (
Component,
get_component_instance,
Expand Down Expand Up @@ -66,6 +67,13 @@ def __init__(
self.scale = scale
self.min_width = min_width
self._components = [get_component_instance(c) for c in components]
self.component_props = [
utils.recover_kwargs(
component.get_config(),
["value"],
)
for component in self._components
]

# Narrow type to Component
assert all(
Expand Down Expand Up @@ -95,10 +103,16 @@ def skip_api(self):

def get_config(self):
config = super().get_config()
config["components"] = [
component.get_block_name() for component in self._components
]
config["component_ids"] = [component._id for component in self._components]

config["components"] = []
config["component_props"] = self.component_props
config["component_ids"] = []

for component in self._components:
config["components"].append(component.get_block_name())

config["component_ids"].append(component._id)

return config

def preprocess(self, x: Any) -> Any:
Expand Down
11 changes: 10 additions & 1 deletion gradio/helpers.py
Expand Up @@ -259,7 +259,16 @@ def create(self) -> None:

async def load_example(example_id):
processed_example = self.non_none_processed_examples[example_id]
return utils.resolve_singleton(processed_example)
examples = utils.resolve_singleton(processed_example)

return (
update(value=examples, **self.dataset.component_props[0])
if not isinstance(examples, list)
else [
update(value=ex, **self.dataset.component_props[i])
for i, ex in enumerate(examples)
]
)

if Context.root_block:
self.load_input_event = self.dataset.click(
Expand Down
9 changes: 9 additions & 0 deletions gradio/utils.py
Expand Up @@ -950,3 +950,12 @@ def find_user_stack_level() -> int:
frame = frame.f_back
n += 1
return n


def recover_kwargs(config: dict, additional_keys_to_ignore: list[str] | None = None):
not_kwargs = ["type", "name", "selectable", "server_fns", "streamable"]
return {
k: v
for k, v in config.items()
if k not in not_kwargs and k not in (additional_keys_to_ignore or [])
}
3 changes: 3 additions & 0 deletions js/dataset/Index.svelte
Expand Up @@ -4,6 +4,7 @@
import type { Gradio, SelectData } from "@gradio/utils";
import { get_fetchable_url_or_file } from "@gradio/upload";
export let components: string[];
export let component_props: Record<string, any>[];
export let component_map: Map<
string,
Promise<{
Expand Down Expand Up @@ -141,6 +142,7 @@
{#if component_meta.length && component_map.get(components[0])}
<svelte:component
this={component_meta[0][0].component}
{...component_props[0]}
value={sample_row[0]}
{samples_dir}
type="gallery"
Expand Down Expand Up @@ -185,6 +187,7 @@
>
<svelte:component
this={component}
{...component_props[j]}
{value}
{samples_dir}
type="table"
Expand Down
12 changes: 7 additions & 5 deletions test/test_components.py
Expand Up @@ -2178,11 +2178,13 @@ def test_as_example_returns_file_basename(component):
assert component.as_example(None) == ""


@patch("gradio.components.Component.as_example")
@patch("gradio.components.Image.as_example")
@patch("gradio.components.File.as_example")
@patch("gradio.components.Dataframe.as_example")
@patch("gradio.components.Model3D.as_example")
@patch(
"gradio.components.Component.as_example", spec=gr.components.Component.as_example
)
@patch("gradio.components.Image.as_example", spec=gr.Image.as_example)
@patch("gradio.components.File.as_example", spec=gr.File.as_example)
@patch("gradio.components.Dataframe.as_example", spec=gr.DataFrame.as_example)
@patch("gradio.components.Model3D.as_example", spec=gr.Model3D.as_example)
def test_dataset_calls_as_example(*mocks):
gr.Dataset(
components=[gr.Dataframe(), gr.File(), gr.Image(), gr.Model3D(), gr.Textbox()],
Expand Down
34 changes: 32 additions & 2 deletions test/test_helpers.py
Expand Up @@ -481,10 +481,40 @@ def concatenate(str1, str2):
client = TestClient(app)

response = client.post("/api/load_example/", json={"data": [0]})
assert response.json()["data"] == ["Hello,"]
assert response.json()["data"] == [
{
"lines": 1,
"max_lines": 20,
"show_label": True,
"container": True,
"min_width": 160,
"autofocus": False,
"autoscroll": True,
"rtl": False,
"show_copy_button": False,
"__type__": "update",
"visible": True,
"value": "Hello,",
}
]

response = client.post("/api/load_example/", json={"data": [1]})
assert response.json()["data"] == ["Michael"]
assert response.json()["data"] == [
{
"lines": 1,
"max_lines": 20,
"show_label": True,
"container": True,
"min_width": 160,
"autofocus": False,
"autoscroll": True,
"rtl": False,
"show_copy_button": False,
"__type__": "update",
"visible": True,
"value": "Michael",
}
]

def test_end_to_end_cache_examples(self):
def concatenate(str1, str2):
Expand Down

0 comments on commit cad537a

Please sign in to comment.