Skip to content

Commit

Permalink
Fix client tests sse branch (#6150)
Browse files Browse the repository at this point in the history
* Switch spaces

* Fix tests

* Add code

* changes

* changes

---------

Co-authored-by: Ali Abid <aabid94@gmail.com>
  • Loading branch information
freddyaboulton and aliabid94 committed Oct 30, 2023
1 parent 373ba21 commit 3bc906e
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 36 deletions.
15 changes: 9 additions & 6 deletions client/python/gradio_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ def predict(
)
return self.submit(*args, api_name=api_name, fn_index=fn_index).result()

def new_helper(self, fn_index: int) -> Communicator:
return Communicator(
Lock(),
JobStatus(),
self.endpoints[fn_index].process_predictions,
self.reset_url,
)

def submit(
self,
*args,
Expand Down Expand Up @@ -334,12 +342,7 @@ def submit(

helper = None
if self.endpoints[inferred_fn_index].protocol in ("ws", "sse"):
helper = Communicator(
Lock(),
JobStatus(),
self.endpoints[inferred_fn_index].process_predictions,
self.reset_url,
)
helper = self.new_helper(inferred_fn_index)
end_to_end_fn = self.endpoints[inferred_fn_index].make_end_to_end_fn(helper)
future = self.executor.submit(end_to_end_fn, *args)

Expand Down
1 change: 0 additions & 1 deletion client/python/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def yield_demo():
def spell(x):
for i in range(len(x)):
time.sleep(0.5)
print(">>>>>", x[:i])
yield x[:i]

return gr.Interface(spell, "textbox", "textbox")
Expand Down
10 changes: 6 additions & 4 deletions client/python/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_numerical_to_label_space(self):

@pytest.mark.flaky
def test_numerical_to_label_space_v4(self):
client = Client("gradio-tests/titanic-survival-v4")
client = Client("gradio-tests/titanic-survivalv4-sse")
label = client.predict("male", 77, 10, api_name="/predict")
assert label["label"] == "Perishes"

Expand All @@ -81,7 +81,9 @@ def test_private_space(self):

@pytest.mark.flaky
def test_private_space_v4(self):
client = Client("gradio-tests/not-actually-private-space-v4", hf_token=HF_TOKEN)
client = Client(
"gradio-tests/not-actually-private-spacev4-sse", hf_token=HF_TOKEN
)
output = client.predict("abc", api_name="/predict")
assert output == "abc"

Expand Down Expand Up @@ -301,7 +303,7 @@ def test_stream_audio(self, stream_audio):
@pytest.mark.xfail
def test_upload_file_private_space_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload-v4", hf_token=HF_TOKEN
src="gradio-tests/not-actually-private-file-uploadv4-sse", hf_token=HF_TOKEN
)

with patch.object(
Expand Down Expand Up @@ -1049,7 +1051,7 @@ def test_upload(self):

def test_upload_v4(self):
client = Client(
src="gradio-tests/not-actually-private-file-upload-v4", hf_token=HF_TOKEN
src="gradio-tests/not-actually-private-file-uploadv4-sse", hf_token=HF_TOKEN
)
response = MagicMock(status_code=200)
response.json.return_value = [
Expand Down
2 changes: 1 addition & 1 deletion client/python/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_decode_base64_to_file():

def test_download_private_file(gradio_temp_dir):
url_path = (
"https://gradio-tests-not-actually-private-space-v4.hf.space/file=lion.jpg"
"https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
)
hf_token = "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" # Intentionally revealing this key for testing purposes
file = utils.download_file(
Expand Down
11 changes: 9 additions & 2 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import requests
from gradio_client import Client
from gradio_client import utils as client_utils
from gradio_client.client import Endpoint
from gradio_client.documentation import document, set_documentation_group
from packaging import version

Expand Down Expand Up @@ -522,13 +523,19 @@ def from_spaces(

def from_spaces_blocks(space: str, hf_token: str | None) -> Blocks:
client = Client(space, hf_token=hf_token)
if client.app_version < version.Version("4.0.0"):
if client.app_version < version.Version("4.0.0b14"):
raise GradioVersionIncompatibleError(
f"Gradio version 4.x cannot load spaces with versions less than 4.x ({client.app_version})."
"Please downgrade to version 3 to load this space."
)
# Use end_to_end_fn here to properly upload/download all files
predict_fns = [endpoint.make_end_to_end_fn() for endpoint in client.endpoints]
predict_fns = []
for fn_index, endpoint in enumerate(client.endpoints):
assert isinstance(endpoint, Endpoint)
helper = None
if endpoint.protocol in ("ws", "sse"):
helper = client.new_helper(fn_index)
predict_fns.append(endpoint.make_end_to_end_fn(helper))
return gradio.Blocks.from_config(client.config, predict_fns, client.src)


Expand Down
38 changes: 20 additions & 18 deletions test/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_english_to_spanish(self):

def test_english_to_spanish_v4(self):
with pytest.warns(UserWarning):
io = gr.load("spaces/gradio-tests/english_to_spanish-v4", title="hi")
io = gr.load("spaces/gradio-tests/english_to_spanishv4-sse", title="hi")
assert isinstance(io.input_components[0], gr.Textbox)
assert isinstance(io.output_components[0], gr.Textbox)

Expand Down Expand Up @@ -217,7 +217,7 @@ def test_raise_incompatbile_version_error(self):
gr.load("spaces/gradio-tests/titanic-survival")

def test_numerical_to_label_space(self):
io = gr.load("spaces/gradio-tests/titanic-survival-v4")
io = gr.load("spaces/gradio-tests/titanic-survivalv4-sse")
try:
assert io.theme.name == "soft"
assert io("male", 77, 10)["label"] == "Perishes"
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_text_to_image_model(self):
def test_private_space(self):
hf_token = "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" # Intentionally revealing this key for testing purposes
io = gr.load(
"spaces/gradio-tests/not-actually-private-space-v4", hf_token=hf_token
"spaces/gradio-tests/not-actually-private-spacev4-sse", hf_token=hf_token
)
try:
output = io("abc")
Expand All @@ -307,7 +307,8 @@ def test_private_space(self):
def test_private_space_audio(self):
hf_token = "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" # Intentionally revealing this key for testing purposes
io = gr.load(
"spaces/gradio-tests/not-actually-private-space-audio-v4", hf_token=hf_token
"spaces/gradio-tests/not-actually-private-space-audiov4-sse",
hf_token=hf_token,
)
try:
output = io(media_data.BASE64_AUDIO["path"])
Expand All @@ -319,23 +320,24 @@ def test_multiple_spaces_one_private(self):
hf_token = "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" # Intentionally revealing this key for testing purposes
with gr.Blocks():
gr.load(
"spaces/gradio-tests/not-actually-private-space-v4", hf_token=hf_token
"spaces/gradio-tests/not-actually-private-spacev4-sse",
hf_token=hf_token,
)
gr.load(
"spaces/gradio/test-loading-examples-v4",
"spaces/gradio/test-loading-examplesv4-sse",
)
assert Context.hf_token == hf_token

def test_loading_files_via_proxy_works(self):
hf_token = "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" # Intentionally revealing this key for testing purposes
io = gr.load(
"spaces/gradio-tests/test-loading-examples-private-v4", hf_token=hf_token
"spaces/gradio-tests/test-loading-examples-privatev4-sse", hf_token=hf_token
)
assert io.theme.name == "default"
app, _, _ = io.launch(prevent_thread_lock=True)
test_client = TestClient(app)
r = test_client.get(
"/proxy=https://gradio-tests-test-loading-examples-private-v4.hf.space/file=Bunny.obj"
"/proxy=https://gradio-tests-test-loading-examples-privatev4-sse.hf.space/file=Bunny.obj"
)
assert r.status_code == 200

Expand All @@ -360,25 +362,25 @@ def test_interface_load_cache_examples(self, tmp_path):
)

def test_root_url(self):
demo = gr.load("spaces/gradio/test-loading-examples-v4")
demo = gr.load("spaces/gradio/test-loading-examplesv4-sse")
assert all(
c["props"]["root_url"]
== "https://gradio-test-loading-examples-v4.hf.space/"
== "https://gradio-test-loading-examplesv4-sse.hf.space/"
for c in demo.get_config_file()["components"]
)

def test_root_url_deserialization(self):
demo = gr.load("spaces/gradio/simple_gallery-v4")
demo = gr.load("spaces/gradio/simple_galleryv4-sse")
gallery = demo("test")
assert all("caption" in d for d in gallery)

def test_interface_with_examples(self):
# This demo has the "fake_event" correctly removed
demo = gr.load("spaces/gradio-tests/test-calculator-1-v4")
demo = gr.load("spaces/gradio-tests/test-calculator-1v4-sse")
assert demo(2, "add", 3) == 5

# This demo still has the "fake_event". both should work
demo = gr.load("spaces/gradio-tests/test-calculator-2-v4")
demo = gr.load("spaces/gradio-tests/test-calculator-2v4-sse")
assert demo(2, "add", 4) == 6


Expand Down Expand Up @@ -450,13 +452,13 @@ def check_dataset(config, readme_examples):

@pytest.mark.xfail
def test_load_blocks_with_default_values():
io = gr.load("spaces/gradio-tests/min-dalle-v4")
io = gr.load("spaces/gradio-tests/min-dallev4-sse")
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)

io = gr.load("spaces/gradio-tests/min-dalle-later-v4")
io = gr.load("spaces/gradio-tests/min-dalle-laterv4-sse")
assert isinstance(io.get_config_file()["components"][0]["props"]["value"], list)

io = gr.load("spaces/gradio-tests/dataframe_load-v4")
io = gr.load("spaces/gradio-tests/dataframe_loadv4-sse")
assert io.get_config_file()["components"][0]["props"]["value"] == {
"headers": ["a", "b"],
"data": [[1, 4], [2, 5], [3, 6]],
Expand All @@ -483,14 +485,14 @@ def test_can_load_tabular_model_with_different_widget_data(hypothetical_readme):


def test_raise_value_error_when_api_name_invalid():
demo = gr.load(name="spaces/gradio/hello_world-v4")
demo = gr.load(name="spaces/gradio/hello_worldv4-sse")
with pytest.raises(InvalidApiNameError):
demo("freddy", api_name="route does not exist")


def test_use_api_name_in_call_method():
# Interface
demo = gr.load(name="spaces/gradio/hello_world-v4")
demo = gr.load(name="spaces/gradio/hello_worldv4-sse")
assert demo("freddy", api_name="predict") == "Hello freddy!"

# Blocks demo with multiple functions
Expand Down
8 changes: 4 additions & 4 deletions test/test_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_in_interface(self):
assert series("Hello") == "Hello World!"

def test_with_external(self):
io1 = gr.load("spaces/gradio-tests/image-identity-new-v4")
io2 = gr.load("spaces/gradio-tests/image-classifier-new-v4")
io1 = gr.load("spaces/gradio-tests/image-identity-newv4-sse")
io2 = gr.load("spaces/gradio-tests/image-classifier-newv4-sse")
series = mix.Series(io1, io2)
try:
assert series("gradio/test_data/lion.jpg")["label"] == "lion"
Expand Down Expand Up @@ -45,8 +45,8 @@ def test_multiple_return_in_interface(self):
]

def test_with_external(self):
io1 = gr.load("spaces/gradio-tests/english_to_spanish-v4")
io2 = gr.load("spaces/gradio-tests/english2german-v4")
io1 = gr.load("spaces/gradio-tests/english_to_spanishv4-sse")
io2 = gr.load("spaces/gradio-tests/english2germanv4-sse")
parallel = mix.Parallel(io1, io2)
try:
hello_es, hello_de = parallel("Hello")
Expand Down

0 comments on commit 3bc906e

Please sign in to comment.