Skip to content

Commit

Permalink
Prevent components from working with non-uploaded files (#7465)
Browse files Browse the repository at this point in the history
* changes

* chanegs

* add changeset

* changes

* changes

* add changeset

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 22, 2024
1 parent ba3ec13 commit 16fbe9c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 20 deletions.
5 changes: 5 additions & 0 deletions .changeset/moody-impalas-rule.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

feat:Prevent components from working with non-uploaded files
42 changes: 30 additions & 12 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import secrets
import string
import sys
import tempfile
import threading
import time
import warnings
Expand Down Expand Up @@ -70,6 +69,7 @@
get_cancel_function,
get_continuous_fn,
get_package_version,
get_upload_folder,
)

try:
Expand Down Expand Up @@ -119,12 +119,7 @@ def __init__(
self._constructor_args: list[dict]
self.state_session_capacity = 10000
self.temp_files: set[str] = set()
self.GRADIO_CACHE = str(
Path(
os.environ.get("GRADIO_TEMP_DIR")
or str(Path(tempfile.gettempdir()) / "gradio")
).resolve()
)
self.GRADIO_CACHE = get_upload_folder()

if render:
self.render()
Expand Down Expand Up @@ -1110,6 +1105,7 @@ def __call__(self, *inputs, fn_index: int = 0, api_name: str | None = None):
inputs=processed_inputs,
request=None,
state={},
explicit_call=True,
)
outputs = outputs["data"]

Expand Down Expand Up @@ -1299,7 +1295,11 @@ def validate_inputs(self, fn_index: int, inputs: list[Any]):
)

def preprocess_data(
self, fn_index: int, inputs: list[Any], state: SessionState | None
self,
fn_index: int,
inputs: list[Any],
state: SessionState | None,
explicit_call: bool = False,
):
state = state or SessionState(self)
block_fn = self.fns[fn_index]
Expand All @@ -1326,7 +1326,10 @@ def preprocess_data(
if input_id in state:
block = state[input_id]
inputs_cached = processing_utils.move_files_to_cache(
inputs[i], block, add_urls=True
inputs[i],
block,
add_urls=True,
check_in_upload_folder=not explicit_call,
)
if getattr(block, "data_model", None) and inputs_cached is not None:
if issubclass(block.data_model, GradioModel): # type: ignore
Expand Down Expand Up @@ -1522,8 +1525,14 @@ def handle_streaming_diffs(

return data

def run_fn_batch(self, fn, batch, fn_index, state):
return [fn(fn_index, list(i), state) for i in zip(*batch)]
def run_fn_batch(self, fn, batch, fn_index, state, explicit_call=None):
output = []
for i in zip(*batch):
args = [fn_index, list(i), state]
if explicit_call is not None:
args.append(explicit_call)
output.append(fn(*args))
return output

async def process_api(
self,
Expand All @@ -1536,6 +1545,7 @@ async def process_api(
event_id: str | None = None,
event_data: EventData | None = None,
in_event_listener: bool = True,
explicit_call: bool = False,
) -> dict[str, Any]:
"""
Processes API calls from the frontend. First preprocesses the data,
Expand All @@ -1548,6 +1558,8 @@ async def process_api(
iterators: the in-progress iterators for each generator function (key is function index)
event_id: id of event that triggered this API call
event_data: data associated with the event trigger itself
in_event_listener: whether this API call is being made in response to an event listener
explicit_call: whether this call is being made directly by calling the Blocks function, instead of through an event listener or API route
Returns: None
"""
block_fn = self.fns[fn_index]
Expand Down Expand Up @@ -1575,6 +1587,7 @@ async def process_api(
inputs,
fn_index,
state,
explicit_call,
limiter=self.limiter,
)
result = await self.call_function(
Expand Down Expand Up @@ -1603,7 +1616,12 @@ async def process_api(
inputs = []
else:
inputs = await anyio.to_thread.run_sync(
self.preprocess_data, fn_index, inputs, state, limiter=self.limiter
self.preprocess_data,
fn_index,
inputs,
state,
explicit_call,
limiter=self.limiter,
)
was_generating = old_iterator is not None
result = await self.call_function(
Expand Down
13 changes: 12 additions & 1 deletion gradio/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from gradio import wasm_utils
from gradio.data_classes import FileData, GradioModel, GradioRootModel
from gradio.utils import abspath
from gradio.utils import abspath, get_upload_folder, is_in_or_equal

with warnings.catch_warnings():
warnings.simplefilter("ignore") # Ignore pydub warning if ffmpeg is not installed
Expand Down Expand Up @@ -241,6 +241,7 @@ def move_files_to_cache(
block: Component,
postprocess: bool = False,
add_urls=False,
check_in_upload_folder=False,
) -> dict:
"""Move any files in `data` to cache and (optionally), adds URL prefixes (/file=...) needed to access the cached file.
Also handles the case where the file is on an external Gradio app (/proxy=...).
Expand All @@ -252,6 +253,8 @@ def move_files_to_cache(
block: The component whose data is being processed
postprocess: Whether its running from postprocessing
root_url: The root URL of the local server, if applicable
add_urls: Whether to add URLs to the payload
check_in_upload_folder: If True, instead of moving the file to cache, checks if the file is in already in cache (exception if not).
"""

def _move_to_cache(d: dict):
Expand All @@ -264,6 +267,14 @@ def _move_to_cache(d: dict):
payload.path = payload.url
elif not block.proxy_url:
# If the file is on a remote server, do not move it to cache.
if check_in_upload_folder and not client_utils.is_http_url_like(
payload.path
):
path = os.path.abspath(payload.path)
if not is_in_or_equal(path, get_upload_folder()):
raise ValueError(
f"File {path} is not in the upload folder and cannot be accessed."
)
temp_file_path = block.move_resource_to_block_cache(payload.path)
if temp_file_path is None:
raise ValueError("Did not determine a file path for the resource.")
Expand Down
9 changes: 2 additions & 7 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import os
import posixpath
import secrets
import tempfile
import threading
import time
import traceback
Expand Down Expand Up @@ -67,9 +66,7 @@
move_uploaded_files_to_cache,
)
from gradio.state_holder import StateHolder
from gradio.utils import (
get_package_version,
)
from gradio.utils import get_package_version, get_upload_folder

if TYPE_CHECKING:
from gradio.blocks import Block
Expand Down Expand Up @@ -136,9 +133,7 @@ def __init__(self, **kwargs):
self.cookie_id = secrets.token_urlsafe(32)
self.queue_token = secrets.token_urlsafe(32)
self.startup_events_triggered = False
self.uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)
self.uploaded_file_dir = get_upload_folder()
self.change_event: None | threading.Event = None
self._asyncio_tasks: list[asyncio.Task] = []
# Allow user to manually set `docs_url` and `redoc_url`
Expand Down
7 changes: 7 additions & 0 deletions gradio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import pkgutil
import re
import tempfile
import threading
import time
import traceback
Expand Down Expand Up @@ -1082,3 +1083,9 @@ def compare_objects(obj1, obj2, path=None):
return edits

return compare_objects(old, new)


def get_upload_folder() -> str:
return os.environ.get("GRADIO_TEMP_DIR") or str(
(Path(tempfile.gettempdir()) / "gradio").resolve()
)

0 comments on commit 16fbe9c

Please sign in to comment.