Skip to content

Commit

Permalink
Add support for python client connecting to gradio apps running with …
Browse files Browse the repository at this point in the history
…self-signed SSL certificates (#7718)

* verify

* add changeset

* docstring

* add changeset

* test fixes

* add remaining

* test fixes

* changes

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Mar 15, 2024
1 parent 188b86b commit 6390d0b
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 87 deletions.
6 changes: 6 additions & 0 deletions .changeset/kind-eyes-shake.md
@@ -0,0 +1,6 @@
---
"gradio": patch
"gradio_client": patch
---

fix:Add support for python client connecting to gradio apps running with self-signed SSL certificates
84 changes: 67 additions & 17 deletions client/python/gradio_client/client.py
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

import concurrent.futures
import hashlib
import json
import os
import re
import secrets
import shutil
import tempfile
import threading
import time
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
upload_files: bool = True, # TODO: remove and hardcode to False in 1.0
download_files: bool = True, # TODO: consider setting to False in 1.0
_skip_components: bool = True, # internal parameter to skip values certain components (e.g. State) that do not need to be displayed to users.
ssl_verify: bool = True,
):
"""
Parameters:
Expand All @@ -93,6 +96,7 @@ def __init__(
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
upload_files: Whether the client should treat input string filepath as files and upload them to the remote server. If False, the client will treat input string filepaths as strings always and not modify them, and files should be passed in explicitly using `gradio_client.file("path/to/file/or/url")` instead. This parameter will be deleted and False will become the default in a future version.
download_files: Whether the client should download output files from the remote API and return them as string filepaths on the local machine. If False, the client will return a FileData dataclass object with the filepath on the remote machine instead.
ssl_verify: If False, skips certificate validation which allows the client to connect to Gradio apps that are using self-signed certificates.
"""
self.verbose = verbose
self.hf_token = hf_token
Expand All @@ -111,6 +115,7 @@ def __init__(
)
if headers:
self.headers.update(headers)
self.ssl_verify = ssl_verify
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (
Expand Down Expand Up @@ -187,7 +192,9 @@ def __init__(

async def stream_messages(self) -> None:
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=None), verify=self.ssl_verify
) as client:
async with client.stream(
"GET",
self.sse_url,
Expand Down Expand Up @@ -227,7 +234,7 @@ async def stream_messages(self) -> None:
raise e

async def send_data(self, data, hash_data):
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(verify=self.ssl_verify) as client:
req = await client.post(
self.sse_data_url,
json={**data, **hash_data},
Expand Down Expand Up @@ -484,7 +491,12 @@ def _get_api_info(self):
else:
api_info_url = urllib.parse.urljoin(self.src, utils.RAW_API_INFO_URL)
if self.app_version > version.Version("3.36.1"):
r = httpx.get(api_info_url, headers=self.headers, cookies=self.cookies)
r = httpx.get(
api_info_url,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if r.is_success:
info = r.json()
else:
Expand Down Expand Up @@ -735,6 +747,7 @@ def _login(self, auth: tuple[str, str]):
resp = httpx.post(
urllib.parse.urljoin(self.src, utils.LOGIN_URL),
data={"username": auth[0], "password": auth[1]},
verify=self.ssl_verify,
)
if not resp.is_success:
if resp.status_code == 401:
Expand All @@ -752,6 +765,7 @@ def _get_config(self) -> dict:
urllib.parse.urljoin(self.src, utils.CONFIG_URL),
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if r.is_success:
return r.json()
Expand All @@ -760,7 +774,12 @@ def _get_config(self) -> dict:
f"Could not load {self.src} as credentials were not provided. Please login."
)
else: # to support older versions of Gradio
r = httpx.get(self.src, headers=self.headers, cookies=self.cookies)
r = httpx.get(
self.src,
headers=self.headers,
cookies=self.cookies,
verify=self.ssl_verify,
)
if not r.is_success:
raise ValueError(f"Could not fetch config for {self.src}")
# some basic regex to extract the config
Expand Down Expand Up @@ -1126,7 +1145,7 @@ def reduce_singleton_output(self, *data) -> Any:
else:
return data

def _upload_file(self, f: str | dict):
def _upload_file(self, f: str | dict) -> dict[str, str]:
if isinstance(f, str):
warnings.warn(
f'The Client is treating: "{f}" as a file path. In future versions, this behavior will not happen automatically. '
Expand All @@ -1137,24 +1156,53 @@ def _upload_file(self, f: str | dict):
else:
file_path = f["path"]
if not utils.is_http_url_like(file_path):
file_path = utils.upload_file(
file_path=file_path,
upload_url=self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
)
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
cookies=self.client.cookies,
verify=self.client.ssl_verify,
files=files,
)
r.raise_for_status()
result = r.json()
file_path = result[0]
return {"path": file_path}

def _download_file(self, x: dict) -> str | None:
return utils.download_file(
self.root_url + "file=" + x["path"],
save_dir=self.client.output_dir,
def _download_file(self, x: dict) -> str:
url_path = self.root_url + "file=" + x["path"]
if self.client.output_dir is not None:
os.makedirs(self.client.output_dir, exist_ok=True)

sha1 = hashlib.sha1()
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)

with httpx.stream(
"GET",
url_path,
headers=self.client.headers,
cookies=self.client.cookies,
)
verify=self.client.ssl_verify,
follow_redirects=True,
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
sha1.update(chunk)
f.write(chunk)

directory = Path(self.client.output_dir) / sha1.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(url_path).name
shutil.move(temp_dir / Path(url_path).name, dest)
return str(dest.resolve())

async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout=None)) as client:
async with httpx.AsyncClient(
timeout=httpx.Timeout(timeout=None), verify=self.client.ssl_verify
) as client:
return await utils.get_pred_from_sse_v0(
client,
data,
Expand All @@ -1164,6 +1212,7 @@ async def _sse_fn_v0(self, data: dict, hash_data: dict, helper: Communicator):
self.client.sse_data_url,
self.client.headers,
self.client.cookies,
self.client.ssl_verify,
)

