Skip to content

Commit

Permalink
Refactor queue so that there are separate queues for each concurrency…
Browse files Browse the repository at this point in the history
… id (#6814)

* change

* changes

* add changeset

* add changeset

* changes

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <ubuntu@ip-172-31-25-241.us-west-2.compute.internal>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people committed Dec 19, 2023
1 parent 73268ee commit 828fb9e
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 190 deletions.
7 changes: 7 additions & 0 deletions .changeset/eighty-teeth-greet.md
@@ -0,0 +1,7 @@
---
"@gradio/client": patch
"@gradio/statustracker": patch
"gradio": patch
---

feat:Refactor queue so that there are separate queues for each concurrency id
29 changes: 21 additions & 8 deletions client/js/src/client.ts
Expand Up @@ -287,6 +287,7 @@ export function api_factory(
const session_hash = Math.random().toString(36).substring(2);
const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
let event_stream: EventSource | null = null;
const event_callbacks: Record<string, () => Promise<void>> = {};
let config: Config;
Expand Down Expand Up @@ -908,8 +909,8 @@ export function api_factory(
}

if (
status.stage === "complete" ||
status.stage === "error"
status?.stage === "complete" ||
status?.stage === "error"
) {
if (event_callbacks[event_id]) {
delete event_callbacks[event_id];
Expand All @@ -932,6 +933,12 @@ export function api_factory(
close_stream();
}
};
if (event_id in pending_stream_messages) {
pending_stream_messages[event_id].forEach((msg) =>
callback(msg)
);
delete pending_stream_messages[event_id];
}
event_callbacks[event_id] = callback;
if (!stream_open) {
open_stream();
Expand Down Expand Up @@ -1051,15 +1058,21 @@ export function api_factory(
event_stream = new EventSource(url);
event_stream.onmessage = async function (event) {
let _data = JSON.parse(event.data);
if (!("event_id" in _data)) {
const event_id = _data.event_id;
if (!event_id) {
await Promise.all(
Object.keys(event_callbacks).map((event_id) =>
event_callbacks[event_id](_data)
)
);
return;
} else if (event_callbacks[event_id]) {
await event_callbacks[event_id](_data);
} else {
if (!pending_stream_messages[event_id]) {
pending_stream_messages[event_id] = [];
}
pending_stream_messages[event_id].push(_data);
}
await event_callbacks[_data.event_id](_data);
};
}

Expand Down Expand Up @@ -1701,8 +1714,7 @@ function handle_message(
message: !data.success ? data.output.error : undefined,
stage: data.success ? "complete" : "error",
code: data.code,
progress_data: data.progress_data,
eta: data.output.average_duration
progress_data: data.progress_data
},
data: data.success ? data.output : null
};
Expand All @@ -1716,7 +1728,8 @@ function handle_message(
code: data.code,
size: data.rank,
position: 0,
success: data.success
success: data.success,
eta: data.eta
}
};
}
Expand Down
2 changes: 1 addition & 1 deletion gradio/blocks.py
Expand Up @@ -376,7 +376,7 @@ def __init__(
self.preprocess = preprocess
self.postprocess = postprocess
self.tracks_progress = tracks_progress
self.concurrency_limit = concurrency_limit
self.concurrency_limit: int | None | Literal["default"] = concurrency_limit
self.concurrency_id = concurrency_id or str(id(fn))
self.batch = batch
self.max_batch_size = max_batch_size
Expand Down
3 changes: 0 additions & 3 deletions gradio/data_classes.py
Expand Up @@ -101,10 +101,7 @@ class InterfaceTypes(Enum):
class Estimation(BaseModel):
rank: Optional[int] = None
queue_size: int
avg_event_process_time: Optional[float] = None
avg_event_concurrent_process_time: Optional[float] = None
rank_eta: Optional[float] = None
queue_eta: float


class ProgressUnit(BaseModel):
Expand Down

0 comments on commit 828fb9e

Please sign in to comment.