From b95d0d043c739926af986e573200af92732bbc01 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 7 Feb 2024 07:19:53 -0800 Subject: [PATCH] Allow setting custom headers in Python Client (#7334) * add headers * add changeset * fix * test --------- Co-authored-by: gradio-pr-bot --- .changeset/twelve-crabs-refuse.md | 6 ++++++ client/python/gradio_client/client.py | 5 +++++ client/python/test/test_client.py | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+) create mode 100644 .changeset/twelve-crabs-refuse.md diff --git a/.changeset/twelve-crabs-refuse.md b/.changeset/twelve-crabs-refuse.md new file mode 100644 index 000000000000..2890b0966925 --- /dev/null +++ b/.changeset/twelve-crabs-refuse.md @@ -0,0 +1,6 @@ +--- +"gradio": minor +"gradio_client": minor +--- + +feat:Allow setting custom headers in Python Client diff --git a/client/python/gradio_client/client.py b/client/python/gradio_client/client.py index 5b4c30f32675..9abee7f2a2fa 100644 --- a/client/python/gradio_client/client.py +++ b/client/python/gradio_client/client.py @@ -75,6 +75,8 @@ def __init__( output_dir: str | Path = DEFAULT_TEMP_DIR, verbose: bool = True, auth: tuple[str, str] | None = None, + *, + headers: dict[str, str] | None = None, ): """ Parameters: @@ -84,6 +86,7 @@ def __init__( serialize: Whether the client should serialize the inputs and deserialize the outputs of the remote API. If set to False, the client will pass the inputs and outputs as-is, without serializing/deserializing them. E.g. you if you set this to False, you'd submit an image in base64 format instead of a filepath, and you'd get back an image in base64 format from the remote API instead of a filepath. output_dir: The directory to save files that are downloaded from the remote API. If None, reads from the GRADIO_TEMP_DIR environment variable. Defaults to a temporary directory on your machine. verbose: Whether the client should print statements to the console. + headers: Additional headers to send to the remote Gradio app on every request. By default only the HF authorization and user-agent headers are sent. These headers will override the default headers if they have the same keys. """ self.verbose = verbose self.hf_token = hf_token @@ -93,6 +96,8 @@ def __init__( library_name="gradio_client", library_version=utils.__version__, ) + if headers: + self.headers.update(headers) self.space_id = None self.cookies: dict[str, str] = {} self.output_dir = ( diff --git a/client/python/test/test_client.py b/client/python/test/test_client.py index 62eadf529868..d0017422c834 100644 --- a/client/python/test/test_client.py +++ b/client/python/test/test_client.py @@ -51,6 +51,27 @@ def connect( demo.server.thread.join(timeout=1) +class TestClientInitialization: + def test_headers_constructed_correctly(self): + client = Client("gradio-tests/titanic-survival", hf_token=HF_TOKEN) + assert {"authorization": f"Bearer {HF_TOKEN}"}.items() <= client.headers.items() + client = Client( + "gradio-tests/titanic-survival", + hf_token=HF_TOKEN, + headers={"additional": "value"}, + ) + assert { + "authorization": f"Bearer {HF_TOKEN}", + "additional": "value", + }.items() <= client.headers.items() + client = Client( + "gradio-tests/titanic-survival", + hf_token=HF_TOKEN, + headers={"authorization": "Bearer abcde"}, + ) + assert {"authorization": "Bearer abcde"}.items() <= client.headers.items() + + class TestClientPredictions: @pytest.mark.flaky def test_raise_error_invalid_state(self):