Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swap websockets for SSE #6069

Merged
merged 56 commits into from Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
e7cfeb6
changes
aliabid94 Oct 19, 2023
6f54f20
merge
aliabid94 Oct 19, 2023
a828a6c
changes
aliabid94 Oct 19, 2023
36aaaf5
changes
aliabid94 Oct 20, 2023
a21fbf1
changes
aliabid94 Oct 23, 2023
7b88e8d
changes
aliabid94 Oct 23, 2023
fbe580f
merge
aliabid94 Oct 23, 2023
a4f7759
changes
aliabid94 Oct 23, 2023
5a42da7
changes
aliabid94 Oct 23, 2023
608842e
changes
aliabid94 Oct 23, 2023
6b446fd
changes
aliabid94 Oct 24, 2023
cc35772
changes
aliabid94 Oct 24, 2023
ef6a299
Merge branch 'v4' into sse
aliabid94 Oct 24, 2023
7315f48
changes
aliabid94 Oct 25, 2023
e9b416e
changes
aliabid94 Oct 25, 2023
56c6a31
changes
aliabid94 Oct 25, 2023
3ad6fe8
Merge remote-tracking branch 'origin/v4' into sse
aliabid94 Oct 25, 2023
1b0005f
Merge branch 'v4' into sse
aliabid94 Oct 25, 2023
efb6986
changes
aliabid94 Oct 25, 2023
9d4b19a
changes
aliabid94 Oct 26, 2023
7052464
merge
aliabid94 Oct 26, 2023
d4b6023
changes
aliabid94 Oct 26, 2023
09dd08e
changes
aliabid94 Oct 26, 2023
efe2990
changes
aliabid94 Oct 26, 2023
3760d5b
changes
aliabid94 Oct 26, 2023
3168903
changes
aliabid94 Oct 26, 2023
a43f49f
changes
aliabid94 Oct 26, 2023
1a670da
add changeset
gradio-pr-bot Oct 26, 2023
d94b3db
changes
aliabid94 Oct 27, 2023
266aec2
changes
aliabid94 Oct 27, 2023
79987a3
changes
aliabid94 Oct 27, 2023
103d78a
Merge branch 'v4' into sse
aliabid94 Oct 27, 2023
f9cfd29
changes
aliabid94 Oct 29, 2023
d32e68e
add changeset
gradio-pr-bot Oct 29, 2023
d712f5a
changes
aliabid94 Oct 29, 2023
abeb701
Merge branch 'sse' of https://github.com/gradio-app/gradio into sse
aliabid94 Oct 29, 2023
5bfbc27
changes
aliabid94 Oct 29, 2023
78a700a
changes
aliabid94 Oct 30, 2023
9be2c16
changes
aliabid94 Oct 30, 2023
dc95b3d
add changeset
gradio-pr-bot Oct 30, 2023
e8f4a2d
changes
aliabid94 Oct 30, 2023
d25616e
Merge branch 'sse' of https://github.com/gradio-app/gradio into sse
aliabid94 Oct 30, 2023
d1c4bc6
changes
aliabid94 Oct 30, 2023
6246d5e
changes
aliabid94 Oct 30, 2023
ea2aa97
add changeset
gradio-pr-bot Oct 30, 2023
d35a890
changes
aliabid94 Oct 30, 2023
165466a
changes
aliabid94 Oct 30, 2023
ec2cd06
Merge branch 'v4' into sse
aliabid94 Oct 30, 2023
ca5fac4
changes
aliabid94 Oct 30, 2023
60572e0
changes
aliabid94 Oct 30, 2023
22b7d26
Merge branch 'v4' into sse
aliabid94 Oct 30, 2023
732e1b7
changes
aliabid94 Oct 30, 2023
ef7bcd7
Merge branch 'v4' into sse
aliabid94 Oct 30, 2023
373ba21
add changeset
gradio-pr-bot Oct 30, 2023
3bc906e
Fix client tests sse branch (#6150)
freddyaboulton Oct 30, 2023
44eda09
Merge branch 'v4' into sse
aliabid94 Oct 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/brave-cats-show.md
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/theme": minor
"gradio": minor
---

feat:Lite: Support the custom HTML element syntax `<gradio-lite>`
5 changes: 5 additions & 0 deletions .changeset/crazy-dancers-allow.md
@@ -0,0 +1,5 @@
---
"@gradio/wasm": minor
---

feat:Make the HTTP requests for the Wasm worker wait for the initial `run_code()` or `run_file()` to finish
7 changes: 7 additions & 0 deletions .changeset/green-tips-read.md
@@ -0,0 +1,7 @@
---
"@gradio/app": minor
"@gradio/wasm": minor
"gradio": minor
---

feat:Apply formatter (and small refactoring) to the Lite-related frontend code
7 changes: 7 additions & 0 deletions .changeset/twenty-dryers-wave.md
@@ -0,0 +1,7 @@
---
"@gradio/atoms": patch
"@gradio/json": patch
"gradio": patch
---

fix:Show empty JSON icon when `value` is `null`
6 changes: 6 additions & 0 deletions .changeset/wise-beans-itch.md
@@ -0,0 +1,6 @@
---
"@gradio/dataframe": minor
"gradio": minor
---

feat:Adds `column_widths` to `gr.Dataframe` and hide overflowing text when `wrap=False`
113 changes: 112 additions & 1 deletion client/js/src/client.ts
Expand Up @@ -427,6 +427,9 @@ export function api_factory(
}

let websocket: WebSocket;
let eventSource: EventSource;
const MAJOR_VERSION = parseInt(config.version.split(".")[0]);
let protocol = MAJOR_VERSION >= 4 ? "sse" : "ws";

const _endpoint = typeof endpoint === "number" ? "/predict" : endpoint;
let payload: Payload;
Expand Down Expand Up @@ -514,7 +517,7 @@ export function api_factory(
time: new Date()
});
});
} else {
} else if (protocol == "ws") {
fire_event({
type: "status",
stage: "pending",
Expand Down Expand Up @@ -634,6 +637,114 @@ export function api_factory(
websocket.send(JSON.stringify({ hash: session_hash }))
);
}
} else {
fire_event({
type: "status",
stage: "pending",
queue: true,
endpoint: _endpoint,
fn_index,
time: new Date()
});
var params = new URLSearchParams({
fn_index: fn_index.toString(),
session_hash: session_hash
}).toString();
let url = new URL(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/join?${params}`
);

eventSource = new EventSource(url);

eventSource.onmessage = function (event) {
const _data = JSON.parse(event.data);
const { type, status, data } = handle_message(
_data,
last_status[fn_index]
);

if (type === "update" && status && !complete) {
// call 'status' listeners
fire_event({
type: "status",
endpoint: _endpoint,
fn_index,
time: new Date(),
...status
});
if (status.stage === "error") {
eventSource.close();
}
} else if (type === "data") {
let event_id = _data.event_id;
post_data(
`${http_protocol}//${resolve_root(
host,
config.path,
true
)}/queue/data`,
{
...payload,
session_hash,
event_id
},
hf_token
);
} else if (type === "complete") {
complete = status;
} else if (type === "log") {
fire_event({
type: "log",
log: data.log,
level: data.level,
endpoint: _endpoint,
fn_index
});
} else if (type === "generating") {
fire_event({
type: "status",
time: new Date(),
...status,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
}
if (data) {
fire_event({
type: "data",
time: new Date(),
data: transform_files
? transform_output(
data.data,
api_info,
config.root,
config.root_url
)
: data.data,
endpoint: _endpoint,
fn_index
});

if (complete) {
fire_event({
type: "status",
time: new Date(),
...complete,
stage: status?.stage!,
queue: true,
endpoint: _endpoint,
fn_index
});
eventSource.close();
}
}
};
}
});

Expand Down
103 changes: 74 additions & 29 deletions client/python/gradio_client/client.py
Expand Up @@ -19,6 +19,7 @@
from threading import Lock
from typing import Any, Callable, Literal

import httpx
import huggingface_hub
import requests
import websockets
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
serialize: bool = True,
output_dir: str | Path = DEFAULT_TEMP_DIR,
verbose: bool = True,
auth: tuple[str, str] | None = None,
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Parameters:
Expand All @@ -92,6 +94,7 @@ def __init__(
library_version=utils.__version__,
)
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (
str(output_dir) if isinstance(output_dir, Path) else output_dir
)
Expand Down Expand Up @@ -122,11 +125,15 @@ def __init__(
print(f"Loaded as API: {self.src} ✔")

self.api_url = urllib.parse.urljoin(self.src, utils.API_URL)
self.sse_url = urllib.parse.urljoin(self.src, utils.SSE_URL)
self.sse_data_url = urllib.parse.urljoin(self.src, utils.SSE_DATA_URL)
self.ws_url = urllib.parse.urljoin(
self.src.replace("http", "ws", 1), utils.WS_URL
)
self.upload_url = urllib.parse.urljoin(self.src, utils.UPLOAD_URL)
self.reset_url = urllib.parse.urljoin(self.src, utils.RESET_URL)
if auth is not None:
self._login(auth)
self.config = self._get_config()
self._info = self._get_api_info()
self.session_hash = str(uuid.uuid4())
Expand Down Expand Up @@ -322,7 +329,7 @@ def submit(
inferred_fn_index = self._infer_fn_index(api_name, fn_index)

helper = None
if self.endpoints[inferred_fn_index].use_ws:
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: let's modify and use_ws to check if the protocol is "ws" or "sse" since we're using it in different places

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry don't understand what you're suggesting

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, instead of self.endpoints[inferred_fn_index].protocol in ("ws", "sse") in multiple places, we do self.endpoints[inferred_fn_index].use_http or something like that.

helper = Communicator(
Lock(),
JobStatus(),
Expand Down Expand Up @@ -363,11 +370,11 @@ def _get_api_info(self):
# Versions of Gradio older than 3.29.0 returned format of the API info
# from the /info endpoint
if version.parse(self.config.get("version", "2.0")) > version.Version("3.36.1"):
r = requests.get(api_info_url, headers=self.headers)
r = requests.get(api_info_url, headers=self.headers, cookies=self.cookies)
if r.ok:
info = r.json()
else:
raise ValueError(f"Could not fetch api info for {self.src}")
raise ValueError(f"Could not fetch api info for {self.src}: {r.text}")
else:
fetch = requests.post(
utils.SPACE_FETCHER_URL,
Expand All @@ -376,7 +383,9 @@ def _get_api_info(self):
if fetch.ok:
info = fetch.json()["api"]
else:
raise ValueError(f"Could not fetch api info for {self.src}")
raise ValueError(
f"Could not fetch api info for {self.src}: {fetch.text}"
)

return info

Expand Down Expand Up @@ -599,14 +608,33 @@ def __del__(self):
def _space_name_to_src(self, space) -> str | None:
return huggingface_hub.space_info(space, token=self.hf_token).host # type: ignore

def _login(self, auth: tuple[str, str]):
resp = requests.post(
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
data={"username": auth[0], "password": auth[1]},
)
if not resp.ok:
raise ValueError(f"Could not login to {self.src}")
self.cookies = {
cookie.name: cookie.value
for cookie in resp.cookies
if cookie.value is not None
}

def _get_config(self) -> dict:
r = requests.get(
urllib.parse.urljoin(self.src, utils.CONFIG_URL), headers=self.headers
urllib.parse.urljoin(self.src, utils.CONFIG_URL),
headers=self.headers,
cookies=self.cookies,
)
if r.ok:
return r.json()
elif r.status_code == 401:
raise ValueError("Could not load {self.src}. Please login.")
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
else: # to support older versions of Gradio
r = requests.get(self.src, headers=self.headers)
r = requests.get(self.src, headers=self.headers, cookies=self.cookies)
if not r.ok:
raise ValueError(f"Could not fetch config for {self.src}")
# some basic regex to extract the config
result = re.search(r"window.gradio_config = (.*?);[\s]*</script>", r.text)
try:
Expand Down Expand Up @@ -781,7 +809,7 @@ def __init__(self, client: Client, fn_index: int, dependency: dict):
self.api_name: str | Literal[False] | None = (
"/" + api_name if isinstance(api_name, str) else api_name
)
self.use_ws = self._use_websocket(self.dependency)
self.protocol = self._get_protocol(self.dependency)
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved
self.input_component_types = [
self._get_component_type(id_) for id_ in dependency["inputs"]
]
Expand Down Expand Up @@ -847,23 +875,26 @@ def _inner(*data):

def make_predict(self, helper: Communicator | None = None):
def _predict(*data) -> tuple:
data = json.dumps(
{
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
hash_data = json.dumps(
{
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}
)
data = {
"data": data,
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}

if self.use_ws:
hash_data = {
"fn_index": self.fn_index,
"session_hash": self.client.session_hash,
}

if self.protocol == "ws":
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
if "error" in result:
raise ValueError(result["error"])
elif self.protocol == "sse":
result = utils.synchronize_async(self._sse_fn, data, hash_data, helper)
if "error" in result:
print(">>>", result)
aliabid94 marked this conversation as resolved.
Show resolved Hide resolved

raise ValueError(result["error"])
else:
response = requests.post(
Expand Down Expand Up @@ -1045,15 +1076,21 @@ def process_predictions(self, *predictions):
predictions = self.reduce_singleton_output(*predictions)
return predictions

def _use_websocket(self, dependency: dict) -> bool:
def _get_protocol(self, dependency: dict) -> Literal["sse", "ws", "http"]:
major_version = int(self.client.config["version"].split(".")[0])
if major_version >= 4:
return "sse"
queue_enabled = self.client.config.get("enable_queue", False)
queue_uses_websocket = version.parse(
self.client.config.get("version", "2.0")
) >= version.Version("3.2")
dependency_uses_queue = dependency.get("queue", False) is not False
return queue_enabled and queue_uses_websocket and dependency_uses_queue
if queue_enabled and queue_uses_websocket and dependency_uses_queue:
return "ws"
else:
return "http"

async def _ws_fn(self, data, hash_data, helper: Communicator):
async def _ws_fn(self, data: dict, hash_data: dict, helper: Communicator):
async with websockets.connect( # type: ignore
self.client.ws_url,
open_timeout=10,
Expand All @@ -1062,6 +1099,18 @@ async def _ws_fn(self, data, hash_data, helper: Communicator):
) as websocket:
return await utils.get_pred_from_ws(websocket, data, hash_data, helper)

async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
return await utils.get_pred_from_sse(
client,
data,
hash_data,
helper,
self.client.sse_url,
self.client.sse_data_url,
self.client.cookies,
)


@document("result", "outputs", "status")
class Job(Future):
Expand Down Expand Up @@ -1102,13 +1151,9 @@ def __next__(self) -> tuple | Any:
if not self.communicator:
raise StopIteration()

with self.communicator.lock:
if self.communicator.job.latest_status.code == Status.FINISHED:
raise StopIteration()

while True:
with self.communicator.lock:
if len(self.communicator.job.outputs) == self._counter + 1:
if len(self.communicator.job.outputs) >= self._counter + 1:
o = self.communicator.job.outputs[self._counter]
self._counter += 1
return o
Expand Down