Skip to content

Commit

Permalink
Add code
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Oct 24, 2023
1 parent d031119 commit 797e026
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 121 deletions.
10 changes: 7 additions & 3 deletions client/python/gradio_client/client.py
Expand Up @@ -971,13 +971,18 @@ def _gather_files(self, *data):
file_list = []

def get_file(d):
file_list.append(d)
if utils.is_file_obj(d):
file_list.append(d["name"])
else:
file_list.append(d)
return ReplaceMe(len(file_list) - 1)

new_data = []
for i, d in enumerate(data):
if self.input_component_types[i].value_is_file:
d = utils.traverse(d, get_file, utils.is_filepath)
d = utils.traverse(
d, get_file, lambda s: utils.is_file_obj(s) or utils.is_filepath(s)
)
new_data.append(d)
return file_list, new_data

Expand Down Expand Up @@ -1132,7 +1137,6 @@ def _predict(*data) -> tuple:
"session_hash": self.client.session_hash,
}
)

if self.use_ws:
result = utils.synchronize_async(self._ws_fn, data, hash_data, helper)
if "error" in result:
Expand Down
95 changes: 34 additions & 61 deletions client/python/test/test_client.py
Expand Up @@ -53,8 +53,8 @@ def test_raise_error_invalid_state(self):
def test_numerical_to_label_space(self):
client = Client("gradio-tests/titanic-survival")
label = json.load(
open(client.predict("male", 77, 10, api_name="/predict"))
) # noqa: SIM115
open(client.predict("male", 77, 10, api_name="/predict")) # noqa: SIM115
)
assert label["label"] == "Perishes"
with pytest.raises(
ValueError,
Expand All @@ -69,21 +69,9 @@ def test_numerical_to_label_space(self):

@pytest.mark.flaky
def test_numerical_to_label_space_v4(self):
client = Client("gradio-tests/titanic-survival")
label = json.load(
open(client.predict("male", 77, 10, api_name="/predict"))
) # noqa: SIM115
client = Client("gradio-tests/titanic-survival-v4")
label = client.predict("male", 77, 10, api_name="/predict")
assert label["label"] == "Perishes"
with pytest.raises(
ValueError,
match="This Gradio app might have multiple endpoints. Please specify an `api_name` or `fn_index`",
):
client.predict("male", 77, 10)
with pytest.raises(
ValueError,
match="Cannot find a function with `api_name`: predict. Did you mean to use a leading slash?",
):
client.predict("male", 77, 10, api_name="predict")

@pytest.mark.flaky
def test_private_space(self):
Expand All @@ -93,7 +81,7 @@ def test_private_space(self):

@pytest.mark.flaky
def test_private_space_v4(self):
client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN)
client = Client("gradio-tests/not-actually-private-space-v4", hf_token=HF_TOKEN)
output = client.predict("abc", api_name="/predict")
assert output == "abc"

Expand Down Expand Up @@ -322,7 +310,7 @@ def test_stream_audio(self, stream_audio):

def test_upload_file_private_space_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload", hf_token=HF_TOKEN
src="gradio-tests/not-actually-private-file-upload-v4", hf_token=HF_TOKEN
)

with patch.object(
Expand Down Expand Up @@ -379,51 +367,36 @@ def test_upload_file_private_space(self):
)

with patch.object(
client.endpoints[0], "_upload", wraps=client.endpoints[0]._upload
) as upload:
with patch.object(
client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize
) as serialize:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")

output = client.submit(
1, "foo", f.name, api_name="/file_upload"
).result()
with open(output) as f:
assert f.read() == "Hello from private space!"
upload.assert_called_once()
assert all(f["is_file"] for f in serialize.return_value())

with patch.object(
client.endpoints[1], "_upload", wraps=client.endpoints[0]._upload
) as upload:
client.endpoints[0], "serialize", wraps=client.endpoints[0].serialize
) as serialize:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")

with open(client.submit(f.name, api_name="/upload_btn").result()) as f:
assert f.read() == "Hello from private space!"
upload.assert_called_once()

with patch.object(
client.endpoints[2], "_upload", wraps=client.endpoints[0]._upload
) as upload:
# `delete=False` is required for Windows compat
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2:
f1.write("File1")
f2.write("File2")
r1, r2 = client.submit(
3,
[f1.name, f2.name],
"hello",
api_name="/upload_multiple",
).result()
with open(r1) as f:
assert f.read() == "File1"
with open(r2) as f:
assert f.read() == "File2"
upload.assert_called_once()
output = client.submit(1, "foo", f.name, api_name="/file_upload").result()
with open(output) as f:
assert f.read() == "Hello from private space!"
assert all(f["is_file"] for f in serialize.return_value())

