diff --git a/.changeset/sour-cows-mix.md b/.changeset/sour-cows-mix.md new file mode 100644 index 000000000000..0512fb37ddbf --- /dev/null +++ b/.changeset/sour-cows-mix.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Add concurrency_limit to ChatInterface, add IDE support for concurrency_limit diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index e044f579ee94..c8163680878d 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -6,7 +6,7 @@ from __future__ import annotations import inspect -from typing import AsyncGenerator, Callable +from typing import AsyncGenerator, Callable, Literal, Union, cast import anyio from gradio_client import utils as client_utils @@ -75,6 +75,7 @@ def __init__( undo_btn: str | None | Button = "↩ī¸ Undo", clear_btn: str | None | Button = "🗑ī¸ Clear", autofocus: bool = True, + concurrency_limit: int | None | Literal["default"] = "default", ): """ Parameters: @@ -97,6 +98,7 @@ def __init__( undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used. clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used. autofocus: If True, autofocuses to the textbox when the page loads. + concurrency_limit: If set, this this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default). """ super().__init__( analytics_enabled=analytics_enabled, @@ -105,6 +107,7 @@ def __init__( title=title or "Gradio", theme=theme, ) + self.concurrency_limit = concurrency_limit self.fn = fn self.is_async = inspect.iscoroutinefunction( self.fn @@ -304,6 +307,9 @@ def _setup_events(self) -> None: [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, + concurrency_limit=cast( + Union[int, Literal["default"], None], self.concurrency_limit + ), ) ) self._setup_stop_events(submit_triggers, submit_event) @@ -329,6 +335,9 @@ def _setup_events(self) -> None: [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, + concurrency_limit=cast( + Union[int, Literal["default"], None], self.concurrency_limit + ), ) ) self._setup_stop_events([self.retry_btn.click], retry_event) @@ -412,6 +421,9 @@ def _setup_api(self) -> None: [self.textbox, self.chatbot_state] + self.additional_inputs, [self.textbox, self.chatbot_state], api_name="chat", + concurrency_limit=cast( + Union[int, Literal["default"], None], self.concurrency_limit + ), ) def _clear_and_save_textbox(self, message: str) -> tuple[str, str]: diff --git a/gradio/component_meta.py b/gradio/component_meta.py index 90656b025dbd..c33c383f820f 100644 --- a/gradio/component_meta.py +++ b/gradio/component_meta.py @@ -21,7 +21,6 @@ def {{ event }}(self, inputs: Component | Sequence[Component] | set[Component] | None = None, outputs: Component | Sequence[Component] | None = None, api_name: str | None | Literal[False] = None, - status_tracker: None = None, scroll_to_output: bool = False, show_progress: Literal["full", "minimal", "hidden"] = "full", queue: bool | None = None, @@ -32,7 +31,9 @@ def {{ event }}(self, cancels: dict[str, Any] | list[dict[str, Any]] | None = None, every: float | None = None, trigger_mode: Literal["once", "multiple", "always_last"] | None = None, - js: str | None = None,) -> Dependency: + js: str | None = None, + concurrency_limit: int | None | Literal["default"] = "default", + concurrency_id: str | None = None) -> Dependency: """ Parameters: fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component. @@ -50,6 +51,8 @@ def {{ event }}(self, every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). + concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. """ ... {% endfor %} diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index 43a73f9e2466..bee1bf6f8d12 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -49,6 +49,12 @@ def test_configuring_buttons(self): assert chatbot.submit_btn is None assert chatbot.retry_btn is None + def test_concurrency_limit(self): + chat = gr.ChatInterface(double, concurrency_limit=10) + assert chat.concurrency_limit == 10 + fns = [fn for fn in chat.fns if fn.name in {"_submit_fn", "_api_submit_fn"}] + assert all(fn.concurrency_limit == 10 for fn in fns) + def test_events_attached(self): chatbot = gr.ChatInterface(double) dependencies = chatbot.dependencies diff --git a/test/test_events.py b/test/test_events.py index 72ca9d0dd321..b4ab145af0fc 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -1,3 +1,7 @@ +import ast +import inspect +from pathlib import Path + import pytest from fastapi.testclient import TestClient @@ -159,3 +163,23 @@ def test_event_defined_invalid_scope(self): with pytest.raises(AttributeError): textbox.change(lambda x: x + x, textbox, textbox) + + +def test_event_pyi_file_matches_source_code(): + """Test that the template used to create pyi files (search INTERFACE_TEMPLATE in component_meta) matches the source code of EventListener._setup.""" + code = ( + Path(__file__).parent / ".." / "gradio" / "components" / "button.pyi" + ).read_text() + mod = ast.parse(code) + segment = None + for node in ast.walk(mod): + if isinstance(node, ast.FunctionDef) and node.name == "click": + segment = ast.get_source_segment(code, node) + + # This would fail if Button no longer has a click method + assert segment + sig = inspect.signature(gr.Button.click) + for param in sig.parameters.values(): + if param.name == "block": + continue + assert param.name in segment