diff --git a/README.md b/README.md index f27e449e..82577356 100644 --- a/README.md +++ b/README.md @@ -341,6 +341,7 @@ import httpx from lithic import Lithic client = Lithic( + # Or use the `LITHIC_BASE_URL` env var base_url="http://my.test.server.example.com:8083", http_client=httpx.Client( proxies="http://my.test.proxy.example.com", diff --git a/src/lithic/_client.py b/src/lithic/_client.py index f31583a8..7daca9a0 100644 --- a/src/lithic/_client.py +++ b/src/lithic/_client.py @@ -4,7 +4,7 @@ import os import asyncio -from typing import Dict, Union, Mapping +from typing import Dict, Union, Mapping, cast from typing_extensions import Literal, override import httpx @@ -87,15 +87,15 @@ class Lithic(SyncAPIClient): api_key: str webhook_secret: str | None - _environment: Literal["production", "sandbox"] + _environment: Literal["production", "sandbox"] | NotGiven def __init__( self, *, api_key: str | None = None, webhook_secret: str | None = None, - environment: Literal["production", "sandbox"] = "production", - base_url: str | httpx.URL | None = None, + environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN, + base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, @@ -138,7 +138,25 @@ def __init__( self._environment = environment - if base_url is None: + base_url_env = os.environ.get("LITHIC_BASE_URL") + if is_given(base_url) and base_url is not None: + # cast required because mypy doesn't understand the type narrowing + base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast] + elif is_given(environment): + if base_url_env and base_url is not None: + raise ValueError( + "Ambiguous URL; The `LITHIC_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None", + ) + + try: + base_url = ENVIRONMENTS[environment] + except KeyError as exc: + raise ValueError(f"Unknown environment: {environment}") from exc + elif base_url_env is not None: + base_url = base_url_env + else: + self._environment = environment = "production" + try: base_url = ENVIRONMENTS[environment] except KeyError as exc: @@ -371,15 +389,15 @@ class AsyncLithic(AsyncAPIClient): api_key: str webhook_secret: str | None - _environment: Literal["production", "sandbox"] + _environment: Literal["production", "sandbox"] | NotGiven def __init__( self, *, api_key: str | None = None, webhook_secret: str | None = None, - environment: Literal["production", "sandbox"] = "production", - base_url: str | httpx.URL | None = None, + environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN, + base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Mapping[str, str] | None = None, @@ -422,7 +440,25 @@ def __init__( self._environment = environment - if base_url is None: + base_url_env = os.environ.get("LITHIC_BASE_URL") + if is_given(base_url) and base_url is not None: + # cast required because mypy doesn't understand the type narrowing + base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast] + elif is_given(environment): + if base_url_env and base_url is not None: + raise ValueError( + "Ambiguous URL; The `LITHIC_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None", + ) + + try: + base_url = ENVIRONMENTS[environment] + except KeyError as exc: + raise ValueError(f"Unknown environment: {environment}") from exc + elif base_url_env is not None: + base_url = base_url_env + else: + self._environment = environment = "production" + try: base_url = ENVIRONMENTS[environment] except KeyError as exc: diff --git a/tests/test_client.py b/tests/test_client.py index e96a1067..7d0d0996 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -25,6 +25,8 @@ make_request_options, ) +from .utils import update_env + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") api_key = "My Lithic API Key" @@ -429,6 +431,19 @@ def test_idempotency_header_options(self, respx_mock: MockRouter) -> None: ) assert response.request.headers.get("Idempotency-Key") == "custom-key" + def test_base_url_env(self) -> None: + with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"): + client = Lithic(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + # explicit environment arg requires explicitness + with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"): + with pytest.raises(ValueError, match=r"you must pass base_url=None"): + Lithic(api_key=api_key, _strict_response_validation=True, environment="production") + + client = Lithic(base_url=None, api_key=api_key, _strict_response_validation=True, environment="production") + assert str(client.base_url).startswith("https://api.lithic.com/v1") + @pytest.mark.parametrize( "client", [ @@ -1070,6 +1085,21 @@ async def test_idempotency_header_options(self, respx_mock: MockRouter) -> None: ) assert response.request.headers.get("Idempotency-Key") == "custom-key" + def test_base_url_env(self) -> None: + with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"): + client = AsyncLithic(api_key=api_key, _strict_response_validation=True) + assert client.base_url == "http://localhost:5000/from/env/" + + # explicit environment arg requires explicitness + with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"): + with pytest.raises(ValueError, match=r"you must pass base_url=None"): + AsyncLithic(api_key=api_key, _strict_response_validation=True, environment="production") + + client = AsyncLithic( + base_url=None, api_key=api_key, _strict_response_validation=True, environment="production" + ) + assert str(client.base_url).startswith("https://api.lithic.com/v1") + @pytest.mark.parametrize( "client", [ diff --git a/tests/utils.py b/tests/utils.py index 5fd784e4..c8c49755 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,7 +1,9 @@ from __future__ import annotations +import os import traceback -from typing import Any, TypeVar, cast +import contextlib +from typing import Any, TypeVar, Iterator, cast from datetime import date, datetime from typing_extensions import Literal, get_args, get_origin, assert_type @@ -103,3 +105,16 @@ def _assert_list_type(type_: type[object], value: object) -> None: inner_type = get_args(type_)[0] for entry in value: assert_type(inner_type, entry) # type: ignore + + +@contextlib.contextmanager +def update_env(**new_env: str) -> Iterator[None]: + old = os.environ.copy() + + try: + os.environ.update(new_env) + + yield None + finally: + os.environ.clear() + os.environ.update(old)