diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index 75ec52f..b44b287 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "2.3.0"
+ ".": "2.4.0"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index c0cf1c6..7f8b2d8 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
-configured_endpoints: 6
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/channel3%2Fpublic-sdk-a86114ce098255360b65356eedfe9c93f9db44aa99cb90d8c36756d39c2c2de0.yml
-openapi_spec_hash: 113158785b160e8b67d66e2820137df8
+configured_endpoints: 5
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/channel3%2Fpublic-sdk-3366fbfe5ea0c833c184c33d00d301e23e23c0cfa7398b0ebc34a90ab03f65fd.yml
+openapi_spec_hash: e428021f51d697d779a5ddd3ee7109b7
config_hash: 0ec132fef7cbcef12aebece85f2ef2b1
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 683c5ae..388dfc7 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,24 @@
# Changelog
+## 2.4.0 (2025-11-10)
+
+Full Changelog: [v2.3.0...v2.4.0](https://github.com/channel3-ai/sdk-python/compare/v2.3.0...v2.4.0)
+
+### Features
+
+* **api:** api update ([f50613e](https://github.com/channel3-ai/sdk-python/commit/f50613e9cf5d2067b282896dfafcb73ef99ee0ae))
+
+
+### Bug Fixes
+
+* **client:** close streams without requiring full consumption ([4d44bb1](https://github.com/channel3-ai/sdk-python/commit/4d44bb1d045d7d2447bc305dd8463e27dd170175))
+
+
+### Chores
+
+* **internal/tests:** avoid race condition with implicit client cleanup ([ef7fe91](https://github.com/channel3-ai/sdk-python/commit/ef7fe91e5e5c14e15172642362aee07b96f19b3d))
+* **internal:** grammar fix (it's -> its) ([ae7ab14](https://github.com/channel3-ai/sdk-python/commit/ae7ab1494bc21b8b5aeefb91af3353b696949bcb))
+
## 2.3.0 (2025-10-28)
Full Changelog: [v2.2.1...v2.3.0](https://github.com/channel3-ai/sdk-python/compare/v2.2.1...v2.3.0)
diff --git a/api.md b/api.md
index 046c050..4359e79 100644
--- a/api.md
+++ b/api.md
@@ -33,13 +33,12 @@ Methods:
Types:
```python
-from channel3_sdk.types import Brand, BrandListResponse
+from channel3_sdk.types import Brand
```
Methods:
-- client.brands.retrieve(brand_id) -> Brand
-- client.brands.list(\*\*params) -> BrandListResponse
+- client.brands.list(\*\*params) -> Brand
# Enrich
diff --git a/pyproject.toml b/pyproject.toml
index 1ebbba5..3961e1f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "channel3_sdk"
-version = "2.3.0"
+version = "2.4.0"
description = "The official Python library for the channel3 API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/src/channel3_sdk/_streaming.py b/src/channel3_sdk/_streaming.py
index eefa638..e5f4e09 100644
--- a/src/channel3_sdk/_streaming.py
+++ b/src/channel3_sdk/_streaming.py
@@ -57,9 +57,8 @@ def __stream__(self) -> Iterator[_T]:
for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
- # Ensure the entire stream is consumed
- for _sse in iterator:
- ...
+ # As we might not fully consume the response stream, we need to close it explicitly
+ response.close()
def __enter__(self) -> Self:
return self
@@ -121,9 +120,8 @@ async def __stream__(self) -> AsyncIterator[_T]:
async for sse in iterator:
yield process_data(data=sse.json(), cast_to=cast_to, response=response)
- # Ensure the entire stream is consumed
- async for _sse in iterator:
- ...
+ # As we might not fully consume the response stream, we need to close it explicitly
+ await response.aclose()
async def __aenter__(self) -> Self:
return self
diff --git a/src/channel3_sdk/_utils/_utils.py b/src/channel3_sdk/_utils/_utils.py
index 50d5926..eec7f4a 100644
--- a/src/channel3_sdk/_utils/_utils.py
+++ b/src/channel3_sdk/_utils/_utils.py
@@ -133,7 +133,7 @@ def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
# Type safe methods for narrowing types with TypeVars.
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
# however this cause Pyright to rightfully report errors. As we know we don't
-# care about the contained types we can safely use `object` in it's place.
+# care about the contained types we can safely use `object` in its place.
#
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
# `is_*` is for when you're dealing with an unknown input
diff --git a/src/channel3_sdk/_version.py b/src/channel3_sdk/_version.py
index 1fac827..66f444f 100644
--- a/src/channel3_sdk/_version.py
+++ b/src/channel3_sdk/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "channel3_sdk"
-__version__ = "2.3.0" # x-release-please-version
+__version__ = "2.4.0" # x-release-please-version
diff --git a/src/channel3_sdk/resources/brands.py b/src/channel3_sdk/resources/brands.py
index 306db20..57aa10f 100644
--- a/src/channel3_sdk/resources/brands.py
+++ b/src/channel3_sdk/resources/brands.py
@@ -2,12 +2,10 @@
from __future__ import annotations
-from typing import Optional
-
import httpx
from ..types import brand_list_params
-from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
+from .._types import Body, Query, Headers, NotGiven, not_given
from .._utils import maybe_transform, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
@@ -19,7 +17,6 @@
)
from ..types.brand import Brand
from .._base_client import make_request_options
-from ..types.brand_list_response import BrandListResponse
__all__ = ["BrandsResource", "AsyncBrandsResource"]
@@ -44,54 +41,19 @@ def with_streaming_response(self) -> BrandsResourceWithStreamingResponse:
"""
return BrandsResourceWithStreamingResponse(self)
- def retrieve(
- self,
- brand_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = not_given,
- ) -> Brand:
- """
- Get detailed information for a specific brand by its ID.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not brand_id:
- raise ValueError(f"Expected a non-empty value for `brand_id` but received {brand_id!r}")
- return self._get(
- f"/v0/brands/{brand_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Brand,
- )
-
def list(
self,
*,
- page: int | Omit = omit,
- query: Optional[str] | Omit = omit,
- size: int | Omit = omit,
+ query: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
- ) -> BrandListResponse:
+ ) -> Brand:
"""
- Get all brands that the vendor currently sells.
+ Find a brand by name.
Args:
extra_headers: Send extra headers
@@ -109,16 +71,9 @@ def list(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
- query=maybe_transform(
- {
- "page": page,
- "query": query,
- "size": size,
- },
- brand_list_params.BrandListParams,
- ),
+ query=maybe_transform({"query": query}, brand_list_params.BrandListParams),
),
- cast_to=BrandListResponse,
+ cast_to=Brand,
)
@@ -142,54 +97,19 @@ def with_streaming_response(self) -> AsyncBrandsResourceWithStreamingResponse:
"""
return AsyncBrandsResourceWithStreamingResponse(self)
- async def retrieve(
- self,
- brand_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = not_given,
- ) -> Brand:
- """
- Get detailed information for a specific brand by its ID.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not brand_id:
- raise ValueError(f"Expected a non-empty value for `brand_id` but received {brand_id!r}")
- return await self._get(
- f"/v0/brands/{brand_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Brand,
- )
-
async def list(
self,
*,
- page: int | Omit = omit,
- query: Optional[str] | Omit = omit,
- size: int | Omit = omit,
+ query: str,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = not_given,
- ) -> BrandListResponse:
+ ) -> Brand:
"""
- Get all brands that the vendor currently sells.
+ Find a brand by name.
Args:
extra_headers: Send extra headers
@@ -207,16 +127,9 @@ async def list(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
- query=await async_maybe_transform(
- {
- "page": page,
- "query": query,
- "size": size,
- },
- brand_list_params.BrandListParams,
- ),
+ query=await async_maybe_transform({"query": query}, brand_list_params.BrandListParams),
),
- cast_to=BrandListResponse,
+ cast_to=Brand,
)
@@ -224,9 +137,6 @@ class BrandsResourceWithRawResponse:
def __init__(self, brands: BrandsResource) -> None:
self._brands = brands
- self.retrieve = to_raw_response_wrapper(
- brands.retrieve,
- )
self.list = to_raw_response_wrapper(
brands.list,
)
@@ -236,9 +146,6 @@ class AsyncBrandsResourceWithRawResponse:
def __init__(self, brands: AsyncBrandsResource) -> None:
self._brands = brands
- self.retrieve = async_to_raw_response_wrapper(
- brands.retrieve,
- )
self.list = async_to_raw_response_wrapper(
brands.list,
)
@@ -248,9 +155,6 @@ class BrandsResourceWithStreamingResponse:
def __init__(self, brands: BrandsResource) -> None:
self._brands = brands
- self.retrieve = to_streamed_response_wrapper(
- brands.retrieve,
- )
self.list = to_streamed_response_wrapper(
brands.list,
)
@@ -260,9 +164,6 @@ class AsyncBrandsResourceWithStreamingResponse:
def __init__(self, brands: AsyncBrandsResource) -> None:
self._brands = brands
- self.retrieve = async_to_streamed_response_wrapper(
- brands.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
brands.list,
)
diff --git a/src/channel3_sdk/resources/enrich.py b/src/channel3_sdk/resources/enrich.py
index 518cf9a..c8c46af 100644
--- a/src/channel3_sdk/resources/enrich.py
+++ b/src/channel3_sdk/resources/enrich.py
@@ -53,7 +53,8 @@ def enrich_url(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> EnrichEnrichURLResponse:
"""
- Enrich a product URL with additional information.
+ Search by product URL, get back full product information from Channel3’s product
+ database.
Args:
url: The URL of the product to enrich
@@ -108,7 +109,8 @@ async def enrich_url(
timeout: float | httpx.Timeout | None | NotGiven = not_given,
) -> EnrichEnrichURLResponse:
"""
- Enrich a product URL with additional information.
+ Search by product URL, get back full product information from Channel3’s product
+ database.
Args:
url: The URL of the product to enrich
diff --git a/src/channel3_sdk/resources/search.py b/src/channel3_sdk/resources/search.py
index ca0bf09..8d2f327 100644
--- a/src/channel3_sdk/resources/search.py
+++ b/src/channel3_sdk/resources/search.py
@@ -70,7 +70,8 @@ def perform(
context: Optional customer information to personalize search results
- filters: Optional filters
+ filters: Optional filters. Search will only consider products that match all of the
+ filters.
image_url: Image URL
@@ -154,7 +155,8 @@ async def perform(
context: Optional customer information to personalize search results
- filters: Optional filters
+ filters: Optional filters. Search will only consider products that match all of the
+ filters.
image_url: Image URL
diff --git a/src/channel3_sdk/types/__init__.py b/src/channel3_sdk/types/__init__.py
index cf95e97..a520e33 100644
--- a/src/channel3_sdk/types/__init__.py
+++ b/src/channel3_sdk/types/__init__.py
@@ -7,7 +7,6 @@
from .variant import Variant as Variant
from .brand_list_params import BrandListParams as BrandListParams
from .availability_status import AvailabilityStatus as AvailabilityStatus
-from .brand_list_response import BrandListResponse as BrandListResponse
from .search_perform_params import SearchPerformParams as SearchPerformParams
from .search_perform_response import SearchPerformResponse as SearchPerformResponse
from .enrich_enrich_url_params import EnrichEnrichURLParams as EnrichEnrichURLParams
diff --git a/src/channel3_sdk/types/brand.py b/src/channel3_sdk/types/brand.py
index ac624c5..b2c1a33 100644
--- a/src/channel3_sdk/types/brand.py
+++ b/src/channel3_sdk/types/brand.py
@@ -12,6 +12,9 @@ class Brand(BaseModel):
name: str
+ best_commission_rate: Optional[float] = None
+ """The maximum commission rate for the brand, as a percentage"""
+
description: Optional[str] = None
logo_url: Optional[str] = None
diff --git a/src/channel3_sdk/types/brand_list_params.py b/src/channel3_sdk/types/brand_list_params.py
index fdbb5fe..9475030 100644
--- a/src/channel3_sdk/types/brand_list_params.py
+++ b/src/channel3_sdk/types/brand_list_params.py
@@ -2,15 +2,10 @@
from __future__ import annotations
-from typing import Optional
-from typing_extensions import TypedDict
+from typing_extensions import Required, TypedDict
__all__ = ["BrandListParams"]
class BrandListParams(TypedDict, total=False):
- page: int
-
- query: Optional[str]
-
- size: int
+ query: Required[str]
diff --git a/src/channel3_sdk/types/brand_list_response.py b/src/channel3_sdk/types/brand_list_response.py
deleted file mode 100644
index 8caf935..0000000
--- a/src/channel3_sdk/types/brand_list_response.py
+++ /dev/null
@@ -1,25 +0,0 @@
-# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
-
-from typing import List
-
-from .brand import Brand
-from .._models import BaseModel
-
-__all__ = ["BrandListResponse", "Pagination"]
-
-
-class Pagination(BaseModel):
- current_page: int
-
- page_size: int
-
- total_count: int
-
- total_pages: int
-
-
-class BrandListResponse(BaseModel):
- items: List[Brand]
-
- pagination: Pagination
- """Pagination metadata for responses"""
diff --git a/src/channel3_sdk/types/product_retrieve_response.py b/src/channel3_sdk/types/product_retrieve_response.py
index 393efa8..79f5244 100644
--- a/src/channel3_sdk/types/product_retrieve_response.py
+++ b/src/channel3_sdk/types/product_retrieve_response.py
@@ -26,6 +26,8 @@ class ProductRetrieveResponse(BaseModel):
brand_name: Optional[str] = None
+ categories: Optional[List[str]] = None
+
description: Optional[str] = None
gender: Optional[Literal["male", "female", "unisex"]] = None
diff --git a/src/channel3_sdk/types/search_perform_params.py b/src/channel3_sdk/types/search_perform_params.py
index 32794ad..fe34c78 100644
--- a/src/channel3_sdk/types/search_perform_params.py
+++ b/src/channel3_sdk/types/search_perform_params.py
@@ -22,7 +22,10 @@ class SearchPerformParams(TypedDict, total=False):
"""Optional customer information to personalize search results"""
filters: Filters
- """Optional filters"""
+ """Optional filters.
+
+ Search will only consider products that match all of the filters.
+ """
image_url: Optional[str]
"""Image URL"""
@@ -36,6 +39,13 @@ class SearchPerformParams(TypedDict, total=False):
class Config(TypedDict, total=False):
enrich_query: bool
+ """
+ If True, search will use AI to enrich the query, for example pulling the gender,
+ brand, and price range from the query.
+ """
+
+ monetizable_only: bool
+ """If True, search will only consider products that offer commission."""
redirect_mode: Optional[Literal["brand", "price", "commission"]]
"""
@@ -44,8 +54,6 @@ class Config(TypedDict, total=False):
to the brand's product page
"""
- semantic_search: bool
-
class FiltersPrice(TypedDict, total=False):
max_price: Optional[float]
@@ -57,15 +65,28 @@ class FiltersPrice(TypedDict, total=False):
class Filters(TypedDict, total=False):
availability: Optional[List[AvailabilityStatus]]
- """List of availability statuses"""
+ """If provided, only products with these availability statuses will be returned"""
brand_ids: Optional[SequenceNotStr[str]]
- """List of brand IDs"""
+ """If provided, only products from these brands will be returned"""
+
+ category_ids: Optional[SequenceNotStr[str]]
+ """If provided, only products from these categories will be returned"""
+
+ condition: Optional[Literal["new", "refurbished", "used"]]
+ """Filter by product condition.
+
+ Incubating: condition data is currently incomplete; products without condition
+ data will be included in all condition filter results.
+ """
exclude_product_ids: Optional[SequenceNotStr[str]]
- """List of product IDs to exclude"""
+ """If provided, products with these IDs will be excluded from the results"""
gender: Optional[Literal["male", "female", "unisex"]]
price: Optional[FiltersPrice]
"""Price filter. Values are inclusive."""
+
+ website_ids: Optional[SequenceNotStr[str]]
+ """If provided, only products from these websites will be returned"""
diff --git a/src/channel3_sdk/types/search_perform_response.py b/src/channel3_sdk/types/search_perform_response.py
index 45c5620..3e257fe 100644
--- a/src/channel3_sdk/types/search_perform_response.py
+++ b/src/channel3_sdk/types/search_perform_response.py
@@ -28,6 +28,8 @@ class SearchPerformResponseItem(BaseModel):
url: str
+ categories: Optional[List[str]] = None
+
description: Optional[str] = None
variants: Optional[List[Variant]] = None
diff --git a/tests/api_resources/test_brands.py b/tests/api_resources/test_brands.py
index 7cc8fb5..bb6b914 100644
--- a/tests/api_resources/test_brands.py
+++ b/tests/api_resources/test_brands.py
@@ -9,7 +9,7 @@
from tests.utils import assert_matches_type
from channel3_sdk import Channel3, AsyncChannel3
-from channel3_sdk.types import Brand, BrandListResponse
+from channel3_sdk.types import Brand
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -17,83 +17,37 @@
class TestBrands:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- def test_method_retrieve(self, client: Channel3) -> None:
- brand = client.brands.retrieve(
- "brand_id",
- )
- assert_matches_type(Brand, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- def test_raw_response_retrieve(self, client: Channel3) -> None:
- response = client.brands.with_raw_response.retrieve(
- "brand_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- brand = response.parse()
- assert_matches_type(Brand, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- def test_streaming_response_retrieve(self, client: Channel3) -> None:
- with client.brands.with_streaming_response.retrieve(
- "brand_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- brand = response.parse()
- assert_matches_type(Brand, brand, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- def test_path_params_retrieve(self, client: Channel3) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `brand_id` but received ''"):
- client.brands.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_method_list(self, client: Channel3) -> None:
- brand = client.brands.list()
- assert_matches_type(BrandListResponse, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- def test_method_list_with_all_params(self, client: Channel3) -> None:
brand = client.brands.list(
- page=0,
query="query",
- size=0,
)
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_raw_response_list(self, client: Channel3) -> None:
- response = client.brands.with_raw_response.list()
+ response = client.brands.with_raw_response.list(
+ query="query",
+ )
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
brand = response.parse()
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
def test_streaming_response_list(self, client: Channel3) -> None:
- with client.brands.with_streaming_response.list() as response:
+ with client.brands.with_streaming_response.list(
+ query="query",
+ ) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
brand = response.parse()
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -103,82 +57,36 @@ class TestAsyncBrands:
"async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"]
)
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncChannel3) -> None:
- brand = await async_client.brands.retrieve(
- "brand_id",
- )
- assert_matches_type(Brand, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncChannel3) -> None:
- response = await async_client.brands.with_raw_response.retrieve(
- "brand_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- brand = await response.parse()
- assert_matches_type(Brand, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncChannel3) -> None:
- async with async_client.brands.with_streaming_response.retrieve(
- "brand_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- brand = await response.parse()
- assert_matches_type(Brand, brand, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncChannel3) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `brand_id` but received ''"):
- await async_client.brands.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_method_list(self, async_client: AsyncChannel3) -> None:
- brand = await async_client.brands.list()
- assert_matches_type(BrandListResponse, brand, path=["response"])
-
- @pytest.mark.skip(reason="Prism tests are disabled")
- @parametrize
- async def test_method_list_with_all_params(self, async_client: AsyncChannel3) -> None:
brand = await async_client.brands.list(
- page=0,
query="query",
- size=0,
)
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_raw_response_list(self, async_client: AsyncChannel3) -> None:
- response = await async_client.brands.with_raw_response.list()
+ response = await async_client.brands.with_raw_response.list(
+ query="query",
+ )
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
brand = await response.parse()
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
@pytest.mark.skip(reason="Prism tests are disabled")
@parametrize
async def test_streaming_response_list(self, async_client: AsyncChannel3) -> None:
- async with async_client.brands.with_streaming_response.list() as response:
+ async with async_client.brands.with_streaming_response.list(
+ query="query",
+ ) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
brand = await response.parse()
- assert_matches_type(BrandListResponse, brand, path=["response"])
+ assert_matches_type(Brand, brand, path=["response"])
assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_search.py b/tests/api_resources/test_search.py
index 7c9d28d..fa4236f 100644
--- a/tests/api_resources/test_search.py
+++ b/tests/api_resources/test_search.py
@@ -30,19 +30,22 @@ def test_method_perform_with_all_params(self, client: Channel3) -> None:
base64_image="base64_image",
config={
"enrich_query": True,
+ "monetizable_only": True,
"redirect_mode": "brand",
- "semantic_search": True,
},
context="context",
filters={
"availability": ["InStock"],
"brand_ids": ["string"],
+ "category_ids": ["string"],
+ "condition": "new",
"exclude_product_ids": ["string"],
"gender": "male",
"price": {
"max_price": 0,
"min_price": 0,
},
+ "website_ids": ["string"],
},
image_url="image_url",
limit=0,
@@ -91,19 +94,22 @@ async def test_method_perform_with_all_params(self, async_client: AsyncChannel3)
base64_image="base64_image",
config={
"enrich_query": True,
+ "monetizable_only": True,
"redirect_mode": "brand",
- "semantic_search": True,
},
context="context",
filters={
"availability": ["InStock"],
"brand_ids": ["string"],
+ "category_ids": ["string"],
+ "condition": "new",
"exclude_product_ids": ["string"],
"gender": "male",
"price": {
"max_price": 0,
"min_price": 0,
},
+ "website_ids": ["string"],
},
image_url="image_url",
limit=0,
diff --git a/tests/test_client.py b/tests/test_client.py
index 3227185..400fdc1 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -59,51 +59,49 @@ def _get_open_connections(client: Channel3 | AsyncChannel3) -> int:
class TestChannel3:
- client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- def test_raw_response(self, respx_mock: MockRouter) -> None:
+ def test_raw_response(self, respx_mock: MockRouter, client: Channel3) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ def test_raw_response_for_binary(self, respx_mock: MockRouter, client: Channel3) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = self.client.post("/foo", cast_to=httpx.Response)
+ response = client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, client: Channel3) -> None:
+ copied = client.copy()
+ assert id(copied) != id(client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, client: Channel3) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(client.timeout, httpx.Timeout)
+ copied = client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(client.timeout, httpx.Timeout)
def test_copy_default_headers(self) -> None:
client = Channel3(
@@ -138,6 +136,7 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ client.close()
def test_copy_default_query(self) -> None:
client = Channel3(
@@ -175,13 +174,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ client.close()
+
+ def test_copy_signature(self, client: Channel3) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -192,12 +193,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, client: Channel3) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -254,14 +255,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ def test_request_timeout(self, client: Channel3) -> None:
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
- FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
- )
+ request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(100.0)
@@ -274,6 +273,8 @@ def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ client.close()
+
def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
with httpx.Client(timeout=None) as http_client:
@@ -285,6 +286,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ client.close()
+
# no timeout given to the httpx client should not use the httpx default
with httpx.Client() as http_client:
client = Channel3(
@@ -295,6 +298,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ client.close()
+
# explicitly passing the default timeout currently results in it being ignored
with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = Channel3(
@@ -305,6 +310,8 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ client.close()
+
async def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
async with httpx.AsyncClient() as http_client:
@@ -316,14 +323,14 @@ async def test_invalid_http_client(self) -> None:
)
def test_default_headers_option(self) -> None:
- client = Channel3(
+ test_client = Channel3(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = Channel3(
+ test_client2 = Channel3(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -332,10 +339,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ test_client.close()
+ test_client2.close()
+
def test_validate_headers(self) -> None:
client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -364,8 +374,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ client.close()
+
+ def test_request_extra_json(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -376,7 +388,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -387,7 +399,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -398,8 +410,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -409,7 +421,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -420,8 +432,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -434,7 +446,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -448,7 +460,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -491,7 +503,7 @@ def test_multipart_repeating_array(self, client: Channel3) -> None:
]
@pytest.mark.respx(base_url=base_url)
- def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ def test_basic_union_response(self, respx_mock: MockRouter, client: Channel3) -> None:
class Model1(BaseModel):
name: str
@@ -500,12 +512,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ def test_union_response_different_types(self, respx_mock: MockRouter, client: Channel3) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -516,18 +528,18 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: Channel3) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -543,7 +555,7 @@ class Model(BaseModel):
)
)
- response = self.client.get("/foo", cast_to=Model)
+ response = client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
@@ -555,6 +567,8 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
+ client.close()
+
def test_base_url_env(self) -> None:
with update_env(CHANNEL3_BASE_URL="http://localhost:5000/from/env"):
client = Channel3(api_key=api_key, _strict_response_validation=True)
@@ -582,6 +596,7 @@ def test_base_url_trailing_slash(self, client: Channel3) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -605,6 +620,7 @@ def test_base_url_no_trailing_slash(self, client: Channel3) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ client.close()
@pytest.mark.parametrize(
"client",
@@ -628,35 +644,36 @@ def test_absolute_request_url(self, client: Channel3) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ client.close()
def test_copied_client_does_not_close_http(self) -> None:
- client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
- assert not client.is_closed()
+ assert not test_client.is_closed()
def test_client_context_manager(self) -> None:
- client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- with client as c2:
- assert c2 is client
+ test_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ def test_client_response_validation_error(self, respx_mock: MockRouter, client: Channel3) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- self.client.get("/foo", cast_to=Model)
+ client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -676,11 +693,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
strict_client.get("/foo", cast_to=Model)
- client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = client.get("/foo", cast_to=Model)
+ response = non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ strict_client.close()
+ non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -703,9 +723,9 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = Channel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, client: Channel3
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
@@ -719,7 +739,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, clien
with pytest.raises(APITimeoutError):
client.search.with_streaming_response.perform().__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@@ -728,7 +748,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client
with pytest.raises(APIStatusError):
client.search.with_streaming_response.perform().__enter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -830,83 +850,77 @@ def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects(self, respx_mock: MockRouter, client: Channel3) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: Channel3) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- self.client.post(
- "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
- )
+ client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
assert exc_info.value.response.status_code == 302
assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
class TestAsyncChannel3:
- client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
+ async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
respx_mock.post("/foo").mock(
return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
)
- response = await self.client.post("/foo", cast_to=httpx.Response)
+ response = await async_client.post("/foo", cast_to=httpx.Response)
assert response.status_code == 200
assert isinstance(response, httpx.Response)
assert response.json() == {"foo": "bar"}
- def test_copy(self) -> None:
- copied = self.client.copy()
- assert id(copied) != id(self.client)
+ def test_copy(self, async_client: AsyncChannel3) -> None:
+ copied = async_client.copy()
+ assert id(copied) != id(async_client)
- copied = self.client.copy(api_key="another My API Key")
+ copied = async_client.copy(api_key="another My API Key")
assert copied.api_key == "another My API Key"
- assert self.client.api_key == "My API Key"
+ assert async_client.api_key == "My API Key"
- def test_copy_default_options(self) -> None:
+ def test_copy_default_options(self, async_client: AsyncChannel3) -> None:
# options that have a default are overridden correctly
- copied = self.client.copy(max_retries=7)
+ copied = async_client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 2
+ assert async_client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
assert copied.max_retries == 7
# timeout
- assert isinstance(self.client.timeout, httpx.Timeout)
- copied = self.client.copy(timeout=None)
+ assert isinstance(async_client.timeout, httpx.Timeout)
+ copied = async_client.copy(timeout=None)
assert copied.timeout is None
- assert isinstance(self.client.timeout, httpx.Timeout)
+ assert isinstance(async_client.timeout, httpx.Timeout)
- def test_copy_default_headers(self) -> None:
+ async def test_copy_default_headers(self) -> None:
client = AsyncChannel3(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
@@ -939,8 +953,9 @@ def test_copy_default_headers(self) -> None:
match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
):
client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
+ await client.close()
- def test_copy_default_query(self) -> None:
+ async def test_copy_default_query(self) -> None:
client = AsyncChannel3(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
)
@@ -976,13 +991,15 @@ def test_copy_default_query(self) -> None:
):
client.copy(set_default_query={}, default_query={"foo": "Bar"})
- def test_copy_signature(self) -> None:
+ await client.close()
+
+ def test_copy_signature(self, async_client: AsyncChannel3) -> None:
# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
init_signature = inspect.signature(
# mypy doesn't like that we access the `__init__` property.
- self.client.__init__, # type: ignore[misc]
+ async_client.__init__, # type: ignore[misc]
)
- copy_signature = inspect.signature(self.client.copy)
+ copy_signature = inspect.signature(async_client.copy)
exclude_params = {"transport", "proxies", "_strict_response_validation"}
for name in init_signature.parameters.keys():
@@ -993,12 +1010,12 @@ def test_copy_signature(self) -> None:
assert copy_param is not None, f"copy() signature is missing the {name} param"
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
- def test_copy_build_request(self) -> None:
+ def test_copy_build_request(self, async_client: AsyncChannel3) -> None:
options = FinalRequestOptions(method="get", url="/foo")
def build_request(options: FinalRequestOptions) -> None:
- client = self.client.copy()
- client._build_request(options)
+ client_copy = async_client.copy()
+ client_copy._build_request(options)
# ensure that the machinery is warmed up before tracing starts.
build_request(options)
@@ -1055,12 +1072,12 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic
print(frame)
raise AssertionError()
- async def test_request_timeout(self) -> None:
- request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ async def test_request_timeout(self, async_client: AsyncChannel3) -> None:
+ request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
- request = self.client._build_request(
+ request = async_client._build_request(
FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
)
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
@@ -1075,6 +1092,8 @@ async def test_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(0)
+ await client.close()
+
async def test_http_client_timeout_option(self) -> None:
# custom timeout given to the httpx client should be used
async with httpx.AsyncClient(timeout=None) as http_client:
@@ -1086,6 +1105,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == httpx.Timeout(None)
+ await client.close()
+
# no timeout given to the httpx client should not use the httpx default
async with httpx.AsyncClient() as http_client:
client = AsyncChannel3(
@@ -1096,6 +1117,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT
+ await client.close()
+
# explicitly passing the default timeout currently results in it being ignored
async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
client = AsyncChannel3(
@@ -1106,6 +1129,8 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ await client.close()
+
def test_invalid_http_client(self) -> None:
with pytest.raises(TypeError, match="Invalid `http_client` arg"):
with httpx.Client() as http_client:
@@ -1116,15 +1141,15 @@ def test_invalid_http_client(self) -> None:
http_client=cast(Any, http_client),
)
- def test_default_headers_option(self) -> None:
- client = AsyncChannel3(
+ async def test_default_headers_option(self) -> None:
+ test_client = AsyncChannel3(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
)
- request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "bar"
assert request.headers.get("x-stainless-lang") == "python"
- client2 = AsyncChannel3(
+ test_client2 = AsyncChannel3(
base_url=base_url,
api_key=api_key,
_strict_response_validation=True,
@@ -1133,10 +1158,13 @@ def test_default_headers_option(self) -> None:
"X-Stainless-Lang": "my-overriding-header",
},
)
- request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
+ request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
assert request.headers.get("x-foo") == "stainless"
assert request.headers.get("x-stainless-lang") == "my-overriding-header"
+ await test_client.close()
+ await test_client2.close()
+
def test_validate_headers(self) -> None:
client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
@@ -1147,7 +1175,7 @@ def test_validate_headers(self) -> None:
client2 = AsyncChannel3(base_url=base_url, api_key=None, _strict_response_validation=True)
_ = client2
- def test_default_query_option(self) -> None:
+ async def test_default_query_option(self) -> None:
client = AsyncChannel3(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
)
@@ -1165,8 +1193,10 @@ def test_default_query_option(self) -> None:
url = httpx.URL(request.url)
assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
- def test_request_extra_json(self) -> None:
- request = self.client._build_request(
+ await client.close()
+
+ def test_request_extra_json(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1177,7 +1207,7 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": False}
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1188,7 +1218,7 @@ def test_request_extra_json(self) -> None:
assert data == {"baz": False}
# `extra_json` takes priority over `json_data` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1199,8 +1229,8 @@ def test_request_extra_json(self) -> None:
data = json.loads(request.content.decode("utf-8"))
assert data == {"foo": "bar", "baz": None}
- def test_request_extra_headers(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_headers(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1210,7 +1240,7 @@ def test_request_extra_headers(self) -> None:
assert request.headers.get("X-Foo") == "Foo"
# `extra_headers` takes priority over `default_headers` when keys clash
- request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
+ request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1221,8 +1251,8 @@ def test_request_extra_headers(self) -> None:
)
assert request.headers.get("X-Bar") == "false"
- def test_request_extra_query(self) -> None:
- request = self.client._build_request(
+ def test_request_extra_query(self, client: Channel3) -> None:
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1235,7 +1265,7 @@ def test_request_extra_query(self) -> None:
assert params == {"my_query_param": "Foo"}
# if both `query` and `extra_query` are given, they are merged
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1249,7 +1279,7 @@ def test_request_extra_query(self) -> None:
assert params == {"bar": "1", "foo": "2"}
# `extra_query` takes priority over `query` when keys clash
- request = self.client._build_request(
+ request = client._build_request(
FinalRequestOptions(
method="post",
url="/foo",
@@ -1292,7 +1322,7 @@ def test_multipart_repeating_array(self, async_client: AsyncChannel3) -> None:
]
@pytest.mark.respx(base_url=base_url)
- async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
+ async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
class Model1(BaseModel):
name: str
@@ -1301,12 +1331,12 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
@pytest.mark.respx(base_url=base_url)
- async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
+ async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
"""Union of objects with the same field name using a different type"""
class Model1(BaseModel):
@@ -1317,18 +1347,20 @@ class Model2(BaseModel):
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model2)
assert response.foo == "bar"
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
- response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
+ response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
assert isinstance(response, Model1)
assert response.foo == 1
@pytest.mark.respx(base_url=base_url)
- async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
+ async def test_non_application_json_content_type_for_json_data(
+ self, respx_mock: MockRouter, async_client: AsyncChannel3
+ ) -> None:
"""
Response that sets Content-Type to something other than application/json but returns json data
"""
@@ -1344,11 +1376,11 @@ class Model(BaseModel):
)
)
- response = await self.client.get("/foo", cast_to=Model)
+ response = await async_client.get("/foo", cast_to=Model)
assert isinstance(response, Model)
assert response.foo == 2
- def test_base_url_setter(self) -> None:
+ async def test_base_url_setter(self) -> None:
client = AsyncChannel3(
base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
)
@@ -1358,7 +1390,9 @@ def test_base_url_setter(self) -> None:
assert client.base_url == "https://example.com/from_setter/"
- def test_base_url_env(self) -> None:
+ await client.close()
+
+ async def test_base_url_env(self) -> None:
with update_env(CHANNEL3_BASE_URL="http://localhost:5000/from/env"):
client = AsyncChannel3(api_key=api_key, _strict_response_validation=True)
assert client.base_url == "http://localhost:5000/from/env/"
@@ -1378,7 +1412,7 @@ def test_base_url_env(self) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None:
+ async def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1387,6 +1421,7 @@ def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1403,7 +1438,7 @@ def test_base_url_trailing_slash(self, client: AsyncChannel3) -> None:
],
ids=["standard", "custom http client"],
)
- def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None:
+ async def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1412,6 +1447,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None:
),
)
assert request.url == "http://localhost:5000/custom/path/foo"
+ await client.close()
@pytest.mark.parametrize(
"client",
@@ -1428,7 +1464,7 @@ def test_base_url_no_trailing_slash(self, client: AsyncChannel3) -> None:
],
ids=["standard", "custom http client"],
)
- def test_absolute_request_url(self, client: AsyncChannel3) -> None:
+ async def test_absolute_request_url(self, client: AsyncChannel3) -> None:
request = client._build_request(
FinalRequestOptions(
method="post",
@@ -1437,37 +1473,37 @@ def test_absolute_request_url(self, client: AsyncChannel3) -> None:
),
)
assert request.url == "https://myapi.com/foo"
+ await client.close()
async def test_copied_client_does_not_close_http(self) -> None:
- client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- assert not client.is_closed()
+ test_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ assert not test_client.is_closed()
- copied = client.copy()
- assert copied is not client
+ copied = test_client.copy()
+ assert copied is not test_client
del copied
await asyncio.sleep(0.2)
- assert not client.is_closed()
+ assert not test_client.is_closed()
async def test_client_context_manager(self) -> None:
- client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
- async with client as c2:
- assert c2 is client
+ test_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
+ async with test_client as c2:
+ assert c2 is test_client
assert not c2.is_closed()
- assert not client.is_closed()
- assert client.is_closed()
+ assert not test_client.is_closed()
+ assert test_client.is_closed()
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
- async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
+ async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
class Model(BaseModel):
foo: str
respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
with pytest.raises(APIResponseValidationError) as exc:
- await self.client.get("/foo", cast_to=Model)
+ await async_client.get("/foo", cast_to=Model)
assert isinstance(exc.value.__cause__, ValidationError)
@@ -1478,7 +1514,6 @@ async def test_client_max_retries_validation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
class Model(BaseModel):
name: str
@@ -1490,11 +1525,14 @@ class Model(BaseModel):
with pytest.raises(APIResponseValidationError):
await strict_client.get("/foo", cast_to=Model)
- client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=False)
+ non_strict_client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=False)
- response = await client.get("/foo", cast_to=Model)
+ response = await non_strict_client.get("/foo", cast_to=Model)
assert isinstance(response, str) # type: ignore[unreachable]
+ await strict_client.close()
+ await non_strict_client.close()
+
@pytest.mark.parametrize(
"remaining_retries,retry_after,timeout",
[
@@ -1517,13 +1555,12 @@ class Model(BaseModel):
],
)
@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
- @pytest.mark.asyncio
- async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
- client = AsyncChannel3(base_url=base_url, api_key=api_key, _strict_response_validation=True)
-
+ async def test_parse_retry_after_header(
+ self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncChannel3
+ ) -> None:
headers = httpx.Headers({"retry-after": retry_after})
options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
- calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
+ calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@@ -1536,7 +1573,7 @@ async def test_retrying_timeout_errors_doesnt_leak(
with pytest.raises(APITimeoutError):
await async_client.search.with_streaming_response.perform().__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
@@ -1547,12 +1584,11 @@ async def test_retrying_status_errors_doesnt_leak(
with pytest.raises(APIStatusError):
await async_client.search.with_streaming_response.perform().__aenter__()
- assert _get_open_connections(self.client) == 0
+ assert _get_open_connections(async_client) == 0
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
@pytest.mark.parametrize("failure_mode", ["status", "exception"])
async def test_retries_taken(
self,
@@ -1584,7 +1620,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_omit_retry_count_header(
self, async_client: AsyncChannel3, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1608,7 +1643,6 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@mock.patch("channel3_sdk._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
- @pytest.mark.asyncio
async def test_overwrite_retry_count_header(
self, async_client: AsyncChannel3, failures_before_success: int, respx_mock: MockRouter
) -> None:
@@ -1656,26 +1690,26 @@ async def test_default_client_creation(self) -> None:
)
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
# Test that the default follow_redirects=True allows following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
- response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
+ response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.respx(base_url=base_url)
- async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
+ async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncChannel3) -> None:
# Test that follow_redirects=False prevents following redirects
respx_mock.post("/redirect").mock(
return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
)
with pytest.raises(APIStatusError) as exc_info:
- await self.client.post(
+ await async_client.post(
"/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
)