Skip to content

Commit

Permalink
Add concurrency_limit to ChatInterface, add IDE support for concurren…
Browse files Browse the repository at this point in the history
…cy_limit (#6653)

* concurrency limit chat interface

* add changeset

* Update gradio/chat_interface.py

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Dec 4, 2023
1 parent 19c9d26 commit d92c819
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .changeset/sour-cows-mix.md
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:Add concurrency_limit to ChatInterface, add IDE support for concurrency_limit
14 changes: 13 additions & 1 deletion gradio/chat_interface.py
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import inspect
from typing import AsyncGenerator, Callable
from typing import AsyncGenerator, Callable, Literal, Union, cast

import anyio
from gradio_client import utils as client_utils
Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
autofocus: bool = True,
concurrency_limit: int | None | Literal["default"] = "default",
):
"""
Parameters:
Expand All @@ -97,6 +98,7 @@ def __init__(
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
autofocus: If True, autofocuses to the textbox when the page loads.
concurrency_limit: If set, this this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand All @@ -105,6 +107,7 @@ def __init__(
title=title or "Gradio",
theme=theme,
)
self.concurrency_limit = concurrency_limit
self.fn = fn
self.is_async = inspect.iscoroutinefunction(
self.fn
Expand Down Expand Up @@ -304,6 +307,9 @@ def _setup_events(self) -> None:
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
)
self._setup_stop_events(submit_triggers, submit_event)
Expand All @@ -329,6 +335,9 @@ def _setup_events(self) -> None:
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
Expand Down Expand Up @@ -412,6 +421,9 @@ def _setup_api(self) -> None:
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
api_name="chat",
concurrency_limit=cast(
Union[int, Literal["default"], None], self.concurrency_limit
),
)

def _clear_and_save_textbox(self, message: str) -> tuple[str, str]:
Expand Down
7 changes: 5 additions & 2 deletions gradio/component_meta.py
Expand Up @@ -21,7 +21,6 @@ def {{ event }}(self,
inputs: Component | Sequence[Component] | set[Component] | None = None,
outputs: Component | Sequence[Component] | None = None,
api_name: str | None | Literal[False] = None,
status_tracker: None = None,
scroll_to_output: bool = False,
show_progress: Literal["full", "minimal", "hidden"] = "full",
queue: bool | None = None,
Expand All @@ -32,7 +31,9 @@ def {{ event }}(self,
cancels: dict[str, Any] | list[dict[str, Any]] | None = None,
every: float | None = None,
trigger_mode: Literal["once", "multiple", "always_last"] | None = None,
js: str | None = None,) -> Dependency:
js: str | None = None,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None) -> Dependency:
"""
Parameters:
fn: the function to call when this event is triggered. Often a machine learning model's prediction function. Each parameter of the function corresponds to one input component, and the function should return a single value or a tuple of values, with each element in the tuple corresponding to one output component.
Expand All @@ -50,6 +51,8 @@ def {{ event }}(self,
every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled.
trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete.
js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components.
concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default).
concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit.
"""
...
{% endfor %}
Expand Down
6 changes: 6 additions & 0 deletions test/test_chat_interface.py
Expand Up @@ -49,6 +49,12 @@ def test_configuring_buttons(self):
assert chatbot.submit_btn is None
assert chatbot.retry_btn is None

def test_concurrency_limit(self):
chat = gr.ChatInterface(double, concurrency_limit=10)
assert chat.concurrency_limit == 10
fns = [fn for fn in chat.fns if fn.name in {"_submit_fn", "_api_submit_fn"}]
assert all(fn.concurrency_limit == 10 for fn in fns)

def test_events_attached(self):
chatbot = gr.ChatInterface(double)
dependencies = chatbot.dependencies
Expand Down
24 changes: 24 additions & 0 deletions test/test_events.py
@@ -1,3 +1,7 @@
import ast
import inspect
from pathlib import Path

import pytest
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -159,3 +163,23 @@ def test_event_defined_invalid_scope(self):

with pytest.raises(AttributeError):
textbox.change(lambda x: x + x, textbox, textbox)


def test_event_pyi_file_matches_source_code():
"""Test that the template used to create pyi files (search INTERFACE_TEMPLATE in component_meta) matches the source code of EventListener._setup."""
code = (
Path(__file__).parent / ".." / "gradio" / "components" / "button.pyi"
).read_text()
mod = ast.parse(code)
segment = None
for node in ast.walk(mod):
if isinstance(node, ast.FunctionDef) and node.name == "click":
segment = ast.get_source_segment(code, node)

# This would fail if Button no longer has a click method
assert segment
sig = inspect.signature(gr.Button.click)
for param in sig.parameters.values():
if param.name == "block":
continue
assert param.name in segment

0 comments on commit d92c819

Please sign in to comment.