diff --git a/.changeset/itchy-doors-work.md b/.changeset/itchy-doors-work.md new file mode 100644 index 000000000000..86941c540a87 --- /dev/null +++ b/.changeset/itchy-doors-work.md @@ -0,0 +1,7 @@ +--- +"@gradio/app": patch +"@gradio/client": patch +"gradio": patch +--- + +fix:Fix event target diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 2e01a279a9bf..7ff732df7893 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -47,7 +47,8 @@ type client_return = { submit: ( endpoint: string | number, data?: unknown[], - event_data?: unknown + event_data?: unknown, + trigger_id?: number | null ) => SubmitReturn; component_server: ( component_id: number, @@ -412,7 +413,8 @@ export function api_factory( function submit( endpoint: string | number, data: unknown[], - event_data?: unknown + event_data?: unknown, + trigger_id: number | null = null ): SubmitReturn { let fn_index: number; let api_info; @@ -453,7 +455,7 @@ export function api_factory( api_info, hf_token ).then((_payload) => { - payload = { data: _payload || [], event_data, fn_index }; + payload = { data: _payload || [], event_data, fn_index, trigger_id }; if (skip_queue(fn_index, config)) { fire_event({ type: "status", diff --git a/demo/on_listener_basic/run.ipynb b/demo/on_listener_basic/run.ipynb index d3cc658c5bb7..8b9fa173ba17 100644 --- a/demo/on_listener_basic/run.ipynb +++ b/demo/on_listener_basic/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_basic"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", "\n", " def greet(name):\n", " return \"Hello \" + name + \"!\"\n", "\n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=output,\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_basic"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", " trigger = gr.Textbox(label=\"Trigger Box\")\n", " trigger2 = gr.Textbox(label=\"Trigger Box\")\n", "\n", " def greet(name, evt_data: gr.EventData):\n", " return \"Hello \" + name + \"!\", evt_data.target.__class__.__name__\n", " \n", " def clear_name(evt_data: gr.EventData):\n", " return \"\", evt_data.target.__class__.__name__\n", " \n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=[output, trigger],\n", " ).then(clear_name, outputs=[name, trigger2])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/on_listener_basic/run.py b/demo/on_listener_basic/run.py index 929858661122..214b1cc7dc72 100644 --- a/demo/on_listener_basic/run.py +++ b/demo/on_listener_basic/run.py @@ -4,16 +4,21 @@ name = gr.Textbox(label="Name") output = gr.Textbox(label="Output Box") greet_btn = gr.Button("Greet") + trigger = gr.Textbox(label="Trigger Box") + trigger2 = gr.Textbox(label="Trigger Box") - def greet(name): - return "Hello " + name + "!" - + def greet(name, evt_data: gr.EventData): + return "Hello " + name + "!", evt_data.target.__class__.__name__ + + def clear_name(evt_data: gr.EventData): + return "", evt_data.target.__class__.__name__ + gr.on( triggers=[name.submit, greet_btn.click], fn=greet, inputs=name, - outputs=output, - ) + outputs=[output, trigger], + ).then(clear_name, outputs=[name, trigger2]) if __name__ == "__main__": diff --git a/demo/on_listener_test/run.ipynb b/demo/on_listener_test/run.ipynb new file mode 100644 index 000000000000..a7258cdc437e --- /dev/null +++ b/demo/on_listener_test/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output\")\n", " greet_btn = gr.Button(\"Greet\")\n", " trigger = gr.Textbox(label=\"Trigger 1\")\n", " trigger2 = gr.Textbox(label=\"Trigger 2\")\n", "\n", " def greet(name, evt_data: gr.EventData):\n", " return \"Hello \" + name + \"!\", evt_data.target.__class__.__name__\n", " \n", " def clear_name(evt_data: gr.EventData):\n", " return \"\", evt_data.target.__class__.__name__\n", " \n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=[output, trigger],\n", " ).then(clear_name, outputs=[name, trigger2])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/on_listener_test/run.py b/demo/on_listener_test/run.py new file mode 100644 index 000000000000..d5ae73c8de2f --- /dev/null +++ b/demo/on_listener_test/run.py @@ -0,0 +1,25 @@ +import gradio as gr + +with gr.Blocks() as demo: + name = gr.Textbox(label="Name") + output = gr.Textbox(label="Output") + greet_btn = gr.Button("Greet") + trigger = gr.Textbox(label="Trigger 1") + trigger2 = gr.Textbox(label="Trigger 2") + + def greet(name, evt_data: gr.EventData): + return "Hello " + name + "!", evt_data.target.__class__.__name__ + + def clear_name(evt_data: gr.EventData): + return "", evt_data.target.__class__.__name__ + + gr.on( + triggers=[name.submit, greet_btn.click], + fn=greet, + inputs=name, + outputs=[output, trigger], + ).then(clear_name, outputs=[name, trigger2]) + + +if __name__ == "__main__": + demo.launch() diff --git a/gradio/data_classes.py b/gradio/data_classes.py index f46386940e08..86b89408eb88 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -24,6 +24,7 @@ class Config: data: List[Any] event_data: Optional[Any] = None fn_index: Optional[int] = None + trigger_id: Optional[int] = None batched: Optional[ bool ] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI) diff --git a/gradio/route_utils.py b/gradio/route_utils.py index be05ec66d164..8e8726dedd74 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -199,12 +199,10 @@ def restore_session_state(app: App, body: PredictBody): def prepare_event_data( blocks: Blocks, body: PredictBody, - fn_index_inferred: int, ) -> EventData: - dependency = blocks.dependencies[fn_index_inferred] - target = dependency["targets"][0] if len(dependency["targets"]) else None + target = body.trigger_id event_data = EventData( - blocks.blocks.get(target[0]) if target else None, + blocks.blocks.get(target) if target else None, body.event_data, ) return event_data @@ -219,7 +217,7 @@ async def call_process_api( session_state, iterator = restore_session_state(app=app, body=body) dependency = app.get_blocks().dependencies[fn_index_inferred] - event_data = prepare_event_data(app.get_blocks(), body, fn_index_inferred) + event_data = prepare_event_data(app.get_blocks(), body) event_id = body.event_id session_hash = getattr(body, "session_hash", None) diff --git a/js/app/src/Blocks.svelte b/js/app/src/Blocks.svelte index e2b94c8cfe68..5918e429d5b3 100644 --- a/js/app/src/Blocks.svelte +++ b/js/app/src/Blocks.svelte @@ -376,6 +376,7 @@ async function trigger_api_call( dep_index: number, + trigger_id: number | null = null, event_data: unknown = null ): Promise { let dep = dependencies[dep_index]; @@ -397,7 +398,8 @@ let payload: Payload = { fn_index: dep_index, data: dep.inputs.map((id) => instance_map[id].props.value), - event_data: dep.collects_event_data ? event_data : null + event_data: dep.collects_event_data ? event_data : null, + trigger_id: trigger_id }; if (dep.frontend_fn) { @@ -435,7 +437,12 @@ const pending_outputs: number[] = []; let outputs_set_to_non_interactive: number[] = []; const submission = app - .submit(payload.fn_index, payload.data as unknown[], payload.event_data) + .submit( + payload.fn_index, + payload.data as unknown[], + payload.event_data, + payload.trigger_id + ) .on("data", ({ data, fn_index }) => { if (dep.pending_request && dep.final_event) { dep.pending_request = false; @@ -499,7 +506,7 @@ if (status.stage === "complete") { dependencies.map(async (dep, i) => { if (dep.trigger_after === fn_index) { - trigger_api_call(i); + trigger_api_call(i, payload.trigger_id); } }); @@ -512,7 +519,7 @@ ...messages ]; }, 0); - trigger_api_call(dep_index, event_data); + trigger_api_call(dep_index, payload.trigger_id, event_data); user_left_page = false; } else if (status.stage === "error") { if (status.message) { @@ -530,7 +537,7 @@ dep.trigger_after === fn_index && !dep.trigger_only_on_success ) { - trigger_api_call(i); + trigger_api_call(i, payload.trigger_id); } }); @@ -627,7 +634,7 @@ } else { const deps = target_map[id]?.[event]; deps?.forEach((dep_id) => { - trigger_api_call(dep_id, data); + trigger_api_call(dep_id, id, data); }); } }); diff --git a/js/app/src/types.ts b/js/app/src/types.ts index 4cf4850ce3a8..b3ea294b686b 100644 --- a/js/app/src/types.ts +++ b/js/app/src/types.ts @@ -29,6 +29,7 @@ export interface Payload { fn_index: number; data: unknown[]; event_data: unknown | null; + trigger_id: number | null; } export interface Dependency { diff --git a/js/app/test/on_listener_test.spec.ts b/js/app/test/on_listener_test.spec.ts new file mode 100644 index 000000000000..5d2fd1a77675 --- /dev/null +++ b/js/app/test/on_listener_test.spec.ts @@ -0,0 +1,22 @@ +import { test, expect } from "@gradio/tootils"; + +test("On listener works.", async ({ page }) => { + const name_box = await page.locator("textarea").nth(0); + const output_box = await page.locator("textarea").nth(1); + const trigger1_box = await page.locator("textarea").nth(2); + const trigger2_box = await page.locator("textarea").nth(3); + + name_box.fill("Jimmy"); + await page.click("text=Greet"); + await expect(output_box).toHaveValue("Hello Jimmy!"); + await expect(trigger1_box).toHaveValue("Button"); + await expect(name_box).toHaveValue(""); + await expect(trigger2_box).toHaveValue("Button"); + + await name_box.fill("Sally"); + await name_box.press("Enter"); + await expect(output_box).toHaveValue("Hello Sally!"); + await expect(trigger1_box).toHaveValue("Textbox"); + await expect(name_box).toHaveValue(""); + await expect(trigger2_box).toHaveValue("Textbox"); +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index bf7d548cc307..7fc6145e75c9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1,4 +1,4 @@ -lockfileVersion: '6.0' +lockfileVersion: '6.1' settings: autoInstallPeers: true @@ -6737,7 +6737,7 @@ packages: svelte: 4.2.2 tiny-glob: 0.2.9 undici: 5.26.4 - vite: 4.5.0(@types/node@20.3.2)(less@4.1.3) + vite: 4.5.0(@types/node@20.3.1)(less@4.1.3)(lightningcss@1.21.7)(sass@1.66.1)(stylus@0.60.0)(sugarss@4.0.1) transitivePeerDependencies: - supports-color @@ -16208,7 +16208,7 @@ packages: mlly: 1.4.0 pathe: 1.1.1 picocolors: 1.0.0 - vite: 4.5.0(@types/node@20.3.2)(less@4.1.3) + vite: 4.5.0(@types/node@20.3.1)(less@4.1.3)(lightningcss@1.21.7)(sass@1.66.1)(stylus@0.60.0)(sugarss@4.0.1) transitivePeerDependencies: - '@types/node' - less @@ -16260,42 +16260,6 @@ packages: optionalDependencies: fsevents: 2.3.2 - /vite@4.5.0(@types/node@20.3.2)(less@4.1.3): - resolution: {integrity: sha512-ulr8rNLA6rkyFAlVWw2q5YJ91v098AFQ2R0PRFwPzREXOUJQPtFUG0t+/ZikhaOCDqFoDhN6/v8Sq0o4araFAw==} - engines: {node: ^14.18.0 || >=16.0.0} - hasBin: true - peerDependencies: - '@types/node': '>= 14' - less: '*' - lightningcss: ^1.21.0 - sass: '*' - stylus: '*' - sugarss: '*' - terser: ^5.4.0 - peerDependenciesMeta: - '@types/node': - optional: true - less: - optional: true - lightningcss: - optional: true - sass: - optional: true - stylus: - optional: true - sugarss: - optional: true - terser: - optional: true - dependencies: - '@types/node': 20.3.2 - esbuild: 0.18.20 - less: 4.1.3 - postcss: 8.4.27 - rollup: 3.29.0 - optionalDependencies: - fsevents: 2.3.2 - /vitefu@0.2.4(vite@4.5.0): resolution: {integrity: sha512-fanAXjSaf9xXtOOeno8wZXIhgia+CZury481LsDaV++lSvcU2R9Ch2bPh3PYFyoHW+w9LqAeYRISVQjUIew14g==} peerDependencies: @@ -16360,7 +16324,7 @@ packages: strip-literal: 1.0.1 tinybench: 2.5.0 tinypool: 0.7.0 - vite: 4.5.0(@types/node@20.3.2)(less@4.1.3) + vite: 4.5.0(@types/node@20.3.1)(less@4.1.3)(lightningcss@1.21.7)(sass@1.66.1)(stylus@0.60.0)(sugarss@4.0.1) vite-node: 0.34.6(@types/node@20.3.2)(less@4.1.3) why-is-node-running: 2.2.2 transitivePeerDependencies: diff --git a/test/test_route_utils.py b/test/test_route_utils.py deleted file mode 100644 index bcef9e681d79..000000000000 --- a/test/test_route_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import gradio as gr -from gradio.data_classes import PredictBody -from gradio.helpers import EventData -from gradio.route_utils import prepare_event_data - - -def test_prepare_event_data(): - def on_select(evt: gr.SelectData): - return f"You selected {evt.value} at {evt.index} from {evt.target}" - - with gr.Blocks() as demo: - textbox = gr.Textbox("Hello World!") - statement = gr.Textbox() - textbox.select(on_select, None, statement) - - body = PredictBody(data=[], event_data={"value": "World", "index": [6, 11]}) - event_data = prepare_event_data(demo, body, 0) - correct_event_data = EventData(textbox, {"value": "World", "index": [6, 11]}) - assert vars(event_data) == vars(correct_event_data)