diff --git a/.changeset/dirty-guests-suffer.md b/.changeset/dirty-guests-suffer.md new file mode 100644 index 000000000000..cb1762e6d14d --- /dev/null +++ b/.changeset/dirty-guests-suffer.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Allow setting a `default_concurrency_limit` other than 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 31c7f12fbd66..25d82148afee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -232,6 +232,8 @@ Previously, in Gradio 3.x, there was a single global `concurrency_count` paramet In Gradio 4.0, the `concurrency_count` parameter has been removed. You can still control the number of total threads by using the `max_threads` parameter. The default value of this parameter is `40`, but you don't have worry (as much) about OOM errors, because even though there are 40 threads, we use a single-worker-single-event model, which means each worker thread only executes a specific function. So effectively, each function has its own "concurrency count" of 1. If you'd like to change this behavior, you can do so by setting a parameter `concurrency_limit`, which is now a parameter of *each event*, not a global parameter. By default this is `1` for each event, but you can set it to a higher value, or to `None` if you'd like to allow an arbitrary number of executions of this event simultaneously. Events can also be grouped together using the `concurrency_id` parameter so that they share the same limit, and by default, events that call the same function share the same `concurrency_id`. +Lastly, it should be noted that the default value of the `concurrency_limit` of all events in a Blocks (which is normally 1) can be changed using the `default_concurrency_limit` parameter in `Blocks.queue()`. You can set this to a higher integer or to `None`. This in turn sets the `concurrency_limit` of all events that don't have an explicit `conurrency_limit` specified. + To summarize migration: * For events that execute quickly or don't use much CPU or GPU resources, you should set `concurrency_limit=None` in Gradio 4.0. (Previously you would set `queue=False`.) diff --git a/gradio/blocks.py b/gradio/blocks.py index 75366cba504c..2e26fe9e7a89 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -366,7 +366,7 @@ def __init__( inputs_as_dict: bool, batch: bool = False, max_batch_size: int = 4, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, tracks_progress: bool = False, ): @@ -592,6 +592,7 @@ def __init__( self.output_components = None self.__name__ = None self.api_mode = None + self.progress_tracking = None self.ssl_verify = True @@ -822,7 +823,7 @@ def set_event_trigger( trigger_after: int | None = None, trigger_only_on_success: bool = False, trigger_mode: Literal["once", "multiple", "always_last"] | None = "once", - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> tuple[dict[str, Any], int]: """ @@ -848,7 +849,7 @@ def set_event_trigger( trigger_after: if set, this event will be triggered after 'trigger_after' function index trigger_only_on_success: if True, this event will only be triggered if the previous event was successful (only applies if `trigger_after` is set) trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. - concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Extra events triggered by this listener will be queued. On Spaces, this is set to 1 by default. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. Returns: dependency information, dependency index """ @@ -1649,6 +1650,8 @@ def queue( api_open: bool | None = None, max_size: int | None = None, concurrency_count: int | None = None, + *, + default_concurrency_limit: int | None | Literal["not_set"] = "not_set", ): """ By enabling the queue you can control when users know their position in the queue, and set a limit on maximum number of events allowed. @@ -1657,6 +1660,7 @@ def queue( api_open: If True, the REST routes of the backend will be open, allowing requests made directly to those endpoints to skip the queue. max_size: The maximum number of events the queue will store at any given moment. If the queue is full, new events will not be added and a user will receive a message saying that the queue is full. If None, the queue size will be unlimited. concurrency_count: Deprecated and has no effect. Set the concurrency_limit directly on event listeners e.g. btn.click(fn, ..., concurrency_limit=10) or gr.Interface(concurrency_limit=10). If necessary, the total number of workers can be configured via `max_threads` in launch(). + default_concurrency_limit: The default value of `concurrency_limit` to use for event listeners that don't specify a value. Can be set by environment variable GRADIO_DEFAULT_CONCURRENCY_LIMIT. Defaults to 1 if not set otherwise. Example: (Blocks) with gr.Blocks() as demo: button = gr.Button(label="Generate Image") @@ -1682,6 +1686,7 @@ def queue( update_intervals=status_update_rate if status_update_rate != "auto" else 1, max_size=max_size, block_fns=self.fns, + default_concurrency_limit=default_concurrency_limit, ) self.config = self.get_config_file() self.app = routes.App.create_app(self) diff --git a/gradio/events.py b/gradio/events.py index 5b215b0916dc..12f2bde712ff 100644 --- a/gradio/events.py +++ b/gradio/events.py @@ -207,7 +207,7 @@ def event_trigger( every: float | None = None, trigger_mode: Literal["once", "multiple", "always_last"] | None = None, js: str | None = None, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> Dependency: """ @@ -227,7 +227,7 @@ def event_trigger( every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. trigger_mode: If "once" (default for all events except `.change()`) would not allow any submissions while an event is pending. If set to "multiple", unlimited submissions are allowed while pending, and "always_last" (default for `.change()` event) would allow a second submission after the pending event is complete. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs' and 'outputs', return should be a list of values for output components. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. """ @@ -351,7 +351,7 @@ def on( cancels: dict[str, Any] | list[dict[str, Any]] | None = None, every: float | None = None, js: str | None = None, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", concurrency_id: str | None = None, ) -> Dependency: """ @@ -371,7 +371,7 @@ def on( cancels: A list of other events to cancel when this listener is triggered. For example, setting cancels=[click_event] will cancel the click_event, where click_event is the return value of another components .click method. Functions that have not yet run (or generators that are iterating) will be cancelled, but functions that are currently running will be allowed to finish. every: Run this event 'every' number of seconds while the client connection is open. Interpreted in seconds. Queue must be enabled. js: Optional frontend js method to run before running 'fn'. Input arguments for js method are values of 'inputs', return should be a list of values for output components. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `Blocks.queue()`, which itself is 1 by default). concurrency_id: If set, this is the id of the concurrency group. Events with the same concurrency_id will be limited by the lowest set concurrency_limit. """ from gradio.components.base import Component diff --git a/gradio/interface.py b/gradio/interface.py index 7539b0eff815..ce46cfd331d8 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -114,7 +114,7 @@ def __init__( api_name: str | Literal[False] | None = "predict", _api_mode: bool = False, allow_duplication: bool = False, - concurrency_limit: int | None = 1, + concurrency_limit: int | None | Literal["default"] = "default", **kwargs, ): """ @@ -141,7 +141,7 @@ def __init__( max_batch_size: Maximum number of inputs to batch together if this is called from the queue (only relevant if batch=True) api_name: defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None, the name of the prediction function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that `gr.load` this app) will not be able to use this event. allow_duplication: If True, then will show a 'Duplicate Spaces' button on Hugging Face Spaces. - concurrency_limit: If set, this this is the maximum number of events that can be running simultaneously. Extra requests will be queued. + concurrency_limit: If set, this this is the maximum number of this event that can be running simultaneously. Can be set to None to mean no concurrency_limit (any number of this event can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which itself is 1 by default). """ super().__init__( analytics_enabled=analytics_enabled, @@ -312,7 +312,7 @@ def __init__( self.batch = batch self.max_batch_size = max_batch_size self.allow_duplication = allow_duplication - self.concurrency_limit = concurrency_limit + self.concurrency_limit: int | None | Literal["default"] = concurrency_limit self.share = None self.share_url = None diff --git a/gradio/queueing.py b/gradio/queueing.py index ae347512daee..1509331684ee 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -3,6 +3,7 @@ import asyncio import copy import json +import os import time import traceback import uuid @@ -78,6 +79,7 @@ def __init__( update_intervals: float, max_size: int | None, block_fns: list[BlockFunction], + default_concurrency_limit: int | None | Literal["not_set"] = "not_set", ): self.event_queue: list[Event] = [] self.awaiting_data_events: dict[str, Event] = {} @@ -99,20 +101,27 @@ def __init__( self.block_fns = block_fns self.continuous_tasks: list[Event] = [] self._asyncio_tasks: list[asyncio.Task] = [] - + self.default_concurrency_limit = self._resolve_concurrency_limit( + default_concurrency_limit + ) self.concurrency_limit_per_concurrency_id = {} def start(self): self.active_jobs = [None] * self.max_thread_count for block_fn in self.block_fns: - if block_fn.concurrency_limit is not None: + concurrency_limit = ( + self.default_concurrency_limit + if block_fn.concurrency_limit == "default" + else block_fn.concurrency_limit + ) + if concurrency_limit is not None: self.concurrency_limit_per_concurrency_id[ block_fn.concurrency_id ] = min( self.concurrency_limit_per_concurrency_id.get( - block_fn.concurrency_id, block_fn.concurrency_limit + block_fn.concurrency_id, concurrency_limit ), - block_fn.concurrency_limit, + concurrency_limit, ) run_coro_in_background(self.start_processing) @@ -123,6 +132,26 @@ def start(self): def close(self): self.stopped = True + def _resolve_concurrency_limit(self, default_concurrency_limit): + """ + Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination + of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT` + environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable. + Parameters: + default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`. + """ + if default_concurrency_limit != "not_set": + return default_concurrency_limit + if default_concurrency_limit_env := os.environ.get( + "GRADIO_DEFAULT_CONCURRENCY_LIMIT" + ): + if default_concurrency_limit_env.lower() == "none": + return None + else: + return int(default_concurrency_limit_env) + else: + return 1 + def attach_data(self, body: PredictBody): event_id = body.event_id if event_id in self.awaiting_data_events: diff --git a/test/test_queueing.py b/test/test_queueing.py index 83c2491c6946..e65bb060fa85 100644 --- a/test/test_queueing.py +++ b/test/test_queueing.py @@ -1,13 +1,14 @@ import time import gradio_client as grc +import pytest from fastapi.testclient import TestClient import gradio as gr class TestQueueing: - def test_single_request(self): + def test_single_request(self, connect): with gr.Blocks() as demo: name = gr.Textbox() output = gr.Textbox() @@ -19,12 +20,11 @@ def greet(x): demo.launch(prevent_thread_lock=True) - client = grc.Client(f"http://localhost:{demo.server_port}") - job = client.submit("x", fn_index=0) + with connect(demo) as client: + job = client.submit("x", fn_index=0) + assert job.result() == "Hello, x!" - assert job.result() == "Hello, x!" - - def test_all_status_messages(self): + def test_all_status_messages(self, connect): with gr.Blocks() as demo: name = gr.Textbox() output = gr.Textbox() @@ -35,10 +35,10 @@ def greet(x): name.submit(greet, name, output, concurrency_limit=2) - app, _, _ = demo.launch(prevent_thread_lock=True) + app, local_url, _ = demo.launch(prevent_thread_lock=True) test_client = TestClient(app) + client = grc.Client(local_url) - client = grc.Client(f"http://localhost:{demo.server_port}") client.submit("a", fn_index=0) job2 = client.submit("b", fn_index=0) client.submit("c", fn_index=0) @@ -70,7 +70,44 @@ def greet(x): assert job2.result() == "Hello, b!" assert job4.result() == "Hello, d!" - def test_concurrency_limits(self): + @pytest.mark.parametrize( + "default_concurrency_limit, statuses", + [ + ("not_set", ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), + (None, ["PROCESSING", "PROCESSING", "PROCESSING"]), + (1, ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]), + (2, ["IN_QUEUE", "PROCESSING", "PROCESSING"]), + ], + ) + def test_default_concurrency_limits(self, default_concurrency_limit, statuses): + with gr.Blocks() as demo: + a = gr.Number() + b = gr.Number() + output = gr.Number() + + add_btn = gr.Button("Add") + + @add_btn.click(inputs=[a, b], outputs=output) + def add(x, y): + time.sleep(2) + return x + y + + demo.queue(default_concurrency_limit=default_concurrency_limit) + _, local_url, _ = demo.launch( + prevent_thread_lock=True, + ) + client = grc.Client(local_url) + + add_job_1 = client.submit(1, 1, fn_index=0) + add_job_2 = client.submit(1, 1, fn_index=0) + add_job_3 = client.submit(1, 1, fn_index=0) + + time.sleep(1) + + add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] + assert sorted([s.code.value for s in add_job_statuses]) == statuses + + def test_concurrency_limits(self, connect): with gr.Blocks() as demo: a = gr.Number() b = gr.Number() @@ -80,14 +117,14 @@ def test_concurrency_limits(self): @add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2) def add(x, y): - time.sleep(4) + time.sleep(2) return x + y sub_btn = gr.Button("Subtract") @sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None) def sub(x, y): - time.sleep(4) + time.sleep(2) return x - y mul_btn = gr.Button("Multiply") @@ -99,7 +136,7 @@ def sub(x, y): concurrency_id="muldiv", ) def mul(x, y): - time.sleep(4) + time.sleep(2) return x * y div_btn = gr.Button("Divide") @@ -111,49 +148,55 @@ def mul(x, y): concurrency_id="muldiv", ) def div(x, y): - time.sleep(4) + time.sleep(2) return x / y - app, _, _ = demo.launch(prevent_thread_lock=True) - - client = grc.Client(f"http://localhost:{demo.server_port}") - add_job_1 = client.submit(1, 1, fn_index=0) - add_job_2 = client.submit(1, 1, fn_index=0) - add_job_3 = client.submit(1, 1, fn_index=0) - sub_job_1 = client.submit(1, 1, fn_index=1) - sub_job_2 = client.submit(1, 1, fn_index=1) - sub_job_3 = client.submit(1, 1, fn_index=1) - sub_job_3 = client.submit(1, 1, fn_index=1) - mul_job_1 = client.submit(1, 1, fn_index=2) - div_job_1 = client.submit(1, 1, fn_index=3) - mul_job_2 = client.submit(1, 1, fn_index=2) - - time.sleep(2) - - add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] - assert sorted([s.code.value for s in add_job_statuses]) == [ - "IN_QUEUE", - "PROCESSING", - "PROCESSING", - ] - - sub_job_statuses = [sub_job_1.status(), sub_job_2.status(), sub_job_3.status()] - assert [s.code.value for s in sub_job_statuses] == [ - "PROCESSING", - "PROCESSING", - "PROCESSING", - ] - - muldiv_job_statuses = [ - mul_job_1.status(), - div_job_1.status(), - mul_job_2.status(), - ] - assert sorted([s.code.value for s in muldiv_job_statuses]) == [ - "IN_QUEUE", - "PROCESSING", - "PROCESSING", - ] + with connect(demo) as client: + add_job_1 = client.submit(1, 1, fn_index=0) + add_job_2 = client.submit(1, 1, fn_index=0) + add_job_3 = client.submit(1, 1, fn_index=0) + sub_job_1 = client.submit(1, 1, fn_index=1) + sub_job_2 = client.submit(1, 1, fn_index=1) + sub_job_3 = client.submit(1, 1, fn_index=1) + sub_job_3 = client.submit(1, 1, fn_index=1) + mul_job_1 = client.submit(1, 1, fn_index=2) + div_job_1 = client.submit(1, 1, fn_index=3) + mul_job_2 = client.submit(1, 1, fn_index=2) + + time.sleep(1) + + add_job_statuses = [ + add_job_1.status(), + add_job_2.status(), + add_job_3.status(), + ] + assert sorted([s.code.value for s in add_job_statuses]) == [ + "IN_QUEUE", + "PROCESSING", + "PROCESSING", + ] + + sub_job_statuses = [ + sub_job_1.status(), + sub_job_2.status(), + sub_job_3.status(), + ] + assert [s.code.value for s in sub_job_statuses] == [ + "PROCESSING", + "PROCESSING", + "PROCESSING", + ] + + muldiv_job_statuses = [ + mul_job_1.status(), + div_job_1.status(), + mul_job_2.status(), + ] + assert sorted([s.code.value for s in muldiv_job_statuses]) == [ + "IN_QUEUE", + "PROCESSING", + "PROCESSING", + ] def test_every_does_not_block_queue(self): with gr.Blocks() as demo: @@ -162,10 +205,10 @@ def test_every_does_not_block_queue(self): num.submit(lambda n: 2 * n, num, num, every=0.5) num2.submit(lambda n: 3 * n, num, num) - app, _, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True) + app, local_url, _ = demo.queue(max_size=1).launch(prevent_thread_lock=True) test_client = TestClient(app) - client = grc.Client(f"http://localhost:{demo.server_port}") + client = grc.Client(local_url) job = client.submit(1, fn_index=1) for _ in range(5):