Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lite: Websocket queueing #5124

Merged
merged 13 commits into from Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/fresh-bees-allow.md
@@ -0,0 +1,8 @@
---
"@gradio/app": minor
"@gradio/client": minor
"@gradio/wasm": minor
"gradio": minor
---

feat:Lite: Websocket queueing
13 changes: 9 additions & 4 deletions client/js/src/client.ts
Expand Up @@ -165,7 +165,10 @@ interface Client {
) => Promise<unknown[]>;
}

export function api_factory(fetch_implementation: typeof fetch): Client {
export function api_factory(
fetch_implementation: typeof fetch,
WebSocket_factory: (url: URL) => WebSocket
): Client {
return { post_data, upload_files, client, handle_blob };

async function post_data(
Expand Down Expand Up @@ -517,7 +520,7 @@ export function api_factory(fetch_implementation: typeof fetch): Client {
url.searchParams.set("__sign", jwt);
}

websocket = new WebSocket(url);
websocket = WebSocket_factory(url);

websocket.onclose = (evt) => {
if (!evt.wasClean) {
Expand Down Expand Up @@ -803,8 +806,10 @@ export function api_factory(fetch_implementation: typeof fetch): Client {
}
}

export const { post_data, upload_files, client, handle_blob } =
api_factory(fetch);
export const { post_data, upload_files, client, handle_blob } = api_factory(
fetch,
(...args) => new WebSocket(...args)
);

function transform_output(
data: any[],
Expand Down
20 changes: 18 additions & 2 deletions gradio/blocks.py
Expand Up @@ -2074,8 +2074,13 @@ def reverse(text):
# Workaround by triggering the app endpoint
requests.get(f"{self.local_url}startup-events", verify=ssl_verify)
else:
pass
# TODO: Call the startup endpoint in the Wasm env too.
# NOTE: One benefit of the code above dispatching `startup_events()` via a self HTTP request is
# that `self._queue.start()` is called in another thread which is managed by the HTTP server, `uvicorn`
# so all the asyncio tasks created by the queue runs in an event loop in that thread and
# will be cancelled just by stopping the server.
# In contrast, in the Wasm env, we can't do that because `threading` is not supported and all async tasks will run in the same event loop, `pyodide.webloop.WebLoop` in the main thread.
# So we need to manually cancel them. See `self.close()`..
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aliabid94 Is this comment correct about the purpose of calling the HTTP endpoint startup-events? (Ref: #1969)

self.startup_events()

utils.launch_counter()
self.is_sagemaker = utils.sagemaker_check()
Expand Down Expand Up @@ -2327,6 +2332,17 @@ def close(self, verbose: bool = True) -> None:
Closes the Interface that was launched and frees the port.
"""
try:
if wasm_utils.IS_WASM:
# NOTE:
# Normally, queue-related async tasks (e.g. continuous events created by `gr.Blocks.load(..., every=interval)`, whose async tasks are started at the `/queue/join` endpoint function)
# are running in an event loop in the server thread,
# so they will be cancelled by `self.server.close()` below.
# However, in the Wasm env, we don't have the `server` and
# all async tasks are running in the same event loop, `pyodide.webloop.WebLoop` in the main thread,
# so we have to cancel them explicitly so that these tasks won't run after a new app is launched.
if self.enable_queue:
self._queue._cancel_asyncio_tasks()
self.server_app._cancel_asyncio_tasks()
if self.enable_queue:
self._queue.close()
if self.server:
Expand Down
24 changes: 20 additions & 4 deletions gradio/queueing.py
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.max_size = max_size
self.blocks_dependencies = blocks_dependencies
self.continuous_tasks: list[Event] = []
self._asyncio_tasks: list[asyncio.Task] = []

def start(self):
run_coro_in_background(self.start_processing)
Expand All @@ -88,6 +89,11 @@ def close(self):
def resume(self):
self.stopped = False

def _cancel_asyncio_tasks(self):
for task in self._asyncio_tasks:
task.cancel()
self._asyncio_tasks = []

def set_server_app(self, app: routes.App):
self.server_app = app

Expand Down Expand Up @@ -133,9 +139,12 @@ async def start_processing(self) -> None:

if events:
self.active_jobs[self.active_jobs.index(None)] = events
task = run_coro_in_background(self.process_events, events, batch)
run_coro_in_background(self.broadcast_live_estimations)
set_task_name(task, events[0].session_hash, events[0].fn_index, batch)
process_event_task = run_coro_in_background(self.process_events, events, batch)
set_task_name(process_event_task, events[0].session_hash, events[0].fn_index, batch)
broadcast_live_estimations_task = run_coro_in_background(self.broadcast_live_estimations)

self._asyncio_tasks.append(process_event_task)
self._asyncio_tasks.append(broadcast_live_estimations_task)

async def start_log_and_progress_updates(self) -> None:
while not self.stopped:
Expand Down Expand Up @@ -518,7 +527,14 @@ async def process_events(self, events: list[Event], batch: bool) -> None:
await event.disconnect()
except Exception:
pass
self.active_jobs[self.active_jobs.index(events)] = None
try:
self.active_jobs[self.active_jobs.index(events)] = None
except ValueError:
# `events` can be absent from `self.active_jobs`
# when this coroutine is called from the `join_queue` endpoint handler in `routes.py`
# without putting the `events` into `self.active_jobs`.
# https://github.com/gradio-app/gradio/blob/f09aea34d6bd18c1e2fef80c86ab2476a6d1dd83/gradio/routes.py#L594-L596
pass
for event in events:
await self.clean_event(event)
# Always reset the state of the iterator
Expand Down
7 changes: 7 additions & 0 deletions gradio/routes.py
Expand Up @@ -128,6 +128,7 @@ def __init__(self, **kwargs):
set()
) # these are the full paths to the replicas if running on a Hugging Face Space with multiple replicas
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
# Allow user to manually set `docs_url` and `redoc_url`
# when instantiating an App; when they're not set, disable docs and redoc.
kwargs.setdefault("docs_url", None)
Expand Down Expand Up @@ -174,6 +175,11 @@ def build_proxy_request(self, url_path):
rp_req = client.build_request("GET", url, headers=headers)
return rp_req

def _cancel_asyncio_tasks(self):
for task in self._asyncio_tasks:
task.cancel()
self._asyncio_tasks = []

@staticmethod
def create_app(
blocks: gradio.Blocks, app_kwargs: Dict[str, Any] | None = None
Expand Down Expand Up @@ -595,6 +601,7 @@ async def join_queue(
blocks._queue.process_events, [event], False
)
set_task_name(task, event.session_hash, event.fn_index, batch=False)
app._asyncio_tasks.append(task)
else:
rank = blocks._queue.push(event)

Expand Down
2 changes: 2 additions & 0 deletions gradio/wasm_utils.py
Expand Up @@ -16,6 +16,8 @@ class WasmUnsupportedError(Exception):
# the Gradio's FastAPI app instance (`app`).
def register_app(_app):
global app
if app:
app.blocks.close()
whitphx marked this conversation as resolved.
Show resolved Hide resolved
app = _app


Expand Down
4 changes: 2 additions & 2 deletions js/app/src/lite/css.ts
@@ -1,5 +1,5 @@
import type { WorkerProxy } from "@gradio/wasm";
import { is_self_origin } from "./url";
import { is_self_host } from "./url";
import { mount_css as default_mount_css } from "../css";

// In the Wasm mode, we use a prebuilt CSS file `/static/css/theme.css` to apply the styles in the initialization phase,
Expand All @@ -20,7 +20,7 @@ export async function wasm_proxied_mount_css(
const request = new Request(url_string); // Resolve a relative URL.
const url = new URL(request.url);

if (!is_self_origin(url)) {
if (!is_self_host(url)) {
// Fallback to the default implementation for external resources.
return default_mount_css(url_string, target);
}
Expand Down
29 changes: 10 additions & 19 deletions js/app/src/lite/dev/App.svelte
Expand Up @@ -13,32 +13,23 @@
let editorFiles: EditorFile[] = [
{
name: "app.py",
content: `import gradio as gr
content: `import time
import gradio as gr

from greeting import hi

def greet(name):
return "Hello " + name + "!"
def greet_with_time():
return "Time is " + time.ctime()

def upload_file(files):
file_paths = [file.name for file in files]
return file_paths

with gr.Blocks(theme=gr.themes.Soft()) as demo:
name = gr.Textbox(label="Name")
output = gr.Textbox(label="Output Box")
greet_btn = gr.Button("Greet")
greet_btn.click(fn=greet, inputs=name, outputs=output, api_name="greet")
say_hi_btn = gr.Button("Say hi")
say_hi_btn.click(fn=hi, inputs=name, outputs=output, api_name="hi")
with gr.Blocks() as demo:
text = gr.Markdown(f"Time is {time.time()}")

gr.File()
dep = demo.load(greet_with_time, None, text, every=1)

file_output = gr.File()
upload_button = gr.UploadButton("Click to Upload a File", file_types=["image", "video"], file_count="multiple")
upload_button.upload(upload_file, upload_button, file_output)

demo.launch()`
if __name__ == "__main__":
demo.queue().launch()
`
},
{
name: "greeting.py",
Expand Down
4 changes: 2 additions & 2 deletions js/app/src/lite/fetch.ts
@@ -1,5 +1,5 @@
import type { WorkerProxy } from "@gradio/wasm";
import { is_self_origin } from "./url";
import { is_self_host } from "./url";

/**
* A fetch() function that proxies HTTP requests to the worker,
Expand All @@ -24,7 +24,7 @@ export async function wasm_proxied_fetch(

const url = new URL(request.url);

if (!is_self_origin(url)) {
if (!is_self_host(url)) {
console.debug("Fallback to original fetch");
return fetch(input, init);
}
Expand Down
9 changes: 8 additions & 1 deletion js/app/src/lite/index.ts
Expand Up @@ -2,6 +2,7 @@ import "@gradio/theme";
import { WorkerProxy, type WorkerProxyOptions } from "@gradio/wasm";
import { api_factory } from "@gradio/client";
import { wasm_proxied_fetch } from "./fetch";
import { wasm_proxied_WebSocket_factory } from "./websocket";
import { wasm_proxied_mount_css, mount_prebuilt_css } from "./css";
import type { mount_css } from "../css";
import Index from "../Index.svelte";
Expand Down Expand Up @@ -91,7 +92,13 @@ export function create(options: Options): GradioAppController {
const overridden_fetch: typeof fetch = (input, init?) => {
return wasm_proxied_fetch(worker_proxy, input, init);
};
const { client, upload_files } = api_factory(overridden_fetch);
const WebSocket_factory = (url: URL): WebSocket => {
return wasm_proxied_WebSocket_factory(worker_proxy, url);
};
const { client, upload_files } = api_factory(
overridden_fetch,
WebSocket_factory
);
const overridden_mount_css: typeof mount_css = async (url, target) => {
return wasm_proxied_mount_css(worker_proxy, url, target);
};
Expand Down
5 changes: 2 additions & 3 deletions js/app/src/lite/url.ts
@@ -1,6 +1,5 @@
export function is_self_origin(url: URL): boolean {
export function is_self_host(url: URL): boolean {
return (
url.origin === window.location.origin ||
url.origin === "http://localhost:7860" // Ref: https://github.com/gradio-app/gradio/blob/v3.32.0/js/app/src/Index.svelte#L194
url.host === window.location.host || url.host === "localhost:7860" // Ref: https://github.com/gradio-app/gradio/blob/v3.32.0/js/app/src/Index.svelte#L194
);
}
19 changes: 19 additions & 0 deletions js/app/src/lite/websocket.ts
@@ -0,0 +1,19 @@
import type { WorkerProxy } from "@gradio/wasm";
import { is_self_host } from "./url";

/**
* A WebSocket factory that proxies requests to the worker,
* which also falls back to the original WebSocket() for external resource requests.
*/

export function wasm_proxied_WebSocket_factory(
worker_proxy: WorkerProxy,
url: URL
): WebSocket {
if (!is_self_host(url)) {
console.debug("Fallback to original WebSocket");
return new WebSocket(url);
}

return worker_proxy.openWebSocket(url.pathname) as unknown as WebSocket;
}
7 changes: 7 additions & 0 deletions js/wasm/src/message-types.ts
Expand Up @@ -51,6 +51,12 @@ export interface InMessageHttpRequest extends InMessageBase {
request: HttpRequest;
};
}
export interface InMessageWebSocket extends InMessageBase {
type: "websocket";
data: {
path: string;
};
}
export interface InMessageFileWrite extends InMessageBase {
type: "file:write";
data: {
Expand Down Expand Up @@ -89,6 +95,7 @@ export type InMessage =
| InMessageInit
| InMessageRunPythonCode
| InMessageRunPythonFile
| InMessageWebSocket
| InMessageHttpRequest
| InMessageFileWrite
| InMessageFileRename
Expand Down