diff --git a/.changeset/ripe-spiders-love.md b/.changeset/ripe-spiders-love.md new file mode 100644 index 000000000000..ddabb0c46903 --- /dev/null +++ b/.changeset/ripe-spiders-love.md @@ -0,0 +1,7 @@ +--- +"@gradio/client": patch +"gradio": patch +"gradio_client": patch +--- + +fix:Fix api event drops diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 2135324f76b5..cfe42b8f9d8c 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -278,6 +278,9 @@ export function api_factory( const session_hash = Math.random().toString(36).substring(2); const last_status: Record = {}; + let stream_open = false; + let event_stream: EventSource | null = null; + const event_callbacks: Record Promise> = {}; let config: Config; let api_map: Record = {}; @@ -437,7 +440,7 @@ export function api_factory( let websocket: WebSocket; let eventSource: EventSource; - let protocol = config.protocol ?? "sse"; + let protocol = config.protocol ?? "ws"; const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint; let payload: Payload; @@ -646,7 +649,7 @@ export function api_factory( websocket.send(JSON.stringify({ hash: session_hash })) ); } - } else { + } else if (protocol == "sse") { fire_event({ type: "status", stage: "pending", @@ -766,6 +769,121 @@ export function api_factory( } } }; + } else if (protocol == "sse_v1") { + fire_event({ + type: "status", + stage: "pending", + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date() + }); + + post_data( + `${http_protocol}//${resolve_root( + host, + config.path, + true + )}/queue/join?${url_params}`, + { + ...payload, + session_hash + }, + hf_token + ).then(([response, status]) => { + if (status !== 200) { + fire_event({ + type: "status", + stage: "error", + message: BROKEN_CONNECTION_MSG, + queue: true, + endpoint: _endpoint, + fn_index, + time: new Date() + }); + } else { + event_id = response.event_id as string; + if (!stream_open) { + open_stream(); + } + + let callback = async function (_data: object): void { + const { type, status, data } = handle_message( + _data, + last_status[fn_index] + ); + + if (type === "update" && status && !complete) { + // call 'status' listeners + fire_event({ + type: "status", + endpoint: _endpoint, + fn_index, + time: new Date(), + ...status + }); + } else if (type === "complete") { + complete = status; + } else if (type === "log") { + fire_event({ + type: "log", + log: data.log, + level: data.level, + endpoint: _endpoint, + fn_index + }); + } else if (type === "generating") { + fire_event({ + type: "status", + time: new Date(), + ...status, + stage: status?.stage!, + queue: true, + endpoint: _endpoint, + fn_index + }); + } + if (data) { + fire_event({ + type: "data", + time: new Date(), + data: transform_files + ? transform_output( + data.data, + api_info, + config.root, + config.root_url + ) + : data.data, + endpoint: _endpoint, + fn_index + }); + + if (complete) { + fire_event({ + type: "status", + time: new Date(), + ...complete, + stage: status?.stage!, + queue: true, + endpoint: _endpoint, + fn_index + }); + } + } + + if (status.stage === "complete" || status.stage === "error") { + if (event_callbacks[event_id]) { + delete event_callbacks[event_id]; + if (Object.keys(event_callbacks).length === 0) { + close_stream(); + } + } + } + }; + event_callbacks[event_id] = callback; + } + }); } }); @@ -864,6 +982,30 @@ export function api_factory( }; } + function open_stream(): void { + stream_open = true; + let params = new URLSearchParams({ + session_hash: session_hash + }).toString(); + let url = new URL( + `${http_protocol}//${resolve_root( + host, + config.path, + true + )}/queue/data?${params}` + ); + event_stream = new EventSource(url); + event_stream.onmessage = async function (event) { + let _data = JSON.parse(event.data); + await event_callbacks[_data.event_id](_data); + }; + } + + function close_stream(): void { + stream_open = false; + event_stream?.close(); + } + async function component_server( component_id: number, fn_name: string, diff --git a/client/js/src/types.ts b/client/js/src/types.ts index ccdccbea22f6..2b1869855ef0 100644 --- a/client/js/src/types.ts +++ b/client/js/src/types.ts @@ -20,7 +20,7 @@ export interface Config { show_api: boolean; stylesheets: string[]; path: string; - protocol?: "sse" | "ws"; + protocol?: "sse_v1" | "sse" | "ws"; } export interface Payload { diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 9e72853f2f58..034ed0adccc1 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -36,6 +36,7 @@ from gradio_client.utils import ( Communicator, JobStatus, + Message, Status, StatusUpdate, ) @@ -124,25 +125,33 @@ def __init__( if self.verbose: print(f"Loaded as API: {self.src} ✔") + if auth is not None: + self._login(auth) + + self.config = self._get_config() + self.protocol: str = self.config.get("protocol", "ws") self.api_url = urllib.parse.urljoin(self.src, utils.API_URL) - self.sse_url = urllib.parse.urljoin(self.src, utils.SSE_URL) - self.sse_data_url = urllib.parse.urljoin(self.src, utils.SSE_DATA_URL) + self.sse_url = urllib.parse.urljoin( + self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL + ) + self.sse_data_url = urllib.parse.urljoin( + self.src, + utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL, + ) self.ws_url = urllib.parse.urljoin( self.src.replace("http", "ws", 1), utils.WS_URL ) self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL) self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL) - if auth is not None: - self._login(auth) - self.config = self._get_config() self.app_version = version.parse(self.config.get("version", "2.0")) self._info = self._get_api_info() self.session_hash = str(uuid.uuid4()) - protocol = self.config.get("protocol") - endpoint_class = Endpoint if protocol == "sse" else EndpointV3Compatibility + endpoint_class = ( + Endpoint if self.protocol.startswith("sse") else EndpointV3Compatibility + ) self.endpoints = [ - endpoint_class(self, fn_index, dependency) + endpoint_class(self, fn_index, dependency, self.protocol) for fn_index, dependency in enumerate(self.config["dependencies"]) ] @@ -152,6 +161,84 @@ def __init__( # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 threading.Thread(target=self._telemetry_thread).start() + self.stream_open = False + self.streaming_future: Future | None = None + self.pending_messages_per_event: dict[str, list[Message | None]] = {} + self.pending_event_ids: set[str] = set() + + async def stream_messages(self) -> None: + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: + buffer = "" + async with client.stream( + "GET", + self.sse_url, + params={"session_hash": self.session_hash}, + headers=self.headers, + cookies=self.cookies, + ) as response: + async for line in response.aiter_text(): + buffer += line + while "\n\n" in buffer: + message, buffer = buffer.split("\n\n", 1) + if message.startswith("data:"): + resp = json.loads(message[5:]) + if resp["msg"] == "heartbeat": + continue + elif resp["msg"] == "server_stopped": + for ( + pending_messages + ) in self.pending_messages_per_event.values(): + pending_messages.append(resp) + return + event_id = resp["event_id"] + if event_id not in self.pending_messages_per_event: + self.pending_messages_per_event[event_id] = [] + self.pending_messages_per_event[event_id].append(resp) + if resp["msg"] == "process_completed": + self.pending_event_ids.remove(event_id) + if len(self.pending_event_ids) == 0: + self.stream_open = False + return + elif message == "": + continue + else: + raise ValueError(f"Unexpected SSE line: '{message}'") + except BaseException as e: + import traceback + + traceback.print_exc() + raise e + + async def send_data(self, data, hash_data): + async with httpx.AsyncClient() as client: + req = await client.post( + self.sse_data_url, + json={**data, **hash_data}, + headers=self.headers, + cookies=self.cookies, + ) + req.raise_for_status() + resp = req.json() + event_id = resp["event_id"] + + if not self.stream_open: + self.stream_open = True + + def open_stream(): + return utils.synchronize_async(self.stream_messages) + + def close_stream(_): + self.stream_open = False + for _, pending_messages in self.pending_messages_per_event.items(): + pending_messages.append(None) + + if self.streaming_future is None or self.streaming_future.done(): + self.streaming_future = self.executor.submit(open_stream) + self.streaming_future.add_done_callback(close_stream) + + return event_id + @classmethod def duplicate( cls, @@ -340,7 +427,7 @@ def submit( inferred_fn_index = self._infer_fn_index(api_name, fn_index) helper = None - if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"): + if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"): helper = self.new_helper(inferred_fn_index) end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) future = self.executor.submit(end_to_end_fn, *args) @@ -806,7 +893,9 @@ class ReplaceMe: class Endpoint: """Helper class for storing all the information about a single API endpoint.""" - def __init__(self, client: Client, fn_index: int, dependency: dict): + def __init__( + self, client: Client, fn_index: int, dependency: dict, protocol: str = "sse_v1" + ): self.client: Client = client self.fn_index = fn_index self.dependency = dependency @@ -814,7 +903,7 @@ def __init__(self, client: Client, fn_index: int, dependency: dict): self.api_name: str | Literal[False] | None = ( "/" + api_name if isinstance(api_name, str) else api_name ) - self.protocol = "sse" + self.protocol = protocol self.input_component_types = [ self._get_component_type(id_) for id_ in dependency["inputs"] ] @@ -891,7 +980,20 @@ def _predict(*data) -> tuple: "session_hash": self.client.session_hash, } - result = utils.synchronize_async(self._sse_fn, data, hash_data, helper) + if self.protocol == "sse": + result = utils.synchronize_async( + self._sse_fn_v0, data, hash_data, helper + ) + elif self.protocol == "sse_v1": + event_id = utils.synchronize_async( + self.client.send_data, data, hash_data + ) + self.client.pending_event_ids.add(event_id) + self.client.pending_messages_per_event[event_id] = [] + result = utils.synchronize_async(self._sse_fn_v1, helper, event_id) + else: + raise ValueError(f"Unsupported protocol: {self.protocol}") + if "error" in result: raise ValueError(result["error"]) @@ -1068,24 +1170,33 @@ def process_predictions(self, *predictions): predictions = self.reduce_singleton_output(*predictions) return predictions - async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator): + async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: - return await utils.get_pred_from_sse( + return await utils.get_pred_from_sse_v0( client, data, hash_data, helper, - sse_url=self.client.sse_url, - sse_data_url=self.client.sse_data_url, - headers=self.client.headers, - cookies=self.client.cookies, + self.client.sse_url, + self.client.sse_data_url, + self.client.headers, + self.client.cookies, ) + async def _sse_fn_v1(self, helper: Communicator, event_id: str): + return await utils.get_pred_from_sse_v1( + helper, + self.client.headers, + self.client.cookies, + self.client.pending_messages_per_event, + event_id, + ) + class EndpointV3Compatibility: """Endpoint class for connecting to v3 endpoints. Backwards compatibility.""" - def __init__(self, client: Client, fn_index: int, dependency: dict): + def __init__(self, client: Client, fn_index: int, dependency: dict, *args): self.client: Client = client self.fn_index = fn_index self.dependency = dependency diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 71b7a0cc635f..21f61d2c924d 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -17,7 +17,7 @@ from enum import Enum from pathlib import Path from threading import Lock -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypedDict import fsspec.asyn import httpx @@ -26,8 +26,10 @@ from websockets.legacy.protocol import WebSocketCommonProtocol API_URL = "api/predict/" -SSE_URL = "queue/join" -SSE_DATA_URL = "queue/data" +SSE_URL_V0 = "queue/join" +SSE_DATA_URL_V0 = "queue/data" +SSE_URL = "queue/data" +SSE_DATA_URL = "queue/join" WS_URL = "queue/join" UPLOAD_URL = "upload" LOGIN_URL = "login" @@ -48,6 +50,19 @@ ] +class Message(TypedDict, total=False): + msg: str + output: dict[str, Any] + event_id: str + rank: int + rank_eta: float + queue_size: int + success: bool + progress_data: list[dict] + log: str + level: str + + def get_package_version() -> str: try: package_json_data = ( @@ -100,6 +115,7 @@ class Status(Enum): PROGRESS = "PROGRESS" FINISHED = "FINISHED" CANCELLED = "CANCELLED" + LOG = "LOG" @staticmethod def ordering(status: Status) -> int: @@ -133,6 +149,7 @@ def msg_to_status(msg: str) -> Status: "process_generating": Status.ITERATING, "process_completed": Status.FINISHED, "progress": Status.PROGRESS, + "log": Status.LOG, }[msg] @@ -169,6 +186,7 @@ class StatusUpdate: success: bool | None time: datetime | None progress_data: list[ProgressUnit] | None + log: tuple[str, str] | None = None def create_initial_status_update(): @@ -307,7 +325,7 @@ async def get_pred_from_ws( return resp["output"] -async def get_pred_from_sse( +async def get_pred_from_sse_v0( client: httpx.AsyncClient, data: dict, hash_data: dict, @@ -315,21 +333,54 @@ async def get_pred_from_sse( sse_url: str, sse_data_url: str, headers: dict[str, str], - cookies: dict[str, str] | None = None, + cookies: dict[str, str] | None, ) -> dict[str, Any] | None: done, pending = await asyncio.wait( [ - asyncio.create_task(check_for_cancel(helper, cookies)), + asyncio.create_task(check_for_cancel(helper, headers, cookies)), asyncio.create_task( - stream_sse( + stream_sse_v0( client, data, hash_data, helper, sse_url, sse_data_url, - headers=headers, - cookies=cookies, + headers, + cookies, + ) + ), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert len(done) == 1 + for task in done: + return task.result() + + +async def get_pred_from_sse_v1( + helper: Communicator, + headers: dict[str, str], + cookies: dict[str, str] | None, + pending_messages_per_event: dict[str, list[Message | None]], + event_id: str, +) -> dict[str, Any] | None: + done, pending = await asyncio.wait( + [ + asyncio.create_task(check_for_cancel(helper, headers, cookies)), + asyncio.create_task( + stream_sse_v1( + helper, + pending_messages_per_event, + event_id, ) ), ], @@ -348,7 +399,9 @@ async def get_pred_from_sse( return task.result() -async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None): +async def check_for_cancel( + helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None +): while True: await asyncio.sleep(0.05) with helper.lock: @@ -357,12 +410,15 @@ async def check_for_cancel(helper: Communicator, cookies: dict[str, str] | None) if helper.event_id: async with httpx.AsyncClient() as http: await http.post( - helper.reset_url, json={"event_id": helper.event_id}, cookies=cookies + helper.reset_url, + json={"event_id": helper.event_id}, + headers=headers, + cookies=cookies, ) raise CancelledError() -async def stream_sse( +async def stream_sse_v0( client: httpx.AsyncClient, data: dict, hash_data: dict, @@ -370,15 +426,15 @@ async def stream_sse( sse_url: str, sse_data_url: str, headers: dict[str, str], - cookies: dict[str, str] | None = None, + cookies: dict[str, str] | None, ) -> dict[str, Any]: try: async with client.stream( "GET", sse_url, params=hash_data, - cookies=cookies, headers=headers, + cookies=cookies, ) as response: async for line in response.aiter_text(): if line.startswith("data:"): @@ -413,8 +469,8 @@ async def stream_sse( req = await client.post( sse_data_url, json={"event_id": event_id, **data, **hash_data}, - cookies=cookies, headers=headers, + cookies=cookies, ) req.raise_for_status() elif resp["msg"] == "process_completed": @@ -426,6 +482,64 @@ async def stream_sse( raise +async def stream_sse_v1( + helper: Communicator, + pending_messages_per_event: dict[str, list[Message | None]], + event_id: str, +) -> dict[str, Any]: + try: + pending_messages = pending_messages_per_event[event_id] + + while True: + if len(pending_messages) > 0: + msg = pending_messages.pop(0) + else: + await asyncio.sleep(0.05) + continue + + if msg is None: + raise CancelledError() + + with helper.lock: + log_message = None + if msg["msg"] == "log": + log = msg.get("log") + level = msg.get("level") + if log and level: + log_message = (log, level) + status_update = StatusUpdate( + code=Status.msg_to_status(msg["msg"]), + queue_size=msg.get("queue_size"), + rank=msg.get("rank", None), + success=msg.get("success"), + time=datetime.now(), + eta=msg.get("rank_eta"), + progress_data=ProgressUnit.from_msg(msg["progress_data"]) + if "progress_data" in msg + else None, + log=log_message, + ) + output = msg.get("output", {}).get("data", []) + if output and status_update.code != Status.FINISHED: + try: + result = helper.prediction_processor(*output) + except Exception as e: + result = [e] + helper.job.outputs.append(result) + helper.job.latest_status = status_update + + if msg["msg"] == "queue_full": + raise QueueError("Queue is full! Please try again.") + elif msg["msg"] == "process_completed": + del pending_messages_per_event[event_id] + return msg["output"] + elif msg["msg"] == "server_stopped": + raise ValueError("Server stopped.") + + except asyncio.CancelledError: + raise + + ######################## # Data processing utils ######################## diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 44032f32505f..b29e31b1d245 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -95,6 +95,17 @@ def test_private_space_v4(self): output = client.predict("abc", api_name="/predict") assert output == "abc" + @pytest.mark.flaky + def test_private_space_v4_sse_v1(self): + space_id = "gradio-tests/not-actually-private-spacev4-sse-v1" + api = huggingface_hub.HfApi() + assert api.space_info(space_id).private + client = Client( + space_id, + ) + output = client.predict("abc", api_name="/predict") + assert output == "abc" + def test_state(self, increment_demo): with connect(increment_demo) as client: output = client.predict(api_name="/increment_without_queue") diff --git a/gradio/blocks.py b/gradio/blocks.py index adcd7c9b760c..38cf1ebfe3e7 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -893,6 +893,13 @@ def set_event_trigger( fn = get_continuous_fn(fn, every) elif every: raise ValueError("Cannot set a value for `every` without a `fn`.") + if every and concurrency_limit is not None: + if concurrency_limit == "default": + concurrency_limit = None + else: + raise ValueError( + "Cannot set a value for `concurrency_limit` with `every`." + ) if _targets[0][1] == "change" and trigger_mode is None: trigger_mode = "always_last" @@ -1581,7 +1588,7 @@ def get_config_file(self): "is_colab": utils.colab_check(), "stylesheets": self.stylesheets, "theme": self.theme.name, - "protocol": "sse", + "protocol": "sse_v1", } def get_layout(block): @@ -2169,7 +2176,7 @@ def close(self, verbose: bool = True) -> None: try: if wasm_utils.IS_WASM: # NOTE: - # Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/join` endpoint function) + # Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/data` endpoint function) # are running in an event loop in the server thread, # so they will be cancelled by `self.server.close()` below. # However, in the Wasm env, we don't have the `server` and diff --git a/gradio/queueing.py b/gradio/queueing.py index 1509331684ee..f119267c4797 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -23,7 +23,7 @@ ) from gradio.exceptions import Error from gradio.helpers import TrackedIterable -from gradio.utils import run_coro_in_background, safe_get_lock, set_task_name +from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name if TYPE_CHECKING: from gradio.blocks import BlockFunction @@ -37,7 +37,6 @@ def __init__( request: fastapi.Request, username: str | None, ): - self.message_queue = ThreadQueue() self.session_hash = session_hash self.fn_index = fn_index self.request = request @@ -48,28 +47,6 @@ def __init__( self.progress_pending: bool = False self.alive = True - def send_message( - self, - message_type: str, - data: dict | None = None, - final: bool = False, - ): - data = {} if data is None else data - self.message_queue.put_nowait({"msg": message_type, **data}) - if final: - self.message_queue.put_nowait(None) - - async def get_data(self, timeout=5) -> bool: - self.send_message("send_data", {"event_id": self._id}) - sleep_interval = 0.05 - wait_time = 0 - while wait_time < timeout and self.alive: - if self.data is not None: - break - await asyncio.sleep(sleep_interval) - wait_time += sleep_interval - return self.data is not None - class Queue: def __init__( @@ -81,6 +58,9 @@ def __init__( block_fns: list[BlockFunction], default_concurrency_limit: int | None | Literal["not_set"] = "not_set", ): + self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000) + self.pending_event_ids_session: dict[str, set[str]] = {} + self.pending_message_lock = safe_get_lock() self.event_queue: list[Event] = [] self.awaiting_data_events: dict[str, Event] = {} self.stopped = False @@ -132,6 +112,16 @@ def start(self): def close(self): self.stopped = True + def send_message( + self, + event: Event, + message_type: str, + data: dict | None = None, + ): + data = {} if data is None else data + messages = self.pending_messages_per_session[event.session_hash] + messages.put_nowait({"msg": message_type, "event_id": event._id, **data}) + def _resolve_concurrency_limit(self, default_concurrency_limit): """ Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination @@ -152,13 +142,33 @@ def _resolve_concurrency_limit(self, default_concurrency_limit): else: return 1 - def attach_data(self, body: PredictBody): - event_id = body.event_id - if event_id in self.awaiting_data_events: - event = self.awaiting_data_events[event_id] - event.data = body - else: - raise ValueError("Event not found", event_id) + async def push( + self, body: PredictBody, request: fastapi.Request, username: str | None + ): + if body.session_hash is None: + raise ValueError("No session hash provided.") + if body.fn_index is None: + raise ValueError("No function index provided.") + queue_len = len(self.event_queue) + if self.max_size is not None and queue_len >= self.max_size: + raise ValueError( + f"Queue is full. Max size is {self.max_size} and current size is {queue_len}." + ) + + event = Event(body.session_hash, body.fn_index, request, username) + event.data = body + async with self.pending_message_lock: + if body.session_hash not in self.pending_messages_per_session: + self.pending_messages_per_session[body.session_hash] = ThreadQueue() + if body.session_hash not in self.pending_event_ids_session: + self.pending_event_ids_session[body.session_hash] = set() + self.pending_event_ids_session[body.session_hash].add(event._id) + self.event_queue.append(event) + + estimation = self.get_estimation() + await self.send_estimation(event, estimation, queue_len) + + return event._id def _cancel_asyncio_tasks(self): for task in self._asyncio_tasks: @@ -276,7 +286,7 @@ async def start_progress_updates(self) -> None: for event in events: if event.progress_pending and event.progress: event.progress_pending = False - event.send_message("progress", event.progress.model_dump()) + self.send_message(event, "progress", event.progress.model_dump()) await asyncio.sleep(self.progress_update_sleep_when_free) @@ -320,34 +330,23 @@ def log_message( log=log, level=level, ) - event.send_message("log", log_message.model_dump()) - - def push(self, event: Event) -> int | None: - """ - Add event to queue, or return None if Queue is full - Parameters: - event: Event to add to Queue - Returns: - rank of submitted Event - """ - queue_len = len(self.event_queue) - if self.max_size is not None and queue_len >= self.max_size: - return None - self.event_queue.append(event) - return queue_len - - async def clean_event(self, event: Event | str) -> None: - if isinstance(event, str): - for job_set in self.active_jobs: - if job_set: - for job in job_set: - if job._id == event: - event = job - break - if isinstance(event, str): - raise ValueError("Event not found", event) - event.alive = False - if event in self.event_queue: + self.send_message(event, "log", log_message.model_dump()) + + async def clean_events( + self, *, session_hash: str | None = None, event_id: str | None = None + ) -> None: + for job_set in self.active_jobs: + if job_set: + for job in job_set: + if job.session_hash == session_hash or job._id == event_id: + job.alive = False + + events_to_remove = [] + for event in self.event_queue: + if event.session_hash == session_hash or event._id == event_id: + events_to_remove.append(event) + + for event in events_to_remove: async with self.delete_lock: self.event_queue.remove(event) @@ -391,7 +390,7 @@ async def send_estimation( if None not in self.active_jobs: # Add estimated amount of time for a thread to get empty estimation.rank_eta += self.avg_concurrent_process_time - event.send_message("estimation", estimation.model_dump()) + self.send_message(event, "estimation", estimation.model_dump()) return estimation def update_estimation(self, duration: float) -> None: @@ -485,14 +484,7 @@ async def process_events(self, events: list[Event], batch: bool) -> None: awake_events: list[Event] = [] try: for event in events: - if not event.data: - self.awaiting_data_events[event._id] = event - client_awake = await event.get_data() - del self.awaiting_data_events[event._id] - if not client_awake: - await self.clean_event(event) - continue - event.send_message("process_starts") + self.send_message(event, "process_starts") awake_events.append(event) if not awake_events: return @@ -505,7 +497,8 @@ async def process_events(self, events: list[Event], batch: bool) -> None: response = None err = e for event in awake_events: - event.send_message( + self.send_message( + event, "process_completed", { "output": { @@ -515,7 +508,6 @@ async def process_events(self, events: list[Event], batch: bool) -> None: }, "success": False, }, - final=True, ) if response and response.get("is_generating", False): old_response = response @@ -524,7 +516,8 @@ async def process_events(self, events: list[Event], batch: bool) -> None: old_response = response old_err = err for event in awake_events: - event.send_message( + self.send_message( + event, "process_generating", { "output": old_response, @@ -545,7 +538,8 @@ async def process_events(self, events: list[Event], batch: bool) -> None: relevant_response = err else: relevant_response = old_response or old_err - event.send_message( + self.send_message( + event, "process_completed", { "output": {"error": str(relevant_response)} @@ -554,20 +548,19 @@ async def process_events(self, events: list[Event], batch: bool) -> None: "success": relevant_response and not isinstance(relevant_response, Exception), }, - final=True, ) elif response: output = copy.deepcopy(response) for e, event in enumerate(awake_events): if batch and "data" in output: output["data"] = list(zip(*response.get("data")))[e] - event.send_message( + self.send_message( + event, "process_completed", { "output": output, "success": response is not None, }, - final=True, ) end_time = time.time() if response is not None: diff --git a/gradio/routes.py b/gradio/routes.py index f8f24eb0d406..89841a0a68f9 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -55,7 +55,7 @@ from gradio.exceptions import Error from gradio.helpers import CACHED_FOLDER from gradio.oauth import attach_oauth -from gradio.queueing import Estimation, Event +from gradio.queueing import Estimation from gradio.route_utils import ( # noqa: F401 FileUploadProgress, GradioMultiPartParser, @@ -65,10 +65,7 @@ ) from gradio.state_holder import StateHolder from gradio.utils import ( - cancel_tasks, get_package_version, - run_coro_in_background, - set_task_name, ) if TYPE_CHECKING: @@ -532,7 +529,7 @@ async def reset_iterator(body: ResetBody): async with app.lock: del app.iterators[body.event_id] app.iterators_to_reset.add(body.event_id) - await app.get_blocks()._queue.clean_event(body.event_id) + await app.get_blocks()._queue.clean_events(event_id=body.event_id) return {"success": True} # had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility @@ -582,63 +579,38 @@ async def predict( ) return output - @app.get("/queue/join", dependencies=[Depends(login_check)]) - async def queue_join( - fn_index: int, - session_hash: str, + @app.get("/queue/data", dependencies=[Depends(login_check)]) + async def queue_data( request: fastapi.Request, - username: str = Depends(get_current_user), - data: Optional[str] = None, + session_hash: str, ): blocks = app.get_blocks() - if blocks._queue.server_app is None: - blocks._queue.set_server_app(app) - - event = Event(session_hash, fn_index, request, username) - if data is not None: - input_data = json.loads(data) - event.data = PredictBody( - session_hash=session_hash, - fn_index=fn_index, - data=input_data, - request=request, - ) - - # Continuous events are not put in the queue so that they do not - # occupy the queue's resource as they are expected to run forever - if blocks.dependencies[event.fn_index].get("every", 0): - await cancel_tasks({f"{event.session_hash}_{event.fn_index}"}) - await blocks._queue.reset_iterators(event._id) - blocks._queue.continuous_tasks.append(event) - task = run_coro_in_background( - blocks._queue.process_events, [event], False - ) - set_task_name(task, event.session_hash, event.fn_index, batch=False) - app._asyncio_tasks.append(task) - else: - rank = blocks._queue.push(event) - if rank is None: - event.send_message("queue_full", final=True) - else: - estimation = blocks._queue.get_estimation() - await blocks._queue.send_estimation(event, estimation, rank) async def sse_stream(request: fastapi.Request): try: last_heartbeat = time.perf_counter() while True: if await request.is_disconnected(): - await blocks._queue.clean_event(event) - if not event.alive: + await blocks._queue.clean_events(session_hash=session_hash) return + if ( + session_hash + not in blocks._queue.pending_messages_per_session + ): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Session not found.", + ) + heartbeat_rate = 15 check_rate = 0.05 message = None try: - message = event.message_queue.get_nowait() - if message is None: # end of stream marker - return + messages = blocks._queue.pending_messages_per_session[ + session_hash + ] + message = messages.get_nowait() except EmptyQueue: await asyncio.sleep(check_rate) if time.perf_counter() - last_heartbeat > heartbeat_rate: @@ -648,10 +620,29 @@ async def sse_stream(request: fastapi.Request): # and then the stream will retry leading to infinite queue 😬 last_heartbeat = time.perf_counter() + if blocks._queue.stopped: + message = {"msg": "server_stopped", "success": False} if message: yield f"data: {json.dumps(message)}\n\n" + if message["msg"] == "process_completed": + blocks._queue.pending_event_ids_session[ + session_hash + ].remove(message["event_id"]) + if message["msg"] == "server_stopped" or ( + message["msg"] == "process_completed" + and ( + len( + blocks._queue.pending_event_ids_session[ + session_hash + ] + ) + == 0 + ) + ): + return except asyncio.CancelledError as e: - await blocks._queue.clean_event(event) + del blocks._queue.pending_messages_per_session[session_hash] + await blocks._queue.clean_events(session_hash=session_hash) raise e return StreamingResponse( @@ -659,14 +650,17 @@ async def sse_stream(request: fastapi.Request): media_type="text/event-stream", ) - @app.post("/queue/data", dependencies=[Depends(login_check)]) - async def queue_data( + @app.post("/queue/join", dependencies=[Depends(login_check)]) + async def queue_join( body: PredictBody, request: fastapi.Request, username: str = Depends(get_current_user), ): - blocks = app.get_blocks() - blocks._queue.attach_data(body) + if blocks._queue.server_app is None: + blocks._queue.set_server_app(app) + + event_id = await blocks._queue.push(body, request, username) + return {"event_id": event_id} @app.post("/component_server", dependencies=[Depends(login_check)]) @app.post("/component_server/", dependencies=[Depends(login_check)]) diff --git a/gradio/utils.py b/gradio/utils.py index 9985a55d7c50..fdb644758a84 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -19,6 +19,7 @@ import urllib.parse import warnings from abc import ABC, abstractmethod +from collections import OrderedDict from contextlib import contextmanager from io import BytesIO from numbers import Number @@ -28,6 +29,7 @@ TYPE_CHECKING, Any, Callable, + Generic, Iterable, Iterator, Optional, @@ -997,3 +999,20 @@ def convert_to_dict_if_dataclass(value): if dataclasses.is_dataclass(value): return dataclasses.asdict(value) return value + + +K = TypeVar("K") +V = TypeVar("V") + + +class LRUCache(OrderedDict, Generic[K, V]): + def __init__(self, max_size: int = 100): + super().__init__() + self.max_size: int = max_size + + def __setitem__(self, key: K, value: V) -> None: + if key in self: + self.move_to_end(key) + elif len(self) >= self.max_size: + self.popitem(last=False) + super().__setitem__(key, value) diff --git a/js/app/test/blocks_chained_events.spec.ts b/js/app/test/blocks_chained_events.spec.ts index 2299f7b8af33..9ba42ef166f5 100644 --- a/js/app/test/blocks_chained_events.spec.ts +++ b/js/app/test/blocks_chained_events.spec.ts @@ -1,20 +1,10 @@ import { test, expect } from "@gradio/tootils"; test(".success should not run if function fails", async ({ page }) => { - let last_iteration; const textbox = page.getByLabel("Result"); await expect(textbox).toHaveValue(""); - page.on("websocket", (ws) => { - last_iteration = ws.waitForEvent("framereceived", { - predicate: (event) => { - return JSON.parse(event.payload as string).msg === "process_completed"; - } - }); - }); - await page.click("text=Trigger Failure"); - await last_iteration; expect(textbox).toHaveValue(""); }); @@ -38,17 +28,7 @@ test("Consecutive .success event is triggered successfully", async ({ }); test("gr.Error makes the toast show up", async ({ page }) => { - let complete; - page.on("websocket", (ws) => { - complete = ws.waitForEvent("framereceived", { - predicate: (event) => { - return JSON.parse(event.payload as string).msg === "process_completed"; - } - }); - }); - await page.click("text=Trigger Failure"); - await complete; const toast = page.getByTestId("toast-body"); expect(toast).toContainText("error"); @@ -60,17 +40,7 @@ test("gr.Error makes the toast show up", async ({ page }) => { test("ValueError makes the toast show up when show_error=True", async ({ page }) => { - let complete; - page.on("websocket", (ws) => { - complete = ws.waitForEvent("framereceived", { - predicate: (event) => { - return JSON.parse(event.payload as string).msg === "process_completed"; - } - }); - }); - await page.click("text=Trigger Failure With ValueError"); - await complete; const toast = page.getByTestId("toast-body"); expect(toast).toContainText("error"); diff --git a/scripts/benchmark_queue.py b/scripts/benchmark_queue.py index c7d4ee3ae88f..f56de3eee3bb 100644 --- a/scripts/benchmark_queue.py +++ b/scripts/benchmark_queue.py @@ -107,7 +107,7 @@ async def main(host, n_results=100): parser.add_argument("-o", "--output", type=str, help="path to write output to", required=False) args = parser.parse_args() - host = f"{demo.local_url.replace('http', 'ws')}queue/join" + host = f"{demo.local_url.replace('http', 'ws')}queue/data" data = asyncio.run(main(host, n_results=args.n_jobs)) data = dict(zip(data["fn_to_hit"], data["duration"])) diff --git a/test/test_helpers.py b/test/test_helpers.py index 40d885aac793..8f9fb3072a08 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -8,7 +8,7 @@ from pathlib import Path from unittest.mock import patch -import httpx +import gradio_client as grc import pytest from gradio_client import media_data, utils from pydub import AudioSegment @@ -660,50 +660,29 @@ def greet(s, prog=gr.Progress()): button.click(greet, name, greeting) demo.queue(max_size=1).launch(prevent_thread_lock=True) - progress_updates = [] - async with httpx.AsyncClient() as client: - async with client.stream( - "GET", - f"http://localhost:{demo.server_port}/queue/join", - params={"fn_index": 0, "session_hash": "shdce"}, - ) as response: - async for line in response.aiter_text(): - if line.startswith("data:"): - msg = json.loads(line[5:]) - if msg["msg"] == "send_data": - event_id = msg["event_id"] - req = await client.post( - f"http://localhost:{demo.server_port}/queue/data", - json={ - "event_id": event_id, - "data": [0], - "fn_index": 0, - }, - ) - if not req.is_success: - raise ValueError( - f"Could not send payload to endpoint: {req.text}" - ) - if msg["msg"] == "progress": - progress_updates.append(msg["progress_data"]) - if msg["msg"] == "process_completed": - break - - assert progress_updates == [ - [ - { - "index": None, - "length": None, - "unit": "steps", - "progress": 0.0, - "desc": "start", - } - ], - [{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}], + client = grc.Client(demo.local_url) + job = client.submit("Gradio") + + status_updates = [] + while not job.done(): + status = job.status() + update = ( + status.progress_data[0].index if status.progress_data else None, + status.progress_data[0].desc if status.progress_data else None, + ) + if update != (None, None) and ( + len(status_updates) == 0 or status_updates[-1] != update + ): + status_updates.append(update) + time.sleep(0.05) + + assert status_updates == [ + (None, "start"), + (0, None), + (1, None), + (2, None), + (3, None), + (4, None), ] @pytest.mark.asyncio @@ -726,77 +705,32 @@ def greet(s, prog=gr.Progress(track_tqdm=True)): button.click(greet, name, greeting) demo.queue(max_size=1).launch(prevent_thread_lock=True) - progress_updates = [] - async with httpx.AsyncClient() as client: - async with client.stream( - "GET", - f"http://localhost:{demo.server_port}/queue/join", - params={"fn_index": 0, "session_hash": "shdce"}, - ) as response: - async for line in response.aiter_text(): - if line.startswith("data:"): - msg = json.loads(line[5:]) - if msg["msg"] == "send_data": - event_id = msg["event_id"] - req = await client.post( - f"http://localhost:{demo.server_port}/queue/data", - json={ - "event_id": event_id, - "data": [0], - "fn_index": 0, - }, - ) - if not req.is_success: - raise ValueError( - f"Could not send payload to endpoint: {req.text}" - ) - if msg["msg"] == "progress": - progress_updates.append(msg["progress_data"]) - if msg["msg"] == "process_completed": - break - - assert progress_updates == [ - [ - { - "index": None, - "length": None, - "unit": "steps", - "progress": 0.0, - "desc": "start", - } - ], - [{"index": 0, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 1, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 2, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 3, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [{"index": 4, "length": 4, "unit": "iter", "progress": None, "desc": None}], - [ - { - "index": 0, - "length": 3, - "unit": "steps", - "progress": None, - "desc": "alphabet", - } - ], - [ - { - "index": 1, - "length": 3, - "unit": "steps", - "progress": None, - "desc": "alphabet", - } - ], - [ - { - "index": 2, - "length": 3, - "unit": "steps", - "progress": None, - "desc": "alphabet", - } - ], + client = grc.Client(demo.local_url) + job = client.submit("Gradio") + + status_updates = [] + while not job.done(): + status = job.status() + update = ( + status.progress_data[0].index if status.progress_data else None, + status.progress_data[0].desc if status.progress_data else None, + ) + if update != (None, None) and ( + len(status_updates) == 0 or status_updates[-1] != update + ): + status_updates.append(update) + time.sleep(0.05) + + assert status_updates == [ + (None, "start"), + (0, None), + (1, None), + (2, None), + (3, None), + (4, None), + (0, "alphabet"), + (1, "alphabet"), + (2, "alphabet"), ] @pytest.mark.asyncio @@ -811,63 +745,29 @@ def greet(s, _=gr.Progress(track_tqdm=True)): demo = gr.Interface(greet, "text", "text") demo.queue().launch(prevent_thread_lock=True) - progress_updates = [] - async with httpx.AsyncClient() as client: - async with client.stream( - "GET", - f"http://localhost:{demo.server_port}/queue/join", - params={"fn_index": 0, "session_hash": "shdce"}, - ) as response: - async for line in response.aiter_text(): - if line.startswith("data:"): - msg = json.loads(line[5:]) - if msg["msg"] == "send_data": - event_id = msg["event_id"] - req = await client.post( - f"http://localhost:{demo.server_port}/queue/data", - json={ - "event_id": event_id, - "data": ["abc"], - "fn_index": 0, - }, - ) - if not req.is_success: - raise ValueError( - f"Could not send payload to endpoint: {req.text}" - ) - if msg["msg"] == "progress": - progress_updates.append(msg["progress_data"]) - if msg["msg"] == "process_completed": - break - - assert progress_updates == [ - [ - { - "index": 1, - "length": 3, - "unit": "steps", - "progress": None, - "desc": None, - } - ], - [ - { - "index": 2, - "length": 3, - "unit": "steps", - "progress": None, - "desc": None, - } - ], - [ - { - "index": 3, - "length": 3, - "unit": "steps", - "progress": None, - "desc": None, - } - ], + client = grc.Client(demo.local_url) + job = client.submit("Gradio") + + status_updates = [] + while not job.done(): + status = job.status() + update = ( + status.progress_data[0].index if status.progress_data else None, + status.progress_data[0].unit if status.progress_data else None, + ) + if update != (None, None) and ( + len(status_updates) == 0 or status_updates[-1] != update + ): + status_updates.append(update) + time.sleep(0.05) + + assert status_updates == [ + (1, "steps"), + (2, "steps"), + (3, "steps"), + (4, "steps"), + (5, "steps"), + (6, "steps"), ] @pytest.mark.asyncio @@ -878,45 +778,30 @@ def greet(s): time.sleep(0.15) if len(s) < 5: gr.Warning("Too short!") + time.sleep(0.15) return f"Hello, {s}!" demo = gr.Interface(greet, "text", "text") demo.queue().launch(prevent_thread_lock=True) - log_messages = [] - async with httpx.AsyncClient() as client: - async with client.stream( - "GET", - f"http://localhost:{demo.server_port}/queue/join", - params={"fn_index": 0, "session_hash": "shdce"}, - ) as response: - async for line in response.aiter_text(): - if line.startswith("data:"): - msg = json.loads(line[5:]) - if msg["msg"] == "send_data": - event_id = msg["event_id"] - req = await client.post( - f"http://localhost:{demo.server_port}/queue/data", - json={ - "event_id": event_id, - "data": ["abc"], - "fn_index": 0, - }, - ) - if not req.is_success: - raise ValueError( - f"Could not send payload to endpoint: {req.text}" - ) - if msg["msg"] == "log": - log_messages.append([msg["log"], msg["level"]]) - if msg["msg"] == "process_completed": - break - - assert log_messages == [ - ["Letter a", "info"], - ["Letter b", "info"], - ["Letter c", "info"], - ["Too short!", "warning"], + client = grc.Client(demo.local_url) + job = client.submit("Jon") + + status_updates = [] + while not job.done(): + status = job.status() + update = status.log + if update is not None and ( + len(status_updates) == 0 or status_updates[-1] != update + ): + status_updates.append(update) + time.sleep(0.05) + + assert status_updates == [ + ("Letter J", "info"), + ("Letter o", "info"), + ("Letter n", "info"), + ("Too short!", "warning"), ] @@ -926,11 +811,13 @@ async def test_info_isolation(async_handler: bool): async def greet_async(name): await asyncio.sleep(2) gr.Info(f"Hello {name}") + await asyncio.sleep(1) return name def greet_sync(name): time.sleep(2) gr.Info(f"Hello {name}") + time.sleep(1) return name demo = gr.Interface( @@ -942,42 +829,24 @@ def greet_sync(name): demo.launch(prevent_thread_lock=True) async def session_interaction(name, delay=0): - await asyncio.sleep(delay) - - log_messages = [] - async with httpx.AsyncClient() as client: - async with client.stream( - "GET", - f"http://localhost:{demo.server_port}/queue/join", - params={"fn_index": 0, "session_hash": name}, - ) as response: - async for line in response.aiter_text(): - if line.startswith("data:"): - msg = json.loads(line[5:]) - if msg["msg"] == "send_data": - event_id = msg["event_id"] - req = await client.post( - f"http://localhost:{demo.server_port}/queue/data", - json={ - "event_id": event_id, - "data": [name], - "fn_index": 0, - }, - ) - if not req.is_success: - raise ValueError( - f"Could not send payload to endpoint: {req.text}" - ) - if msg["msg"] == "log": - log_messages.append(msg["log"]) - if msg["msg"] == "process_completed": - break - return log_messages + client = grc.Client(demo.local_url) + job = client.submit(name) + + status_updates = [] + while not job.done(): + status = job.status() + update = status.log + if update is not None and ( + len(status_updates) == 0 or status_updates[-1] != update + ): + status_updates.append(update) + time.sleep(0.05) + return status_updates[-1][0] if status_updates else None alice_logs, bob_logs = await asyncio.gather( session_interaction("Alice"), session_interaction("Bob", delay=1), ) - assert alice_logs == ["Hello Alice"] - assert bob_logs == ["Hello Bob"] + assert alice_logs == "Hello Alice" + assert bob_logs == "Hello Bob" diff --git a/test/test_queueing.py b/test/test_queueing.py index 7c2de5792cae..5f7f35cdf635 100644 --- a/test/test_queueing.py +++ b/test/test_queueing.py @@ -18,8 +18,6 @@ def greet(x): name.submit(greet, name, output) - demo.launch(prevent_thread_lock=True) - with connect(demo) as client: job = client.submit("x", fn_index=0) assert job.result() == "Hello, x!" @@ -92,7 +90,7 @@ def test_default_concurrency_limits(self, default_concurrency_limit, statuses): @add_btn.click(inputs=[a, b], outputs=output) def add(x, y): - time.sleep(2) + time.sleep(4) return x + y demo.queue(default_concurrency_limit=default_concurrency_limit) @@ -105,7 +103,7 @@ def add(x, y): add_job_2 = client.submit(1, 1, fn_index=0) add_job_3 = client.submit(1, 1, fn_index=0) - time.sleep(1) + time.sleep(2) add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()] assert sorted([s.code.value for s in add_job_statuses]) == statuses @@ -161,12 +159,11 @@ def div(x, y): sub_job_1 = client.submit(1, 1, fn_index=1) sub_job_2 = client.submit(1, 1, fn_index=1) sub_job_3 = client.submit(1, 1, fn_index=1) - sub_job_3 = client.submit(1, 1, fn_index=1) mul_job_1 = client.submit(1, 1, fn_index=2) div_job_1 = client.submit(1, 1, fn_index=3) mul_job_2 = client.submit(1, 1, fn_index=2) - time.sleep(1) + time.sleep(2) add_job_statuses = [ add_job_1.status(),