diff --git a/.changeset/eighty-teeth-greet.md b/.changeset/eighty-teeth-greet.md new file mode 100644 index 000000000000..7e9a6c82c462 --- /dev/null +++ b/.changeset/eighty-teeth-greet.md @@ -0,0 +1,7 @@ +--- +"@gradio/client": patch +"@gradio/statustracker": patch +"gradio": patch +--- + +feat:Refactor queue so that there are separate queues for each concurrency id diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 6dbf4c82e59a..f2e3170266c3 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -287,6 +287,7 @@ export function api_factory( const session_hash = Math.random().toString(36).substring(2); const last_status: Record = {}; let stream_open = false; + let pending_stream_messages: Record = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete. let event_stream: EventSource | null = null; const event_callbacks: Record Promise> = {}; let config: Config; @@ -908,8 +909,8 @@ export function api_factory( } if ( - status.stage === "complete" || - status.stage === "error" + status?.stage === "complete" || + status?.stage === "error" ) { if (event_callbacks[event_id]) { delete event_callbacks[event_id]; @@ -932,6 +933,12 @@ export function api_factory( close_stream(); } }; + if (event_id in pending_stream_messages) { + pending_stream_messages[event_id].forEach((msg) => + callback(msg) + ); + delete pending_stream_messages[event_id]; + } event_callbacks[event_id] = callback; if (!stream_open) { open_stream(); @@ -1051,15 +1058,21 @@ export function api_factory( event_stream = new EventSource(url); event_stream.onmessage = async function (event) { let _data = JSON.parse(event.data); - if (!("event_id" in _data)) { + const event_id = _data.event_id; + if (!event_id) { await Promise.all( Object.keys(event_callbacks).map((event_id) => event_callbacks[event_id](_data) ) ); - return; + } else if (event_callbacks[event_id]) { + await event_callbacks[event_id](_data); + } else { + if (!pending_stream_messages[event_id]) { + pending_stream_messages[event_id] = []; + } + pending_stream_messages[event_id].push(_data); } - await event_callbacks[_data.event_id](_data); }; } @@ -1701,8 +1714,7 @@ function handle_message( message: !data.success ? data.output.error : undefined, stage: data.success ? "complete" : "error", code: data.code, - progress_data: data.progress_data, - eta: data.output.average_duration + progress_data: data.progress_data }, data: data.success ? data.output : null }; @@ -1716,7 +1728,8 @@ function handle_message( code: data.code, size: data.rank, position: 0, - success: data.success + success: data.success, + eta: data.eta } }; } diff --git a/gradio/blocks.py b/gradio/blocks.py index 0fd06bd903ad..d68473000c13 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -376,7 +376,7 @@ def __init__( self.preprocess = preprocess self.postprocess = postprocess self.tracks_progress = tracks_progress - self.concurrency_limit = concurrency_limit + self.concurrency_limit: int | None | Literal["default"] = concurrency_limit self.concurrency_id = concurrency_id or str(id(fn)) self.batch = batch self.max_batch_size = max_batch_size diff --git a/gradio/data_classes.py b/gradio/data_classes.py index 8a87f7ac5a3c..2f6341b36712 100644 --- a/gradio/data_classes.py +++ b/gradio/data_classes.py @@ -101,10 +101,7 @@ class InterfaceTypes(Enum): class Estimation(BaseModel): rank: Optional[int] = None queue_size: int - avg_event_process_time: Optional[float] = None - avg_event_concurrent_process_time: Optional[float] = None rank_eta: Optional[float] = None - queue_eta: float class ProgressUnit(BaseModel): diff --git a/gradio/queueing.py b/gradio/queueing.py index 4d6bc826eb6e..6b3f8f62392d 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -4,9 +4,11 @@ import copy import json import os +import random import time import traceback import uuid +from collections import defaultdict from queue import Queue as ThreadQueue from typing import TYPE_CHECKING @@ -37,11 +39,13 @@ def __init__( fn_index: int, request: fastapi.Request, username: str | None, + concurrency_id: str, ): self.session_hash = session_hash self.fn_index = fn_index self.request = request self.username = username + self.concurrency_id = concurrency_id self._id = uuid.uuid4().hex self.data: PredictBody | None = None self.progress: Progress | None = None @@ -49,6 +53,27 @@ def __init__( self.alive = True +class EventQueue: + def __init__(self, concurrency_id: str, concurrency_limit: int | None): + self.queue: list[Event] = [] + self.concurrency_id = concurrency_id + self.concurrency_limit = concurrency_limit + self.current_concurrency = 0 + self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set) + + +class ProcessTime: + def __init__(self): + self.process_time = 0 + self.count = 0 + self.avg_time = 0 + + def add(self, time: float): + self.process_time += time + self.count += 1 + self.avg_time = self.process_time / self.count + + class Queue: def __init__( self, @@ -62,19 +87,16 @@ def __init__( self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000) self.pending_event_ids_session: dict[str, set[str]] = {} self.pending_message_lock = safe_get_lock() - self.event_queue: list[Event] = [] - self.awaiting_data_events: dict[str, Event] = {} + self.event_queue_per_concurrency_id: dict[str, EventQueue] = {} self.stopped = False self.max_thread_count = concurrency_count self.update_intervals = update_intervals self.active_jobs: list[None | list[Event]] = [] self.delete_lock = safe_get_lock() self.server_app = None - self.duration_history_total = 0 - self.duration_history_count = 0 - self.avg_process_time = 0 - self.avg_concurrent_process_time = None - self.queue_duration = 1 + self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict( + ProcessTime + ) self.live_updates = live_updates self.sleep_when_free = 0.05 self.progress_update_sleep_when_free = 0.1 @@ -85,25 +107,31 @@ def __init__( self.default_concurrency_limit = self._resolve_concurrency_limit( default_concurrency_limit ) - self.concurrency_limit_per_concurrency_id = {} def start(self): self.active_jobs = [None] * self.max_thread_count for block_fn in self.block_fns: - concurrency_limit = ( - self.default_concurrency_limit - if block_fn.concurrency_limit == "default" - else block_fn.concurrency_limit - ) - if concurrency_limit is not None: - self.concurrency_limit_per_concurrency_id[ - block_fn.concurrency_id - ] = min( - self.concurrency_limit_per_concurrency_id.get( - block_fn.concurrency_id, concurrency_limit - ), - concurrency_limit, + concurrency_id = block_fn.concurrency_id + concurrency_limit: int | None + if block_fn.concurrency_limit == "default": + concurrency_limit = self.default_concurrency_limit + else: + concurrency_limit = block_fn.concurrency_limit + if concurrency_id not in self.event_queue_per_concurrency_id: + self.event_queue_per_concurrency_id[concurrency_id] = EventQueue( + concurrency_id, concurrency_limit ) + elif ( + concurrency_limit is not None + ): # Update concurrency limit if it is lower than existing limit + existing_event_queue = self.event_queue_per_concurrency_id[ + concurrency_id + ] + if ( + existing_event_queue.concurrency_limit is None + or concurrency_limit < existing_event_queue.concurrency_limit + ): + existing_event_queue.concurrency_limit = concurrency_limit run_coro_in_background(self.start_processing) run_coro_in_background(self.start_progress_updates) @@ -119,11 +147,15 @@ def send_message( message_type: str, data: dict | None = None, ): + if not event.alive: + return data = {} if data is None else data messages = self.pending_messages_per_session[event.session_hash] messages.put_nowait({"msg": message_type, "event_id": event._id, **data}) - def _resolve_concurrency_limit(self, default_concurrency_limit): + def _resolve_concurrency_limit( + self, default_concurrency_limit: int | None | Literal["not_set"] + ) -> int | None: """ Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT` @@ -143,6 +175,12 @@ def _resolve_concurrency_limit(self, default_concurrency_limit): else: return 1 + def __len__(self): + total_len = 0 + for event_queue in self.event_queue_per_concurrency_id.values(): + total_len += len(event_queue.queue) + return total_len + async def push( self, body: PredictBody, request: fastapi.Request, username: str | None ) -> tuple[bool, str]: @@ -150,14 +188,19 @@ async def push( return False, "No session hash provided." if body.fn_index is None: return False, "No function index provided." - queue_len = len(self.event_queue) - if self.max_size is not None and queue_len >= self.max_size: + if self.max_size is not None and len(self) >= self.max_size: return ( False, - f"Queue is full. Max size is {self.max_size} and size is {queue_len}.", + f"Queue is full. Max size is {self.max_size} and size is {len(self)}.", ) - event = Event(body.session_hash, body.fn_index, request, username) + event = Event( + body.session_hash, + body.fn_index, + request, + username, + self.block_fns[body.fn_index].concurrency_id, + ) event.data = body async with self.pending_message_lock: if body.session_hash not in self.pending_messages_per_session: @@ -165,10 +208,10 @@ async def push( if body.session_hash not in self.pending_event_ids_session: self.pending_event_ids_session[body.session_hash] = set() self.pending_event_ids_session[body.session_hash].add(event._id) - self.event_queue.append(event) + event_queue = self.event_queue_per_concurrency_id[event.concurrency_id] + event_queue.queue.append(event) - estimation = self.get_estimation() - await self.send_estimation(event, estimation, queue_len) + self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1) return True, event._id @@ -187,88 +230,73 @@ def get_active_worker_count(self) -> int: count += 1 return count - def get_events_in_batch(self) -> tuple[list[Event] | None, bool]: - if not self.event_queue: - return None, False - - worker_count_per_concurrency_id = {} - for job in self.active_jobs: - if job is not None: - for event in job: - concurrency_id = self.block_fns[event.fn_index].concurrency_id - worker_count_per_concurrency_id[concurrency_id] = ( - worker_count_per_concurrency_id.get(concurrency_id, 0) + 1 - ) - - events = [] - batch = False - for index, event in enumerate(self.event_queue): - block_fn = self.block_fns[event.fn_index] - concurrency_id = block_fn.concurrency_id - concurrency_limit = self.concurrency_limit_per_concurrency_id.get( - concurrency_id, None - ) - existing_worker_count = worker_count_per_concurrency_id.get( - concurrency_id, 0 - ) - if concurrency_limit is None or existing_worker_count < concurrency_limit: + def get_events(self) -> tuple[list[Event], bool, str] | None: + concurrency_ids = list(self.event_queue_per_concurrency_id.keys()) + random.shuffle(concurrency_ids) + for concurrency_id in concurrency_ids: + event_queue = self.event_queue_per_concurrency_id[concurrency_id] + if len(event_queue.queue) and ( + event_queue.concurrency_limit is None + or event_queue.current_concurrency < event_queue.concurrency_limit + ): + first_event = event_queue.queue[0] + block_fn = self.block_fns[first_event.fn_index] + events = [first_event] batch = block_fn.batch if batch: - batch_size = block_fn.max_batch_size - if concurrency_limit is None: - remaining_worker_count = batch_size - 1 - else: - remaining_worker_count = ( - concurrency_limit - existing_worker_count - ) - rest_of_batch = [ + events += [ event - for event in self.event_queue[index:] - if event.fn_index == event.fn_index - ][: min(batch_size - 1, remaining_worker_count)] - events = [event] + rest_of_batch - else: - events = [event] - break + for event in event_queue.queue[1:] + if event.fn_index == first_event.fn_index + ][: block_fn.max_batch_size - 1] - for event in events: - self.event_queue.remove(event) + for event in events: + event_queue.queue.remove(event) - return events, batch + return events, batch, concurrency_id async def start_processing(self) -> None: - while not self.stopped: - if not self.event_queue: - await asyncio.sleep(self.sleep_when_free) - continue - - if None not in self.active_jobs: - await asyncio.sleep(self.sleep_when_free) - continue - # Using mutex to avoid editing a list in use - async with self.delete_lock: - events, batch = self.get_events_in_batch() - - if events: - self.active_jobs[self.active_jobs.index(None)] = events - process_event_task = run_coro_in_background( - self.process_events, events, batch - ) - set_task_name( - process_event_task, - events[0].session_hash, - events[0].fn_index, - batch, - ) - - self._asyncio_tasks.append(process_event_task) - if self.live_updates: - broadcast_live_estimations_task = run_coro_in_background( - self.broadcast_estimations + try: + while not self.stopped: + if len(self) == 0: + await asyncio.sleep(self.sleep_when_free) + continue + + if None not in self.active_jobs: + await asyncio.sleep(self.sleep_when_free) + continue + + # Using mutex to avoid editing a list in use + async with self.delete_lock: + event_batch = self.get_events() + + if event_batch: + events, batch, concurrency_id = event_batch + self.active_jobs[self.active_jobs.index(None)] = events + event_queue = self.event_queue_per_concurrency_id[concurrency_id] + event_queue.current_concurrency += 1 + start_time = time.time() + event_queue.start_times_per_fn_index[events[0].fn_index].add( + start_time ) - self._asyncio_tasks.append(broadcast_live_estimations_task) - else: - await asyncio.sleep(self.sleep_when_free) + process_event_task = run_coro_in_background( + self.process_events, events, batch, start_time + ) + set_task_name( + process_event_task, + events[0].session_hash, + events[0].fn_index, + batch, + ) + + self._asyncio_tasks.append(process_event_task) + if self.live_updates: + self.broadcast_estimations(concurrency_id) + else: + await asyncio.sleep(self.sleep_when_free) + finally: + self.stopped = True + self._cancel_asyncio_tasks() async def start_progress_updates(self) -> None: """ @@ -345,14 +373,17 @@ async def clean_events( if job.session_hash == session_hash or job._id == event_id: job.alive = False - events_to_remove = [] - for event in self.event_queue: - if event.session_hash == session_hash or event._id == event_id: - events_to_remove.append(event) + async with self.delete_lock: + events_to_remove: list[Event] = [] + for event_queue in self.event_queue_per_concurrency_id.values(): + for event in event_queue.queue: + if event.session_hash == session_hash or event._id == event_id: + events_to_remove.append(event) - for event in events_to_remove: - async with self.delete_lock: - self.event_queue.remove(event) + for event in events_to_remove: + self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove( + event + ) async def notify_clients(self) -> None: """ @@ -360,66 +391,65 @@ async def notify_clients(self) -> None: """ while not self.stopped: await asyncio.sleep(self.update_intervals) - if self.event_queue: - await self.broadcast_estimations() - - async def broadcast_estimations(self) -> None: - estimation = self.get_estimation() - # Send all messages concurrently - await asyncio.gather( - *[ - self.send_estimation(event, estimation, rank) - for rank, event in enumerate(self.event_queue) - ] - ) + if len(self) > 0: + for concurrency_id in self.event_queue_per_concurrency_id: + self.broadcast_estimations(concurrency_id) - async def send_estimation( - self, event: Event, estimation: Estimation, rank: int - ) -> Estimation: - """ - Send estimation about ETA to the client. - - Parameters: - event: - estimation: - rank: - """ - estimation.rank = rank + def broadcast_estimations( + self, concurrency_id: str, after: int | None = None + ) -> None: + wait_so_far = 0 + event_queue = self.event_queue_per_concurrency_id[concurrency_id] + time_till_available_worker: int | None = 0 + + if event_queue.current_concurrency == event_queue.concurrency_limit: + expected_end_times = [] + for fn_index, start_times in event_queue.start_times_per_fn_index.items(): + if fn_index not in self.process_time_per_fn_index: + time_till_available_worker = None + break + process_time = self.process_time_per_fn_index[fn_index].avg_time + expected_end_times += [ + start_time + process_time for start_time in start_times + ] + if time_till_available_worker is not None and len(expected_end_times) > 0: + time_of_first_completion = min(expected_end_times) + time_till_available_worker = max( + time_of_first_completion - time.time(), 0 + ) - if self.avg_concurrent_process_time is not None: - estimation.rank_eta = ( - estimation.rank * self.avg_concurrent_process_time - + self.avg_process_time + for rank, event in enumerate(event_queue.queue): + process_time_for_fn = ( + self.process_time_per_fn_index[event.fn_index].avg_time + if event.fn_index in self.process_time_per_fn_index + else None + ) + rank_eta = ( + process_time_for_fn + wait_so_far + time_till_available_worker + if process_time_for_fn is not None + and wait_so_far is not None + and time_till_available_worker is not None + else None ) - if None not in self.active_jobs: - # Add estimated amount of time for a thread to get empty - estimation.rank_eta += self.avg_concurrent_process_time - self.send_message(event, ServerMessage.estimation, estimation.model_dump()) - return estimation - - def update_estimation(self, duration: float) -> None: - """ - Update estimation by last x element's average duration. - Parameters: - duration: - """ - self.duration_history_total += duration - self.duration_history_count += 1 - self.avg_process_time = ( - self.duration_history_total / self.duration_history_count - ) - self.avg_concurrent_process_time = self.avg_process_time / min( - self.max_thread_count, self.duration_history_count - ) - self.queue_duration = self.avg_concurrent_process_time * len(self.event_queue) + if after is None or rank >= after: + self.send_message( + event, + ServerMessage.estimation, + Estimation( + rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue) + ).model_dump(), + ) + if event_queue.concurrency_limit is None: + wait_so_far = 0 + elif wait_so_far is not None and process_time_for_fn is not None: + wait_so_far += process_time_for_fn / event_queue.concurrency_limit + else: + wait_so_far = None - def get_estimation(self) -> Estimation: + def get_status(self) -> Estimation: return Estimation( - queue_size=len(self.event_queue), - avg_event_process_time=self.avg_process_time, - avg_event_concurrent_process_time=self.avg_concurrent_process_time, - queue_eta=self.queue_duration, + queue_size=len(self), ) async def call_prediction(self, events: list[Event], batch: bool): @@ -484,20 +514,30 @@ async def call_prediction(self, events: list[Event], batch: bool): return response_json - async def process_events(self, events: list[Event], batch: bool) -> None: + async def process_events( + self, events: list[Event], batch: bool, begin_time: float + ) -> None: awake_events: list[Event] = [] + fn_index = events[0].fn_index try: for event in events: - self.send_message(event, ServerMessage.process_starts) - awake_events.append(event) + if event.alive: + self.send_message( + event, + ServerMessage.process_starts, + { + "eta": self.process_time_per_fn_index[fn_index].avg_time + if fn_index in self.process_time_per_fn_index + else None + }, + ) + awake_events.append(event) if not awake_events: return - begin_time = time.time() try: response = await self.call_prediction(awake_events, batch) err = None except Exception as e: - traceback.print_exc() response = None err = e for event in awake_events: @@ -568,10 +608,17 @@ async def process_events(self, events: list[Event], batch: bool) -> None: ) end_time = time.time() if response is not None: - self.update_estimation(end_time - begin_time) + self.process_time_per_fn_index[events[0].fn_index].add( + end_time - begin_time + ) except Exception as e: traceback.print_exc() finally: + event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id] + event_queue.current_concurrency -= 1 + start_times = event_queue.start_times_per_fn_index[fn_index] + if begin_time in start_times: + start_times.remove(begin_time) try: self.active_jobs[self.active_jobs.index(events)] = None except ValueError: diff --git a/gradio/routes.py b/gradio/routes.py index 139b6dd38f1c..293dfec69aa0 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -673,6 +673,12 @@ async def queue_join( if blocks._queue.server_app is None: blocks._queue.set_server_app(app) + if blocks._queue.stopped: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Queue is stopped.", + ) + success, event_id = await blocks._queue.push(body, request, username) if not success: status_code = ( @@ -702,7 +708,7 @@ def component_server(body: ComponentServerBody): response_model=Estimation, ) async def get_queue_status(): - return app.get_blocks()._queue.get_estimation() + return app.get_blocks()._queue.get_status() @app.get("/upload_progress") def get_upload_progress(upload_id: str, request: fastapi.Request): diff --git a/js/statustracker/static/index.svelte b/js/statustracker/static/index.svelte index 3e8d5577638f..24d23288c08c 100644 --- a/js/statustracker/static/index.svelte +++ b/js/statustracker/static/index.svelte @@ -53,7 +53,6 @@ export let i18n: I18nFormatter; export let eta: number | null = null; - export let queue = false; export let queue_position: number | null; export let queue_size: number | null; export let status: "complete" | "pending" | "error" | "generating"; @@ -75,6 +74,7 @@ let timer_start = 0; let timer_diff = 0; let old_eta: number | null = null; + let eta_from_start: number | null = null; let message_visible = false; let eta_level: number | null = 0; let progress_level: (number | undefined)[] | null = null; @@ -83,9 +83,9 @@ let show_eta_bar = true; $: eta_level = - eta === null || eta <= 0 || !timer_diff + eta_from_start === null || eta_from_start <= 0 || !timer_diff ? null - : Math.min(timer_diff / eta, 1); + : Math.min(timer_diff / eta_from_start, 1); $: if (progress != null) { show_eta_bar = false; } @@ -119,6 +119,7 @@ } const start_timer = (): void => { + eta = old_eta = formatted_eta = null; timer_start = performance.now(); timer_diff = 0; _timer = true; @@ -134,6 +135,7 @@ function stop_timer(): void { timer_diff = 0; + eta = old_eta = formatted_eta = null; if (!_timer) return; _timer = false; @@ -160,11 +162,10 @@ $: { if (eta === null) { eta = old_eta; - } else if (queue) { - eta = (performance.now() - timer_start) / 1000 + eta; } - if (eta != null) { - formatted_eta = eta.toFixed(1); + if (eta != null && old_eta !== eta) { + eta_from_start = (performance.now() - timer_start) / 1000 + eta; + formatted_eta = eta_from_start.toFixed(1); old_eta = eta; } }