diff --git a/src/openai/azure/__init__.py b/src/openai/azure/__init__.py new file mode 100644 index 0000000000..296a7da84d --- /dev/null +++ b/src/openai/azure/__init__.py @@ -0,0 +1,7 @@ +from ._client import AzureOpenAIClient, AsyncAzureOpenAIClient + + +__all__ = [ + "AzureOpenAIClient", + "AsyncAzureOpenAIClient", +] diff --git a/src/openai/azure/_client.py b/src/openai/azure/_client.py new file mode 100644 index 0000000000..7f49ec81b2 --- /dev/null +++ b/src/openai/azure/_client.py @@ -0,0 +1,47 @@ +import httpx +from typing import Any, Optional, Dict + +from openai import Client, AsyncClient +from ._credential import TokenAuth + + +class AzureOpenAIClient(Client): + + def __init__(self, *args: Any, base_url: str, credential: Optional["TokenCredential"] = None, api_version: str = '2023-09-01-preview', **kwargs: Any): + default_query = kwargs.get('default_query', {}) + default_query.setdefault('api-version', api_version) + kwargs['default_query'] = default_query + self.credential = credential + if credential: + kwargs['api_key'] = 'Placeholder: AAD' + super().__init__(*args, base_url=base_url, **kwargs) + + @property + def auth_headers(self) -> Dict[str, str]: + return {"api-key": self.api_key} + + @property + def custom_auth(self) -> Optional[httpx.Auth]: + if self.credential: + return TokenAuth(self.credential) + + +class AsyncAzureOpenAIClient(AsyncClient): + + def __init__(self, *args: Any, credential: Optional["TokenCredential"] = None, api_version: str = '2023-09-01-preview', **kwargs: Any): + default_query = kwargs.get('default_query', {}) + default_query.setdefault('api-version', api_version) + kwargs['default_query'] = default_query + self.credential = credential + if credential: + kwargs['api_key'] = 'Placeholder: AAD' + super().__init__(*args, **kwargs) + + @property + def auth_headers(self) -> Dict[str, str]: + return {"api-key": self.api_key} + + @property + def custom_auth(self) -> httpx.Auth | None: + if self.credential: + return TokenAuth(self.credential) diff --git a/src/openai/azure/_credential.py b/src/openai/azure/_credential.py new file mode 100644 index 0000000000..fbf3647c17 --- /dev/null +++ b/src/openai/azure/_credential.py @@ -0,0 +1,32 @@ +import time +import asyncio +import httpx +from typing import Any, Generator, AsyncGenerator + + +class TokenAuth(httpx.Auth): + def __init__(self, credential: "TokenCredential") -> None: + self._credential = credential + self._async_lock = asyncio.Lock() + self.cached_token = None + + def sync_get_token(self) -> str: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return self._credential.get_token("https://cognitiveservices.azure.com/.default").token + return self.cached_token.token + + def sync_auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, Any, Any]: + token = self.sync_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request + + async def async_get_token(self) -> str: + async with self._async_lock: + if not self.cached_token or self.cached_token.expires_on - time.time() < 300: + return (await self._credential.get_token("https://cognitiveservices.azure.com/.default")).token + return self.cached_token.token + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, Any]: + token = await self.async_get_token() + request.headers["Authorization"] = f"Bearer {token}" + yield request