diff --git a/src/openai/azure/_async_client.py b/src/openai/azure/_async_client.py index 77708fa74f..8ec0cc6cae 100644 --- a/src/openai/azure/_async_client.py +++ b/src/openai/azure/_async_client.py @@ -18,7 +18,7 @@ from openai._streaming import AsyncStream # Azure specific types -from ._credential import TokenCredential +from ._credential import TokenCredential, TokenAuth from ._azuremodels import ChatExtensionConfiguration TIMEOUT_SECS = 600 @@ -374,10 +374,12 @@ def __init__(self, *args: Any, credential: Optional["TokenCredential"] = None, a @property def auth_headers(self) -> Dict[str, str]: - if self.credential: - return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + @property + def custom_auth(self) -> httpx.Auth | None: + if self.credential: + return TokenAuth(self.credential) def _check_polling_response(self, response: httpx.Response, predicate: Callable[[httpx.Response], bool]) -> bool: if not predicate(response): diff --git a/src/openai/azure/_credential.py b/src/openai/azure/_credential.py index 075a7116cd..9d10e14909 100644 --- a/src/openai/azure/_credential.py +++ b/src/openai/azure/_credential.py @@ -1,3 +1,9 @@ +from typing import AsyncGenerator, Generator, Any +import time +import asyncio +import httpx + + class TokenCredential: """Placeholder/example token credential class @@ -11,3 +17,30 @@ def __init__(self): def get_token(self): return self._credential.get_token('https://cognitiveservices.azure.com/.default').token + +class TokenAuth(httpx.Auth): + def __init__(self, credential: "TokenCredential") -> None: + self._credential = credential + self._async_lock = asyncio.Lock() + self.cached_token = None + + def sync_get_token(self) -> str: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return self._credential.get_token("https://cognitiveservices.azure.com/.default").token + return self.cached_token.token + + def sync_auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, Any]: + token = self.sync_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def async_get_token(self) -> str: + async with self._async_lock: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return (await self._credential.get_token("https://cognitiveservices.azure.com/.default")).token + return self.cached_token.token + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, Any]: + token = await self.async_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request diff --git a/src/openai/azure/_sync_client.py b/src/openai/azure/_sync_client.py index ba7faccf20..3f904b482f 100644 --- a/src/openai/azure/_sync_client.py +++ b/src/openai/azure/_sync_client.py @@ -19,7 +19,7 @@ from openai.types.chat.completion_create_params import FunctionCall, Function # Azure specific types -from ._credential import TokenCredential +from ._credential import TokenCredential, TokenAuth from ._azuremodels import ChatExtensionConfiguration TIMEOUT_SECS = 600 @@ -376,10 +376,13 @@ def __init__(self, *args: Any, base_url: str, credential: Optional["TokenCredent @property def auth_headers(self) -> Dict[str, str]: - if self.credential: - return { 'Authorization': f'Bearer {self.credential.get_token()}'} return {"api-key": self.api_key} + @property + def custom_auth(self) -> httpx.Auth | None: + if self.credential: + return TokenAuth(self.credential) + # NOTE: We override the internal method because `@overrid`ing `@overload`ed methods and keeping typing happy is a pain. Most typing tools are lacking... def _request(self, *, options: FinalRequestOptions, **kwargs: Any) -> Any: if options.url == "/images/generations":