diff --git a/.changeset/public-hoops-drum.md b/.changeset/public-hoops-drum.md new file mode 100644 index 000000000000..7559c1d229db --- /dev/null +++ b/.changeset/public-hoops-drum.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Add delete_cache parameter to gr.Blocks to delete files created by app on shutdown diff --git a/gradio/blocks.py b/gradio/blocks.py index 3b61b63ea964..f344773f72da 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -526,6 +526,7 @@ def __init__( js: str | None = None, head: str | None = None, fill_height: bool = False, + delete_cache: tuple[int, int] | None = None, **kwargs, ): """ @@ -538,6 +539,7 @@ def __init__( js: Custom js or path to js file to run when demo is first loaded. This javascript will be included in the demo webpage. head: Custom html to insert into the head of the demo webpage. This can be used to add custom meta tags, scripts, stylesheets, etc. to the page. fill_height: Whether to vertically expand top-level child components to the height of the window. If True, expansion occurs when the scale value of the child components >= 1. + delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur. """ self.limiter = None if theme is None: @@ -566,6 +568,7 @@ def __init__( self.show_error = True self.head = head self.fill_height = fill_height + self.delete_cache = delete_cache if css is not None and os.path.exists(css): with open(css) as css_file: self.css = css_file.read() @@ -608,7 +611,8 @@ def __init__( self.auth = None self.dev_mode = bool(os.getenv("GRADIO_WATCH_DIRS", "")) self.app_id = random.getrandbits(64) - self.temp_file_sets = [] + self.upload_file_set = set() + self.temp_file_sets = [self.upload_file_set] self.title = title self.show_api = not wasm_utils.IS_WASM diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 786662159dea..67c91e286248 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -77,6 +77,7 @@ def __init__( autofocus: bool = True, concurrency_limit: int | None | Literal["default"] = "default", fill_height: bool = True, + delete_cache: tuple[int, int] | None = None, ): """ Parameters: @@ -103,6 +104,7 @@ def __init__( autofocus: If True, autofocuses to the textbox when the page loads. concurrency_limit: If set, this is the maximum number of chatbot submissions that can be running simultaneously. Can be set to None to mean no limit (any number of chatbot submissions can be running simultaneously). Set to "default" to use the default concurrency limit (defined by the `default_concurrency_limit` parameter in `.queue()`, which is 1 by default). fill_height: If True, the chat interface will expand to the height of window. + delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur. """ super().__init__( analytics_enabled=analytics_enabled, @@ -113,6 +115,7 @@ def __init__( js=js, head=head, fill_height=fill_height, + delete_cache=delete_cache, ) self.concurrency_limit = concurrency_limit self.fn = fn diff --git a/gradio/components/base.py b/gradio/components/base.py index 51ef59a9b736..e274ef8b54b8 100644 --- a/gradio/components/base.py +++ b/gradio/components/base.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +from gradio_client.utils import is_file_obj + from gradio import utils from gradio.blocks import Block, BlockContext from gradio.component_meta import ComponentMeta @@ -189,6 +191,9 @@ def __init__( self.scale = scale self.min_width = min_width self.interactive = interactive + # Keep tracks of files that should not be deleted when the delete_cache parmaeter is set + # These files are the default value of the component and files that are used in examples + self.keep_in_cache = set() # load_event is set in the Blocks.attach_load_events method self.load_event: None | dict[str, Any] = None @@ -200,6 +205,8 @@ def __init__( self, # type: ignore postprocess=True, ) + if is_file_obj(self.value): + self.keep_in_cache.add(self.value["path"]) if callable(load_fn): self.attach_load_event(load_fn, every) diff --git a/gradio/components/dataset.py b/gradio/components/dataset.py index 84b00f8cf3b1..3efe66a10f47 100644 --- a/gradio/components/dataset.py +++ b/gradio/components/dataset.py @@ -5,6 +5,7 @@ from typing import Any, Literal from gradio_client.documentation import document +from gradio_client.utils import is_file_obj from gradio import processing_utils from gradio.components.base import ( @@ -98,6 +99,8 @@ def __init__( example[i], component, ) + if is_file_obj(example[i]): + self.keep_in_cache.add(example[i]["path"]) self.type = type self.label = label if headers is not None: diff --git a/gradio/interface.py b/gradio/interface.py index d80aef7d7bb6..c16254c61503 100644 --- a/gradio/interface.py +++ b/gradio/interface.py @@ -121,6 +121,7 @@ def __init__( submit_btn: str | Button = "Submit", stop_btn: str | Button = "Stop", clear_btn: str | Button = "Clear", + delete_cache: tuple[int, int] | None = None, **kwargs, ): """ @@ -155,6 +156,7 @@ def __init__( submit_btn: The button to use for submitting inputs. Defaults to a `gr.Button("Submit", variant="primary")`. This parameter does not apply if the Interface is output-only, in which case the submit button always displays "Generate". Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization). stop_btn: The button to use for stopping the interface. Defaults to a `gr.Button("Stop", variant="stop", visible=False)`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization). clear_btn: The button to use for clearing the inputs. Defaults to a `gr.Button("Clear", variant="secondary")`. Can be set to a string (which becomes the button label) or a `gr.Button` object (which allows for more customization). + delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur. """ super().__init__( analytics_enabled=analytics_enabled, @@ -164,6 +166,7 @@ def __init__( theme=theme, js=js, head=head, + delete_cache=delete_cache, **kwargs, ) self.api_name: str | Literal[False] | None = api_name diff --git a/gradio/route_utils.py b/gradio/route_utils.py index 25d99eb81ce8..ff7f82beac6b 100644 --- a/gradio/route_utils.py +++ b/gradio/route_utils.py @@ -1,16 +1,32 @@ from __future__ import annotations +import asyncio import hashlib import hmac import json +import os import re import shutil from collections import deque +from contextlib import asynccontextmanager from dataclasses import dataclass as python_dataclass +from datetime import datetime +from pathlib import Path from tempfile import NamedTemporaryFile, _TemporaryFileWrapper -from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + AsyncContextManager, + AsyncGenerator, + BinaryIO, + Callable, + List, + Optional, + Tuple, + Union, +) from urllib.parse import urlparse +import anyio import fastapi import httpx import multipart @@ -640,3 +656,67 @@ async def dispatch(self, request: fastapi.Request, call_next): "Access-Control-Allow-Headers" ] = "Origin, Content-Type, Accept" return response + + +def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: + """Delete files that are older than age. If age is None, delete all files.""" + + dont_delete = set() + for component in blocks.blocks.values(): + dont_delete.update(getattr(component, "keep_in_cache", set())) + for temp_set in blocks.temp_file_sets: + # We use a copy of the set to avoid modifying the set while iterating over it + # otherwise we would get an exception: Set changed size during iteration + to_remove = set() + for file in temp_set: + if file in dont_delete: + continue + try: + file_path = Path(file) + modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime) + if age is None or (datetime.now() - modified_time).seconds > age: + os.remove(file) + to_remove.add(file) + except FileNotFoundError: + continue + temp_set -= to_remove + + +async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None: + """Startup task to delete files created by the app based on time since last modification.""" + while True: + await asyncio.sleep(frequency) + await anyio.to_thread.run_sync( + delete_files_created_by_app, app.get_blocks(), age + ) + + +@asynccontextmanager +async def _lifespan_handler( + app: App, frequency: int = 1, age: int = 1 +) -> AsyncGenerator: + """A context manager that triggers the startup and shutdown events of the app.""" + app.get_blocks().startup_events() + app.startup_events_triggered = True + asyncio.create_task(delete_files_on_schedule(app, frequency, age)) + yield + delete_files_created_by_app(app.get_blocks(), age=None) + + +def create_lifespan_handler( + user_lifespan: Callable[[App], AsyncContextManager] | None, + frequency: int = 1, + age: int = 1, +) -> Callable[[App], AsyncContextManager]: + """Return a context manager that applies _lifespan_handler and user_lifespan if it exists.""" + + @asynccontextmanager + async def _handler(app: App): + async with _lifespan_handler(app, frequency, age): + if user_lifespan is not None: + async with user_lifespan(app): + yield + else: + yield + + return _handler diff --git a/gradio/routes.py b/gradio/routes.py index 1e782a35ada1..09883d3fa376 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -63,6 +63,7 @@ MultiPartException, Request, compare_passwords_securely, + create_lifespan_handler, move_uploaded_files_to_cache, ) from gradio.state_holder import StateHolder @@ -192,6 +193,10 @@ def create_app( ) -> App: app_kwargs = app_kwargs or {} app_kwargs.setdefault("default_response_class", ORJSONResponse) + if blocks.delete_cache is not None: + app_kwargs["lifespan"] = create_lifespan_handler( + app_kwargs.get("lifespan", None), *blocks.delete_cache + ) app = App(**app_kwargs) app.configure_app(blocks) @@ -827,6 +832,7 @@ async def upload_file( files_to_copy.append(temp_file.file.name) locations.append(str(dest)) output_files.append(dest) + blocks.upload_file_set.add(str(dest)) if files_to_copy: bg_tasks.add_task( move_uploaded_files_to_cache, files_to_copy, locations diff --git a/guides/01_getting-started/03_sharing-your-app.md b/guides/01_getting-started/03_sharing-your-app.md index 531bc93ded41..0d901c330f3d 100644 --- a/guides/01_getting-started/03_sharing-your-app.md +++ b/guides/01_getting-started/03_sharing-your-app.md @@ -315,7 +315,7 @@ Sharing your Gradio app with others (by hosting it on Spaces, on your own server In particular, Gradio apps ALLOW users to access to three kinds of files: -- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`. +- **Temporary files created by Gradio.** These are files that are created by Gradio as part of running your prediction function. For example, if your prediction function returns a video file, then Gradio will save that video to a temporary cache on your device and then send the path to the file to the front end. You can customize the location of temporary cache files created by Gradio by setting the environment variable `GRADIO_TEMP_DIR` to an absolute path, such as `/home/usr/scripts/project/temp/`. You can delete the files created by your app when it shuts down with the `delete_cache` parameter of `gradio.Blocks`, `gradio.Interface`, and `gradio.ChatInterface`. This parameter is a tuple of integers of the form `[frequency, age]` where `frequency` is how often to delete files and `age` is the time in seconds since the file was last modified. - **Cached examples created by Gradio.** These are files that are created by Gradio as part of caching examples for faster runtimes, if you set `cache_examples=True` in `gr.Interface()` or in `gr.Examples()`. By default, these files are saved in the `gradio_cached_examples/` subdirectory within your app's working directory. You can customize the location of cached example files created by Gradio by setting the environment variable `GRADIO_EXAMPLES_CACHE` to an absolute path or a path relative to your working directory. diff --git a/test/conftest.py b/test/conftest.py index e7f0358a495e..ed6d1c9325f8 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -41,8 +41,8 @@ def io_components(): @pytest.fixture def connect(): @contextmanager - def _connect(demo: gr.Blocks, serialize=True): - _, local_url, _ = demo.launch(prevent_thread_lock=True) + def _connect(demo: gr.Blocks, serialize=True, **kwargs): + _, local_url, _ = demo.launch(prevent_thread_lock=True, **kwargs) try: yield Client(local_url, serialize=serialize) finally: diff --git a/test/test_blocks.py b/test/test_blocks.py index 578140c8b0ff..1ad495fc38ee 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -1584,8 +1584,12 @@ def test_temp_file_sets_get_extended(): with gr.Blocks() as demo3: demo1.render() demo2.render() - - assert demo3.temp_file_sets == demo1.temp_file_sets + demo2.temp_file_sets + # The upload_set is empty so we remove it from the check + demo_3_no_empty = [s for s in demo3.temp_file_sets if len(s)] + demo_1_and_2_no_empty = [ + s for s in demo1.temp_file_sets + demo2.temp_file_sets if len(s) + ] + assert demo_3_no_empty == demo_1_and_2_no_empty def test_recover_kwargs(): diff --git a/test/test_routes.py b/test/test_routes.py index c9ea080754f4..89e12415c767 100644 --- a/test/test_routes.py +++ b/test/test_routes.py @@ -480,6 +480,49 @@ def test_cors_restrictions(self): assert file_response.headers["access-control-allow-origin"] == "127.0.0.1" io.close() + def test_delete_cache(self, connect, gradio_temp_dir, capsys): + def check_num_files_exist(blocks: Blocks): + num_files = 0 + for temp_file_set in blocks.temp_file_sets: + for temp_file in temp_file_set: + if os.path.exists(temp_file): + num_files += 1 + return num_files + + demo = gr.Interface(lambda s: s, gr.Textbox(), gr.File(), delete_cache=None) + with connect(demo) as client: + client.predict("test/test_files/cheetah1.jpg") + assert check_num_files_exist(demo) == 1 + + demo_delete = gr.Interface( + lambda s: s, gr.Textbox(), gr.File(), delete_cache=(60, 30) + ) + with connect(demo_delete) as client: + client.predict("test/test_files/alphabet.txt") + client.predict("test/test_files/bus.png") + assert check_num_files_exist(demo_delete) == 2 + assert check_num_files_exist(demo_delete) == 0 + assert check_num_files_exist(demo) == 1 + + @asynccontextmanager + async def mylifespan(app: FastAPI): + print("IN CUSTOM LIFESPAN") + yield + print("AFTER CUSTOM LIFESPAN") + + demo_custom_lifespan = gr.Interface( + lambda s: s, gr.Textbox(), gr.File(), delete_cache=(5, 1) + ) + + with connect( + demo_custom_lifespan, app_kwargs={"lifespan": mylifespan} + ) as client: + client.predict("test/test_files/alphabet.txt") + assert check_num_files_exist(demo_custom_lifespan) == 0 + captured = capsys.readouterr() + assert "IN CUSTOM LIFESPAN" in captured.out + assert "AFTER CUSTOM LIFESPAN" in captured.out + class TestApp: def test_create_app(self):