Skip to content

Commit

Permalink
Fix: Gradio Client work with private Spaces (#6602)
Browse files Browse the repository at this point in the history
* client with private space

* add changeset

* lint

* add test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Nov 30, 2023
1 parent 4f040c7 commit b8034a1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .changeset/two-streets-crash.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Fix: Gradio Client work with private Spaces
7 changes: 4 additions & 3 deletions client/python/gradio_client/client.py
Expand Up @@ -1077,9 +1077,10 @@ async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator):
data,
hash_data,
helper,
self.client.sse_url,
self.client.sse_data_url,
self.client.cookies,
sse_url=self.client.sse_url,
sse_data_url=self.client.sse_data_url,
headers=self.client.headers,
cookies=self.client.cookies,
)


Expand Down
18 changes: 16 additions & 2 deletions client/python/gradio_client/utils.py
Expand Up @@ -315,14 +315,22 @@ async def get_pred_from_sse(
helper: Communicator,
sse_url: str,
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None = None,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, cookies)),
asyncio.create_task(
stream_sse(
client, data, hash_data, helper, sse_url, sse_data_url, cookies
client,
data,
hash_data,
helper,
sse_url,
sse_data_url,
headers=headers,
cookies=cookies,
)
),
],
Expand Down Expand Up @@ -362,11 +370,16 @@ async def stream_sse(
helper: Communicator,
sse_url: str,
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None = None,
) -> dict[str, Any]:
try:
async with client.stream(
"GET", sse_url, params=hash_data, cookies=cookies
"GET",
sse_url,
params=hash_data,
cookies=cookies,
headers=headers,
) as response:
async for line in response.aiter_text():
if line.startswith("data:"):
Expand Down Expand Up @@ -402,6 +415,7 @@ async def stream_sse(
sse_data_url,
json={"event_id": event_id, **data, **hash_data},
cookies=cookies,
headers=headers,
)
req.raise_for_status()
elif resp["msg"] == "process_completed":
Expand Down
12 changes: 8 additions & 4 deletions client/python/test/test_client.py
Expand Up @@ -75,15 +75,19 @@ def test_numerical_to_label_space_v4(self):

@pytest.mark.flaky
def test_private_space(self):
client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN)
space_id = "gradio-tests/not-actually-private-space"
api = huggingface_hub.HfApi(token=HF_TOKEN)
assert api.space_info(space_id).private
client = Client(space_id, hf_token=HF_TOKEN)
output = client.predict("abc", api_name="/predict")
assert output == "abc"

@pytest.mark.flaky
def test_private_space_v4(self):
client = Client(
"gradio-tests/not-actually-private-spacev4-sse", hf_token=HF_TOKEN
)
space_id = "gradio-tests/not-actually-private-spacev4-sse"
api = huggingface_hub.HfApi(token=HF_TOKEN)
assert api.space_info(space_id).private
client = Client(space_id, hf_token=HF_TOKEN)
output = client.predict("abc", api_name="/predict")
assert output == "abc"

Expand Down

0 comments on commit b8034a1

Please sign in to comment.