Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve chatbot streaming performance with diffs #7102

Merged
merged 36 commits into from Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6c2ecdc
changes
Jan 22, 2024
037ccf9
add changeset
gradio-pr-bot Jan 22, 2024
c91eca2
changes
Jan 24, 2024
30becbd
Merge branch 'diff_chatbot_streaming' of https://github.com/gradio-ap…
Jan 24, 2024
fd8de47
add changeset
gradio-pr-bot Jan 24, 2024
b641643
changes
Jan 24, 2024
5a58dcd
Merge branch 'diff_chatbot_streaming' of https://github.com/gradio-ap…
Jan 24, 2024
7fc0255
channges
Jan 24, 2024
5de09f8
Merge remote-tracking branch 'origin' into diff_chatbot_streaming
Jan 24, 2024
fe3d0e0
changes
Jan 24, 2024
9c08c18
changes
Jan 24, 2024
0e6fb0a
Merge remote-tracking branch 'origin' into diff_chatbot_streaming
Jan 24, 2024
dbcb970
changes
Jan 24, 2024
ebf3092
changes
Jan 24, 2024
25581e2
changes
Jan 24, 2024
b72129d
changes
Jan 24, 2024
d05d4c7
Merge remote-tracking branch 'origin' into diff_chatbot_streaming
Jan 24, 2024
597b7eb
changes
Jan 25, 2024
97317e6
Merge remote-tracking branch 'origin' into diff_chatbot_streaming
Jan 25, 2024
1ec1925
changes
Jan 25, 2024
0c724d5
changes
Jan 25, 2024
aedd61e
changes
Jan 25, 2024
7248242
changes
Jan 25, 2024
0a7ff17
Merge remote-tracking branch 'origin' into diff_chatbot_streaming
Jan 25, 2024
6656bd4
changes
Jan 25, 2024
3e22fb3
changes
Jan 25, 2024
44ee6f1
changes
Jan 25, 2024
6b2e192
canges
Jan 25, 2024
37275ba
changes
Jan 25, 2024
1330b3c
changes
Jan 30, 2024
ba9a877
changes
Jan 30, 2024
e4d0b57
changes
Jan 30, 2024
ea717ea
Update free-moose-guess.md
aliabid94 Jan 31, 2024
ad7f9c4
Merge branch 'main' into diff_chatbot_streaming
aliabid94 Jan 31, 2024
5bf10f6
changes
Jan 31, 2024
da3b2b2
changes
Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions .changeset/free-moose-guess.md
@@ -0,0 +1,9 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---

feat:Improve chatbot streaming performance with diffs

Note that this PR changes the API format for generator functions, which would be a breaking change for any clients reading the EventStream directly
32 changes: 30 additions & 2 deletions client/js/src/client.ts
Expand Up @@ -11,7 +11,8 @@ import {
set_space_hardware,
set_space_timeout,
hardware_types,
resolve_root
resolve_root,
apply_diff
} from "./utils.js";

import type {
Expand Down Expand Up @@ -288,6 +289,7 @@ export function api_factory(
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let pending_diff_streams: Record<string, any[][]> = {};
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
const unclosed_events: Set<string> = new Set();
Expand Down Expand Up @@ -774,7 +776,8 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
} else if (protocol == "sse_v1" || protocol == "sse_v2") {
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
// latest API format. v2 introduces sending diffs for intermediate outputs in generative functions, which makes payloads lighter.
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -867,6 +870,9 @@ export function api_factory(
endpoint: _endpoint,
fn_index
});
if (data && protocol === "sse_v2") {
apply_diff_stream(event_id!, data);
}
}
if (data) {
fire_event({
Expand Down Expand Up @@ -904,6 +910,9 @@ export function api_factory(
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
}
if (event_id in pending_diff_streams) {
delete pending_diff_streams[event_id];
}
}
} catch (e) {
console.error("Unexpected client exception", e);
Expand Down Expand Up @@ -936,6 +945,25 @@ export function api_factory(
}
);

function apply_diff_stream(event_id: string, data: any): void {
let is_first_generation = !pending_diff_streams[event_id];
if (is_first_generation) {
pending_diff_streams[event_id] = [];
data.data.forEach((value: any, i: number) => {
pending_diff_streams[event_id][i] = value;
});
} else {
data.data.forEach((value: any, i: number) => {
let new_data = apply_diff(
pending_diff_streams[event_id][i],
value
);
pending_diff_streams[event_id][i] = new_data;
data.data[i] = new_data;
});
}
}

function fire_event<K extends EventType>(event: Event<K>): void {
const narrowed_listener_map: ListenerMap<K> = listener_map;
const listeners = narrowed_listener_map[event.type] || [];
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/types.ts
Expand Up @@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse_v1" | "sse" | "ws";
protocol?: "sse_v2" | "sse_v1" | "sse" | "ws";
}

export interface Payload {
Expand Down
59 changes: 59 additions & 0 deletions client/js/src/utils.ts
Expand Up @@ -239,3 +239,62 @@ export const hardware_types = [
"a10g-large",
"a100-large"
] as const;

function apply_edit(
target: any,
path: (number | string)[],
action: string,
value: any
): any {
if (path.length === 0) {
if (action === "replace") {
return value;
} else if (action === "append") {
return target + value;
}
throw new Error(`Unsupported action: ${action}`);
}

let current = target;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}

const last_path = path[path.length - 1];
switch (action) {
case "replace":
current[last_path] = value;
break;
case "append":
current[last_path] += value;
break;
case "add":
if (Array.isArray(current)) {
current.splice(Number(last_path), 0, value);
} else {
current[last_path] = value;
}
break;
case "delete":
if (Array.isArray(current)) {
current.splice(Number(last_path), 1);
} else {
delete current[last_path];
}
break;
default:
throw new Error(`Unknown action: ${action}`);
}
return target;
}

export function apply_diff(
obj: any,
diff: [string, (number | string)[], any][]
): any {
diff.forEach(([action, path, value]) => {
obj = apply_edit(obj, path, action, value);
});

return obj;
}
20 changes: 15 additions & 5 deletions client/python/gradio_client/client.py
Expand Up @@ -428,7 +428,12 @@ def submit(
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse", "sse_v1"):
if self.endpoints[inferred_fn_index].protocol in (
"ws",
"sse",
"sse_v1",
"sse_v2",
):
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)
Expand Down Expand Up @@ -998,13 +1003,15 @@ def _predict(*data) -> tuple:
result = utils.synchronize_async(
self._sse_fn_v0, data, hash_data, helper
)
elif self.protocol == "sse_v1":
elif self.protocol == "sse_v1" or self.protocol == "sse_v2":
event_id = utils.synchronize_async(
self.client.send_data, data, hash_data
)
self.client.pending_event_ids.add(event_id)
self.client.pending_messages_per_event[event_id] = []
result = utils.synchronize_async(self._sse_fn_v1, helper, event_id)
result = utils.synchronize_async(
self._sse_fn_v1_v2, helper, event_id, self.protocol
)
else:
raise ValueError(f"Unsupported protocol: {self.protocol}")

Expand Down Expand Up @@ -1197,13 +1204,16 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.cookies,
)

