Skip to content

Commit

Permalink
Closing stream from the backend (#7691)
Browse files Browse the repository at this point in the history
* chganges

* add changeset

* changes

* changes

* changes

* changes

* maybe fix (#7715)

* maybe fix

* fix with demo

* changeset

* fix demo

* changeset

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: pngwn <hello@pngwn.io>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
5 people committed Mar 18, 2024
1 parent 7f9b291 commit 84f81fe
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 92 deletions.
7 changes: 7 additions & 0 deletions .changeset/curly-parents-dream.md
@@ -0,0 +1,7 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---

feat:Closing stream from the backend
6 changes: 6 additions & 0 deletions .changeset/thick-penguins-act.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"@gradio/app": patch
---

fix: Fix race condition between state updates and loading_status updates.
18 changes: 14 additions & 4 deletions client/js/src/client.ts
Expand Up @@ -759,9 +759,11 @@ export function api_factory(
} else if (
protocol == "sse_v1" ||
protocol == "sse_v2" ||
protocol == "sse_v2.1"
protocol == "sse_v2.1" ||
protocol == "sse_v3"
) {
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
// v3 only closes the stream when the backend sends the close stream message.
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -856,7 +858,7 @@ export function api_factory(
});
if (
data &&
(protocol === "sse_v2" || protocol === "sse_v2.1")
["sse_v2", "sse_v2.1", "sse_v3"].includes(protocol)
) {
apply_diff_stream(event_id!, data);
}
Expand Down Expand Up @@ -905,7 +907,9 @@ export function api_factory(
fn_index,
time: new Date()
});
close_stream();
if (["sse_v2", "sse_v2.1"].includes(protocol)) {
close_stream();
}
}
};
if (event_id in pending_stream_messages) {
Expand Down Expand Up @@ -1049,7 +1053,10 @@ export function api_factory(
)
);
} else if (event_callbacks[event_id]) {
if (_data.msg === "process_completed") {
if (
_data.msg === "process_completed" &&
["sse", "sse_v1", "sse_v2", "sse_v2.1"].includes(config.protocol)
) {
unclosed_events.delete(event_id);
if (unclosed_events.size === 0) {
close_stream();
Expand All @@ -1063,6 +1070,9 @@ export function api_factory(
}
pending_stream_messages[event_id].push(_data);
}
if (_data.msg === "close_stream") {
close_stream();
}
};
event_stream.onerror = async function (event) {
await Promise.all(
Expand Down
29 changes: 19 additions & 10 deletions client/python/gradio_client/client.py
Expand Up @@ -190,7 +190,9 @@ def __init__(
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
self.pending_event_ids: set[str] = set()

async def stream_messages(self) -> None:
async def stream_messages(
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
) -> None:
try:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
Expand All @@ -216,13 +218,19 @@ async def stream_messages(self) -> None:
) in self.pending_messages_per_event.values():
pending_messages.append(resp)
return
elif resp["msg"] == ServerMessage.close_stream:
self.stream_open = False
return
event_id = resp["event_id"]
if event_id not in self.pending_messages_per_event:
self.pending_messages_per_event[event_id] = []
self.pending_messages_per_event[event_id].append(resp)
if resp["msg"] == ServerMessage.process_completed:
self.pending_event_ids.remove(event_id)
if len(self.pending_event_ids) == 0:
if (
len(self.pending_event_ids) == 0
and protocol != "sse_v3"
):
self.stream_open = False
return
else:
Expand All @@ -233,7 +241,7 @@ async def stream_messages(self) -> None:
traceback.print_exc()
raise e

async def send_data(self, data, hash_data):
async def send_data(self, data, hash_data, protocol):
async with httpx.AsyncClient(verify=self.ssl_verify) as client:
req = await client.post(
self.sse_data_url,
Expand All @@ -251,7 +259,7 @@ async def send_data(self, data, hash_data):
self.stream_open = True

def open_stream():
return utils.synchronize_async(self.stream_messages)
return utils.synchronize_async(self.stream_messages, protocol)

def close_stream(_):
self.stream_open = False
Expand Down Expand Up @@ -458,6 +466,7 @@ def submit(
"sse_v1",
"sse_v2",
"sse_v2.1",
"sse_v3",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
Expand Down Expand Up @@ -1047,14 +1056,14 @@ def _predict(*data) -> tuple:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1"):
elif self.protocol in ("sse_v1", "sse_v2", "sse_v2.1", "sse_v3"):
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
self.client.send_data, data, hash_data, self.protocol
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(
self._sse_fn_v1_v2, helper, event_id, self.protocol
self._sse_fn_v1plus, helper, event_id, self.protocol
)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")
Expand Down Expand Up @@ -1215,13 +1224,13 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.ssl_verify,
)

async def _sse_fn_v1_v2(
async def _sse_fn_v1plus(
self,
helper: Communicator,
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"],
):
return await utils.get_pred_from_sse_v1_v2(
return await utils.get_pred_from_sse_v1plus(
helper,
self.client.headers,
self.client.cookies,
Expand Down
12 changes: 8 additions & 4 deletions client/python/gradio_client/utils.py
Expand Up @@ -115,6 +115,7 @@ class ServerMessage(str, Enum):
heartbeat = "heartbeat"
server_stopped = "server_stopped"
unexpected_error = "unexpected_error"
close_stream = "close_stream"


class Status(Enum):
Expand Down Expand Up @@ -386,7 +387,7 @@ async def get_pred_from_sse_v0(
return task.result()


async def get_pred_from_sse_v1_v2(
async def get_pred_from_sse_v1plus(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
Expand All @@ -399,7 +400,9 @@ async def get_pred_from_sse_v1_v2(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
stream_sse_v1plus(
helper, pending_messages_per_event, event_id, protocol
)
),
],
return_when=asyncio.FIRST_COMPLETED,
Expand Down Expand Up @@ -512,11 +515,11 @@ async def stream_sse_v0(
raise


async def stream_sse_v1_v2(
async def stream_sse_v1plus(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"],
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
Expand Down Expand Up @@ -555,6 +558,7 @@ async def stream_sse_v1_v2(
if msg["msg"] == ServerMessage.process_generating and protocol in [
"sse_v2",
"sse_v2.1",
"sse_v3",
]:
if pending_responses_for_diffs is None:
pending_responses_for_diffs = list(output)
Expand Down
2 changes: 1 addition & 1 deletion gradio/blocks.py
Expand Up @@ -1782,7 +1782,7 @@ def get_config_file(self):
"is_colab": utils.colab_check(),
"stylesheets": self.stylesheets,
"theme": self.theme.name,
"protocol": "sse_v2.1",
"protocol": "sse_v3",
"body_css": {
"body_background_fill": self.theme._get_computed_value(
"body_background_fill"
Expand Down
5 changes: 5 additions & 0 deletions gradio/routes.py
Expand Up @@ -79,6 +79,7 @@
move_uploaded_files_to_cache,
)
from gradio.server_messages import (
CloseStreamMessage,
EstimationMessage,
EventMessage,
HeartbeatMessage,
Expand Down Expand Up @@ -792,6 +793,10 @@ async def sse_stream(request: fastapi.Request):
== 0
)
):
message = CloseStreamMessage()
response = process_msg(message)
if response is not None:
yield response
return
except BaseException as e:
message = UnexpectedErrorMessage(
Expand Down
5 changes: 5 additions & 0 deletions gradio/server_messages.py
Expand Up @@ -61,6 +61,10 @@ class HeartbeatMessage(BaseModel):
msg: Literal[ServerMessage.heartbeat] = ServerMessage.heartbeat


class CloseStreamMessage(BaseModel):
msg: Literal[ServerMessage.close_stream] = ServerMessage.close_stream


class UnexpectedErrorMessage(BaseModel):
msg: Literal[ServerMessage.unexpected_error] = ServerMessage.unexpected_error
message: str
Expand All @@ -76,4 +80,5 @@ class UnexpectedErrorMessage(BaseModel):
ProcessGeneratingMessage,
HeartbeatMessage,
UnexpectedErrorMessage,
CloseStreamMessage,
]

0 comments on commit 84f81fe

Please sign in to comment.