Skip to content

Commit

Permalink
feat(client): add header OpenAI-Project (#1320)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Apr 16, 2024
1 parent 1a83130 commit 3408b5d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@

organization: str | None = None

project: str | None = None

base_url: str | _httpx.URL | None = None

timeout: float | Timeout | None = DEFAULT_TIMEOUT
Expand Down Expand Up @@ -159,6 +161,17 @@ def organization(self, value: str | None) -> None: # type: ignore

organization = value

@property # type: ignore
@override
def project(self) -> str | None:
return project

@project.setter # type: ignore
def project(self, value: str | None) -> None: # type: ignore
global project

project = value

@property
@override
def base_url(self) -> _httpx.URL:
Expand Down Expand Up @@ -310,6 +323,7 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction]
_client = _ModuleClient(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
Expand Down
20 changes: 20 additions & 0 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ class OpenAI(SyncAPIClient):
# client options
api_key: str
organization: str | None
project: str | None

def __init__(
self,
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -94,6 +96,7 @@ def __init__(
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
"""
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -107,6 +110,10 @@ def __init__(
organization = os.environ.get("OPENAI_ORG_ID")
self.organization = organization

if project is None:
project = os.environ.get("OPENAI_PROJECT_ID")
self.project = project

if base_url is None:
base_url = os.environ.get("OPENAI_BASE_URL")
if base_url is None:
Expand Down Expand Up @@ -157,6 +164,7 @@ def default_headers(self) -> dict[str, str | Omit]:
**super().default_headers,
"X-Stainless-Async": "false",
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
"OpenAI-Project": self.project if self.project is not None else Omit(),
**self._custom_headers,
}

Expand All @@ -165,6 +173,7 @@ def copy(
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
Expand Down Expand Up @@ -200,6 +209,7 @@ def copy(
return self.__class__(
api_key=api_key or self.api_key,
organization=organization or self.organization,
project=project or self.project,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
Expand Down Expand Up @@ -266,12 +276,14 @@ class AsyncOpenAI(AsyncAPIClient):
# client options
api_key: str
organization: str | None
project: str | None

def __init__(
self,
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -296,6 +308,7 @@ def __init__(
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
"""
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -309,6 +322,10 @@ def __init__(
organization = os.environ.get("OPENAI_ORG_ID")
self.organization = organization

if project is None:
project = os.environ.get("OPENAI_PROJECT_ID")
self.project = project

if base_url is None:
base_url = os.environ.get("OPENAI_BASE_URL")
if base_url is None:
Expand Down Expand Up @@ -359,6 +376,7 @@ def default_headers(self) -> dict[str, str | Omit]:
**super().default_headers,
"X-Stainless-Async": f"async:{get_async_library()}",
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
"OpenAI-Project": self.project if self.project is not None else Omit(),
**self._custom_headers,
}

Expand All @@ -367,6 +385,7 @@ def copy(
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
Expand Down Expand Up @@ -402,6 +421,7 @@ def copy(
return self.__class__(
api_key=api_key or self.api_key,
organization=organization or self.organization,
project=project or self.project,
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
Expand Down
13 changes: 13 additions & 0 deletions src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -143,6 +144,7 @@ def __init__(
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Expand Down Expand Up @@ -205,6 +207,7 @@ def __init__(
super().__init__(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
Expand All @@ -223,6 +226,7 @@ def copy(
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
Expand All @@ -242,6 +246,7 @@ def copy(
return super().copy(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
http_client=http_client,
Expand Down Expand Up @@ -306,6 +311,7 @@ def __init__(
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand All @@ -325,6 +331,7 @@ def __init__(
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand All @@ -344,6 +351,7 @@ def __init__(
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand All @@ -363,6 +371,7 @@ def __init__(
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
Expand All @@ -376,6 +385,7 @@ def __init__(
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Expand Down Expand Up @@ -438,6 +448,7 @@ def __init__(
super().__init__(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
Expand All @@ -456,6 +467,7 @@ def copy(
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
Expand All @@ -475,6 +487,7 @@ def copy(
return super().copy(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
http_client=http_client,
Expand Down
1 change: 1 addition & 0 deletions tests/test_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def reset_state() -> None:
openai._reset_client()
openai.api_key = None or "My API Key"
openai.organization = None
openai.project = None
openai.base_url = None
openai.timeout = DEFAULT_TIMEOUT
openai.max_retries = DEFAULT_MAX_RETRIES
Expand Down

0 comments on commit 3408b5d

Please sign in to comment.