Skip to content

Commit

Permalink
Add delete_cache parameter to gr.Blocks to delete files created by ap…
Browse files Browse the repository at this point in the history
…p on shutdown (#7447)

* Add code

* add changeset

* Add code

* trigger ci

* Add schedule

* Fix implementation

* Fix test

* Address comments

* add changeset

* handle examples

* Update guides/01_getting-started/03_sharing-your-app.md

Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* Fix code

* Fix code

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
3 people committed Mar 5, 2024
1 parent 3645da5 commit a57e34e
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changeset/public-hoops-drum.md
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown
6 changes: 5 additions & 1 deletion gradio/blocks.py
Expand Up @@ -526,6 +526,7 @@ def __init__(
js: str | None = None,
head: str | None = None,
fill_height: bool = False,
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
"""
Expand All @@ -538,6 +539,7 @@ def __init__(
js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage.
head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page.
fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
self.limiter = None
if theme is None:
Expand Down Expand Up @@ -566,6 +568,7 @@ def __init__(
self.show_error = True
self.head = head
self.fill_height = fill_height
self.delete_cache = delete_cache
if css is not None and os.path.exists(css):
with open(css) as css_file:
self.css = css_file.read()
Expand Down Expand Up @@ -608,7 +611,8 @@ def __init__(
self.auth = None
self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", ""))
self.app_id = random.getrandbits(64)
self.temp_file_sets = []
self.upload_file_set = set()
self.temp_file_sets = [self.upload_file_set]
self.title = title
self.show_api = not wasm_utils.IS_WASM

Expand Down
3 changes: 3 additions & 0 deletions gradio/chat_interface.py
Expand Up @@ -77,6 +77,7 @@ def __init__(
autofocus: bool = True,
concurrency_limit: int | None | Literal["default"] = "default",
fill_height: bool = True,
delete_cache: tuple[int, int] | None = None,
):
"""
Parameters:
Expand All @@ -103,6 +104,7 @@ def __init__(
autofocus: If True, autofocuses to the textbox when the page loads.
concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default).
fill_height: If True, the chat interface will expand to the height of window.
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand All @@ -113,6 +115,7 @@ def __init__(
js=js,
head=head,
fill_height=fill_height,
delete_cache=delete_cache,
)
self.concurrency_limit = concurrency_limit
self.fn = fn
Expand Down
7 changes: 7 additions & 0 deletions gradio/components/base.py
Expand Up @@ -14,6 +14,8 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

from gradio_client.utils import is_file_obj

from gradio import utils
from gradio.blocks import Block, BlockContext
from gradio.component_meta import ComponentMeta
Expand Down Expand Up @@ -189,6 +191,9 @@ def __init__(
self.scale = scale
self.min_width = min_width
self.interactive = interactive
# Keep tracks of files that should not be deleted when the delete_cache parmaeter is set
# These files are the default value of the component and files that are used in examples
self.keep_in_cache = set()

# load_event is set in the Blocks.attach_load_events method
self.load_event: None | dict[str, Any] = None
Expand All @@ -200,6 +205,8 @@ def __init__(
self, # type: ignore
postprocess=True,
)
if is_file_obj(self.value):
self.keep_in_cache.add(self.value["path"])

if callable(load_fn):
self.attach_load_event(load_fn, every)
Expand Down
3 changes: 3 additions & 0 deletions gradio/components/dataset.py
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Literal

from gradio_client.documentation import document
from gradio_client.utils import is_file_obj

from gradio import processing_utils
from gradio.components.base import (
Expand Down Expand Up @@ -98,6 +99,8 @@ def __init__(
example[i],
component,
)
if is_file_obj(example[i]):
self.keep_in_cache.add(example[i]["path"])
self.type = type
self.label = label
if headers is not None:
Expand Down
3 changes: 3 additions & 0 deletions gradio/interface.py
Expand Up @@ -121,6 +121,7 @@ def __init__(
submit_btn: str | Button = "Submit",
stop_btn: str | Button = "Stop",
clear_btn: str | Button = "Clear",
delete_cache: tuple[int, int] | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(
submit_btn: The button to use for submitting inputs. Defaults to a `gr.Button("Submit", variant="primary")`. This parameter does not apply if the Interface is output-only, in which case the submit button always displays "Generate". Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
stop_btn: The button to use for stopping the interface. Defaults to a `gr.Button("Stop", variant="stop", visible=False)`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization).
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand All @@ -164,6 +166,7 @@ def __init__(
theme=theme,
js=js,
head=head,
delete_cache=delete_cache,
**kwargs,
)
self.api_name: str | Literal[False] | None = api_name
Expand Down
82 changes: 81 additions & 1 deletion gradio/route_utils.py
@@ -1,16 +1,32 @@
from __future__ import annotations

import asyncio
import hashlib
import hmac
import json
import os
import re
import shutil
from collections import deque
from contextlib import asynccontextmanager
from dataclasses import dataclass as python_dataclass
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
AsyncContextManager,
AsyncGenerator,
BinaryIO,
Callable,
List,
Optional,
Tuple,
Union,
)
from urllib.parse import urlparse

import anyio
import fastapi
import httpx
import multipart
Expand Down Expand Up @@ -640,3 +656,67 @@ async def dispatch(self, request: fastapi.Request, call_next):
"Access-Control-Allow-Headers"
] = "Origin, Content-Type, Accept"
return response


def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None:
"""Delete files that are older than age. If age is None, delete all files."""

dont_delete = set()
for component in blocks.blocks.values():
dont_delete.update(getattr(component, "keep_in_cache", set()))
for temp_set in blocks.temp_file_sets:
# We use a copy of the set to avoid modifying the set while iterating over it
# otherwise we would get an exception: Set changed size during iteration
to_remove = set()
for file in temp_set:
if file in dont_delete:
continue
try:
file_path = Path(file)
modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime)
if age is None or (datetime.now() - modified_time).seconds > age:
os.remove(file)
to_remove.add(file)
except FileNotFoundError:
continue
temp_set -= to_remove


async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None:
"""Startup task to delete files created by the app based on time since last modification."""
while True:
await asyncio.sleep(frequency)
await anyio.to_thread.run_sync(
delete_files_created_by_app, app.get_blocks(), age
)


@asynccontextmanager
async def _lifespan_handler(
app: App, frequency: int = 1, age: int = 1
) -> AsyncGenerator:
"""A context manager that triggers the startup and shutdown events of the app."""
app.get_blocks().startup_events()
app.startup_events_triggered = True
asyncio.create_task(delete_files_on_schedule(app, frequency, age))
yield
delete_files_created_by_app(app.get_blocks(), age=None)


def create_lifespan_handler(
user_lifespan: Callable[[App], AsyncContextManager] | None,
frequency: int = 1,
age: int = 1,
) -> Callable[[App], AsyncContextManager]:
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists."""

@asynccontextmanager
async def _handler(app: App):
async with _lifespan_handler(app, frequency, age):
if user_lifespan is not None:
async with user_lifespan(app):
yield
else:
yield

return _handler
6 changes: 6 additions & 0 deletions gradio/routes.py
Expand Up @@ -63,6 +63,7 @@
MultiPartException,
Request,
compare_passwords_securely,
create_lifespan_handler,
move_uploaded_files_to_cache,
)
from gradio.state_holder import StateHolder
Expand Down Expand Up @@ -192,6 +193,10 @@ def create_app(
) -> App:
app_kwargs = app_kwargs or {}
app_kwargs.setdefault("default_response_class", ORJSONResponse)
if blocks.delete_cache is not None:
app_kwargs["lifespan"] = create_lifespan_handler(
app_kwargs.get("lifespan", None), *blocks.delete_cache
)
app = App(**app_kwargs)
app.configure_app(blocks)

Expand Down Expand Up @@ -827,6 +832,7 @@ async def upload_file(
files_to_copy.append(temp_file.file.name)
locations.append(str(dest))
output_files.append(dest)
blocks.upload_file_set.add(str(dest))
if files_to_copy:
bg_tasks.add_task(
move_uploaded_files_to_cache, files_to_copy, locations
Expand Down
2 changes: 1 addition & 1 deletion guides/01_getting-started/03_sharing-your-app.md
Expand Up @@ -315,7 +315,7 @@ Sharing your Gradio app with others (by hosting it on Spaces, on your own server

In particular, Gradio apps ALLOW users to access to three kinds of files:

- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`.
- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`. You can delete the files created by your app when it shuts down with the `delete_cache` parameter of `gradio.Blocks`, `gradio.Interface`, and `gradio.ChatInterface`. This parameter is a tuple of integers of the form `[frequency, age]` where `frequency` is how often to delete files and `age` is the time in seconds since the file was last modified.


- **Cached examples created by Gradio.** These are files that are created by Gradio as part of caching examples for faster runtimes, if you set `cache_examples=True` in `gr.Interface()` or in `gr.Examples()`. By default, these files are saved in the `gradio_cached_examples/` subdirectory within your app's working directory. You can customize the location of cached example files created by Gradio by setting the environment variable `GRADIO_EXAMPLES_CACHE` to an absolute path or a path relative to your working directory.
Expand Down
4 changes: 2 additions & 2 deletions test/conftest.py
Expand Up @@ -41,8 +41,8 @@ def io_components():
@pytest.fixture
def connect():
@contextmanager
def _connect(demo: gr.Blocks, serialize=True):
_, local_url, _ = demo.launch(prevent_thread_lock=True)
def _connect(demo: gr.Blocks, serialize=True, **kwargs):
_, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs)
try:
yield Client(local_url, serialize=serialize)
finally:
Expand Down
8 changes: 6 additions & 2 deletions test/test_blocks.py
Expand Up @@ -1584,8 +1584,12 @@ def test_temp_file_sets_get_extended():
with gr.Blocks() as demo3:
demo1.render()
demo2.render()

assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets
# The upload_set is empty so we remove it from the check
demo_3_no_empty = [s for s in demo3.temp_file_sets if len(s)]
demo_1_and_2_no_empty = [
s for s in demo1.temp_file_sets + demo2.temp_file_sets if len(s)
]
assert demo_3_no_empty == demo_1_and_2_no_empty


def test_recover_kwargs():
Expand Down
43 changes: 43 additions & 0 deletions test/test_routes.py
Expand Up @@ -480,6 +480,49 @@ def test_cors_restrictions(self):
assert file_response.headers["access-control-allow-origin"] == "127.0.0.1"
io.close()

def test_delete_cache(self, connect, gradio_temp_dir, capsys):
def check_num_files_exist(blocks: Blocks):
num_files = 0
for temp_file_set in blocks.temp_file_sets:
for temp_file in temp_file_set:
if os.path.exists(temp_file):
num_files += 1
return num_files

demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None)
with connect(demo) as client:
client.predict("test/test_files/cheetah1.jpg")
assert check_num_files_exist(demo) == 1

demo_delete = gr.Interface(
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30)
)
with connect(demo_delete) as client:
client.predict("test/test_files/alphabet.txt")
client.predict("test/test_files/bus.png")
assert check_num_files_exist(demo_delete) == 2
assert check_num_files_exist(demo_delete) == 0
assert check_num_files_exist(demo) == 1

@asynccontextmanager
async def mylifespan(app: FastAPI):
print("IN CUSTOM LIFESPAN")
yield
print("AFTER CUSTOM LIFESPAN")

demo_custom_lifespan = gr.Interface(
lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1)
)

with connect(
demo_custom_lifespan, app_kwargs={"lifespan": mylifespan}
) as client:
client.predict("test/test_files/alphabet.txt")
assert check_num_files_exist(demo_custom_lifespan) == 0
captured = capsys.readouterr()
assert "IN CUSTOM LIFESPAN" in captured.out
assert "AFTER CUSTOM LIFESPAN" in captured.out


class TestApp:
def test_create_app(self):
Expand Down

0 comments on commit a57e34e

Please sign in to comment.