Skip to content

Commit

Permalink
Fix event target (#6383)
Browse files Browse the repository at this point in the history
* changes

* changes

* add changeset

* chagnes

* chaneg

* changes

* fix

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
aliabid94 and gradio-pr-bot committed Nov 13, 2023
1 parent 649f3ce commit 324867f
Show file tree
Hide file tree
Showing 13 changed files with 93 additions and 79 deletions.
7 changes: 7 additions & 0 deletions .changeset/itchy-doors-work.md
@@ -0,0 +1,7 @@
---
"@gradio/app": patch
"@gradio/client": patch
"gradio": patch
---

fix:Fix event target
8 changes: 5 additions & 3 deletions client/js/src/client.ts
Expand Up @@ -47,7 +47,8 @@ type client_return = {
submit: (
endpoint: string | number,
data?: unknown[],
event_data?: unknown
event_data?: unknown,
trigger_id?: number | null
) => SubmitReturn;
component_server: (
component_id: number,
Expand Down Expand Up @@ -412,7 +413,8 @@ export function api_factory(
function submit(
endpoint: string | number,
data: unknown[],
event_data?: unknown
event_data?: unknown,
trigger_id: number | null = null
): SubmitReturn {
let fn_index: number;
let api_info;
Expand Down Expand Up @@ -453,7 +455,7 @@ export function api_factory(
api_info,
hf_token
).then((_payload) => {
payload = { data: _payload || [], event_data, fn_index };
payload = { data: _payload || [], event_data, fn_index, trigger_id };
if (skip_queue(fn_index, config)) {
fire_event({
type: "status",
Expand Down
2 changes: 1 addition & 1 deletion demo/on_listener_basic/run.ipynb
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_basic"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", "\n", " def greet(name):\n", " return \"Hello \" + name + \"!\"\n", "\n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=output,\n", " )\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_basic"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output Box\")\n", " greet_btn = gr.Button(\"Greet\")\n", " trigger = gr.Textbox(label=\"Trigger Box\")\n", " trigger2 = gr.Textbox(label=\"Trigger Box\")\n", "\n", " def greet(name, evt_data: gr.EventData):\n", " return \"Hello \" + name + \"!\", evt_data.target.__class__.__name__\n", " \n", " def clear_name(evt_data: gr.EventData):\n", " return \"\", evt_data.target.__class__.__name__\n", " \n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=[output, trigger],\n", " ).then(clear_name, outputs=[name, trigger2])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
15 changes: 10 additions & 5 deletions demo/on_listener_basic/run.py
Expand Up @@ -4,16 +4,21 @@
name = gr.Textbox(label="Name")
output = gr.Textbox(label="Output Box")
greet_btn = gr.Button("Greet")
trigger = gr.Textbox(label="Trigger Box")
trigger2 = gr.Textbox(label="Trigger Box")

def greet(name):
return "Hello " + name + "!"

def greet(name, evt_data: gr.EventData):
return "Hello " + name + "!", evt_data.target.__class__.__name__

def clear_name(evt_data: gr.EventData):
return "", evt_data.target.__class__.__name__

gr.on(
triggers=[name.submit, greet_btn.click],
fn=greet,
inputs=name,
outputs=output,
)
outputs=[output, trigger],
).then(clear_name, outputs=[name, trigger2])


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions demo/on_listener_test/run.ipynb
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: on_listener_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " name = gr.Textbox(label=\"Name\")\n", " output = gr.Textbox(label=\"Output\")\n", " greet_btn = gr.Button(\"Greet\")\n", " trigger = gr.Textbox(label=\"Trigger 1\")\n", " trigger2 = gr.Textbox(label=\"Trigger 2\")\n", "\n", " def greet(name, evt_data: gr.EventData):\n", " return \"Hello \" + name + \"!\", evt_data.target.__class__.__name__\n", " \n", " def clear_name(evt_data: gr.EventData):\n", " return \"\", evt_data.target.__class__.__name__\n", " \n", " gr.on(\n", " triggers=[name.submit, greet_btn.click],\n", " fn=greet,\n", " inputs=name,\n", " outputs=[output, trigger],\n", " ).then(clear_name, outputs=[name, trigger2])\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
25 changes: 25 additions & 0 deletions demo/on_listener_test/run.py
@@ -0,0 +1,25 @@
import gradio as gr

with gr.Blocks() as demo:
name = gr.Textbox(label="Name")
output = gr.Textbox(label="Output")
greet_btn = gr.Button("Greet")
trigger = gr.Textbox(label="Trigger 1")
trigger2 = gr.Textbox(label="Trigger 2")

def greet(name, evt_data: gr.EventData):
return "Hello " + name + "!", evt_data.target.__class__.__name__

def clear_name(evt_data: gr.EventData):
return "", evt_data.target.__class__.__name__

gr.on(
triggers=[name.submit, greet_btn.click],
fn=greet,
inputs=name,
outputs=[output, trigger],
).then(clear_name, outputs=[name, trigger2])


if __name__ == "__main__":
demo.launch()
1 change: 1 addition & 0 deletions gradio/data_classes.py
Expand Up @@ -24,6 +24,7 @@ class Config:
data: List[Any]
event_data: Optional[Any] = None
fn_index: Optional[int] = None
trigger_id: Optional[int] = None
batched: Optional[
bool
] = False # Whether the data is a batch of samples (i.e. called from the queue if batch=True) or a single sample (i.e. called from the UI)
Expand Down
8 changes: 3 additions & 5 deletions gradio/route_utils.py
Expand Up @@ -199,12 +199,10 @@ def restore_session_state(app: App, body: PredictBody):
def prepare_event_data(
blocks: Blocks,
body: PredictBody,
fn_index_inferred: int,
) -> EventData:
dependency = blocks.dependencies[fn_index_inferred]
target = dependency["targets"][0] if len(dependency["targets"]) else None
target = body.trigger_id
event_data = EventData(
blocks.blocks.get(target[0]) if target else None,
blocks.blocks.get(target) if target else None,
body.event_data,
)
return event_data
Expand All @@ -219,7 +217,7 @@ async def call_process_api(
session_state, iterator = restore_session_state(app=app, body=body)

dependency = app.get_blocks().dependencies[fn_index_inferred]
event_data = prepare_event_data(app.get_blocks(), body, fn_index_inferred)
event_data = prepare_event_data(app.get_blocks(), body)
event_id = body.event_id

session_hash = getattr(body, "session_hash", None)
Expand Down
19 changes: 13 additions & 6 deletions js/app/src/Blocks.svelte
Expand Up @@ -376,6 +376,7 @@
async function trigger_api_call(
dep_index: number,
trigger_id: number | null = null,
event_data: unknown = null
): Promise<void> {
let dep = dependencies[dep_index];
Expand All @@ -397,7 +398,8 @@
let payload: Payload = {
fn_index: dep_index,
data: dep.inputs.map((id) => instance_map[id].props.value),
event_data: dep.collects_event_data ? event_data : null
event_data: dep.collects_event_data ? event_data : null,
trigger_id: trigger_id
};
if (dep.frontend_fn) {
Expand Down Expand Up @@ -435,7 +437,12 @@
const pending_outputs: number[] = [];
let outputs_set_to_non_interactive: number[] = [];
const submission = app
.submit(payload.fn_index, payload.data as unknown[], payload.event_data)
.submit(
payload.fn_index,
payload.data as unknown[],
payload.event_data,
payload.trigger_id
)
.on("data", ({ data, fn_index }) => {
if (dep.pending_request && dep.final_event) {
dep.pending_request = false;
Expand Down Expand Up @@ -499,7 +506,7 @@
if (status.stage === "complete") {
dependencies.map(async (dep, i) => {
if (dep.trigger_after === fn_index) {
trigger_api_call(i);
trigger_api_call(i, payload.trigger_id);
}
});
Expand All @@ -512,7 +519,7 @@
...messages
];
}, 0);
trigger_api_call(dep_index, event_data);
trigger_api_call(dep_index, payload.trigger_id, event_data);
user_left_page = false;
} else if (status.stage === "error") {
if (status.message) {
Expand All @@ -530,7 +537,7 @@
dep.trigger_after === fn_index &&
!dep.trigger_only_on_success
) {
trigger_api_call(i);
trigger_api_call(i, payload.trigger_id);
}
});
Expand Down Expand Up @@ -627,7 +634,7 @@
} else {
const deps = target_map[id]?.[event];
deps?.forEach((dep_id) => {
trigger_api_call(dep_id, data);
trigger_api_call(dep_id, id, data);
});
}
});
Expand Down
1 change: 1 addition & 0 deletions js/app/src/types.ts
Expand Up @@ -29,6 +29,7 @@ export interface Payload {
fn_index: number;
data: unknown[];
event_data: unknown | null;
trigger_id: number | null;
}

export interface Dependency {
Expand Down
22 changes: 22 additions & 0 deletions js/app/test/on_listener_test.spec.ts
@@ -0,0 +1,22 @@
import { test, expect } from "@gradio/tootils";

test("On listener works.", async ({ page }) => {
const name_box = await page.locator("textarea").nth(0);
const output_box = await page.locator("textarea").nth(1);
const trigger1_box = await page.locator("textarea").nth(2);
const trigger2_box = await page.locator("textarea").nth(3);

name_box.fill("Jimmy");
await page.click("text=Greet");
await expect(output_box).toHaveValue("Hello Jimmy!");
await expect(trigger1_box).toHaveValue("Button");
await expect(name_box).toHaveValue("");
await expect(trigger2_box).toHaveValue("Button");

await name_box.fill("Sally");
await name_box.press("Enter");
await expect(output_box).toHaveValue("Hello Sally!");
await expect(trigger1_box).toHaveValue("Textbox");
await expect(name_box).toHaveValue("");
await expect(trigger2_box).toHaveValue("Textbox");
});
44 changes: 4 additions & 40 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions test/test_route_utils.py

This file was deleted.

0 comments on commit 324867f

Please sign in to comment.