Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ import httpx
from lithic import Lithic

client = Lithic(
# Or use the `LITHIC_BASE_URL` env var
base_url="http://my.test.server.example.com:8083",
http_client=httpx.Client(
proxies="http://my.test.proxy.example.com",
Expand Down
54 changes: 45 additions & 9 deletions src/lithic/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import asyncio
from typing import Dict, Union, Mapping
from typing import Dict, Union, Mapping, cast
from typing_extensions import Literal, override

import httpx
Expand Down Expand Up @@ -87,15 +87,15 @@ class Lithic(SyncAPIClient):
api_key: str
webhook_secret: str | None

_environment: Literal["production", "sandbox"]
_environment: Literal["production", "sandbox"] | NotGiven

def __init__(
self,
*,
api_key: str | None = None,
webhook_secret: str | None = None,
environment: Literal["production", "sandbox"] = "production",
base_url: str | httpx.URL | None = None,
environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN,
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -138,7 +138,25 @@ def __init__(

self._environment = environment

if base_url is None:
base_url_env = os.environ.get("LITHIC_BASE_URL")
if is_given(base_url) and base_url is not None:
# cast required because mypy doesn't understand the type narrowing
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
elif is_given(environment):
if base_url_env and base_url is not None:
raise ValueError(
"Ambiguous URL; The `LITHIC_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
)

try:
base_url = ENVIRONMENTS[environment]
except KeyError as exc:
raise ValueError(f"Unknown environment: {environment}") from exc
elif base_url_env is not None:
base_url = base_url_env
else:
self._environment = environment = "production"

try:
base_url = ENVIRONMENTS[environment]
except KeyError as exc:
Expand Down Expand Up @@ -371,15 +389,15 @@ class AsyncLithic(AsyncAPIClient):
api_key: str
webhook_secret: str | None

_environment: Literal["production", "sandbox"]
_environment: Literal["production", "sandbox"] | NotGiven

def __init__(
self,
*,
api_key: str | None = None,
webhook_secret: str | None = None,
environment: Literal["production", "sandbox"] = "production",
base_url: str | httpx.URL | None = None,
environment: Literal["production", "sandbox"] | NotGiven = NOT_GIVEN,
base_url: str | httpx.URL | None | NotGiven = NOT_GIVEN,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -422,7 +440,25 @@ def __init__(

self._environment = environment

if base_url is None:
base_url_env = os.environ.get("LITHIC_BASE_URL")
if is_given(base_url) and base_url is not None:
# cast required because mypy doesn't understand the type narrowing
base_url = cast("str | httpx.URL", base_url) # pyright: ignore[reportUnnecessaryCast]
elif is_given(environment):
if base_url_env and base_url is not None:
raise ValueError(
"Ambiguous URL; The `LITHIC_BASE_URL` env var and the `environment` argument are given. If you want to use the environment, you must pass base_url=None",
)

try:
base_url = ENVIRONMENTS[environment]
except KeyError as exc:
raise ValueError(f"Unknown environment: {environment}") from exc
elif base_url_env is not None:
base_url = base_url_env
else:
self._environment = environment = "production"

try:
base_url = ENVIRONMENTS[environment]
except KeyError as exc:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
make_request_options,
)

from .utils import update_env

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
api_key = "My Lithic API Key"

Expand Down Expand Up @@ -429,6 +431,19 @@ def test_idempotency_header_options(self, respx_mock: MockRouter) -> None:
)
assert response.request.headers.get("Idempotency-Key") == "custom-key"

def test_base_url_env(self) -> None:
with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"):
client = Lithic(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"

# explicit environment arg requires explicitness
with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"):
with pytest.raises(ValueError, match=r"you must pass base_url=None"):
Lithic(api_key=api_key, _strict_response_validation=True, environment="production")

client = Lithic(base_url=None, api_key=api_key, _strict_response_validation=True, environment="production")
assert str(client.base_url).startswith("https://api.lithic.com/v1")

@pytest.mark.parametrize(
"client",
[
Expand Down Expand Up @@ -1070,6 +1085,21 @@ async def test_idempotency_header_options(self, respx_mock: MockRouter) -> None:
)
assert response.request.headers.get("Idempotency-Key") == "custom-key"

def test_base_url_env(self) -> None:
with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"):
client = AsyncLithic(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"

# explicit environment arg requires explicitness
with update_env(LITHIC_BASE_URL="http://localhost:5000/from/env"):
with pytest.raises(ValueError, match=r"you must pass base_url=None"):
AsyncLithic(api_key=api_key, _strict_response_validation=True, environment="production")

client = AsyncLithic(
base_url=None, api_key=api_key, _strict_response_validation=True, environment="production"
)
assert str(client.base_url).startswith("https://api.lithic.com/v1")

@pytest.mark.parametrize(
"client",
[
Expand Down
17 changes: 16 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import os
import traceback
from typing import Any, TypeVar, cast
import contextlib
from typing import Any, TypeVar, Iterator, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type

Expand Down Expand Up @@ -103,3 +105,16 @@ def _assert_list_type(type_: type[object], value: object) -> None:
inner_type = get_args(type_)[0]
for entry in value:
assert_type(inner_type, entry) # type: ignore


@contextlib.contextmanager
def update_env(**new_env: str) -> Iterator[None]:
old = os.environ.copy()

try:
os.environ.update(new_env)

yield None
finally:
os.environ.clear()
os.environ.update(old)