From 0d2c57583912fc262c13fe8f60ded2e70adff36f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96mer=20Faruk=20=C3=96zdemir?= Date: Sun, 7 Aug 2022 21:37:02 -0400 Subject: [PATCH] Queue tweaks (#1909) * tweaks on estimation payload * Queue keep ws connections open (#1910) * 1. keep ws connections open after the event process is completed 2. do not send estimations periodically if live queue updates is open * fix calculation * 1. tweaks on event_queue --- gradio/blocks.py | 49 ++++++++++++++------------ gradio/event_queue.py | 80 ++++++++++++++++++++++++++----------------- 2 files changed, 77 insertions(+), 52 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index d03eab28fc9a..e755bf142253 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -727,6 +727,32 @@ def clear(self): self.children = [] return self + def configure_queue( + self, + live_queue_updates: bool = True, + concurrency_count: int = 1, + data_gathering_start: int = 30, + update_intervals: int = 5, + duration_history_size: int = 100, + ): + """ + Parameters: + live_queue_updates: + If True, Queue will send estimations to clients whenever a job is finished. + If False will send estimations periodically, might be preferred when events have very short process-times. + concurrency_count: Number of max number concurrent jobs inside the Queue. + data_gathering_start: If Rank Tuple[FastAPI, str, str]: """ @@ -783,11 +804,6 @@ def launch( ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided. ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https. quiet: If True, suppresses most print statements. - live_queue_updates: If True, Queue will send estimations whenever a job is finished as well. - queue_concurrency_count: Number of max number concurrent jobs inside the Queue. - data_gathering_start: If Rank None: Closes the Interface that was launched and frees the port. """ try: - from gradio.event_queue import Queue - if self.enable_queue: - Queue.close() + event_queue.Queue.close() self.server.close() self.is_running = False if verbose: diff --git a/gradio/event_queue.py b/gradio/event_queue.py index 0180f14fcccc..3f3b5b0ed170 100644 --- a/gradio/event_queue.py +++ b/gradio/event_queue.py @@ -6,7 +6,6 @@ from typing import List, Optional import fastapi -from fastapi import WebSocketDisconnect from pydantic import BaseModel from gradio.utils import Request, run_coro_in_background @@ -14,12 +13,12 @@ class Estimation(BaseModel): msg: Optional[str] = "estimation" - rank: Optional[int] = -1 # waiting duration for the xth rank: - # (rank-1) / queue_size * avg_concurrent_process_time + avg_concurrent_process_time + rank: Optional[int] = -1 queue_size: int - avg_process_time: float # average duration for an event to get processed after the queue is finished - avg_concurrent_process_time: float # average process duration divided by max_thread_count - queue_duration: int # total_queue_duration = avg_concurrent_process_time * queue_size + avg_event_process_time: float # TODO(faruk): might be removed if not used by frontend in the future + avg_event_concurrent_process_time: float # TODO(faruk): might be removed if not used by frontend in the future + rank_eta: Optional[int] = -1 + queue_eta: int class Queue: @@ -44,20 +43,15 @@ class Queue: @classmethod def configure_queue( cls, - server_path: str, - live_queue_updates=True, - queue_concurrency_count: int = 1, - data_gathering_start: int = 30, - update_intervals: int = 5, - duration_history_size=100, + live_queue_updates: bool, + queue_concurrency_count: int, + data_gathering_start: int, + update_intervals: int, + duration_history_size: int, ): """ - See Blocks.launch() docstring for the explanation of parameters. + See Blocks.configure_queue() docstring for the explanation of parameters. """ - - if live_queue_updates is False and update_intervals == 5: - update_intervals = 10 - cls.SERVER_PATH = server_path cls.LIVE_QUEUE_UPDATES = live_queue_updates cls.MAX_THREAD_COUNT = queue_concurrency_count cls.DATA_GATHERING_STARTS_AT = data_gathering_start @@ -65,12 +59,17 @@ def configure_queue( cls.DURATION_HISTORY_SIZE = duration_history_size cls.ACTIVE_JOBS = [None] * cls.MAX_THREAD_COUNT + @classmethod + def set_url(cls, url: str): + cls.SERVER_PATH = url + @classmethod async def init( cls, ) -> None: - run_coro_in_background(Queue.notify_clients) run_coro_in_background(Queue.start_processing) + if not cls.LIVE_QUEUE_UPDATES: + run_coro_in_background(Queue.notify_clients) @classmethod def close(cls): @@ -80,6 +79,14 @@ def close(cls): def resume(cls): cls.STOP = False + @classmethod + def get_active_worker_count(cls) -> int: + count = 0 + for worker in cls.ACTIVE_JOBS: + if worker is not None: + count += 1 + return count + # TODO: Remove prints @classmethod async def start_processing(cls) -> None: @@ -140,7 +147,7 @@ async def gather_event_data(cls, event: Event) -> None: """ Gather data for the event - Args: + Parameters: event: """ if not event.data: @@ -158,13 +165,10 @@ async def notify_clients(cls) -> None: Notify clients about events statuses in the queue periodically. """ while not cls.STOP: - # TODO: if live update is true and queue size does not change, dont notify the clients await asyncio.sleep(cls.UPDATE_INTERVALS) print(f"Event Queue: {cls.EVENT_QUEUE}") - if not cls.EVENT_QUEUE: - continue - - await cls.broadcast_estimation() + if cls.EVENT_QUEUE: + await cls.broadcast_estimation() @classmethod async def broadcast_estimation(cls) -> None: @@ -184,12 +188,15 @@ async def send_estimation( """ Send estimation about ETA to the client. - Args: + Parameters: event: estimation: - rank + rank: """ estimation.rank = rank + estimation.rank_eta = round( + estimation.rank * cls.AVG_CONCURRENT_PROCESS_TIME + cls.AVG_PROCESS_TIME + ) client_awake = await event.send_message(estimation.dict()) if not client_awake: await cls.clean_event(event) @@ -199,7 +206,7 @@ def update_estimation(cls, duration: float) -> None: """ Update estimation by last x element's average duration. - Args: + Parameters: duration: """ cls.DURATION_HISTORY.append(duration) @@ -218,9 +225,9 @@ def update_estimation(cls, duration: float) -> None: def get_estimation(cls) -> Estimation: return Estimation( queue_size=len(cls.EVENT_QUEUE), - avg_process_time=cls.AVG_PROCESS_TIME, - avg_concurrent_process_time=cls.AVG_CONCURRENT_PROCESS_TIME, - queue_duration=cls.QUEUE_DURATION, + avg_event_process_time=cls.AVG_PROCESS_TIME, + avg_event_concurrent_process_time=cls.AVG_CONCURRENT_PROCESS_TIME, + queue_eta=cls.QUEUE_DURATION, ) @classmethod @@ -243,9 +250,20 @@ async def process_event(cls, event: Event) -> None: {"msg": "process_completed", "output": response.json} ) if client_awake: - await event.disconnect() + run_coro_in_background(cls.wait_in_inactive, event) cls.clean_job(event) + @classmethod + async def wait_in_inactive(cls, event: Event) -> None: + """ + Waits the event until it receives the join_back message or loses ws connection. + """ + event.data = None + client_awake = await event.get_message() + if client_awake: + if client_awake["msg"] == "join_back": + cls.EVENT_QUEUE.append(event) + class Event: def __init__(self, websocket: fastapi.WebSocket):