Skip to content

Commit

Permalink
Fix api event drops (#6556)
Browse files Browse the repository at this point in the history
* changes

* changes

* add changeset

* changes

* changes

* changes

* changs

* chagnes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes~git push

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* changes

* change

* changes

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-25-241.us-west-2.compute.internal>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed Dec 12, 2023
1 parent 67ddd40 commit d76bcaa
Show file tree
Hide file tree
Showing 14 changed files with 678 additions and 444 deletions.
7 changes: 7 additions & 0 deletions .changeset/ripe-spiders-love.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/client": patch
"gradio": patch
"gradio_client": patch
---

fix:Fix api event drops
146 changes: 144 additions & 2 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ export function api_factory(

const session_hash = Math.random().toString(36).substring(2);
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
let config: Config;
let api_map: Record<string, number> = {};

Expand Down Expand Up @@ -437,7 +440,7 @@ export function api_factory(

let websocket: WebSocket;
let eventSource: EventSource;
let protocol = config.protocol ?? "sse";
let protocol = config.protocol ?? "ws";

const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
let payload: Payload;
Expand Down Expand Up @@ -646,7 +649,7 @@ export function api_factory(
websocket.send(JSON.stringify({ hash: session_hash }))
);
}
} else {
} else if (protocol == "sse") {
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -766,6 +769,121 @@ export function api_factory(
}
}
};
} else if (protocol == "sse_v1") {
fire_event({
type: "status",
stage: "pending",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});

post_data(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/join?${url_params}`,
{
...payload,
session_hash
},
hf_token
).then(([response, status]) => {
if (status !== 200) {
fire_event({
type: "status",
stage: "error",
message: BROKEN_CONNECTION_MSG,
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
} else {
event_id = response.event_id as string;
if (!stream_open) {
open_stream();
}

let callback = async function (_data: object): void {
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);

if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
} else if (type === "complete") {
complete = status;
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});

if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
}

if (status.stage === "complete" || status.stage === "error") {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
if (Object.keys(event_callbacks).length === 0) {
close_stream();
}
}
}
};
event_callbacks[event_id] = callback;
}
});
}
});

Expand Down Expand Up @@ -864,6 +982,30 @@ export function api_factory(
};
}

function open_stream(): void {
stream_open = true;
let params = new URLSearchParams({
session_hash: session_hash
}).toString();
let url = new URL(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/data?${params}`
);
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
await event_callbacks[_data.event_id](_data);
};
}

function close_stream(): void {
stream_open = false;
event_stream?.close();
}

async function component_server(
component_id: number,
fn_name: string,
Expand Down
2 changes: 1 addition & 1 deletion client/js/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export interface Config {
show_api: boolean;
stylesheets: string[];
path: string;
protocol?: "sse" | "ws";
protocol?: "sse_v1" | "sse" | "ws";
}

export interface Payload {
Expand Down
Loading

0 comments on commit d76bcaa

Please sign in to comment.