From 84f81fec9287b041203a141bbf2852720f7d199c Mon Sep 17 00:00:00 2001 From: aliabid94 Date: Mon, 18 Mar 2024 14:37:38 -0700 Subject: [PATCH] Closing stream from the backend (#7691) * chganges * add changeset * changes * changes * changes * changes * maybe fix (#7715) * maybe fix * fix with demo * changeset * fix demo * changeset * changes --------- Co-authored-by: Ali Abid Co-authored-by: gradio-pr-bot Co-authored-by: pngwn Co-authored-by: Abubakar Abid --- .changeset/curly-parents-dream.md | 7 ++ .changeset/thick-penguins-act.md | 6 ++ client/js/src/client.ts | 18 +++- client/python/gradio_client/client.py | 29 +++-- client/python/gradio_client/utils.py | 12 ++- gradio/blocks.py | 2 +- gradio/routes.py | 5 + gradio/server_messages.py | 5 + js/app/src/Blocks.svelte | 149 +++++++++++++------------- 9 files changed, 141 insertions(+), 92 deletions(-) create mode 100644 .changeset/curly-parents-dream.md create mode 100644 .changeset/thick-penguins-act.md diff --git a/.changeset/curly-parents-dream.md b/.changeset/curly-parents-dream.md new file mode 100644 index 000000000000..51c4488e55e2 --- /dev/null +++ b/.changeset/curly-parents-dream.md @@ -0,0 +1,7 @@ +--- +"@gradio/client": minor +"gradio": minor +"gradio_client": minor +--- + +feat:Closing stream from the backend diff --git a/.changeset/thick-penguins-act.md b/.changeset/thick-penguins-act.md new file mode 100644 index 000000000000..7ef224f686b4 --- /dev/null +++ b/.changeset/thick-penguins-act.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"@gradio/app": patch +--- + +fix: Fix race condition between state updates and loading_status updates. diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 3984ffe6328e..7d0fae28fdcb 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -759,9 +759,11 @@ export function api_factory( } else if ( protocol == "sse_v1" || protocol == "sse_v2" || - protocol == "sse_v2.1" + protocol == "sse_v2.1" || + protocol == "sse_v3" ) { // latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter. + // v3 only closes the stream when the backend sends the close stream message. fire_event({ type: "status", stage: "pending", @@ -856,7 +858,7 @@ export function api_factory( }); if ( data && - (protocol === "sse_v2" || protocol === "sse_v2.1") + ["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol) ) { apply_diff_stream(event_id!, data); } @@ -905,7 +907,9 @@ export function api_factory( fn_index, time: new Date() }); - close_stream(); + if (["sse_v2", "sse_v2.1"].includes(protocol)) { + close_stream(); + } } }; if (event_id in pending_stream_messages) { @@ -1049,7 +1053,10 @@ export function api_factory( ) ); } else if (event_callbacks[event_id]) { - if (_data.msg === "process_completed") { + if ( + _data.msg === "process_completed" && + ["sse", "sse_v1", "sse_v2", "sse_v2.1"].includes(config.protocol) + ) { unclosed_events.delete(event_id); if (unclosed_events.size === 0) { close_stream(); @@ -1063,6 +1070,9 @@ export function api_factory( } pending_stream_messages[event_id].push(_data); } + if (_data.msg === "close_stream") { + close_stream(); + } }; event_stream.onerror = async function (event) { await Promise.all( diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 554ca12df967..c3508d5a3071 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -190,7 +190,9 @@ def __init__( self.pending_messages_per_event: dict[str, list[Message | None]] = {} self.pending_event_ids: set[str] = set() - async def stream_messages(self) -> None: + async def stream_messages( + self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"] + ) -> None: try: async with httpx.AsyncClient( timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify @@ -216,13 +218,19 @@ async def stream_messages(self) -> None: ) in self.pending_messages_per_event.values(): pending_messages.append(resp) return + elif resp["msg"] == ServerMessage.close_stream: + self.stream_open = False + return event_id = resp["event_id"] if event_id not in self.pending_messages_per_event: self.pending_messages_per_event[event_id] = [] self.pending_messages_per_event[event_id].append(resp) if resp["msg"] == ServerMessage.process_completed: self.pending_event_ids.remove(event_id) - if len(self.pending_event_ids) == 0: + if ( + len(self.pending_event_ids) == 0 + and protocol != "sse_v3" + ): self.stream_open = False return else: @@ -233,7 +241,7 @@ async def stream_messages(self) -> None: traceback.print_exc() raise e - async def send_data(self, data, hash_data): + async def send_data(self, data, hash_data, protocol): async with httpx.AsyncClient(verify=self.ssl_verify) as client: req = await client.post( self.sse_data_url, @@ -251,7 +259,7 @@ async def send_data(self, data, hash_data): self.stream_open = True def open_stream(): - return utils.synchronize_async(self.stream_messages) + return utils.synchronize_async(self.stream_messages, protocol) def close_stream(_): self.stream_open = False @@ -458,6 +466,7 @@ def submit( "sse_v1", "sse_v2", "sse_v2.1", + "sse_v3", ): helper = self.new_helper(inferred_fn_index) end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper) @@ -1047,14 +1056,14 @@ def _predict(*data) -> tuple: result = utils.synchronize_async( self._sse_fn_v0, data, hash_data, helper ) - elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1"): + elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1", "sse_v3"): event_id = utils.synchronize_async( - self.client.send_data, data, hash_data + self.client.send_data, data, hash_data, self.protocol ) self.client.pending_event_ids.add(event_id) self.client.pending_messages_per_event[event_id] = [] result = utils.synchronize_async( - self._sse_fn_v1_v2, helper, event_id, self.protocol + self._sse_fn_v1plus, helper, event_id, self.protocol ) else: raise ValueError(f"Unsupported protocol: {self.protocol}") @@ -1215,13 +1224,13 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): self.client.ssl_verify, ) - async def _sse_fn_v1_v2( + async def _sse_fn_v1plus( self, helper: Communicator, event_id: str, - protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"], + protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"], ): - return await utils.get_pred_from_sse_v1_v2( + return await utils.get_pred_from_sse_v1plus( helper, self.client.headers, self.client.cookies, diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index eb207e595c68..575cb2a37e11 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -115,6 +115,7 @@ class ServerMessage(str, Enum): heartbeat = "heartbeat" server_stopped = "server_stopped" unexpected_error = "unexpected_error" + close_stream = "close_stream" class Status(Enum): @@ -386,7 +387,7 @@ async def get_pred_from_sse_v0( return task.result() -async def get_pred_from_sse_v1_v2( +async def get_pred_from_sse_v1plus( helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None, @@ -399,7 +400,9 @@ async def get_pred_from_sse_v1_v2( [ asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)), asyncio.create_task( - stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol) + stream_sse_v1plus( + helper, pending_messages_per_event, event_id, protocol + ) ), ], return_when=asyncio.FIRST_COMPLETED, @@ -512,11 +515,11 @@ async def stream_sse_v0( raise -async def stream_sse_v1_v2( +async def stream_sse_v1plus( helper: Communicator, pending_messages_per_event: dict[str, list[Message | None]], event_id: str, - protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"], + protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"], ) -> dict[str, Any]: try: pending_messages = pending_messages_per_event[event_id] @@ -555,6 +558,7 @@ async def stream_sse_v1_v2( if msg["msg"] == ServerMessage.process_generating and protocol in [ "sse_v2", "sse_v2.1", + "sse_v3", ]: if pending_responses_for_diffs is None: pending_responses_for_diffs = list(output) diff --git a/gradio/blocks.py b/gradio/blocks.py index cc23078f3d9e..302b814b7e7c 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -1782,7 +1782,7 @@ def get_config_file(self): "is_colab": utils.colab_check(), "stylesheets": self.stylesheets, "theme": self.theme.name, - "protocol": "sse_v2.1", + "protocol": "sse_v3", "body_css": { "body_background_fill": self.theme._get_computed_value( "body_background_fill" diff --git a/gradio/routes.py b/gradio/routes.py index 78557f6849dc..76358ebf6651 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -79,6 +79,7 @@ move_uploaded_files_to_cache, ) from gradio.server_messages import ( + CloseStreamMessage, EstimationMessage, EventMessage, HeartbeatMessage, @@ -792,6 +793,10 @@ async def sse_stream(request: fastapi.Request): == 0 ) ): + message = CloseStreamMessage() + response = process_msg(message) + if response is not None: + yield response return except BaseException as e: message = UnexpectedErrorMessage( diff --git a/gradio/server_messages.py b/gradio/server_messages.py index 28fc5ad42096..be3c729774c4 100644 --- a/gradio/server_messages.py +++ b/gradio/server_messages.py @@ -61,6 +61,10 @@ class HeartbeatMessage(BaseModel): msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat +class CloseStreamMessage(BaseModel): + msg: Literal[ServerMessage.close_stream] = ServerMessage.close_stream + + class UnexpectedErrorMessage(BaseModel): msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error message: str @@ -76,4 +80,5 @@ class UnexpectedErrorMessage(BaseModel): ProcessGeneratingMessage, HeartbeatMessage, UnexpectedErrorMessage, + CloseStreamMessage, ] diff --git a/js/app/src/Blocks.svelte b/js/app/src/Blocks.svelte index 7d65f6428537..bd9400a55b1e 100644 --- a/js/app/src/Blocks.svelte +++ b/js/app/src/Blocks.svelte @@ -123,6 +123,8 @@ } }); update_value(updates); + + await tick(); } let submit_map: Map> = new Map(); @@ -263,88 +265,88 @@ } dep.pending_request = false; handle_update(data, fn_index); + set_status($loading_status); }) .on("status", ({ fn_index, ...status }) => { - requestAnimationFrame(() => { - //@ts-ignore - loading_status.update({ - ...status, - status: status.stage, - progress: status.progress_data, - fn_index + //@ts-ignore + loading_status.update({ + ...status, + status: status.stage, + progress: status.progress_data, + fn_index + }); + set_status($loading_status); + if ( + !showed_duplicate_message && + space_id !== null && + status.position !== undefined && + status.position >= 2 && + status.eta !== undefined && + status.eta > SHOW_DUPLICATE_MESSAGE_ON_ETA + ) { + showed_duplicate_message = true; + messages = [ + new_message(DUPLICATE_MESSAGE, fn_index, "warning"), + ...messages + ]; + } + if ( + !showed_mobile_warning && + is_mobile_device && + status.eta !== undefined && + status.eta > SHOW_MOBILE_QUEUE_WARNING_ON_ETA + ) { + showed_mobile_warning = true; + messages = [ + new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"), + ...messages + ]; + } + + if (status.stage === "complete") { + dependencies.map(async (dep, i) => { + if (dep.trigger_after === fn_index) { + wait_then_trigger_api_call(i, payload.trigger_id); + } }); - if ( - !showed_duplicate_message && - space_id !== null && - status.position !== undefined && - status.position >= 2 && - status.eta !== undefined && - status.eta > SHOW_DUPLICATE_MESSAGE_ON_ETA - ) { - showed_duplicate_message = true; + + submission.destroy(); + } + if (status.broken && is_mobile_device && user_left_page) { + window.setTimeout(() => { messages = [ - new_message(DUPLICATE_MESSAGE, fn_index, "warning"), + new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"), ...messages ]; - } - if ( - !showed_mobile_warning && - is_mobile_device && - status.eta !== undefined && - status.eta > SHOW_MOBILE_QUEUE_WARNING_ON_ETA - ) { - showed_mobile_warning = true; + }, 0); + wait_then_trigger_api_call( + dep_index, + payload.trigger_id, + event_data + ); + user_left_page = false; + } else if (status.stage === "error") { + if (status.message) { + const _message = status.message.replace( + MESSAGE_QUOTE_RE, + (_, b) => b + ); messages = [ - new_message(MOBILE_QUEUE_WARNING, fn_index, "warning"), + new_message(_message, fn_index, "error"), ...messages ]; } - - if (status.stage === "complete") { - dependencies.map(async (dep, i) => { - if (dep.trigger_after === fn_index) { - wait_then_trigger_api_call(i, payload.trigger_id); - } - }); - - submission.destroy(); - } - if (status.broken && is_mobile_device && user_left_page) { - window.setTimeout(() => { - messages = [ - new_message(MOBILE_RECONNECT_MESSAGE, fn_index, "error"), - ...messages - ]; - }, 0); - wait_then_trigger_api_call( - dep_index, - payload.trigger_id, - event_data - ); - user_left_page = false; - } else if (status.stage === "error") { - if (status.message) { - const _message = status.message.replace( - MESSAGE_QUOTE_RE, - (_, b) => b - ); - messages = [ - new_message(_message, fn_index, "error"), - ...messages - ]; + dependencies.map(async (dep, i) => { + if ( + dep.trigger_after === fn_index && + !dep.trigger_only_on_success + ) { + wait_then_trigger_api_call(i, payload.trigger_id); } - dependencies.map(async (dep, i) => { - if ( - dep.trigger_after === fn_index && - !dep.trigger_only_on_success - ) { - wait_then_trigger_api_call(i, payload.trigger_id); - } - }); - - submission.destroy(); - } - }); + }); + + submission.destroy(); + } }) .on("log", ({ log, fn_index, level }) => { messages = [new_message(log, fn_index, level), ...messages]; @@ -426,7 +428,9 @@ const deps = $targets[id]?.[event]; deps?.forEach((dep_id) => { - wait_then_trigger_api_call(dep_id, id, data); + requestAnimationFrame(() => { + wait_then_trigger_api_call(dep_id, id, data); + }); }); } }); @@ -455,7 +459,6 @@ }); const inputs_to_update = loading_status.get_inputs_to_update(); - const additional_updates = Array.from(inputs_to_update).map( ([id, pending_status]) => { return {