Skip to content

Commit

Permalink
Allow setting custom headers in Python Client (#7334)
Browse files Browse the repository at this point in the history
* add headers

* add changeset

* fix

* test

---------

Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
abidlabs and gradio-pr-bot committed Feb 7, 2024
1 parent 5b45a16 commit b95d0d0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .changeset/twelve-crabs-refuse.md
@@ -0,0 +1,6 @@
---
"gradio": minor
"gradio_client": minor
---

feat:Allow setting custom headers in Python Client
5 changes: 5 additions & 0 deletions client/python/gradio_client/client.py
Expand Up @@ -75,6 +75,8 @@ def __init__(
output_dir: str | Path = DEFAULT_TEMP_DIR,
verbose: bool = True,
auth: tuple[str, str] | None = None,
*,
headers: dict[str, str] | None = None,
):
"""
Parameters:
Expand All @@ -84,6 +86,7 @@ def __init__(
serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath.
output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine.
verbose: Whether the client should print statements to the console.
headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys.
"""
self.verbose = verbose
self.hf_token = hf_token
Expand All @@ -93,6 +96,8 @@ def __init__(
library_name="gradio_client",
library_version=utils.__version__,
)
if headers:
self.headers.update(headers)
self.space_id = None
self.cookies: dict[str, str] = {}
self.output_dir = (
Expand Down
21 changes: 21 additions & 0 deletions client/python/test/test_client.py
Expand Up @@ -51,6 +51,27 @@ def connect(
demo.server.thread.join(timeout=1)


class TestClientInitialization:
def test_headers_constructed_correctly(self):
client = Client("gradio-tests/titanic-survival", hf_token=HF_TOKEN)
assert {"authorization": f"Bearer {HF_TOKEN}"}.items() <= client.headers.items()
client = Client(
"gradio-tests/titanic-survival",
hf_token=HF_TOKEN,
headers={"additional": "value"},
)
assert {
"authorization": f"Bearer {HF_TOKEN}",
"additional": "value",
}.items() <= client.headers.items()
client = Client(
"gradio-tests/titanic-survival",
hf_token=HF_TOKEN,
headers={"authorization": "Bearer abcde"},
)
assert {"authorization": "Bearer abcde"}.items() <= client.headers.items()


class TestClientPredictions:
@pytest.mark.flaky
def test_raise_error_invalid_state(self):
Expand Down

0 comments on commit b95d0d0

Please sign in to comment.