async def _sse_fn_v1(self, helper: Communicator, event_id: str):
return await utils.get_pred_from_sse_v1(
async def _sse_fn_v1_v2(
self, helper: Communicator, event_id: str, protocol: Literal["sse_v1", "sse_v2"]
):
return await utils.get_pred_from_sse_v1_v2(
helper,
self.client.headers,
self.client.cookies,
self.client.pending_messages_per_event,
event_id,
protocol,
)


Expand Down
74 changes: 66 additions & 8 deletions client/python/gradio_client/utils.py
Expand Up @@ -2,6 +2,7 @@

import asyncio
import base64
import copy
import hashlib
import json
import mimetypes
Expand All @@ -17,7 +18,7 @@
from enum import Enum
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Optional, TypedDict
from typing import Any, Callable, Literal, Optional, TypedDict

import fsspec.asyn
import httpx
Expand Down Expand Up @@ -381,22 +382,19 @@ async def get_pred_from_sse_v0(
return task.result()


async def get_pred_from_sse_v1(
async def get_pred_from_sse_v1_v2(
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(
stream_sse_v1(
helper,
pending_messages_per_event,
event_id,
)
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
),
],
return_when=asyncio.FIRST_COMPLETED,
Expand All @@ -411,6 +409,9 @@ async def get_pred_from_sse_v1(

assert len(done) == 1
for task in done:
exception = task.exception()
if exception:
raise exception
return task.result()


Expand Down Expand Up @@ -502,13 +503,15 @@ async def stream_sse_v0(
raise


async def stream_sse_v1(
async def stream_sse_v1_v2(
helper: Communicator,
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2"],
) -> dict[str, Any]:
try:
pending_messages = pending_messages_per_event[event_id]
pending_responses_for_diffs = None

while True:
if len(pending_messages) > 0:
Expand Down Expand Up @@ -540,6 +543,19 @@ async def stream_sse_v1(
log=log_message,
)
output = msg.get("output", {}).get("data", [])
if (
msg["msg"] == ServerMessage.process_generating
and protocol == "sse_v2"
):
if pending_responses_for_diffs is None:
pending_responses_for_diffs = list(output)
else:
for i, value in enumerate(output):
prev_output = pending_responses_for_diffs[i]
new_output = apply_diff(prev_output, value)
pending_responses_for_diffs[i] = new_output
output[i] = new_output

if output and status_update.code != Status.FINISHED:
try:
result = helper.prediction_processor(*output)
Expand All @@ -557,6 +573,48 @@ async def stream_sse_v1(
raise


def apply_diff(obj, diff):
obj = copy.deepcopy(obj)

def apply_edit(target, path, action, value):
if len(path) == 0:
if action == "replace":
return value
elif action == "append":
return target + value
else:
raise ValueError(f"Unsupported action: {action}")

current = target
for i in range(len(path) - 1):
current = current[path[i]]

last_path = path[-1]
if action == "replace":
current[last_path] = value
elif action == "append":
current[last_path] += value
elif action == "add":
if isinstance(current, list):
current.insert(int(last_path), value)
else:
current[last_path] = value
elif action == "delete":
if isinstance(current, list):
del current[int(last_path)]
else:
del current[last_path]
else:
raise ValueError(f"Unknown action: {action}")

return target

for action, path, value in diff:
obj = apply_edit(obj, path, action, value)

return obj


########################
# Data processing utils
########################
Expand Down