async def _sse_fn_v1_v2(
Expand All @@ -1179,6 +1228,7 @@ async def _sse_fn_v1_v2(
self.client.pending_messages_per_event,
event_id,
protocol,
self.client.ssl_verify,
)


Expand Down
12 changes: 10 additions & 2 deletions client/python/gradio_client/compatibility.py
Expand Up @@ -95,7 +95,10 @@ def _predict(*data) -> tuple:
raise ValueError(result["error"])
else:
response = httpx.post(
self.client.api_url, headers=self.client.headers, json=data
self.client.api_url,
headers=self.client.headers,
json=data,
verify=self.client.ssl_verify,
)
result = json.loads(response.content.decode("utf-8"))
try:
Expand Down Expand Up @@ -144,7 +147,12 @@ def _upload(
for f in fs:
files.append(("files", (Path(f).name, open(f, "rb")))) # noqa: SIM115
indices.append(i)
r = httpx.post(self.client.upload_url, headers=self.client.headers, files=files)
r = httpx.post(
self.client.upload_url,
headers=self.client.headers,
files=files,
verify=self.client.ssl_verify,
)
if r.status_code != 200:
uploaded = file_paths
else:
Expand Down
57 changes: 9 additions & 48 deletions client/python/gradio_client/utils.py
Expand Up @@ -3,7 +3,6 @@
import asyncio
import base64
import copy
import hashlib
import json
import mimetypes
import os
Expand Down Expand Up @@ -353,10 +352,11 @@ async def get_pred_from_sse_v0(
sse_data_url: str,
headers: dict[str, str],
cookies: dict[str, str] | None,
ssl_verify: bool,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v0(
client,
Expand Down Expand Up @@ -393,10 +393,11 @@ async def get_pred_from_sse_v1_v2(
pending_messages_per_event: dict[str, list[Message | None]],
event_id: str,
protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"],
ssl_verify: bool,
) -> dict[str, Any] | None:
done, pending = await asyncio.wait(
[
asyncio.create_task(check_for_cancel(helper, headers, cookies)),
asyncio.create_task(check_for_cancel(helper, headers, cookies, ssl_verify)),
asyncio.create_task(
stream_sse_v1_v2(helper, pending_messages_per_event, event_id, protocol)
),
Expand All @@ -421,15 +422,18 @@ async def get_pred_from_sse_v1_v2(


async def check_for_cancel(
helper: Communicator, headers: dict[str, str], cookies: dict[str, str] | None
helper: Communicator,
headers: dict[str, str],
cookies: dict[str, str] | None,
ssl_verify: bool,
):
while True:
await asyncio.sleep(0.05)
with helper.lock:
if helper.should_cancel:
break
if helper.event_id:
async with httpx.AsyncClient() as http:
async with httpx.AsyncClient(ssl_verify=ssl_verify) as http:
await http.post(
helper.reset_url,
json={"event_id": helper.event_id},
Expand Down Expand Up @@ -625,49 +629,6 @@ def apply_edit(target, path, action, value):
########################


def upload_file(
file_path: str,
upload_url: str,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
):
with open(file_path, "rb") as f:
files = [("files", (Path(file_path).name, f))]
r = httpx.post(upload_url, headers=headers, cookies=cookies, files=files)
r.raise_for_status()
result = r.json()
return result[0]


def download_file(
url_path: str,
save_dir: str,
headers: dict[str, str] | None = None,
cookies: dict[str, str] | None = None,
) -> str:
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)

sha1 = hashlib.sha1()
temp_dir = Path(tempfile.gettempdir()) / secrets.token_hex(20)
temp_dir.mkdir(exist_ok=True, parents=True)

with httpx.stream(
"GET", url_path, headers=headers, cookies=cookies, follow_redirects=True
) as response:
response.raise_for_status()
with open(temp_dir / Path(url_path).name, "wb") as f:
for chunk in response.iter_bytes(chunk_size=128 * sha1.block_size):
sha1.update(chunk)
f.write(chunk)

directory = Path(save_dir) / sha1.hexdigest()
directory.mkdir(exist_ok=True, parents=True)
dest = directory / Path(url_path).name
shutil.move(temp_dir / Path(url_path).name, dest)
return str(dest.resolve())


def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str:
directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20)
directory.mkdir(exist_ok=True, parents=True)
Expand Down
22 changes: 22 additions & 0 deletions client/python/test/test_client.py
Expand Up @@ -11,6 +11,7 @@
from unittest.mock import MagicMock, patch

import gradio as gr
import httpx
import huggingface_hub
import pytest
import uvicorn
Expand Down Expand Up @@ -1171,6 +1172,27 @@ def test_upload(self):
"file7",
]

@pytest.mark.flaky
def test_download_private_file(self, gradio_temp_dir):
client = Client(
src="gradio/zip_files",
)
url_path = "https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
file = client.endpoints[0]._upload_file(url_path) # type: ignore
assert file["path"].endswith(".jpg")

@pytest.mark.flaky
def test_download_tmp_copy_of_file_does_not_save_errors(
self, monkeypatch, gradio_temp_dir
):
client = Client(
src="gradio/zip_files",
)
error_response = httpx.Response(status_code=404)
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
with pytest.raises(httpx.HTTPStatusError):
client.endpoints[0]._download_file({"path": "https://example.com/foo"}) # type: ignore


cpu = huggingface_hub.SpaceHardware.CPU_BASIC

Expand Down
20 changes: 0 additions & 20 deletions client/python/test/test_utils.py
Expand Up @@ -71,26 +71,6 @@ def test_decode_base64_to_file():
assert isinstance(temp_file, tempfile._TemporaryFileWrapper)


@pytest.mark.flaky
def test_download_private_file(gradio_temp_dir):
url_path = (
"https://gradio-tests-not-actually-private-spacev4-sse.hf.space/file=lion.jpg"
)
file = utils.download_file(
url_path=url_path,
headers={"Authorization": f"Bearer {HF_TOKEN}"},
save_dir=str(gradio_temp_dir),
)
assert Path(file).name.endswith(".jpg")


def test_download_tmp_copy_of_file_does_not_save_errors(monkeypatch, gradio_temp_dir):
error_response = httpx.Response(status_code=404)
monkeypatch.setattr(httpx, "get", lambda *args, **kwargs: error_response)
with pytest.raises(httpx.HTTPStatusError):
utils.download_file("https://example.com/foo", save_dir=str(gradio_temp_dir))


@pytest.mark.parametrize(
"orig_filename, new_filename",
[
Expand Down

0 comments on commit 6390d0b

Please sign in to comment.