Skip to content

Commit

Permalink
Delete user state when they close the tab. Add an unload event for th…
Browse files Browse the repository at this point in the history
…e demo and a delete_callback on gr.State to let developers control how resources are cleaned up (#7829)

* Delete state

* add changeset

* Delete state

* WIP

* Add load event

* Working ttl

* unload e2e test

* Clean up

* add changeset

* Fix notebook

* add changeset

* Connect to heartbeat in python client

* 15 second heartbeat

* Demo for unload

* Add notebook

* add changeset

* Fix docs

* revert demo changes

* Add docstrings

* lint 🙄

* Edit

* handle shutdown issue

* state comments

* client test

* Fix:

* Fix e2e test

* 3.11 incompatibility

* delete after one hour

* lint + highlight

* Update .changeset/better-tires-shave.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Update .changeset/better-tires-shave.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Apr 1, 2024
1 parent 83010a2 commit 6a4bf7a
Show file tree
Hide file tree
Showing 21 changed files with 482 additions and 45 deletions.
42 changes: 42 additions & 0 deletions .changeset/better-tires-shave.md
@@ -0,0 +1,42 @@
---
"@gradio/client": minor
"gradio": minor
"gradio_client": minor
---

highlight:

#### Automatically delete state after user has disconnected from the webpage

Gradio now automatically deletes `gr.State` variables stored in the server's RAM when users close their browser tab.
The deletion will happen 60 minutes after the server detected a disconnect from the user's browser.
If the user connects again in that timeframe, their state will not be deleted.

Additionally, Gradio now includes a `Blocks.unload()` event, allowing you to run arbitrary cleanup functions when users disconnect (this does not have a 60 minute delay).
You can think of the `unload` event as the opposite of the `load` event.


```python
with gr.Blocks() as demo:
gr.Markdown(
"""# State Cleanup Demo
🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
img = gr.Image(label="Generated Image", height=300, width=300)
with gr.Row():
gen = gr.Button(value="Generate")
with gr.Row():
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))

demo.load(generate_random_img, [state], [img, state, history])
gen.click(generate_random_img, [state], [img, state, history])
demo.unload(delete_directory)


demo.launch(auth=lambda user,pwd: True,
auth_message="Enter any username and password to continue")
```
3 changes: 2 additions & 1 deletion .config/playwright-setup.js
Expand Up @@ -64,7 +64,8 @@ function spawn_gradio_app(app, port, verbose) {
...process.env,
GRADIO_SERVER_PORT: `7879`,
PYTHONUNBUFFERED: "true",
GRADIO_ANALYTICS_ENABLED: "False"
GRADIO_ANALYTICS_ENABLED: "False",
GRADIO_IS_E2E_TEST: "1"
}
});
_process.stdout.setEncoding("utf8");
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -45,6 +45,7 @@ demo/*/config.json
demo/annotatedimage_component/*.png
demo/fake_diffusion_with_gif/*.gif
demo/cancel_events/cancel_events_output_log.txt
demo/unload_event_test/output_log.txt

# Etc
.idea/*
Expand Down
4 changes: 4 additions & 0 deletions client/js/src/client.ts
Expand Up @@ -357,6 +357,10 @@ export function api_factory(
);

const _config = await config_success(config);
// connect to the heartbeat endpoint via GET request
const heartbeat = new EventSource(
`${config.root}/heartbeat/${session_hash}`
);
res(_config);
} catch (e) {
console.error(e);
Expand Down
34 changes: 33 additions & 1 deletion client/python/gradio_client/client.py
Expand Up @@ -159,6 +159,7 @@ def __init__(
self.sse_url = urllib.parse.urljoin(
self.src, utils.SSE_URL_V0 if self.protocol == "sse" else utils.SSE_URL
)
self.heartbeat_url = urllib.parse.urljoin(self.src, utils.HEARTBEAT_URL)
self.sse_data_url = urllib.parse.urljoin(
self.src,
utils.SSE_DATA_URL_V0 if self.protocol == "sse" else utils.SSE_DATA_URL,
Expand All @@ -184,13 +185,43 @@ def __init__(
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)

# Disable telemetry by setting the env variable HF_HUB_DISABLE_TELEMETRY=1
threading.Thread(target=self._telemetry_thread).start()
threading.Thread(target=self._telemetry_thread, daemon=True).start()
self._refresh_heartbeat = threading.Event()
self._kill_heartbeat = threading.Event()

self.heartbeat = threading.Thread(target=self._stream_heartbeat, daemon=True)
self.heartbeat.start()

self.stream_open = False
self.streaming_future: Future | None = None
self.pending_messages_per_event: dict[str, list[Message | None]] = {}
self.pending_event_ids: set[str] = set()

def close(self):
self._kill_heartbeat.set()
self.heartbeat.join(timeout=1)

def _stream_heartbeat(self):
while True:
url = self.heartbeat_url.format(session_hash=self.session_hash)
try:
with httpx.stream(
"GET",
url,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
timeout=20,
) as response:
for _ in response.iter_lines():
if self._refresh_heartbeat.is_set():
self._refresh_heartbeat.clear()
break
if self._kill_heartbeat.is_set():
return
except httpx.TransportError:
return

async def stream_messages(
self, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"]
) -> None:
Expand Down Expand Up @@ -640,6 +671,7 @@ def view_api(

def reset_session(self) -> None:
self.session_hash = str(uuid.uuid4())
self._refresh_heartbeat.set()

def _render_endpoints_info(
self,
Expand Down
1 change: 1 addition & 0 deletions client/python/gradio_client/utils.py
Expand Up @@ -42,6 +42,7 @@
SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api"
RESET_URL = "reset"
SPACE_URL = "https://hf.space/{}"
HEARTBEAT_URL = "heartbeat/{session_hash}"

STATE_COMPONENT = "state"
INVALID_RUNTIME = [
Expand Down
5 changes: 3 additions & 2 deletions client/python/test/conftest.py
Expand Up @@ -75,10 +75,11 @@ def calculator(num1, operation=None, num2=100):

@pytest.fixture
def state_demo():
state = gr.State(delete_callback=lambda x: print("STATE DELETED"))
demo = gr.Interface(
lambda x, y: (x, y),
["textbox", "state"],
["textbox", "state"],
["textbox", state],
["textbox", state],
)
return demo

Expand Down
15 changes: 10 additions & 5 deletions client/python/test/test_client.py
Expand Up @@ -46,11 +46,7 @@ def connect(
# because we should set a timeout
# the tests that call .cancel() can get stuck
# waiting for the thread to join
if demo.enable_queue:
demo._queue.close()
demo.is_running = False
demo.server.should_exit = True
demo.server.thread.join(timeout=1)
demo.close()


class TestClientInitialization:
Expand Down Expand Up @@ -608,6 +604,15 @@ def return_bad():
pred = client.predict(api_name="/predict")
assert pred[0] == data[0]

def test_state_reset_when_session_changes(self, capsys, state_demo, monkeypatch):
monkeypatch.setenv("GRADIO_IS_E2E_TEST", "1")
with connect(state_demo) as client:
client.predict("Hello", api_name="/predict")
client.reset_session()
time.sleep(5)
out = capsys.readouterr().out
assert "STATE DELETED" in out


class TestClientPredictionsWithKwargs:
def test_no_default_params(self, calculator_demo):
Expand Down
1 change: 1 addition & 0 deletions demo/state_cleanup/run.ipynb
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: state_cleanup"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["from __future__ import annotations\n", "import gradio as gr\n", "import numpy as np\n", "from PIL import Image\n", "from pathlib import Path\n", "import secrets\n", "import shutil\n", "\n", "current_dir = Path(__file__).parent\n", "\n", "\n", "def generate_random_img(history: list[Image.Image], request: gr.Request):\n", " \"\"\"Generate a random red, green, blue, orange, yellor or purple image.\"\"\"\n", " colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]\n", " color = colors[np.random.randint(0, len(colors))]\n", " img = Image.new('RGB', (100, 100), color)\n", " \n", " user_dir: Path = current_dir / request.username # type: ignore\n", " user_dir.mkdir(exist_ok=True)\n", " path = user_dir / f\"{secrets.token_urlsafe(8)}.webp\"\n", "\n", " img.save(path)\n", " history.append(img)\n", "\n", " return img, history, history\n", "\n", "def delete_directory(req: gr.Request):\n", " if not req.username:\n", " return\n", " user_dir: Path = current_dir / req.username\n", " shutil.rmtree(str(user_dir))\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"# State Cleanup Demo\n", " \ud83d\uddbc\ufe0f Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " img = gr.Image(label=\"Generated Image\", height=300, width=300)\n", " with gr.Row():\n", " gen = gr.Button(value=\"Generate\")\n", " with gr.Row():\n", " history = gr.Gallery(label=\"Previous Generations\", height=500, columns=10)\n", " state = gr.State(value=[], delete_callback=lambda v: print(\"STATE DELETED\"))\n", "\n", " demo.load(generate_random_img, [state], [img, state, history]) \n", " gen.click(generate_random_img, [state], [img, state, history])\n", " demo.unload(delete_directory)\n", "\n", "\n", "demo.launch(auth=lambda user,pwd: True,\n", " auth_message=\"Enter any username and password to continue\")"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
53 changes: 53 additions & 0 deletions demo/state_cleanup/run.py
@@ -0,0 +1,53 @@
from __future__ import annotations
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import secrets
import shutil

current_dir = Path(__file__).parent


def generate_random_img(history: list[Image.Image], request: gr.Request):
"""Generate a random red, green, blue, orange, yellor or purple image."""
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (255, 255, 0), (128, 0, 128)]
color = colors[np.random.randint(0, len(colors))]
img = Image.new('RGB', (100, 100), color)

user_dir: Path = current_dir / request.username # type: ignore
user_dir.mkdir(exist_ok=True)
path = user_dir / f"{secrets.token_urlsafe(8)}.webp"

img.save(path)
history.append(img)

return img, history, history

def delete_directory(req: gr.Request):
if not req.username:
return
user_dir: Path = current_dir / req.username
shutil.rmtree(str(user_dir))

with gr.Blocks() as demo:
gr.Markdown("""# State Cleanup Demo
🖼️ Images are saved in a user-specific directory and deleted when the users closes the page via demo.unload.
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
img = gr.Image(label="Generated Image", height=300, width=300)
with gr.Row():
gen = gr.Button(value="Generate")
with gr.Row():
history = gr.Gallery(label="Previous Generations", height=500, columns=10)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))

demo.load(generate_random_img, [state], [img, state, history])
gen.click(generate_random_img, [state], [img, state, history])
demo.unload(delete_directory)


demo.launch(auth=lambda user,pwd: True,
auth_message="Enter any username and password to continue")
1 change: 1 addition & 0 deletions demo/unload_event_test/run.ipynb
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: unload_event_test"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["\"\"\"This demo is only meant to test the unload event.\n", "It will write to a file when the unload event is triggered.\n", "May not work as expected if multiple people are using it.\n", "\"\"\"\n", "import gradio as gr\n", "from pathlib import Path\n", "\n", "log_file = (Path(__file__).parent / \"output_log.txt\").resolve()\n", "\n", "\n", "def test_fn(x):\n", " with open(log_file, \"a\") as f:\n", " f.write(f\"incremented {x}\\n\")\n", " return x + 1, x + 1\n", "\n", "def delete_fn(v):\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"deleted {v}\\n\")\n", "\n", "def unload_fn():\n", " with log_file.open(\"a\") as f:\n", " f.write(f\"unloading\\n\")\n", "\n", "with gr.Blocks() as demo:\n", " n1 = gr.Number(value=0, label=\"Number\")\n", " state = gr.State(value=0, delete_callback=delete_fn)\n", " button = gr.Button(\"Increment\")\n", " button.click(test_fn, [state], [n1, state], api_name=\"increment\")\n", " demo.unload(unload_fn)\n", " demo.load(lambda: log_file.write_text(\"\"))\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
33 changes: 33 additions & 0 deletions demo/unload_event_test/run.py
@@ -0,0 +1,33 @@
"""This demo is only meant to test the unload event.
It will write to a file when the unload event is triggered.
May not work as expected if multiple people are using it.
"""
import gradio as gr
from pathlib import Path

log_file = (Path(__file__).parent / "output_log.txt").resolve()


def test_fn(x):
with open(log_file, "a") as f:
f.write(f"incremented {x}\n")
return x + 1, x + 1

def delete_fn(v):
with log_file.open("a") as f:
f.write(f"deleted {v}\n")

def unload_fn():
with log_file.open("a") as f:
f.write(f"unloading\n")

with gr.Blocks() as demo:
n1 = gr.Number(value=0, label="Number")
state = gr.State(value=0, delete_callback=delete_fn)
button = gr.Button("Increment")
button.click(test_fn, [state], [n1, state], api_name="increment")
demo.unload(unload_fn)
demo.load(lambda: log_file.write_text(""))

if __name__ == "__main__":
demo.launch()
49 changes: 46 additions & 3 deletions gradio/blocks.py
Expand Up @@ -128,6 +128,10 @@ def __init__(
if render:
self.render()

@property
def stateful(self):
return False

@property
def skip_api(self):
return False
Expand Down Expand Up @@ -506,7 +510,7 @@ def convert_component_dict_to_list(
return predictions


@document("launch", "queue", "integrate", "load")
@document("launch", "queue", "integrate", "load", "unload")
class Blocks(BlockContext, BlocksEvents, metaclass=BlocksMeta):
"""
Blocks is Gradio's low-level API that allows you to create more custom web
Expand Down Expand Up @@ -878,6 +882,43 @@ def expects_oauth(self):
for block in self.blocks.values()
)

def unload(self, fn: Callable):
"""This listener is triggered when the user closes or refreshes the tab, ending the user session.
It is useful for cleaning up resources when the app is closed.
Parameters:
fn: Callable function to run to clear resources. The function should not take any arguments and the output is not used.
Example:
import gradio as gr
with gr.Blocks() as demo:
gr.Markdown("# When you close the tab, hello will be printed to the console")
demo.unload(lambda: print("hello"))
demo.launch()
"""
self.set_event_trigger(
targets=[EventListenerMethod(None, "unload")],
fn=fn,
inputs=None,
outputs=None,
preprocess=False,
postprocess=False,
show_progress="hidden",
api_name=None,
js=None,
no_target=True,
queue=None,
batch=False,
max_batch_size=4,
cancels=None,
every=None,
collects_event_data=None,
trigger_after=None,
trigger_only_on_success=False,
trigger_mode="once",
concurrency_limit="default",
concurrency_id=None,
show_api=False,
)

def set_event_trigger(
self,
targets: Sequence[EventListenerMethod],
Expand Down Expand Up @@ -1498,7 +1539,7 @@ def postprocess_data(
f"Output component with id {output_id} used in {dependency['trigger']}() event not found in this gr.Blocks context. You are allowed to nest gr.Blocks contexts, but there must be a gr.Blocks context that contains all components and events."
) from e

if getattr(block, "stateful", False):
if block.stateful:
if not utils.is_update(predictions[i]):
state[output_id] = predictions[i]
output.append(None)
Expand Down Expand Up @@ -2392,9 +2433,10 @@ def close(self, verbose: bool = True) -> None:
self._queue._cancel_asyncio_tasks()
self.server_app._cancel_asyncio_tasks()
self._queue.close()
# set this before closing server to shut down heartbeats
self.is_running = False
if self.server:
self.server.close()
self.is_running = False
# So that the startup events (starting the queue)
# happen the next time the app is launched
self.app.startup_events_triggered = False
Expand Down Expand Up @@ -2448,6 +2490,7 @@ def startup_events(self):
self._queue.start()
# So that processing can resume in case the queue was stopped
self._queue.stopped = False
self.is_running = True
self.create_limiter()

def queue_enabled_for_fn(self, fn_index: int):
Expand Down

0 comments on commit 6a4bf7a

Please sign in to comment.