diff --git a/.changeset/better-tires-shave.md b/.changeset/better-tires-shave.md new file mode 100644 index 000000000000..666510bdcc64 --- /dev/null +++ b/.changeset/better-tires-shave.md @@ -0,0 +1,42 @@ +--- +"@gradio/client": minor +"gradio": minor +"gradio_client": minor +--- + +highlight: + +#### Automatically delete state after user has disconnected from the webpage + +Gradio now automatically deletes `gr.State` variables stored in the server's RAM when users close their browser tab. +The deletion will happen 60 minutes after the server detected a disconnect from the user's browser. +If the user connects again in that timeframe, their state will not be deleted. + +Additionally, Gradio now includes a `Blocks.unload()` event, allowing you to run arbitrary cleanup functions when users disconnect (this does not have a 60 minute delay). +You can think of the `unload` event as the opposite of the `load` event. + + +```python +with gr.Blocks() as demo: + gr.Markdown( +"""# State Cleanup Demo +🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload. +""") + with gr.Row(): + with gr.Column(scale=1): + with gr.Row(): + img = gr.Image(label="Generated Image", height=300, width=300) + with gr.Row(): + gen = gr.Button(value="Generate") + with gr.Row(): + history = gr.Gallery(label="Previous Generations", height=500, columns=10) + state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED")) + + demo.load(generate_random_img, [state], [img, state, history]) + gen.click(generate_random_img, [state], [img, state, history]) + demo.unload(delete_directory) + + +demo.launch(auth=lambda user,pwd: True, + auth_message="Enter any username and password to continue") +``` \ No newline at end of file diff --git a/.config/playwright-setup.js b/.config/playwright-setup.js index f8ebf18378f9..1de1955ee66c 100644 --- a/.config/playwright-setup.js +++ b/.config/playwright-setup.js @@ -64,7 +64,8 @@ function spawn_gradio_app(app, port, verbose) { ...process.env, GRADIO_SERVER_PORT: `7879`, PYTHONUNBUFFERED: "true", - GRADIO_ANALYTICS_ENABLED: "False" + GRADIO_ANALYTICS_ENABLED: "False", + GRADIO_IS_E2E_TEST: "1" } }); _process.stdout.setEncoding("utf8"); diff --git a/.gitignore b/.gitignore index 5dbd7349b153..df9c8dbee714 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ demo/*/config.json demo/annotatedimage_component/*.png demo/fake_diffusion_with_gif/*.gif demo/cancel_events/cancel_events_output_log.txt +demo/unload_event_test/output_log.txt # Etc .idea/* diff --git a/client/js/src/client.ts b/client/js/src/client.ts index 7d0fae28fdcb..734004bdf400 100644 --- a/client/js/src/client.ts +++ b/client/js/src/client.ts @@ -357,6 +357,10 @@ export function api_factory( ); const _config = await config_success(config); + // connect to the heartbeat endpoint via GET request + const heartbeat = new EventSource( + `${config.root}/heartbeat/${session_hash}` + ); res(_config); } catch (e) { console.error(e); diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 96385f6e7940..0d37f80eb743 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -159,6 +159,7 @@ def __init__( self.sse_url = urllib.parse.urljoin( self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL ) + self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL) self.sse_data_url = urllib.parse.urljoin( self.src, utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL, @@ -184,13 +185,43 @@ def __init__( self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) # Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1 - threading.Thread(target=self._telemetry_thread).start() + threading.Thread(target=self._telemetry_thread, daemon=True).start() + self._refresh_heartbeat = threading.Event() + self._kill_heartbeat = threading.Event() + + self.heartbeat = threading.Thread(target=self._stream_heartbeat, daemon=True) + self.heartbeat.start() self.stream_open = False self.streaming_future: Future | None = None self.pending_messages_per_event: dict[str, list[Message | None]] = {} self.pending_event_ids: set[str] = set() + def close(self): + self._kill_heartbeat.set() + self.heartbeat.join(timeout=1) + + def _stream_heartbeat(self): + while True: + url = self.heartbeat_url.format(session_hash=self.session_hash) + try: + with httpx.stream( + "GET", + url, + headers=self.headers, + cookies=self.cookies, + verify=self.ssl_verify, + timeout=20, + ) as response: + for _ in response.iter_lines(): + if self._refresh_heartbeat.is_set(): + self._refresh_heartbeat.clear() + break + if self._kill_heartbeat.is_set(): + return + except httpx.TransportError: + return + async def stream_messages( self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"] ) -> None: @@ -640,6 +671,7 @@ def view_api( def reset_session(self) -> None: self.session_hash = str(uuid.uuid4()) + self._refresh_heartbeat.set() def _render_endpoints_info( self, diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 21421654b637..9edd86af65e8 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -42,6 +42,7 @@ SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api" RESET_URL = "reset" SPACE_URL = "https://hf.space/{}" +HEARTBEAT_URL = "heartbeat/{session_hash}" STATE_COMPONENT = "state" INVALID_RUNTIME = [ diff --git a/client/python/test/conftest.py b/client/python/test/conftest.py index d510ce1bf460..986dc480626d 100644 --- a/client/python/test/conftest.py +++ b/client/python/test/conftest.py @@ -75,10 +75,11 @@ def calculator(num1, operation=None, num2=100): @pytest.fixture def state_demo(): + state = gr.State(delete_callback=lambda x: print("STATE DELETED")) demo = gr.Interface( lambda x, y: (x, y), - ["textbox", "state"], - ["textbox", "state"], + ["textbox", state], + ["textbox", state], ) return demo diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index c2c57dd2e460..c83bdb1445f1 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -46,11 +46,7 @@ def connect( # because we should set a timeout # the tests that call .cancel() can get stuck # waiting for the thread to join - if demo.enable_queue: - demo._queue.close() - demo.is_running = False - demo.server.should_exit = True - demo.server.thread.join(timeout=1) + demo.close() class TestClientInitialization: @@ -608,6 +604,15 @@ def return_bad(): pred = client.predict(api_name="/predict") assert pred[0] == data[0] + def test_state_reset_when_session_changes(self, capsys, state_demo, monkeypatch): + monkeypatch.setenv("GRADIO_IS_E2E_TEST", "1") + with connect(state_demo) as client: + client.predict("Hello", api_name="/predict") + client.reset_session() + time.sleep(5) + out = capsys.readouterr().out + assert "STATE DELETED" in out + class TestClientPredictionsWithKwargs: def test_no_default_params(self, calculator_demo): diff --git a/demo/state_cleanup/run.ipynb b/demo/state_cleanup/run.ipynb new file mode 100644 index 000000000000..d6a2a6ee503d --- /dev/null +++ b/demo/state_cleanup/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_cleanup"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from __future__ import annotations\n", "import gradio as gr\n", "import numpy as np\n", "from PIL import Image\n", "from pathlib import Path\n", "import secrets\n", "import shutil\n", "\n", "current_dir = Path(__file__).parent\n", "\n", "\n", "def generate_random_img(history: list[Image.Image], request: gr.Request):\n", " \"\"\"Generate a random red, green, blue, orange, yellor or purple image.\"\"\"\n", " colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]\n", " color = colors[np.random.randint(0, len(colors))]\n", " img = Image.new('RGB', (100, 100), color)\n", " \n", " user_dir: Path = current_dir / request.username # type: ignore\n", " user_dir.mkdir(exist_ok=True)\n", " path = user_dir / f\"{secrets.token_urlsafe(8)}.webp\"\n", "\n", " img.save(path)\n", " history.append(img)\n", "\n", " return img, history, history\n", "\n", "def delete_directory(req: gr.Request):\n", " if not req.username:\n", " return\n", " user_dir: Path = current_dir / req.username\n", " shutil.rmtree(str(user_dir))\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"# State Cleanup Demo\n", " \ud83d\uddbc\ufe0f Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " img = gr.Image(label=\"Generated Image\", height=300, width=300)\n", " with gr.Row():\n", " gen = gr.Button(value=\"Generate\")\n", " with gr.Row():\n", " history = gr.Gallery(label=\"Previous Generations\", height=500, columns=10)\n", " state = gr.State(value=[], delete_callback=lambda v: print(\"STATE DELETED\"))\n", "\n", " demo.load(generate_random_img, [state], [img, state, history]) \n", " gen.click(generate_random_img, [state], [img, state, history])\n", " demo.unload(delete_directory)\n", "\n", "\n", "demo.launch(auth=lambda user,pwd: True,\n", " auth_message=\"Enter any username and password to continue\")"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/state_cleanup/run.py b/demo/state_cleanup/run.py new file mode 100644 index 000000000000..1daf59afd7c7 --- /dev/null +++ b/demo/state_cleanup/run.py @@ -0,0 +1,53 @@ +from __future__ import annotations +import gradio as gr +import numpy as np +from PIL import Image +from pathlib import Path +import secrets +import shutil + +current_dir = Path(__file__).parent + + +def generate_random_img(history: list[Image.Image], request: gr.Request): + """Generate a random red, green, blue, orange, yellor or purple image.""" + colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)] + color = colors[np.random.randint(0, len(colors))] + img = Image.new('RGB', (100, 100), color) + + user_dir: Path = current_dir / request.username # type: ignore + user_dir.mkdir(exist_ok=True) + path = user_dir / f"{secrets.token_urlsafe(8)}.webp" + + img.save(path) + history.append(img) + + return img, history, history + +def delete_directory(req: gr.Request): + if not req.username: + return + user_dir: Path = current_dir / req.username + shutil.rmtree(str(user_dir)) + +with gr.Blocks() as demo: + gr.Markdown("""# State Cleanup Demo + 🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload. + """) + with gr.Row(): + with gr.Column(scale=1): + with gr.Row(): + img = gr.Image(label="Generated Image", height=300, width=300) + with gr.Row(): + gen = gr.Button(value="Generate") + with gr.Row(): + history = gr.Gallery(label="Previous Generations", height=500, columns=10) + state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED")) + + demo.load(generate_random_img, [state], [img, state, history]) + gen.click(generate_random_img, [state], [img, state, history]) + demo.unload(delete_directory) + + +demo.launch(auth=lambda user,pwd: True, + auth_message="Enter any username and password to continue") \ No newline at end of file diff --git a/demo/unload_event_test/run.ipynb b/demo/unload_event_test/run.ipynb new file mode 100644 index 000000000000..e333fd9dd8d9 --- /dev/null +++ b/demo/unload_event_test/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: unload_event_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["\"\"\"This demo is only meant to test the unload event.\n", "It will write to a file when the unload event is triggered.\n", "May not work as expected if multiple people are using it.\n", "\"\"\"\n", "import gradio as gr\n", "from pathlib import Path\n", "\n", "log_file = (Path(__file__).parent / \"output_log.txt\").resolve()\n", "\n", "\n", "def test_fn(x):\n", " with open(log_file, \"a\") as f:\n", " f.write(f\"incremented {x}\\n\")\n", " return x + 1, x + 1\n", "\n", "def delete_fn(v):\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"deleted {v}\\n\")\n", "\n", "def unload_fn():\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"unloading\\n\")\n", "\n", "with gr.Blocks() as demo:\n", " n1 = gr.Number(value=0, label=\"Number\")\n", " state = gr.State(value=0, delete_callback=delete_fn)\n", " button = gr.Button(\"Increment\")\n", " button.click(test_fn, [state], [n1, state], api_name=\"increment\")\n", " demo.unload(unload_fn)\n", " demo.load(lambda: log_file.write_text(\"\"))\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/unload_event_test/run.py b/demo/unload_event_test/run.py new file mode 100644 index 000000000000..281760a88777 --- /dev/null +++ b/demo/unload_event_test/run.py @@ -0,0 +1,33 @@ +"""This demo is only meant to test the unload event. +It will write to a file when the unload event is triggered. +May not work as expected if multiple people are using it. +""" +import gradio as gr +from pathlib import Path + +log_file = (Path(__file__).parent / "output_log.txt").resolve() + + +def test_fn(x): + with open(log_file, "a") as f: + f.write(f"incremented {x}\n") + return x + 1, x + 1 + +def delete_fn(v): + with log_file.open("a") as f: + f.write(f"deleted {v}\n") + +def unload_fn(): + with log_file.open("a") as f: + f.write(f"unloading\n") + +with gr.Blocks() as demo: + n1 = gr.Number(value=0, label="Number") + state = gr.State(value=0, delete_callback=delete_fn) + button = gr.Button("Increment") + button.click(test_fn, [state], [n1, state], api_name="increment") + demo.unload(unload_fn) + demo.load(lambda: log_file.write_text("")) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/gradio/blocks.py b/gradio/blocks.py index 9d06470ff8fa..25b1398c1636 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -128,6 +128,10 @@ def __init__( if render: self.render() + @property + def stateful(self): + return False + @property def skip_api(self): return False @@ -506,7 +510,7 @@ def convert_component_dict_to_list( return predictions -@document("launch", "queue", "integrate", "load") +@document("launch", "queue", "integrate", "load", "unload") class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta): """ Blocks is Gradio's low-level API that allows you to create more custom web @@ -878,6 +882,43 @@ def expects_oauth(self): for block in self.blocks.values() ) + def unload(self, fn: Callable): + """This listener is triggered when the user closes or refreshes the tab, ending the user session. + It is useful for cleaning up resources when the app is closed. + Parameters: + fn: Callable function to run to clear resources. The function should not take any arguments and the output is not used. + Example: + import gradio as gr + with gr.Blocks() as demo: + gr.Markdown("# When you close the tab, hello will be printed to the console") + demo.unload(lambda: print("hello")) + demo.launch() + """ + self.set_event_trigger( + targets=[EventListenerMethod(None, "unload")], + fn=fn, + inputs=None, + outputs=None, + preprocess=False, + postprocess=False, + show_progress="hidden", + api_name=None, + js=None, + no_target=True, + queue=None, + batch=False, + max_batch_size=4, + cancels=None, + every=None, + collects_event_data=None, + trigger_after=None, + trigger_only_on_success=False, + trigger_mode="once", + concurrency_limit="default", + concurrency_id=None, + show_api=False, + ) + def set_event_trigger( self, targets: Sequence[EventListenerMethod], @@ -1498,7 +1539,7 @@ def postprocess_data( f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events." ) from e - if getattr(block, "stateful", False): + if block.stateful: if not utils.is_update(predictions[i]): state[output_id] = predictions[i] output.append(None) @@ -2392,9 +2433,10 @@ def close(self, verbose: bool = True) -> None: self._queue._cancel_asyncio_tasks() self.server_app._cancel_asyncio_tasks() self._queue.close() + # set this before closing server to shut down heartbeats + self.is_running = False if self.server: self.server.close() - self.is_running = False # So that the startup events (starting the queue) # happen the next time the app is launched self.app.startup_events_triggered = False @@ -2448,6 +2490,7 @@ def startup_events(self): self._queue.start() # So that processing can resume in case the queue was stopped self._queue.stopped = False + self.is_running = True self.create_limiter() def queue_enabled_for_fn(self, fn_index: int): diff --git a/gradio/components/state.py b/gradio/components/state.py index 80be5fdfd2e2..bb8526e7a695 100644 --- a/gradio/components/state.py +++ b/gradio/components/state.py @@ -2,8 +2,9 @@ from __future__ import annotations +import math from copy import deepcopy -from typing import Any +from typing import Any, Callable from gradio_client.documentation import document @@ -16,8 +17,7 @@ class State(Component): """ Special hidden component that stores session state across runs of the demo by the same user. The value of the State variable is cleared when the user refreshes the page. - - Demos: interface_state, blocks_simple_squares + Demos: interface_state, blocks_simple_squares, state_cleanup Guides: real-time-speech-recognition """ @@ -27,13 +27,21 @@ def __init__( self, value: Any = None, render: bool = True, + *, + time_to_live: int | float | None = None, + delete_callback: Callable[[Any], None] | None = None, ): """ Parameters: value: the initial value (of arbitrary type) of the state. The provided argument is deepcopied. If a callable is provided, the function will be called whenever the app loads to set the initial value of the state. render: has no effect, but is included for consistency with other components. + time_to_live: The number of seconds the state should be stored for after it is created or updated. If None, the state will be stored indefinitely. Gradio automatically deletes state variables after a user closes the browser tab or refreshes the page, so this is useful for clearing state for potentially long running sessions. + delete_callback: A function that is called when the state is deleted. The function should take the state value as an argument. """ - self.stateful = True + self.time_to_live = self.time_to_live = ( + math.inf if time_to_live is None else time_to_live + ) + self.delete_callback = delete_callback or (lambda a: None) # noqa: ARG005 try: self.value = deepcopy(value) except TypeError as err: @@ -42,6 +50,10 @@ def __init__( ) from err super().__init__(value=self.value, render=render) + @property + def stateful(self): + return True + def preprocess(self, payload: Any) -> Any: """ Parameters: diff --git a/gradio/http_server.py b/gradio/http_server.py index f244f617c858..7afae6bf8896 100644 --- a/gradio/http_server.py +++ b/gradio/http_server.py @@ -65,7 +65,7 @@ def close(self): if self.reloader: self.reloader.stop() self.watch_thread.join() - self.thread.join() + self.thread.join(timeout=5) def start_server( diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 11add279f25b..6e0ee417df4c 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -9,7 +9,7 @@ import re import shutil from collections import deque -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass as python_dataclass from datetime import datetime from pathlib import Path @@ -734,7 +734,6 @@ def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: """Delete files that are older than age. If age is None, delete all files.""" - dont_delete = set() for component in blocks.blocks.values(): dont_delete.update(getattr(component, "keep_in_cache", set())) @@ -770,27 +769,40 @@ async def _lifespan_handler( app: App, frequency: int = 1, age: int = 1 ) -> AsyncGenerator: """A context manager that triggers the startup and shutdown events of the app.""" - app.get_blocks().startup_events() - app.startup_events_triggered = True asyncio.create_task(delete_files_on_schedule(app, frequency, age)) yield delete_files_created_by_app(app.get_blocks(), age=None) +async def _delete_state(app: App): + """Delete all expired state every second.""" + while True: + app.state_holder.delete_all_expired_state() + await asyncio.sleep(1) + + +@asynccontextmanager +async def _delete_state_handler(app: App): + """When the server launches, regularly delete expired state.""" + asyncio.create_task(_delete_state(app)) + yield + + def create_lifespan_handler( user_lifespan: Callable[[App], AsyncContextManager] | None, - frequency: int = 1, - age: int = 1, + frequency: int | None = 1, + age: int | None = 1, ) -> Callable[[App], AsyncContextManager]: """Return a context manager that applies _lifespan_handler and user_lifespan if it exists.""" @asynccontextmanager async def _handler(app: App): - async with _lifespan_handler(app, frequency, age): + async with AsyncExitStack() as stack: + await stack.enter_async_context(_delete_state_handler(app)) + if frequency and age: + await stack.enter_async_context(_lifespan_handler(app, frequency, age)) if user_lifespan is not None: - async with user_lifespan(app): - yield - else: - yield + await stack.enter_async_context(user_lifespan(app)) + yield return _handler diff --git a/gradio/routes.py b/gradio/routes.py index 16027dc33673..783f2189721a 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -222,10 +222,10 @@ def create_app( ) -> App: app_kwargs = app_kwargs or {} app_kwargs.setdefault("default_response_class", ORJSONResponse) - if blocks.delete_cache is not None: - app_kwargs["lifespan"] = create_lifespan_handler( - app_kwargs.get("lifespan", None), *blocks.delete_cache - ) + delete_cache = blocks.delete_cache or (None, None) + app_kwargs["lifespan"] = create_lifespan_handler( + app_kwargs.get("lifespan", None), *delete_cache + ) app = App(auth_dependency=auth_dependency, **app_kwargs) app.configure_app(blocks) @@ -589,6 +589,75 @@ async def reset_iterator(body: ResetBody): await app.get_blocks()._queue.clean_events(event_id=body.event_id) return {"success": True} + @app.get("/heartbeat/{session_hash}") + def heartbeat( + session_hash: str, + request: fastapi.Request, + background_tasks: BackgroundTasks, + username: str = Depends(get_current_user), + ): + """Clients make a persistent connection to this endpoint to keep the session alive. + When the client disconnects, the session state is deleted. + """ + heartbeat_rate = 0.25 if os.getenv("GRADIO_IS_E2E_TEST", None) else 15 + + async def wait(): + await asyncio.sleep(heartbeat_rate) + return "wait" + + async def stop_stream(): + while app.get_blocks().is_running: + await asyncio.sleep(0.25) + return "stop" + + async def iterator(): + while True: + try: + yield "data: ALIVE\n\n" + # We need to close the heartbeat connections as soon as the server stops + # otherwise the server can take forever to close + wait_task = asyncio.create_task(wait()) + stop_stream_task = asyncio.create_task(stop_stream()) + done, _ = await asyncio.wait( + [wait_task, stop_stream_task], + return_when=asyncio.FIRST_COMPLETED, + ) + done = [d.result() for d in done] + if "stop" in done: + raise asyncio.CancelledError() + except asyncio.CancelledError: + req = Request(request, username) + root_path = route_utils.get_root_url( + request=request, + route_path=f"/hearbeat/{session_hash}", + root_path=app.root_path, + ) + body = PredictBody( + session_hash=session_hash, data=[], request=request + ) + unload_fn_indices = [ + i + for i, dep in enumerate(app.get_blocks().dependencies) + if any(t for t in dep["targets"] if t[1] == "unload") + ] + for fn_index in unload_fn_indices: + # The task runnning this loop has been cancelled + # so we add tasks in the background + background_tasks.add_task( + route_utils.call_process_api, + app=app, + body=body, + gr_request=req, + fn_index_inferred=fn_index, + root_path=root_path, + ) + # This will mark the state to be deleted in an hour + if session_hash in app.state_holder.session_data: + app.state_holder.session_data[session_hash].is_closed = True + return + + return StreamingResponse(iterator(), media_type="text/event-stream") + # had to use '/run' endpoint for Colab compatibility, '/api' supported for backwards compatibility @app.post("/run/{api_name}", dependencies=[Depends(login_check)]) @app.post("/run/{api_name}/", dependencies=[Depends(login_check)]) @@ -1098,8 +1167,9 @@ async def new_lifespan(app: FastAPI): async with old_lifespan( app ): # Instert the startup events inside the FastAPI context manager - gradio_app.get_blocks().startup_events() - yield + async with gradio_app.router.lifespan_context(gradio_app): + gradio_app.get_blocks().startup_events() + yield app.router.lifespan_context = new_lifespan diff --git a/gradio/state_holder.py b/gradio/state_holder.py index 27ff77459472..316c2547c981 100644 --- a/gradio/state_holder.py +++ b/gradio/state_holder.py @@ -1,18 +1,22 @@ from __future__ import annotations +import datetime +import os import threading from collections import OrderedDict from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterator if TYPE_CHECKING: from gradio.blocks import Blocks + from gradio.components import State class StateHolder: def __init__(self): self.capacity = 10000 - self.session_data = OrderedDict() + self.session_data: OrderedDict[str, SessionState] = OrderedDict() + self.time_last_used: dict[str, datetime.datetime] = {} self.lock = threading.Lock() def set_blocks(self, blocks: Blocks): @@ -29,6 +33,7 @@ def __getitem__(self, session_id: str) -> SessionState: if session_id not in self.session_data: self.session_data[session_id] = SessionState(self.blocks) self.update(session_id) + self.time_last_used[session_id] = datetime.datetime.now() return self.session_data[session_id] def __contains__(self, session_id: str): @@ -41,11 +46,36 @@ def update(self, session_id: str): if len(self.session_data) > self.capacity: self.session_data.popitem(last=False) + def delete_all_expired_state( + self, + ): + for session_id in self.session_data: + self.delete_state(session_id, expired_only=True) + + def delete_state(self, session_id: str, expired_only: bool = False): + if session_id not in self.session_data: + return + to_delete = [] + session_state = self.session_data[session_id] + for component, value, expired in session_state.state_components: + if not expired_only or expired: + component.delete_callback(value) + to_delete.append(component._id) + for component in to_delete: + del session_state._data[component] + class SessionState: def __init__(self, blocks: Blocks): self.blocks = blocks self._data = {} + self._state_ttl = {} + self.is_closed = False + # When a session is closed, the state is stored for an hour to give the user time to reopen the session. + # During testing we set to a lower value to be able to test + self.STATE_TTL_WHEN_CLOSED = ( + 1 if os.getenv("GRADIO_IS_E2E_TEST", None) else 3600 + ) def __getitem__(self, key: int) -> Any: if key not in self._data: @@ -57,7 +87,32 @@ def __getitem__(self, key: int) -> Any: return self._data[key] def __setitem__(self, key: int, value: Any): + from gradio.components import State + + block = self.blocks.blocks[key] + if isinstance(block, State): + self._state_ttl[key] = ( + block.time_to_live, + datetime.datetime.now(), + ) self._data[key] = value def __contains__(self, key: int): return key in self._data + + @property + def state_components(self) -> Iterator[tuple[State, Any, bool]]: + from gradio.components import State + + for id in self._data: + block = self.blocks.blocks[id] + if isinstance(block, State) and id in self._state_ttl: + time_to_live, created_at = self._state_ttl[id] + if self.is_closed: + time_to_live = self.STATE_TTL_WHEN_CLOSED + value = self._data[id] + yield ( + block, + value, + (datetime.datetime.now() - created_at).seconds > time_to_live, + ) diff --git a/js/app/test/unload_event_test.spec.ts b/js/app/test/unload_event_test.spec.ts new file mode 100644 index 000000000000..a93124330f40 --- /dev/null +++ b/js/app/test/unload_event_test.spec.ts @@ -0,0 +1,34 @@ +import { test, expect } from "@gradio/tootils"; +import { readFileSync } from "fs"; + +test("when a user closes the page, the unload event should be triggered", async ({ + page +}) => { + const increment = await page.locator("button", { + hasText: /Increment/ + }); + + // if you click too fast, the page may close before the event is processed + await increment.click(); + await page.waitForTimeout(100); + await increment.click(); + await page.waitForTimeout(100); + await increment.click(); + await page.waitForTimeout(100); + await increment.click(); + await expect(page.getByLabel("Number")).toHaveValue("4"); + await page.close(); + + await new Promise((resolve) => setTimeout(resolve, 5000)); + + const data = readFileSync( + "../../demo/unload_event_test/output_log.txt", + "utf-8" + ); + expect(data).toContain("incremented 0"); + expect(data).toContain("incremented 1"); + expect(data).toContain("incremented 2"); + expect(data).toContain("incremented 3"); + expect(data).toContain("deleted 4"); + expect(data).toContain("unloading"); +}); diff --git a/test/conftest.py b/test/conftest.py index 768ebe3424e3..7832df08c765 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -49,16 +49,11 @@ def connect(): def _connect(demo: gr.Blocks, serialize=True, **kwargs): _, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs) try: - yield Client(local_url, serialize=serialize) + client = Client(local_url, serialize=serialize) + yield client finally: - # A more verbose version of .close() - # because we should set a timeout - # the tests that call .cancel() can get stuck - # waiting for the thread to join - demo._queue.close() - demo.is_running = False - demo.server.should_exit = True - demo.server.thread.join(timeout=1) + client.close() + demo.close() return _connect diff --git a/test/test_blocks.py b/test/test_blocks.py index 2d6c368cb3ee..742063572d6e 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -1726,3 +1726,44 @@ def test_static_files_multiple_apps(gradio_temp_dir): # Input/Output got saved to cache assert len(list(gradio_temp_dir.glob("**/*.*"))) == 0 + + +def test_time_to_live_and_delete_callback_for_state(capsys, monkeypatch): + monkeypatch.setenv("GRADIO_IS_E2E_TEST", 1) + + def test_fn(x): + return x + 1, x + 1 + + def delete_fn(v): + print(f"deleted {v}") + + with gr.Blocks() as demo: + n1 = gr.Number(value=0) + state = gr.State( + value=0, time_to_live=1, delete_callback=lambda v: delete_fn(v) + ) + button = gr.Button("Increment") + button.click(test_fn, [state], [n1, state], api_name="increment") + + app, url, _ = demo.launch(prevent_thread_lock=True) + + try: + client_1 = grc.Client(url) + client_2 = grc.Client(url) + + client_1.predict(api_name="/increment") + client_1.predict(api_name="/increment") + client_1.predict(api_name="/increment") + + client_2.predict(api_name="/increment") + client_2.predict(api_name="/increment") + + time.sleep(3) + + captured = capsys.readouterr() + assert "deleted 2" in captured.out + assert "deleted 3" in captured.out + for client in [client_1, client_2]: + assert len(app.state_holder.session_data[client.session_hash]._data) == 0 + finally: + demo.close()