Skip to content

Commit

Permalink
Fix gr.load for file-based Spaces (#7350)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* fixes

* changes

* fix

* add test

* add changeset

* improve test

* Fixing `gr.load()` part II (#7358)

* audio

* changes

* add changeset

* changes

* changes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>

* Delete .changeset/fresh-gifts-worry.md

* add changeset

* format

* upload

* add changeset

* changes

* backend

* print

* add changeset

* changes

* client

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Feb 9, 2024
1 parent a7fa47a commit 7302a6e
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 30 deletions.
6 changes: 6 additions & 0 deletions .changeset/tender-lamps-shout.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Fix `gr.load` for file-based Spaces
50 changes: 33 additions & 17 deletions client/python/gradio_client/client.py
Expand Up @@ -71,26 +71,36 @@ def __init__(
src: str,
hf_token: str | None = None,
max_workers: int = 40,
serialize: bool = True,
serialize: bool | None = None,
output_dir: str | Path = DEFAULT_TEMP_DIR,
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
upload_files: bool = True,
download_files: bool = True,
):
"""
Parameters:
src: Either the name of the Hugging Face Space to load, (e.g. "abidlabs/whisper-large-v2") or the full URL (including "http" or "https") of the hosted Gradio app to load (e.g. "http://mydomain.com/app" or "https://bec81a83-5b5c-471e.gradio.live/").
hf_token: The Hugging Face token to use to access private Spaces. Automatically fetched if you are logged in via the Hugging Face Hub CLI. Obtain from: https://huggingface.co/settings/token
max_workers: The maximum number of thread workers that can be used to make requests to the remote Gradio app simultaneously.
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
serialize: Deprecated. Please use the equivalent `upload_files` parameter instead.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will a FileData dataclass object with the filepath on the remote machine instead.
"""
self.verbose = verbose
self.hf_token = hf_token
self.serialize = serialize
if serialize is not None:
warnings.warn(
"The `serialize` parameter is deprecated and will be removed. Please use the equivalent `upload_files` parameter instead."
)
upload_files = serialize
self.upload_files = upload_files
self.download_files = download_files
self.headers = build_hf_headers(
token=hf_token,
library_name="gradio_client",
Expand Down Expand Up @@ -463,11 +473,10 @@ def fn(future):
return job

def _get_api_info(self):
if self.serialize:
if self.upload_files:
api_info_url = urllib.parse.urljoin(self.src, utils.API_INFO_URL)
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)

if self.app_version > version.Version("3.36.1"):
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
if r.is_success:
Expand All @@ -477,7 +486,10 @@ def _get_api_info(self):
else:
fetch = httpx.post(
utils.SPACE_FETCHER_URL,
json={"config": json.dumps(self.config), "serialize": self.serialize},
json={
"config": json.dumps(self.config),
"serialize": self.upload_files,
},
)
if fetch.is_success:
info = fetch.json()["api"]
Expand Down Expand Up @@ -955,7 +967,11 @@ def _get_component_type(self, component_id: int):

@staticmethod
def value_is_file(component: dict) -> bool:
# Hacky for now
# This is still hacky as it does not tell us which part of the payload is a file.
# If a component has a complex payload, part of which is a file, this will simply
# return True, which means that all parts of the payload will be uploaded as files
# if they are valid file paths. The better approach would be to traverse the
# component's api_info and figure out exactly which part of the payload is a file.
if "api_info" not in component:
return False
return utils.value_is_file(component["api_info"])
Expand All @@ -973,7 +989,7 @@ def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
Expand Down Expand Up @@ -1117,6 +1133,9 @@ def get_file(d):
file_list.append(d)
return ReplaceMe(len(file_list) - 1)

def handle_url(s):
return {"path": s, "orig_name": s.split("/")[-1]}

new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
Expand All @@ -1126,6 +1145,8 @@ def get_file(d):
d = utils.traverse(
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
)
# Handle URLs here since we don't upload them
d = utils.traverse(d, handle_url, lambda s: utils.is_url(s))
new_data.append(d)
return file_list, new_data

Expand All @@ -1146,11 +1167,6 @@ def serialize(self, *data) -> tuple:
uploaded_files = self._upload(files)
data = list(new_data)
data = self._add_uploaded_files_to_data(data, uploaded_files)
data = utils.traverse(
data,
lambda s: {"path": s},
utils.is_url,
)
o = tuple(data)
return o

Expand Down Expand Up @@ -1182,12 +1198,12 @@ def _download_file(

def deserialize(self, *data) -> tuple:
data_ = list(data)

data_: list[Any] = utils.traverse(data_, self.download_file, utils.is_file_obj)
return tuple(data_)

def process_predictions(self, *predictions):
predictions = self.deserialize(*predictions)
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
return predictions
Expand Down Expand Up @@ -1258,7 +1274,7 @@ def _inner(*data):
if not self.is_valid:
raise utils.InvalidAPIEndpointError()
data = self.insert_state(*data)
if self.client.serialize:
if self.client.upload_files:
data = self.serialize(*data)
predictions = _predict(*data)
predictions = self.process_predictions(*predictions)
Expand Down Expand Up @@ -1449,7 +1465,7 @@ def deserialize(self, *data) -> tuple:
return outputs

def process_predictions(self, *predictions):
if self.client.serialize:
if self.client.download_files:
predictions = self.deserialize(*predictions)
predictions = self.remove_skipped_components(*predictions)
predictions = self.reduce_singleton_output(*predictions)
Expand Down
3 changes: 2 additions & 1 deletion client/python/gradio_client/utils.py
Expand Up @@ -895,6 +895,7 @@ def get_type(schema: dict):
raise APIInfoParseError(f"Cannot parse type for {schema}")


OLD_FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)"
FILE_DATA = "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)"


Expand Down Expand Up @@ -995,7 +996,7 @@ def traverse(json_obj: Any, func: Callable, is_root: Callable) -> Any:

def value_is_file(api_info: dict) -> bool:
info = _json_schema_to_python_type(api_info, api_info.get("$defs"))
return FILE_DATA in info
return FILE_DATA in info or OLD_FILE_DATA in info


def is_filepath(s):
Expand Down
22 changes: 22 additions & 0 deletions client/python/test/test_client.py
Expand Up @@ -133,6 +133,28 @@ def test_private_space_v4_sse_v1(self):
output = client.predict("abc", api_name="/predict")
assert output == "abc"

@pytest.mark.flaky
def test_space_with_files_v4_sse_v2(self):
space_id = "gradio-tests/space_with_files_v4_sse_v2"
client = Client(space_id)
payload = (
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
{
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
"subtitle": None,
},
"https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3",
)
output = client.predict(*payload, api_name="/predict")
assert output[0].endswith(".wav") # Audio files are converted to wav
assert output[1]["video"].endswith(
"world.mp4"
) # Video files are not converted by default
assert (
output[2]
== "https://audio-samples.github.io/samples/mp3/blizzard_unconditional/sample-0.mp3"
) # textbox string should remain exactly the same

def test_state(self, increment_demo):
with connect(increment_demo) as client:
output = client.predict(api_name="/increment_without_queue")
Expand Down
5 changes: 4 additions & 1 deletion gradio/components/video.py
Expand Up @@ -350,4 +350,7 @@ def srt_to_vtt(srt_file_path, vtt_file_path):
return FileData(path=str(subtitle))

def example_inputs(self) -> Any:
return "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4"
return {
"video": "https://github.com/gradio-app/gradio/raw/main/demo/video_component/files/world.mp4",
"subtitles": None,
}
6 changes: 5 additions & 1 deletion gradio/external.py
Expand Up @@ -418,12 +418,16 @@ def from_spaces(


def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
client = Client(space, hf_token=hf_token)
client = Client(space, hf_token=hf_token, download_files=False)
# We set deserialize to False to avoid downloading output files from the server.
# Instead, we serve them as URLs using the /proxy/ endpoint directly from the server.

if client.app_version < version.Version("4.0.0b14"):
raise GradioVersionIncompatibleError(
f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
"Please downgrade to version 3 to load this space."
)

# Use end_to_end_fn here to properly upload/download all files
predict_fns = []
for fn_index, endpoint in enumerate(client.endpoints):
Expand Down
22 changes: 12 additions & 10 deletions gradio/processing_utils.py
Expand Up @@ -261,22 +261,24 @@ def _move_to_cache(d: dict):
# without it being served from the gradio server
# This makes it so that the URL is not downloaded and speeds up event processing
if payload.url and postprocess:
temp_file_path = payload.url
else:
payload.path = payload.url
elif not block.proxy_url:
# If the file is on a remote server, do not move it to cache.
temp_file_path = move_resource_to_block_cache(payload.path, block)
assert temp_file_path is not None
payload.path = temp_file_path
assert temp_file_path is not None
payload.path = temp_file_path

if add_urls:
url_prefix = "/stream/" if payload.is_stream else "/file="
if block.proxy_url:
url = f"/proxy={block.proxy_url}{url_prefix}{temp_file_path}"
elif client_utils.is_http_url_like(
temp_file_path
) or temp_file_path.startswith(f"{url_prefix}"):
url = temp_file_path
proxy_url = block.proxy_url.rstrip("/")
url = f"/proxy={proxy_url}{url_prefix}{payload.path}"
elif client_utils.is_http_url_like(payload.path) or payload.path.startswith(
f"{url_prefix}"
):
url = payload.path
else:
url = f"{url_prefix}{temp_file_path}"
url = f"{url_prefix}{payload.path}"
payload.url = url

return payload.model_dump()
Expand Down

0 comments on commit 7302a6e

Please sign in to comment.