diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index be737c1..d7f2ae5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -17,8 +17,7 @@ $ rye sync --all-features You can then run scripts using `rye run python script.py` or by activating the virtual environment: ```sh -$ rye shell -# or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work +# Activate the virtual environment - https://docs.python.org/3/library/venv.html#how-venvs-work $ source .venv/bin/activate # now you can omit the `rye run` prefix diff --git a/pyproject.toml b/pyproject.toml index e01ed64..c0a6bf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "llama_api_client" -version = "0.1.1" +version = "0.1.2" description = "The official Python library for the llama-api-client API" dynamic = ["readme"] license = "MIT" @@ -37,6 +37,8 @@ classifiers = [ Homepage = "https://github.com/meta-llama/llama-api-python" Repository = "https://github.com/meta-llama/llama-api-python" +[project.optional-dependencies] +aiohttp = ["aiohttp", "httpx_aiohttp>=0.1.6"] [tool.rye] managed = true @@ -54,6 +56,7 @@ dev-dependencies = [ "importlib-metadata>=6.7.0", "rich>=13.7.1", "nest_asyncio==1.6.0", + "pytest-xdist>=3.6.1", ] [tool.rye.scripts] @@ -125,7 +128,7 @@ replacement = '[\1](https://github.com/meta-llama/llama-api-python/tree/main/\g< [tool.pytest.ini_options] testpaths = ["tests"] -addopts = "--tb=short" +addopts = "--tb=short -n auto" xfail_strict = true asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "session" diff --git a/requirements-dev.lock b/requirements-dev.lock index 4ed5175..7d3944d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,6 +10,13 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.8 + # via httpx-aiohttp + # via llama-api-client +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 @@ -17,6 +24,10 @@ anyio==4.4.0 # via llama-api-client argcomplete==3.1.2 # via nox +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -30,18 +41,27 @@ distro==1.8.0 exceptiongroup==1.2.2 # via anyio # via pytest +execnet==2.1.1 + # via pytest-xdist filelock==3.12.4 # via virtualenv +frozenlist==1.6.2 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp # via llama-api-client # via respx +httpx-aiohttp==0.1.6 + # via llama-api-client idna==3.4 # via anyio # via httpx + # via yarl importlib-metadata==7.0.0 iniconfig==2.0.0 # via pytest @@ -49,6 +69,9 @@ markdown-it-py==3.0.0 # via rich mdurl==0.1.2 # via markdown-it-py +multidict==6.4.4 + # via aiohttp + # via yarl mypy==1.14.1 mypy-extensions==1.0.0 # via mypy @@ -63,6 +86,9 @@ platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 # via pytest +propcache==0.3.1 + # via aiohttp + # via yarl pydantic==2.10.3 # via llama-api-client pydantic-core==2.27.1 @@ -72,7 +98,9 @@ pygments==2.18.0 pyright==1.1.399 pytest==8.3.3 # via pytest-asyncio + # via pytest-xdist pytest-asyncio==0.24.0 +pytest-xdist==3.7.0 python-dateutil==2.8.2 # via time-machine pytz==2023.3.post1 @@ -94,11 +122,14 @@ tomli==2.0.2 typing-extensions==4.12.2 # via anyio # via llama-api-client + # via multidict # via mypy # via pydantic # via pydantic-core # via pyright virtualenv==20.24.5 # via nox +yarl==1.20.0 + # via aiohttp zipp==3.17.0 # via importlib-metadata diff --git a/requirements.lock b/requirements.lock index 7ca293f..8dca115 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,11 +10,22 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.8 + # via httpx-aiohttp + # via llama-api-client +aiosignal==1.3.2 + # via aiohttp annotated-types==0.6.0 # via pydantic anyio==4.4.0 # via httpx # via llama-api-client +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp certifi==2023.7.22 # via httpcore # via httpx @@ -22,15 +33,28 @@ distro==1.8.0 # via llama-api-client exceptiongroup==1.2.2 # via anyio +frozenlist==1.6.2 + # via aiohttp + # via aiosignal h11==0.14.0 # via httpcore httpcore==1.0.2 # via httpx httpx==0.28.1 + # via httpx-aiohttp + # via llama-api-client +httpx-aiohttp==0.1.6 # via llama-api-client idna==3.4 # via anyio # via httpx + # via yarl +multidict==6.4.4 + # via aiohttp + # via yarl +propcache==0.3.1 + # via aiohttp + # via yarl pydantic==2.10.3 # via llama-api-client pydantic-core==2.27.1 @@ -41,5 +65,8 @@ sniffio==1.3.0 typing-extensions==4.12.2 # via anyio # via llama-api-client + # via multidict # via pydantic # via pydantic-core +yarl==1.20.0 + # via aiohttp diff --git a/src/llama_api_client/__init__.py b/src/llama_api_client/__init__.py index 17f72c2..9c0ab41 100644 --- a/src/llama_api_client/__init__.py +++ b/src/llama_api_client/__init__.py @@ -36,7 +36,7 @@ UnprocessableEntityError, APIResponseValidationError, ) -from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient +from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging __all__ = [ @@ -78,6 +78,7 @@ "DEFAULT_CONNECTION_LIMITS", "DefaultHttpxClient", "DefaultAsyncHttpxClient", + "DefaultAioHttpClient", ] if not _t.TYPE_CHECKING: diff --git a/src/llama_api_client/_base_client.py b/src/llama_api_client/_base_client.py index ff43f0b..cc9ca21 100644 --- a/src/llama_api_client/_base_client.py +++ b/src/llama_api_client/_base_client.py @@ -960,6 +960,9 @@ def request( if self.custom_auth is not None: kwargs["auth"] = self.custom_auth + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects + log.debug("Sending HTTP Request: %s %s", request.method, request.url) response = None @@ -1068,7 +1071,14 @@ def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, APIResponse): raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") @@ -1279,6 +1289,24 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) +try: + import httpx_aiohttp +except ImportError: + + class _DefaultAioHttpClient(httpx.AsyncClient): + def __init__(self, **_kwargs: Any) -> None: + raise RuntimeError("To use the aiohttp client you must have installed the package with the `aiohttp` extra") +else: + + class _DefaultAioHttpClient(httpx_aiohttp.HttpxAiohttpClient): # type: ignore + def __init__(self, **kwargs: Any) -> None: + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) + kwargs.setdefault("follow_redirects", True) + + super().__init__(**kwargs) + + if TYPE_CHECKING: DefaultAsyncHttpxClient = httpx.AsyncClient """An alias to `httpx.AsyncClient` that provides the same defaults that this SDK @@ -1287,8 +1315,12 @@ def __init__(self, **kwargs: Any) -> None: This is useful because overriding the `http_client` with your own instance of `httpx.AsyncClient` will result in httpx's defaults being used, not ours. """ + + DefaultAioHttpClient = httpx.AsyncClient + """An alias to `httpx.AsyncClient` that changes the default HTTP transport to `aiohttp`.""" else: DefaultAsyncHttpxClient = _DefaultAsyncHttpxClient + DefaultAioHttpClient = _DefaultAioHttpClient class AsyncHttpxClientWrapper(DefaultAsyncHttpxClient): @@ -1460,6 +1492,9 @@ async def request( if self.custom_auth is not None: kwargs["auth"] = self.custom_auth + if options.follow_redirects is not None: + kwargs["follow_redirects"] = options.follow_redirects + log.debug("Sending HTTP Request: %s %s", request.method, request.url) response = None @@ -1568,7 +1603,14 @@ async def _process_response( ) -> ResponseT: origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if ( + inspect.isclass(origin) + and issubclass(origin, BaseAPIResponse) + # we only want to actually return the custom BaseAPIResponse class if we're + # returning the raw response, or if we're not streaming SSE, as if we're streaming + # SSE then `cast_to` doesn't actively reflect the type we need to parse into + and (not stream or bool(response.request.headers.get(RAW_RESPONSE_HEADER))) + ): if not issubclass(origin, AsyncAPIResponse): raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") diff --git a/src/llama_api_client/_models.py b/src/llama_api_client/_models.py index 798956f..4f21498 100644 --- a/src/llama_api_client/_models.py +++ b/src/llama_api_client/_models.py @@ -737,6 +737,7 @@ class FinalRequestOptionsInput(TypedDict, total=False): idempotency_key: str json_data: Body extra_json: AnyMapping + follow_redirects: bool @final @@ -750,6 +751,7 @@ class FinalRequestOptions(pydantic.BaseModel): files: Union[HttpxRequestFiles, None] = None idempotency_key: Union[str, None] = None post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() + follow_redirects: Union[bool, None] = None # It should be noted that we cannot use `json` here as that would override # a BaseModel method in an incompatible fashion. diff --git a/src/llama_api_client/_types.py b/src/llama_api_client/_types.py index 6ab25a2..157b7b6 100644 --- a/src/llama_api_client/_types.py +++ b/src/llama_api_client/_types.py @@ -100,6 +100,7 @@ class RequestOptions(TypedDict, total=False): params: Query extra_json: AnyMapping idempotency_key: str + follow_redirects: bool # Sentinel class used until PEP 0661 is accepted @@ -215,3 +216,4 @@ class _GenericAlias(Protocol): class HttpxSendArgs(TypedDict, total=False): auth: httpx.Auth + follow_redirects: bool diff --git a/src/llama_api_client/_version.py b/src/llama_api_client/_version.py index 6e8ef07..e73d415 100644 --- a/src/llama_api_client/_version.py +++ b/src/llama_api_client/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "llama_api_client" -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index ac14b24..507e609 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -205,7 +205,9 @@ def test_streaming_response_create_overload_2(self, client: LlamaAPIClient) -> N class TestAsyncCompletions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index eff02e6..4dd069c 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -89,7 +89,9 @@ def test_streaming_response_list(self, client: LlamaAPIClient) -> None: class TestAsyncModels: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/api_resources/test_moderations.py b/tests/api_resources/test_moderations.py index ce095ac..889d949 100644 --- a/tests/api_resources/test_moderations.py +++ b/tests/api_resources/test_moderations.py @@ -82,7 +82,9 @@ def test_streaming_response_create(self, client: LlamaAPIClient) -> None: class TestAsyncModerations: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) @pytest.mark.skip() @parametrize diff --git a/tests/conftest.py b/tests/conftest.py index acf209c..a98ffea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,17 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + from __future__ import annotations import os import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator +import httpx import pytest from pytest_asyncio import is_async_test -from llama_api_client import LlamaAPIClient, AsyncLlamaAPIClient +from llama_api_client import LlamaAPIClient, AsyncLlamaAPIClient, DefaultAioHttpClient +from llama_api_client._utils import is_dict if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] @@ -25,6 +29,19 @@ def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: for async_test in pytest_asyncio_tests: async_test.add_marker(session_scope_marker, append=False) + # We skip tests that use both the aiohttp client and respx_mock as respx_mock + # doesn't support custom transports. + for item in items: + if "async_client" not in item.fixturenames or "respx_mock" not in item.fixturenames: + continue + + if not hasattr(item, "callspec"): + continue + + async_client_param = item.callspec.params.get("async_client") + if is_dict(async_client_param) and async_client_param.get("http_client") == "aiohttp": + item.add_marker(pytest.mark.skip(reason="aiohttp client is not compatible with respx_mock")) + base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -43,9 +60,25 @@ def client(request: FixtureRequest) -> Iterator[LlamaAPIClient]: @pytest.fixture(scope="session") async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncLlamaAPIClient]: - strict = getattr(request, "param", True) - if not isinstance(strict, bool): - raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") - - async with AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + param = getattr(request, "param", True) + + # defaults + strict = True + http_client: None | httpx.AsyncClient = None + + if isinstance(param, bool): + strict = param + elif is_dict(param): + strict = param.get("strict", True) + assert isinstance(strict, bool) + + http_client_type = param.get("http_client", "httpx") + if http_client_type == "aiohttp": + http_client = DefaultAioHttpClient() + else: + raise TypeError(f"Unexpected fixture parameter type {type(param)}, expected bool or dict") + + async with AsyncLlamaAPIClient( + base_url=base_url, api_key=api_key, _strict_response_validation=strict, http_client=http_client + ) as client: yield client diff --git a/tests/test_client.py b/tests/test_client.py index 1742097..7537a2b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -23,9 +23,7 @@ from llama_api_client import LlamaAPIClient, AsyncLlamaAPIClient, APIResponseValidationError from llama_api_client._types import Omit -from llama_api_client._utils import maybe_transform from llama_api_client._models import BaseModel, FinalRequestOptions -from llama_api_client._constants import RAW_RESPONSE_HEADER from llama_api_client._exceptions import ( APIStatusError, APITimeoutError, @@ -36,9 +34,10 @@ DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, + DefaultHttpxClient, + DefaultAsyncHttpxClient, make_request_options, ) -from llama_api_client.types.chat.completion_create_params import CompletionCreateParamsNonStreaming from .utils import update_env @@ -728,60 +727,37 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model="model", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model="model", + ).__enter__() assert _get_open_connections(self.client) == 0 @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: LlamaAPIClient) -> None: respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model="model", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model="model", + ).__enter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -887,6 +863,55 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: assert response.http_request.headers.get("x-stainless-retry-count") == "42" + def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects(self, respx_mock: MockRouter) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + self.client.post( + "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response + ) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected" + class TestAsyncLlamaAPIClient: client = AsyncLlamaAPIClient(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1558,60 +1583,41 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_timeout_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient + ) -> None: respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): - await self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model="model", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) + await async_client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model="model", + ).__aenter__() assert _get_open_connections(self.client) == 0 @mock.patch("llama_api_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: + async def test_retrying_status_errors_doesnt_leak( + self, respx_mock: MockRouter, async_client: AsyncLlamaAPIClient + ) -> None: respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): - await self.client.post( - "/chat/completions", - body=cast( - object, - maybe_transform( - dict( - messages=[ - { - "content": "string", - "role": "user", - } - ], - model="model", - ), - CompletionCreateParamsNonStreaming, - ), - ), - cast_to=httpx.Response, - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - + await async_client.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "user", + } + ], + model="model", + ).__aenter__() assert _get_open_connections(self.client) == 0 @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @@ -1764,3 +1770,52 @@ async def test_main() -> None: raise AssertionError("calling get_platform using asyncify resulted in a hung process") time.sleep(0.1) + + async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None: + # Test that the proxy environment variables are set correctly + monkeypatch.setenv("HTTPS_PROXY", "https://example.org") + + client = DefaultAsyncHttpxClient() + + mounts = tuple(client._mounts.items()) + assert len(mounts) == 1 + assert mounts[0][0].pattern == "https://" + + @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning") + async def test_default_client_creation(self) -> None: + # Ensure that the client can be initialized without any exceptions + DefaultAsyncHttpxClient( + verify=True, + cert=None, + trust_env=True, + http1=True, + http2=False, + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + ) + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects(self, respx_mock: MockRouter) -> None: + # Test that the default follow_redirects=True allows following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"})) + + response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response) + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + @pytest.mark.respx(base_url=base_url) + async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None: + # Test that follow_redirects=False prevents following redirects + respx_mock.post("/redirect").mock( + return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"}) + ) + + with pytest.raises(APIStatusError) as exc_info: + await self.client.post( + "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response + ) + + assert exc_info.value.response.status_code == 302 + assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"