with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
f.write("Hello from private space!")

with open(client.submit(f.name, api_name="/upload_btn").result()) as f:
assert f.read() == "Hello from private space!"

with tempfile.NamedTemporaryFile(mode="w", delete=False) as f1:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f2:
f1.write("File1")
f2.write("File2")
r1, r2 = client.submit(
3,
[f1.name, f2.name],
"hello",
api_name="/upload_multiple",
).result()
with open(r1) as f:
assert f.read() == "File1"
with open(r2) as f:
assert f.read() == "File2"

@pytest.mark.flaky
def test_upload_file_upload_route_does_not_exist(self):
Expand Down Expand Up @@ -1085,7 +1058,7 @@ def test_upload(self):

def test_upload_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload", hf_token=HF_TOKEN
src="gradio-tests/not-actually-private-file-upload-v4", hf_token=HF_TOKEN
)
response = MagicMock(status_code=200)
response.json.return_value = [
Expand Down
4 changes: 3 additions & 1 deletion client/python/test/test_utils.py
Expand Up @@ -67,7 +67,9 @@ def test_decode_base64_to_file():


def test_download_private_file(gradio_temp_dir):
url_path = "https://gradio-tests-not-actually-private-space.hf.space/file=lion.jpg"
url_path = (
"https://gradio-tests-not-actually-private-space-v4.hf.space/file=lion.jpg"
)
hf_token = "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" # Intentionally revealing this key for testing purposes
file = utils.download_file(
url_path=url_path, hf_token=hf_token, dir=str(gradio_temp_dir)
Expand Down
1 change: 0 additions & 1 deletion gradio/blocks.py
Expand Up @@ -700,7 +700,6 @@ def iterate_over_children(children_list):
]
blocks.__name__ = "Interface"
blocks.api_mode = True

blocks.root_urls = root_urls
return blocks

Expand Down
4 changes: 3 additions & 1 deletion gradio/component_meta.py
Expand Up @@ -121,7 +121,9 @@ def create_or_modify_pyi(
else:
contents = pyi_file.read_text()
contents = contents.replace(current_interface, new_interface.strip())
pyi_file.write_text(contents)
current_contents = pyi_file.read_text()
if current_contents != contents:
pyi_file.write_text(contents)


def in_event_listener():
Expand Down
1 change: 0 additions & 1 deletion gradio/components/audio.py
Expand Up @@ -161,7 +161,6 @@ def preprocess(
"""
if x is None:
return x

payload: AudioInputData = AudioInputData(**x)
assert payload.name

Expand Down
20 changes: 12 additions & 8 deletions gradio/external.py
Expand Up @@ -15,12 +15,14 @@
from gradio_client import Client
from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group
from packaging import version

import gradio
from gradio import components, utils
from gradio.context import Context
from gradio.exceptions import (
Error,
GradioVersionIncompatibleError,
ModelNotFoundError,
TooManyRequestsError,
)
Expand Down Expand Up @@ -149,7 +151,7 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
f"Could not find model: {model_name}. If it is a private or gated model, please provide your Hugging Face access token (https://huggingface.co/settings/tokens) as the argument for the `api_key` parameter."
)
p = response.json().get("pipeline_tag")
GRADIO_CACHE = os.environ.get("GRADIO_TEMP_DIR") or str(
GRADIO_CACHE = os.environ.get("GRADIO_TEMP_DIR") or str( # noqa: N806
Path(tempfile.gettempdir()) / "gradio"
)

Expand Down Expand Up @@ -316,7 +318,9 @@ def from_model(model_name: str, hf_token: str | None, alias: str | None, **kwarg
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Audio(label="Audio", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": encode_to_base64,
"postprocess": lambda x: save_base64_to_cache(
encode_to_base64(x), cache_dir=GRADIO_CACHE, file_name="output.wav"
),
},
"text-to-image": {
# example model: osanseviero/BigGAN-deep-128
Expand Down Expand Up @@ -516,12 +520,12 @@ def from_spaces(

def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
client = Client(space, hf_token=hf_token)
# if client.app_version < version.Version("4.0.0"):
# raise GradioVersionIncompatible(
# f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
# "Please downgrade to version 3 to load this space.")
predict_fns = [endpoint._predict_resolve for endpoint in client.endpoints]

if client.app_version < version.Version("4.0.0"):
raise GradioVersionIncompatibleError(
f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
"Please downgrade to version 3 to load this space."
)
predict_fns = [endpoint.make_end_to_end_fn() for endpoint in client.endpoints]
return gradio.Blocks.from_config(client.config, predict_fns, client.src)


Expand Down

0 comments on commit 797e026

Please sign in to comment.