Skip to content

Commit

Permalink
Queue tweaks (#1909)
Browse files Browse the repository at this point in the history
* tweaks on estimation payload

* Queue keep ws connections open (#1910)

* 1. keep ws connections open after the event process is completed
2. do not send estimations periodically if live queue updates is open

* fix calculation

* 1. tweaks on event_queue
  • Loading branch information
omerXfaruq committed Aug 8, 2022
1 parent 0e9d304 commit 0d2c575
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 52 deletions.
49 changes: 28 additions & 21 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,32 @@ def clear(self):
self.children = []
return self

def configure_queue(
self,
live_queue_updates: bool = True,
concurrency_count: int = 1,
data_gathering_start: int = 30,
update_intervals: int = 5,
duration_history_size: int = 100,
):
"""
Parameters:
live_queue_updates:
If True, Queue will send estimations to clients whenever a job is finished.
If False will send estimations periodically, might be preferred when events have very short process-times.
concurrency_count: Number of max number concurrent jobs inside the Queue.
data_gathering_start: If Rank<Parameter, Queue asks for data from the client. You may make it smaller if users can send very big sized data(video or such) to not overflow the memory.
update_intervals: Queue will send estimations to the clients in intervals=update_intervals when live_queue_updates==false
duration_history_size: Queue duration estimation calculation window size.
"""
event_queue.Queue.configure_queue(
live_queue_updates,
concurrency_count,
data_gathering_start,
update_intervals,
duration_history_size,
)

def launch(
self,
inline: bool = None,
Expand All @@ -750,11 +776,6 @@ def launch(
ssl_certfile: Optional[str] = None,
ssl_keyfile_password: Optional[str] = None,
quiet: bool = False,
live_queue_updates=True,
queue_concurrency_count: int = 1,
data_gathering_start: int = 30,
update_intervals: int = 5,
duration_history_size=100,
_frontend: bool = True,
) -> Tuple[FastAPI, str, str]:
"""
Expand Down Expand Up @@ -783,11 +804,6 @@ def launch(
ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided.
ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https.
quiet: If True, suppresses most print statements.
live_queue_updates: If True, Queue will send estimations whenever a job is finished as well.
queue_concurrency_count: Number of max number concurrent jobs inside the Queue.
data_gathering_start: If Rank<Parameter, Queue asks for data from the client. You may make it smaller if users can send very big data(video or such) to not reach memory overflow.
update_intervals: Queue will send estimations to the clients using intervals determined by update_intervals.
duration_history_size: Queue duration estimation calculation window size.
Returns:
app: FastAPI app object that is running the demo
local_url: Locally accessible link to the demo
Expand Down Expand Up @@ -853,14 +869,7 @@ def reverse(text):
raise ValueError(
"Cannot queue with encryption or authentication enabled."
)
event_queue.Queue.configure_queue(
self.local_url,
live_queue_updates,
queue_concurrency_count,
data_gathering_start,
update_intervals,
duration_history_size,
)
event_queue.Queue.set_url(self.local_url)
# Cannot run async functions in background other than app's scope.
# Workaround by triggering the app endpoint
requests.get(f"{self.local_url}start/queue")
Expand Down Expand Up @@ -1026,10 +1035,8 @@ def close(self, verbose: bool = True) -> None:
Closes the Interface that was launched and frees the port.
"""
try:
from gradio.event_queue import Queue

if self.enable_queue:
Queue.close()
event_queue.Queue.close()
self.server.close()
self.is_running = False
if verbose:
Expand Down
80 changes: 49 additions & 31 deletions gradio/event_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@
from typing import List, Optional

import fastapi
from fastapi import WebSocketDisconnect
from pydantic import BaseModel

from gradio.utils import Request, run_coro_in_background


class Estimation(BaseModel):
msg: Optional[str] = "estimation"
rank: Optional[int] = -1 # waiting duration for the xth rank:
# (rank-1) / queue_size * avg_concurrent_process_time + avg_concurrent_process_time
rank: Optional[int] = -1
queue_size: int
avg_process_time: float # average duration for an event to get processed after the queue is finished
avg_concurrent_process_time: float # average process duration divided by max_thread_count
queue_duration: int # total_queue_duration = avg_concurrent_process_time * queue_size
avg_event_process_time: float # TODO(faruk): might be removed if not used by frontend in the future
avg_event_concurrent_process_time: float # TODO(faruk): might be removed if not used by frontend in the future
rank_eta: Optional[int] = -1
queue_eta: int


class Queue:
Expand All @@ -44,33 +43,33 @@ class Queue:
@classmethod
def configure_queue(
cls,
server_path: str,
live_queue_updates=True,
queue_concurrency_count: int = 1,
data_gathering_start: int = 30,
update_intervals: int = 5,
duration_history_size=100,
live_queue_updates: bool,
queue_concurrency_count: int,
data_gathering_start: int,
update_intervals: int,
duration_history_size: int,
):
"""
See Blocks.launch() docstring for the explanation of parameters.
See Blocks.configure_queue() docstring for the explanation of parameters.
"""

if live_queue_updates is False and update_intervals == 5:
update_intervals = 10
cls.SERVER_PATH = server_path
cls.LIVE_QUEUE_UPDATES = live_queue_updates
cls.MAX_THREAD_COUNT = queue_concurrency_count
cls.DATA_GATHERING_STARTS_AT = data_gathering_start
cls.UPDATE_INTERVALS = update_intervals
cls.DURATION_HISTORY_SIZE = duration_history_size
cls.ACTIVE_JOBS = [None] * cls.MAX_THREAD_COUNT

@classmethod
def set_url(cls, url: str):
cls.SERVER_PATH = url

@classmethod
async def init(
cls,
) -> None:
run_coro_in_background(Queue.notify_clients)
run_coro_in_background(Queue.start_processing)
if not cls.LIVE_QUEUE_UPDATES:
run_coro_in_background(Queue.notify_clients)

@classmethod
def close(cls):
Expand All @@ -80,6 +79,14 @@ def close(cls):
def resume(cls):
cls.STOP = False

@classmethod
def get_active_worker_count(cls) -> int:
count = 0
for worker in cls.ACTIVE_JOBS:
if worker is not None:
count += 1
return count

# TODO: Remove prints
@classmethod
async def start_processing(cls) -> None:
Expand Down Expand Up @@ -140,7 +147,7 @@ async def gather_event_data(cls, event: Event) -> None:
"""
Gather data for the event
Args:
Parameters:
event:
"""
if not event.data:
Expand All @@ -158,13 +165,10 @@ async def notify_clients(cls) -> None:
Notify clients about events statuses in the queue periodically.
"""
while not cls.STOP:
# TODO: if live update is true and queue size does not change, dont notify the clients
await asyncio.sleep(cls.UPDATE_INTERVALS)
print(f"Event Queue: {cls.EVENT_QUEUE}")
if not cls.EVENT_QUEUE:
continue

await cls.broadcast_estimation()
if cls.EVENT_QUEUE:
await cls.broadcast_estimation()

@classmethod
async def broadcast_estimation(cls) -> None:
Expand All @@ -184,12 +188,15 @@ async def send_estimation(
"""
Send estimation about ETA to the client.
Args:
Parameters:
event:
estimation:
rank
rank:
"""
estimation.rank = rank
estimation.rank_eta = round(
estimation.rank * cls.AVG_CONCURRENT_PROCESS_TIME + cls.AVG_PROCESS_TIME
)
client_awake = await event.send_message(estimation.dict())
if not client_awake:
await cls.clean_event(event)
Expand All @@ -199,7 +206,7 @@ def update_estimation(cls, duration: float) -> None:
"""
Update estimation by last x element's average duration.
Args:
Parameters:
duration:
"""
cls.DURATION_HISTORY.append(duration)
Expand All @@ -218,9 +225,9 @@ def update_estimation(cls, duration: float) -> None:
def get_estimation(cls) -> Estimation:
return Estimation(
queue_size=len(cls.EVENT_QUEUE),
avg_process_time=cls.AVG_PROCESS_TIME,
avg_concurrent_process_time=cls.AVG_CONCURRENT_PROCESS_TIME,
queue_duration=cls.QUEUE_DURATION,
avg_event_process_time=cls.AVG_PROCESS_TIME,
avg_event_concurrent_process_time=cls.AVG_CONCURRENT_PROCESS_TIME,
queue_eta=cls.QUEUE_DURATION,
)

@classmethod
Expand All @@ -243,9 +250,20 @@ async def process_event(cls, event: Event) -> None:
{"msg": "process_completed", "output": response.json}
)
if client_awake:
await event.disconnect()
run_coro_in_background(cls.wait_in_inactive, event)
cls.clean_job(event)

@classmethod
async def wait_in_inactive(cls, event: Event) -> None:
"""
Waits the event until it receives the join_back message or loses ws connection.
"""
event.data = None
client_awake = await event.get_message()
if client_awake:
if client_awake["msg"] == "join_back":
cls.EVENT_QUEUE.append(event)


class Event:
def __init__(self, websocket: fastapi.WebSocket):
Expand Down

0 comments on commit 0d2c575

Please sign in to comment.