From f42d3e29c7b8cdedd7aea75cffce81857db28eeb Mon Sep 17 00:00:00 2001 From: Freddy Boulton Date: Fri, 29 Mar 2024 14:35:21 -0700 Subject: [PATCH] Make internal event handlers of gr.Interface and gr.ChatInterface async (#7893) * Make async * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/strong-hornets-push.md | 5 +++++ gradio/chat_interface.py | 24 +++++++++++++----------- gradio/interface.py | 20 ++++++++++++++------ gradio/utils.py | 13 +++++++++++++ 4 files changed, 45 insertions(+), 17 deletions(-) create mode 100644 .changeset/strong-hornets-push.md diff --git a/.changeset/strong-hornets-push.md b/.changeset/strong-hornets-push.md new file mode 100644 index 000000000000..5c85dbf76655 --- /dev/null +++ b/.changeset/strong-hornets-push.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +feat:Make internal event handlers of gr.Interface and gr.ChatInterface async diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index b4d14a383f3d..0268b96f074e 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -29,7 +29,7 @@ from gradio.layouts import Accordion, Group, Row from gradio.routes import Request from gradio.themes import ThemeClass as Theme -from gradio.utils import SyncToAsyncIterator, async_iteration +from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda @document() @@ -378,7 +378,7 @@ def _setup_events(self) -> None: show_api=False, queue=False, ).then( - lambda x: x, + async_lambda(lambda x: x), [self.saved_input], [self.textbox], show_api=False, @@ -387,7 +387,7 @@ def _setup_events(self) -> None: if self.clear_btn: self.clear_btn.click( - lambda: ([], [], None), + async_lambda(lambda: ([], [], None)), None, [self.chatbot, self.chatbot_state, self.saved_input], queue=False, @@ -401,9 +401,11 @@ def _setup_stop_events( if self.submit_btn: for event_trigger in event_triggers: event_trigger( - lambda: ( - Button(visible=False), - Button(visible=True), + async_lambda( + lambda: ( + Button(visible=False), + Button(visible=True), + ) ), None, [self.submit_btn, self.stop_btn], @@ -411,7 +413,7 @@ def _setup_stop_events( queue=False, ) event_to_cancel.then( - lambda: (Button(visible=True), Button(visible=False)), + async_lambda(lambda: (Button(visible=True), Button(visible=False))), None, [self.submit_btn, self.stop_btn], show_api=False, @@ -420,14 +422,14 @@ def _setup_stop_events( else: for event_trigger in event_triggers: event_trigger( - lambda: Button(visible=True), + async_lambda(lambda: Button(visible=True)), None, [self.stop_btn], show_api=False, queue=False, ) event_to_cancel.then( - lambda: Button(visible=False), + async_lambda(lambda: Button(visible=False)), None, [self.stop_btn], show_api=False, @@ -471,7 +473,7 @@ def _append_multimodal_history( if message["text"] is not None and isinstance(message["text"], str): history.append([message["text"], response]) - def _display_input( + async def _display_input( self, message: str | dict[str, list], history: list[list[str | tuple | None]] ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: if self.multimodal and isinstance(message, dict): @@ -631,7 +633,7 @@ async def _examples_stream_fn( async for response in generator: yield [[message, response]] - def _delete_prev_fn( + async def _delete_prev_fn( self, message: str | dict[str, list], history: list[list[str | tuple | None]], diff --git a/gradio/interface.py b/gradio/interface.py index 0cfbea035628..db7fb3106e21 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -701,14 +701,16 @@ def attach_submit_events( if _stop_btn: extra_output = [_submit_btn, _stop_btn] - def cleanup(): + async def cleanup(): return [Button(visible=True), Button(visible=False)] predict_event = on( triggers, - lambda: ( - Button(visible=False), - Button(visible=True), + utils.async_lambda( + lambda: ( + Button(visible=False), + Button(visible=True), + ) ), inputs=None, outputs=[_submit_btn, _stop_btn], @@ -827,7 +829,9 @@ def attach_flagging_events( ) flag_method = FlagMethod(self.flagging_callback, label, value) flag_btn.click( - lambda: Button(value="Saving...", interactive=False), + utils.async_lambda( + lambda: Button(value="Saving...", interactive=False) + ), None, flag_btn, queue=False, @@ -842,7 +846,11 @@ def attach_flagging_events( show_api=False, ) _clear_btn.click( - flag_method.reset, None, flag_btn, queue=False, show_api=False + utils.async_lambda(flag_method.reset), + None, + flag_btn, + queue=False, + show_api=False, ) def render_examples(self): diff --git a/gradio/utils.py b/gradio/utils.py index b69ad7b22bfb..5fdd43199efd 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -26,6 +26,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import contextmanager +from functools import wraps from io import BytesIO from numbers import Number from pathlib import Path @@ -1245,3 +1246,15 @@ def simplify_file_data_in_str(s): if isinstance(payload, str): return payload return json.dumps(payload) + + +def async_lambda(f: Callable) -> Callable: + """Turn a function into an async function. + Useful for internal event handlers defined as lambda functions used in the codebase + """ + + @wraps(f) + async def function_wrapper(*args, **kwargs): + return f(*args, **kwargs) + + return function_wrapper