Skip to content

Commit

Permalink
Fixes gradio-app#4016 by using websockets with callback so user can m…
Browse files Browse the repository at this point in the history
…anage freeing up of state objects
  • Loading branch information
pseudotensor committed Feb 14, 2024
1 parent 3e886d8 commit 70376b2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
24 changes: 24 additions & 0 deletions client/js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,30 @@ export function api_factory(
await process_endpoint(app_reference, hf_token);

const session_hash = Math.random().toString(36).substring(2);

// WebSocket setup using the session_hash to check for alive state of client
const wsUrl = `${ws_protocol}://${host}/ws?session_hash=${session_hash}`;
const ws = new WebSocket(wsUrl);

ws.onopen = () => {
console.log("WebSocket connected with session_hash:", session_hash);
// Optionally send a message or perform an action upon connection
};

ws.onmessage = (event) => {
console.log("Received message:", event.data);
// Handle incoming WebSocket messages
};

ws.onerror = (error) => {
console.error("WebSocket error:", error);
};

ws.onclose = () => {
console.log("WebSocket connection closed");
// Optionally perform cleanup or reconnection logic
};

const last_status: Record<string, Status["stage"]> = {};
let stream_open = false;
let pending_stream_messages: Record<string, any[]> = {}; // Event messages may be received by the SSE stream before the initial data POST request is complete. To resolve this race condition, we store the messages in a dictionary and process them when the POST request is complete.
Expand Down
4 changes: 3 additions & 1 deletion gradio/components/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any
from typing import Any, Callable

from gradio_client.documentation import document

Expand All @@ -27,6 +27,7 @@ def __init__(
self,
value: Any = None,
render: bool = True,
callback: Callable = None,
):
"""
Parameters:
Expand All @@ -40,6 +41,7 @@ def __init__(
raise TypeError(
f"The initial value of `gr.State` must be able to be deepcopied. The initial value of type {type(value)} cannot be deepcopied."
) from err
self.callback = callback
super().__init__(value=self.value, render=render)

def preprocess(self, payload: Any) -> Any:
Expand Down
21 changes: 21 additions & 0 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def toorjson(value):

file_upload_statuses = FileUploadProgress()

from fastapi import FastAPI, WebSocket, Query
from starlette.websockets import WebSocketDisconnect

class App(FastAPI):
"""
Expand Down Expand Up @@ -207,6 +209,25 @@ def create_app(
allow_headers=["*"],
)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, session_hash: str = Query(default=None)):
await websocket.accept()
print("Client Connected: %s" % session_hash, flush=True)
try:
while True:
# You can also process incoming messages here
data = await websocket.receive_text()
print(f"Message from client: {data}")
except WebSocketDisconnect:
from gradio.components import State
states = app.state_holder.session_data[session_hash]
for k, state_data in states._data.items():
state = states.blocks.blocks[k]
if isinstance(state, State) and state.callback is not None:
state.callback(state_data)
app.state_holder.session_data.pop(session_hash)
print("Client disconnected: %s" % session_hash, flush=True)

@app.get("/user")
@app.get("/user/")
def get_current_user(request: fastapi.Request) -> Optional[str]:
Expand Down

0 comments on commit 70376b2

Please sign in to comment.