diff --git a/.changeset/two-streets-crash.md b/.changeset/two-streets-crash.md new file mode 100644 index 000000000000..93f67459f1ac --- /dev/null +++ b/.changeset/two-streets-crash.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Fix: Gradio Client work with private Spaces diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index dc08d7d124f1..219030299d76 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -1077,9 +1077,10 @@ async def _sse_fn(self, data: dict, hash_data: dict, helper: Communicator): data, hash_data, helper, - self.client.sse_url, - self.client.sse_data_url, - self.client.cookies, + sse_url=self.client.sse_url, + sse_data_url=self.client.sse_data_url, + headers=self.client.headers, + cookies=self.client.cookies, ) diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index ad2c0ef8cbe9..7ad7ea356518 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -315,6 +315,7 @@ async def get_pred_from_sse( helper: Communicator, sse_url: str, sse_data_url: str, + headers: dict[str, str], cookies: dict[str, str] | None = None, ) -> dict[str, Any] | None: done, pending = await asyncio.wait( @@ -322,7 +323,14 @@ async def get_pred_from_sse( asyncio.create_task(check_for_cancel(helper, cookies)), asyncio.create_task( stream_sse( - client, data, hash_data, helper, sse_url, sse_data_url, cookies + client, + data, + hash_data, + helper, + sse_url, + sse_data_url, + headers=headers, + cookies=cookies, ) ), ], @@ -362,11 +370,16 @@ async def stream_sse( helper: Communicator, sse_url: str, sse_data_url: str, + headers: dict[str, str], cookies: dict[str, str] | None = None, ) -> dict[str, Any]: try: async with client.stream( - "GET", sse_url, params=hash_data, cookies=cookies + "GET", + sse_url, + params=hash_data, + cookies=cookies, + headers=headers, ) as response: async for line in response.aiter_text(): if line.startswith("data:"): @@ -402,6 +415,7 @@ async def stream_sse( sse_data_url, json={"event_id": event_id, **data, **hash_data}, cookies=cookies, + headers=headers, ) req.raise_for_status() elif resp["msg"] == "process_completed": diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 894a6c654881..eded23737e76 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -75,15 +75,19 @@ def test_numerical_to_label_space_v4(self): @pytest.mark.flaky def test_private_space(self): - client = Client("gradio-tests/not-actually-private-space", hf_token=HF_TOKEN) + space_id = "gradio-tests/not-actually-private-space" + api = huggingface_hub.HfApi(token=HF_TOKEN) + assert api.space_info(space_id).private + client = Client(space_id, hf_token=HF_TOKEN) output = client.predict("abc", api_name="/predict") assert output == "abc" @pytest.mark.flaky def test_private_space_v4(self): - client = Client( - "gradio-tests/not-actually-private-spacev4-sse", hf_token=HF_TOKEN - ) + space_id = "gradio-tests/not-actually-private-spacev4-sse" + api = huggingface_hub.HfApi(token=HF_TOKEN) + assert api.space_info(space_id).private + client = Client(space_id, hf_token=HF_TOKEN) output = client.predict("abc", api_name="/predict") assert output == "abc"