diff --git a/.changeset/kind-eyes-shake.md b/.changeset/kind-eyes-shake.md new file mode 100644 index 000000000000..680986de059f --- /dev/null +++ b/.changeset/kind-eyes-shake.md @@ -0,0 +1,6 @@ +--- +"gradio": patch +"gradio_client": patch +--- + +fix:Add support for python client connecting to gradio apps running with self-signed SSL certificates diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index aee2bc45b3e9..554ca12df967 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -2,10 +2,12 @@ from __future__ import annotations import concurrent.futures +import hashlib import json import os import re import secrets +import shutil import tempfile import threading import time @@ -81,6 +83,7 @@ def __init__( upload_files: bool = True, # TODO: remove and hardcode to False in 1.0 download_files: bool = True, # TODO: consider setting to False in 1.0 _skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users. + ssl_verify: bool = True, ): """ Parameters: @@ -93,6 +96,7 @@ def __init__( headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys. upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version. download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead. + ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates. """ self.verbose = verbose self.hf_token = hf_token @@ -111,6 +115,7 @@ def __init__( ) if headers: self.headers.update(headers) + self.ssl_verify = ssl_verify self.space_id = None self.cookies: dict[str, str] = {} self.output_dir = ( @@ -187,7 +192,9 @@ def __init__( async def stream_messages(self) -> None: try: - async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: + async with httpx.AsyncClient( + timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify + ) as client: async with client.stream( "GET", self.sse_url, @@ -227,7 +234,7 @@ async def stream_messages(self) -> None: raise e async def send_data(self, data, hash_data): - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(verify=self.ssl_verify) as client: req = await client.post( self.sse_data_url, json={**data, **hash_data}, @@ -484,7 +491,12 @@ def _get_api_info(self): else: api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL) if self.app_version > version.Version("3.36.1"): - r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies) + r = httpx.get( + api_info_url, + headers=self.headers, + cookies=self.cookies, + verify=self.ssl_verify, + ) if r.is_success: info = r.json() else: @@ -735,6 +747,7 @@ def _login(self, auth: tuple[str, str]): resp = httpx.post( urllib.parse.urljoin(self.src, utils.LOGIN_URL), data={"username": auth[0], "password": auth[1]}, + verify=self.ssl_verify, ) if not resp.is_success: if resp.status_code == 401: @@ -752,6 +765,7 @@ def _get_config(self) -> dict: urllib.parse.urljoin(self.src, utils.CONFIG_URL), headers=self.headers, cookies=self.cookies, + verify=self.ssl_verify, ) if r.is_success: return r.json() @@ -760,7 +774,12 @@ def _get_config(self) -> dict: f"Could not load {self.src} as credentials were not provided. Please login." ) else: # to support older versions of Gradio - r = httpx.get(self.src, headers=self.headers, cookies=self.cookies) + r = httpx.get( + self.src, + headers=self.headers, + cookies=self.cookies, + verify=self.ssl_verify, + ) if not r.is_success: raise ValueError(f"Could not fetch config for {self.src}") # some basic regex to extract the config @@ -1126,7 +1145,7 @@ def reduce_singleton_output(self, *data) -> Any: else: return data - def _upload_file(self, f: str | dict): + def _upload_file(self, f: str | dict) -> dict[str, str]: if isinstance(f, str): warnings.warn( f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. ' @@ -1137,24 +1156,53 @@ def _upload_file(self, f: str | dict): else: file_path = f["path"] if not utils.is_http_url_like(file_path): - file_path = utils.upload_file( - file_path=file_path, - upload_url=self.client.upload_url, - headers=self.client.headers, - cookies=self.client.cookies, - ) + with open(file_path, "rb") as f: + files = [("files", (Path(file_path).name, f))] + r = httpx.post( + self.client.upload_url, + headers=self.client.headers, + cookies=self.client.cookies, + verify=self.client.ssl_verify, + files=files, + ) + r.raise_for_status() + result = r.json() + file_path = result[0] return {"path": file_path} - def _download_file(self, x: dict) -> str | None: - return utils.download_file( - self.root_url + "file=" + x["path"], - save_dir=self.client.output_dir, + def _download_file(self, x: dict) -> str: + url_path = self.root_url + "file=" + x["path"] + if self.client.output_dir is not None: + os.makedirs(self.client.output_dir, exist_ok=True) + + sha1 = hashlib.sha1() + temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20) + temp_dir.mkdir(exist_ok=True, parents=True) + + with httpx.stream( + "GET", + url_path, headers=self.client.headers, cookies=self.client.cookies, - ) + verify=self.client.ssl_verify, + follow_redirects=True, + ) as response: + response.raise_for_status() + with open(temp_dir / Path(url_path).name, "wb") as f: + for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size): + sha1.update(chunk) + f.write(chunk) + + directory = Path(self.client.output_dir) / sha1.hexdigest() + directory.mkdir(exist_ok=True, parents=True) + dest = directory / Path(url_path).name + shutil.move(temp_dir / Path(url_path).name, dest) + return str(dest.resolve()) async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): - async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client: + async with httpx.AsyncClient( + timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify + ) as client: return await utils.get_pred_from_sse_v0( client, data, @@ -1164,6 +1212,7 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator): self.client.sse_data_url, self.client.headers, self.client.cookies, + self.client.ssl_verify, ) async def _sse_fn_v1_v2( @@ -1179,6 +1228,7 @@ async def _sse_fn_v1_v2( self.client.pending_messages_per_event, event_id, protocol, + self.client.ssl_verify, ) diff --git a/client/python/gradio_client/compatibility.py b/client/python/gradio_client/compatibility.py index 143e4a997666..71d8dfa65049 100644 --- a/client/python/gradio_client/compatibility.py +++ b/client/python/gradio_client/compatibility.py @@ -95,7 +95,10 @@ def _predict(*data) -> tuple: raise ValueError(result["error"]) else: response = httpx.post( - self.client.api_url, headers=self.client.headers, json=data + self.client.api_url, + headers=self.client.headers, + json=data, + verify=self.client.ssl_verify, ) result = json.loads(response.content.decode("utf-8")) try: @@ -144,7 +147,12 @@ def _upload( for f in fs: files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115 indices.append(i) - r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files) + r = httpx.post( + self.client.upload_url, + headers=self.client.headers, + files=files, + verify=self.client.ssl_verify, + ) if r.status_code != 200: uploaded = file_paths else: diff --git a/client/python/gradio_client/utils.py b/client/python/gradio_client/utils.py index 2e5c07548e1e..eb207e595c68 100644 --- a/client/python/gradio_client/utils.py +++ b/client/python/gradio_client/utils.py @@ -3,7 +3,6 @@ import asyncio import base64 import copy -import hashlib import json import mimetypes import os @@ -353,10 +352,11 @@ async def get_pred_from_sse_v0( sse_data_url: str, headers: dict[str, str], cookies: dict[str, str] | None, + ssl_verify: bool, ) -> dict[str, Any] | None: done, pending = await asyncio.wait( [ - asyncio.create_task(check_for_cancel(helper, headers, cookies)), + asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)), asyncio.create_task( stream_sse_v0( client, @@ -393,10 +393,11 @@ async def get_pred_from_sse_v1_v2( pending_messages_per_event: dict[str, list[Message | None]], event_id: str, protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"], + ssl_verify: bool, ) -> dict[str, Any] | None: done, pending = await asyncio.wait( [ - asyncio.create_task(check_for_cancel(helper, headers, cookies)), + asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)), asyncio.create_task( stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol) ), @@ -421,7 +422,10 @@ async def get_pred_from_sse_v1_v2( async def check_for_cancel( - helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None + helper: Communicator, + headers: dict[str, str], + cookies: dict[str, str] | None, + ssl_verify: bool, ): while True: await asyncio.sleep(0.05) @@ -429,7 +433,7 @@ async def check_for_cancel( if helper.should_cancel: break if helper.event_id: - async with httpx.AsyncClient() as http: + async with httpx.AsyncClient(ssl_verify=ssl_verify) as http: await http.post( helper.reset_url, json={"event_id": helper.event_id}, @@ -625,49 +629,6 @@ def apply_edit(target, path, action, value): ######################## -def upload_file( - file_path: str, - upload_url: str, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, -): - with open(file_path, "rb") as f: - files = [("files", (Path(file_path).name, f))] - r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files) - r.raise_for_status() - result = r.json() - return result[0] - - -def download_file( - url_path: str, - save_dir: str, - headers: dict[str, str] | None = None, - cookies: dict[str, str] | None = None, -) -> str: - if save_dir is not None: - os.makedirs(save_dir, exist_ok=True) - - sha1 = hashlib.sha1() - temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20) - temp_dir.mkdir(exist_ok=True, parents=True) - - with httpx.stream( - "GET", url_path, headers=headers, cookies=cookies, follow_redirects=True - ) as response: - response.raise_for_status() - with open(temp_dir / Path(url_path).name, "wb") as f: - for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size): - sha1.update(chunk) - f.write(chunk) - - directory = Path(save_dir) / sha1.hexdigest() - directory.mkdir(exist_ok=True, parents=True) - dest = directory / Path(url_path).name - shutil.move(temp_dir / Path(url_path).name, dest) - return str(dest.resolve()) - - def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str: directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20) directory.mkdir(exist_ok=True, parents=True) diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 7cb8dcfa26c8..e1b0fad97b7f 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import gradio as gr +import httpx import huggingface_hub import pytest import uvicorn @@ -1171,6 +1172,27 @@ def test_upload(self): "file7", ] + @pytest.mark.flaky + def test_download_private_file(self, gradio_temp_dir): + client = Client( + src="gradio/zip_files", + ) + url_path = "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg" + file = client.endpoints[0]._upload_file(url_path) # type: ignore + assert file["path"].endswith(".jpg") + + @pytest.mark.flaky + def test_download_tmp_copy_of_file_does_not_save_errors( + self, monkeypatch, gradio_temp_dir + ): + client = Client( + src="gradio/zip_files", + ) + error_response = httpx.Response(status_code=404) + monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response) + with pytest.raises(httpx.HTTPStatusError): + client.endpoints[0]._download_file({"path": "https://example.com/foo"}) # type: ignore + cpu = huggingface_hub.SpaceHardware.CPU_BASIC diff --git a/client/python/test/test_utils.py b/client/python/test/test_utils.py index b8bca5732f53..529371636ddb 100644 --- a/client/python/test/test_utils.py +++ b/client/python/test/test_utils.py @@ -71,26 +71,6 @@ def test_decode_base64_to_file(): assert isinstance(temp_file, tempfile._TemporaryFileWrapper) -@pytest.mark.flaky -def test_download_private_file(gradio_temp_dir): - url_path = ( - "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg" - ) - file = utils.download_file( - url_path=url_path, - headers={"Authorization": f"Bearer {HF_TOKEN}"}, - save_dir=str(gradio_temp_dir), - ) - assert Path(file).name.endswith(".jpg") - - -def test_download_tmp_copy_of_file_does_not_save_errors(monkeypatch, gradio_temp_dir): - error_response = httpx.Response(status_code=404) - monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response) - with pytest.raises(httpx.HTTPStatusError): - utils.download_file("https://example.com/foo", save_dir=str(gradio_temp_dir)) - - @pytest.mark.parametrize( "orig_filename, new_filename", [