From ae5679b9cdf245346bf73bc243e1ddb84e43e332 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 3 Jun 2026 02:08:30 +0200 Subject: [PATCH] fix: support aad token api keys for azure --- src/openai/lib/azure.py | 59 +++++++++++++++++------ tests/lib/test_azure.py | 101 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 14 deletions(-) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 4fcae24788..1b366a3eb5 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -43,6 +43,15 @@ API_KEY_SENTINEL = "".join(["<", "missing API key", ">"]) +def _azure_ad_token_from_api_key(value: str) -> str | None: + token = value.strip().removeprefix("Bearer ").strip() + parts = token.split(".") + if len(parts) == 3 and parts[0].startswith("eyJ"): + return token + + return None + + def _has_header(headers: Headers, header: str) -> bool: header = header.lower() return any(key.lower() == header for key in headers) @@ -352,6 +361,10 @@ def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: # noqa: A return {"Authorization": f"Bearer {self._azure_ad_token}"} if self.api_key and self.api_key != API_KEY_SENTINEL: + azure_ad_token = _azure_ad_token_from_api_key(self.api_key) + if azure_ad_token is not None: + return {"Authorization": f"Bearer {azure_ad_token}"} + return {"api-key": self.api_key} return {} @@ -377,7 +390,11 @@ def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions: if not _has_header(headers, "Authorization"): headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key and self.api_key != API_KEY_SENTINEL: - if not _has_header(headers, "api-key"): + api_key_ad_token = _azure_ad_token_from_api_key(self.api_key) + if api_key_ad_token is not None: + if not _has_auth_header(headers): + headers["Authorization"] = f"Bearer {api_key_ad_token}" + elif not _has_header(headers, "api-key"): headers["api-key"] = self.api_key elif _has_auth_header(headers) or _has_auth_header(self.default_headers): pass @@ -394,12 +411,15 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key and self.api_key != "": - auth_headers = {"api-key": self.api_key} - else: - token = self._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} + token = self._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + elif self.api_key and self.api_key != API_KEY_SENTINEL: + api_key_ad_token = _azure_ad_token_from_api_key(self.api_key) + if api_key_ad_token is not None: + auth_headers = {"Authorization": f"Bearer {api_key_ad_token}"} + else: + auth_headers = {"api-key": self.api_key} if self.websocket_base_url is not None: base_url = httpx.URL(self.websocket_base_url) @@ -674,6 +694,10 @@ def _auth_headers(self, security: SecurityOptions) -> dict[str, str]: # noqa: A return {"Authorization": f"Bearer {self._azure_ad_token}"} if self.api_key and self.api_key != API_KEY_SENTINEL: + azure_ad_token = _azure_ad_token_from_api_key(self.api_key) + if azure_ad_token is not None: + return {"Authorization": f"Bearer {azure_ad_token}"} + return {"api-key": self.api_key} return {} @@ -699,7 +723,11 @@ async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOp if not _has_header(headers, "Authorization"): headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key and self.api_key != API_KEY_SENTINEL: - if not _has_header(headers, "api-key"): + api_key_ad_token = _azure_ad_token_from_api_key(self.api_key) + if api_key_ad_token is not None: + if not _has_auth_header(headers): + headers["Authorization"] = f"Bearer {api_key_ad_token}" + elif not _has_header(headers, "api-key"): headers["api-key"] = self.api_key elif _has_auth_header(headers) or _has_auth_header(self.default_headers): pass @@ -716,12 +744,15 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt "api-version": self._api_version, "deployment": self._azure_deployment or model, } - if self.api_key and self.api_key != "": - auth_headers = {"api-key": self.api_key} - else: - token = await self._get_azure_ad_token() - if token: - auth_headers = {"Authorization": f"Bearer {token}"} + token = await self._get_azure_ad_token() + if token: + auth_headers = {"Authorization": f"Bearer {token}"} + elif self.api_key and self.api_key != API_KEY_SENTINEL: + api_key_ad_token = _azure_ad_token_from_api_key(self.api_key) + if api_key_ad_token is not None: + auth_headers = {"Authorization": f"Bearer {api_key_ad_token}"} + else: + auth_headers = {"api-key": self.api_key} if self.websocket_base_url is not None: base_url = httpx.URL(self.websocket_base_url) diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 3e1d783e2c..5e361d6156 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -30,6 +30,8 @@ azure_endpoint="https://example-resource.azure.openai.com", ) +AZURE_AD_JWT_API_KEY = "eyJhbGciOiJSUzI1NiJ9.eyJhdWQiOiJodHRwczovL2NvZ25pdGl2ZXNlcnZpY2VzLmF6dXJlLmNvbS8ifQ.signature" + class MockRequestCall(Protocol): request: httpx.Request @@ -305,6 +307,105 @@ def token_provider() -> str: assert calls[1].request.headers.get("Authorization") == "Bearer second" +@pytest.mark.respx() +def test_sync_jwt_like_api_key_uses_authorization_header(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"})) + + client = AzureOpenAI( + api_version="2024-02-01", + api_key=f"Bearer {AZURE_AD_JWT_API_KEY}", + azure_endpoint="https://example-resource.azure.openai.com", + ) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert calls[0].request.headers.get("Authorization") == f"Bearer {AZURE_AD_JWT_API_KEY}" + assert calls[0].request.headers.get("api-key") is None + + +@pytest.mark.asyncio +@pytest.mark.respx() +async def test_async_jwt_like_api_key_uses_authorization_header(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"})) + + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key=AZURE_AD_JWT_API_KEY, + azure_endpoint="https://example-resource.azure.openai.com", + ) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert calls[0].request.headers.get("Authorization") == f"Bearer {AZURE_AD_JWT_API_KEY}" + assert calls[0].request.headers.get("api-key") is None + + +@pytest.mark.respx() +def test_sync_regular_api_key_uses_api_key_header(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"})) + + client = AzureOpenAI( + api_version="2024-02-01", + api_key="regular-api-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert calls[0].request.headers.get("api-key") == "regular-api-key" + assert calls[0].request.headers.get("Authorization") is None + + +@pytest.mark.asyncio +@pytest.mark.respx() +async def test_async_regular_api_key_uses_api_key_header(respx_mock: MockRouter) -> None: + respx_mock.post( + "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01" + ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"})) + + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key="regular-api-key", + azure_endpoint="https://example-resource.azure.openai.com", + ) + await client.chat.completions.create(messages=[], model="gpt-4") + + calls = cast("list[MockRequestCall]", respx_mock.calls) + assert calls[0].request.headers.get("api-key") == "regular-api-key" + assert calls[0].request.headers.get("Authorization") is None + + +def test_sync_jwt_like_api_key_configures_realtime_authorization_header() -> None: + client = AzureOpenAI( + api_version="2024-02-01", + api_key=AZURE_AD_JWT_API_KEY, + azure_endpoint="https://example-resource.azure.openai.com", + ) + + _, headers = client._configure_realtime("gpt-4o-realtime", {}) + + assert headers == {"Authorization": f"Bearer {AZURE_AD_JWT_API_KEY}"} + + +@pytest.mark.asyncio +async def test_async_jwt_like_api_key_configures_realtime_authorization_header() -> None: + client = AsyncAzureOpenAI( + api_version="2024-02-01", + api_key=AZURE_AD_JWT_API_KEY, + azure_endpoint="https://example-resource.azure.openai.com", + ) + + _, headers = await client._configure_realtime("gpt-4o-realtime", {}) + + assert headers == {"Authorization": f"Bearer {AZURE_AD_JWT_API_KEY}"} + + class TestAzureLogging: @pytest.fixture(autouse=True) def logger_with_filter(self) -> logging.Logger: