diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c6b2570e..b88cdcbc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ jobs: steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install Rye run: | diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index ed786f39..81fea1d1 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -1,6 +1,6 @@ # This workflow is triggered when a GitHub release is created. # It can also be run manually to re-publish to PyPI in case it failed for some reason. -# You can run this workflow by navigating to https://www.github.com/clibrain/python-sdk/actions/workflows/publish-pypi.yml +# You can run this workflow by navigating to https://www.github.com/maisaai/python-sdk/actions/workflows/publish-pypi.yml name: Publish PyPI on: workflow_dispatch: diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml index 75712fd0..7a5fc2f2 100644 --- a/.github/workflows/release-doctor.yml +++ b/.github/workflows/release-doctor.yml @@ -7,7 +7,7 @@ jobs: release_doctor: name: release doctor runs-on: ubuntu-latest - if: github.repository == 'clibrain/python-sdk' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') + if: github.repository == 'maisaai/python-sdk' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next') steps: - uses: actions/checkout@v3 diff --git a/.stats.yml b/.stats.yml index dd473053..c2549479 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1 +1 @@ -configured_endpoints: 14 +configured_endpoints: 15 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5ee857ad..54364e97 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ If you’d like to use the repository from source, you can either install from g To install via git: ```bash -pip install git+ssh://git@github.com:clibrain/python-sdk.git +pip install git+ssh://git@github.com/maisaai/python-sdk.git ``` Alternatively, you can build from source and install the wheel file: @@ -82,7 +82,7 @@ pip install ./path-to-wheel-file.whl ## Running tests -Most tests will require you to [setup a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. +Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. ```bash # you will need npm installed @@ -117,7 +117,7 @@ the changes aren't made through the automated pipeline, you may want to make rel ### Publish with a GitHub workflow -You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/clibrain/python-sdk/actions/workflows/publish-pypi.yml). This will require a setup organization or repository secret to be set up. +You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/maisaai/python-sdk/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up. ### Publish manually diff --git a/README.md b/README.md index 7bdb10c9..78f19d95 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,12 @@ and offers both synchronous and asynchronous clients powered by [httpx](https:// ## Documentation -The REST API documentation can be found [on maisa.ai](https://maisa.ai/). The full API of this library can be found in [api.md](api.md). +The REST API documentation can be found [on docs.maisa.ai](https://docs.maisa.ai/). The full API of this library can be found in [api.md](api.md). ## Installation ```sh +# install from PyPI pip install --pre maisa ``` @@ -29,10 +30,10 @@ client = Maisa( api_key=os.environ.get("MAISA_API_KEY"), ) -embeddings = client.models.embeddings.create( - texts=["string"], +text_summary = client.capabilities.summarize( + text="Example long text...", ) -print(embeddings.embeddings) +print(text_summary.summary) ``` While you can provide an `api_key` keyword argument, @@ -56,10 +57,10 @@ client = AsyncMaisa( async def main() -> None: - embeddings = await client.models.embeddings.create( - texts=["string"], + text_summary = await client.capabilities.summarize( + text="Example long text...", ) - print(embeddings.embeddings) + print(text_summary.summary) asyncio.run(main()) @@ -92,8 +93,8 @@ from maisa import Maisa client = Maisa() try: - client.models.embeddings.create( - texts=["string"], + client.capabilities.summarize( + text="Example long text...", ) except maisa.APIConnectionError as e: print("The server could not be reached") @@ -121,7 +122,7 @@ Error codes are as followed: ### Retries -Certain errors are automatically retried 3 times by default, with a short exponential backoff. +Certain errors are automatically retried 2 times by default, with a short exponential backoff. Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default. @@ -137,8 +138,8 @@ client = Maisa( ) # Or, configure per-request: -client.with_options(max_retries=5).models.embeddings.create( - texts=["string"], +client.with_options(max_retries=5).capabilities.summarize( + text="Example long text...", ) ``` @@ -162,8 +163,8 @@ client = Maisa( ) # Override per-request: -client.with_options(timeout=5 * 1000).models.embeddings.create( - texts=["string"], +client.with_options(timeout=5 * 1000).capabilities.summarize( + text="Example long text...", ) ``` @@ -203,18 +204,18 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to from maisa import Maisa client = Maisa() -response = client.models.embeddings.with_raw_response.create( - texts=["string"], +response = client.capabilities.with_raw_response.summarize( + text="Example long text...", ) print(response.headers.get('X-My-Header')) -embedding = response.parse() # get the object that `models.embeddings.create()` would have returned -print(embedding.embeddings) +capability = response.parse() # get the object that `capabilities.summarize()` would have returned +print(capability.summary) ``` -These methods return an [`APIResponse`](https://github.com/clibrain/python-sdk/tree/main/src/maisa/_response.py) object. +These methods return an [`APIResponse`](https://github.com/maisaai/python-sdk/tree/main/src/maisa/_response.py) object. -The async client returns an [`AsyncAPIResponse`](https://github.com/clibrain/python-sdk/tree/main/src/maisa/_response.py) with the same structure, the only difference being `await`able methods for reading the response content. +The async client returns an [`AsyncAPIResponse`](https://github.com/maisaai/python-sdk/tree/main/src/maisa/_response.py) with the same structure, the only difference being `await`able methods for reading the response content. #### `.with_streaming_response` @@ -223,8 +224,8 @@ The above interface eagerly reads the full response body when you make the reque To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. ```python -with client.models.embeddings.with_streaming_response.create( - texts=["string"], +with client.capabilities.with_streaming_response.summarize( + text="Example long text...", ) as response: print(response.headers.get("X-My-Header")) @@ -270,7 +271,7 @@ This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) con We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience. -We are keen for your feedback; please open an [issue](https://www.github.com/clibrain/python-sdk/issues) with questions, bugs, or suggestions. +We are keen for your feedback; please open an [issue](https://www.github.com/maisaai/python-sdk/issues) with questions, bugs, or suggestions. ## Requirements diff --git a/api.md b/api.md index 42c1031b..39b91ab5 100644 --- a/api.md +++ b/api.md @@ -46,6 +46,18 @@ Methods: - client.models.rerank.create(\*\*params) -> Rerank +# Kpu + +Types: + +```python +from maisa.types import KpuRunResponse +``` + +Methods: + +- client.kpu.run(\*\*params) -> KpuRunResponse + # FileInterpreter ## FromPdf diff --git a/pyproject.toml b/pyproject.toml index beccaebc..fe875f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,8 @@ classifiers = [ [project.urls] -Homepage = "https://github.com/clibrain/python-sdk" -Repository = "https://github.com/clibrain/python-sdk" +Homepage = "https://github.com/maisaai/python-sdk" +Repository = "https://github.com/maisaai/python-sdk" diff --git a/requirements-dev.lock b/requirements-dev.lock index d26251a3..204fcb08 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -63,7 +63,7 @@ pydantic==2.4.2 # via maisa pydantic-core==2.10.1 # via pydantic -pyright==1.1.332 +pyright==1.1.351 pytest==7.1.1 # via pytest-asyncio pytest-asyncio==0.21.1 diff --git a/src/maisa/_base_client.py b/src/maisa/_base_client.py index 0b5ece2c..2b3a1f98 100644 --- a/src/maisa/_base_client.py +++ b/src/maisa/_base_client.py @@ -79,7 +79,7 @@ RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER, ) -from ._streaming import Stream, AsyncStream +from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder from ._exceptions import ( APIStatusError, APITimeoutError, @@ -430,6 +430,9 @@ def _prepare_url(self, url: str) -> URL: return merge_url + def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: + return SSEDecoder() + def _build_request( self, options: FinalRequestOptions, @@ -776,6 +779,11 @@ def __init__( else: timeout = DEFAULT_TIMEOUT + if http_client is not None and not isinstance(http_client, httpx.Client): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}" + ) + super().__init__( version=version, limits=limits, @@ -1305,6 +1313,11 @@ def __init__( else: timeout = DEFAULT_TIMEOUT + if http_client is not None and not isinstance(http_client, httpx.AsyncClient): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.AsyncClient` but got {type(http_client)}" + ) + super().__init__( version=version, base_url=base_url, diff --git a/src/maisa/_client.py b/src/maisa/_client.py index b6f06497..11d68726 100644 --- a/src/maisa/_client.py +++ b/src/maisa/_client.py @@ -48,6 +48,7 @@ class Maisa(SyncAPIClient): capabilities: resources.Capabilities models: resources.Models + kpu: resources.Kpu file_interpreter: resources.FileInterpreter mainet: resources.Mainet with_raw_response: MaisaWithRawResponse @@ -107,6 +108,7 @@ def __init__( self.capabilities = resources.Capabilities(self) self.models = resources.Models(self) + self.kpu = resources.Kpu(self) self.file_interpreter = resources.FileInterpreter(self) self.mainet = resources.Mainet(self) self.with_raw_response = MaisaWithRawResponse(self) @@ -220,6 +222,7 @@ def _make_status_error( class AsyncMaisa(AsyncAPIClient): capabilities: resources.AsyncCapabilities models: resources.AsyncModels + kpu: resources.AsyncKpu file_interpreter: resources.AsyncFileInterpreter mainet: resources.AsyncMainet with_raw_response: AsyncMaisaWithRawResponse @@ -279,6 +282,7 @@ def __init__( self.capabilities = resources.AsyncCapabilities(self) self.models = resources.AsyncModels(self) + self.kpu = resources.AsyncKpu(self) self.file_interpreter = resources.AsyncFileInterpreter(self) self.mainet = resources.AsyncMainet(self) self.with_raw_response = AsyncMaisaWithRawResponse(self) @@ -393,6 +397,7 @@ class MaisaWithRawResponse: def __init__(self, client: Maisa) -> None: self.capabilities = resources.CapabilitiesWithRawResponse(client.capabilities) self.models = resources.ModelsWithRawResponse(client.models) + self.kpu = resources.KpuWithRawResponse(client.kpu) self.file_interpreter = resources.FileInterpreterWithRawResponse(client.file_interpreter) self.mainet = resources.MainetWithRawResponse(client.mainet) @@ -401,6 +406,7 @@ class AsyncMaisaWithRawResponse: def __init__(self, client: AsyncMaisa) -> None: self.capabilities = resources.AsyncCapabilitiesWithRawResponse(client.capabilities) self.models = resources.AsyncModelsWithRawResponse(client.models) + self.kpu = resources.AsyncKpuWithRawResponse(client.kpu) self.file_interpreter = resources.AsyncFileInterpreterWithRawResponse(client.file_interpreter) self.mainet = resources.AsyncMainetWithRawResponse(client.mainet) @@ -409,6 +415,7 @@ class MaisaWithStreamedResponse: def __init__(self, client: Maisa) -> None: self.capabilities = resources.CapabilitiesWithStreamingResponse(client.capabilities) self.models = resources.ModelsWithStreamingResponse(client.models) + self.kpu = resources.KpuWithStreamingResponse(client.kpu) self.file_interpreter = resources.FileInterpreterWithStreamingResponse(client.file_interpreter) self.mainet = resources.MainetWithStreamingResponse(client.mainet) @@ -417,6 +424,7 @@ class AsyncMaisaWithStreamedResponse: def __init__(self, client: AsyncMaisa) -> None: self.capabilities = resources.AsyncCapabilitiesWithStreamingResponse(client.capabilities) self.models = resources.AsyncModelsWithStreamingResponse(client.models) + self.kpu = resources.AsyncKpuWithStreamingResponse(client.kpu) self.file_interpreter = resources.AsyncFileInterpreterWithStreamingResponse(client.file_interpreter) self.mainet = resources.AsyncMainetWithStreamingResponse(client.mainet) diff --git a/src/maisa/_constants.py b/src/maisa/_constants.py index 036f4f77..bf15141a 100644 --- a/src/maisa/_constants.py +++ b/src/maisa/_constants.py @@ -7,7 +7,7 @@ # default timeout is 1 minute DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0) -DEFAULT_MAX_RETRIES = 3 +DEFAULT_MAX_RETRIES = 2 DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) INITIAL_RETRY_DELAY = 0.5 diff --git a/src/maisa/_files.py b/src/maisa/_files.py index b6e8af8b..0d2022ae 100644 --- a/src/maisa/_files.py +++ b/src/maisa/_files.py @@ -13,12 +13,17 @@ FileContent, RequestFiles, HttpxFileTypes, + Base64FileInput, HttpxFileContent, HttpxRequestFiles, ) from ._utils import is_tuple_t, is_mapping_t, is_sequence_t +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + + def is_file_content(obj: object) -> TypeGuard[FileContent]: return ( isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) diff --git a/src/maisa/_models.py b/src/maisa/_models.py index 48d5624f..81089149 100644 --- a/src/maisa/_models.py +++ b/src/maisa/_models.py @@ -283,7 +283,7 @@ def construct_type(*, value: object, type_: type) -> object: if is_union(origin): try: - return validate_type(type_=type_, value=value) + return validate_type(type_=cast("type[object]", type_), value=value) except Exception: pass diff --git a/src/maisa/_resource.py b/src/maisa/_resource.py index 647a61a7..ee940c3d 100644 --- a/src/maisa/_resource.py +++ b/src/maisa/_resource.py @@ -3,9 +3,10 @@ from __future__ import annotations import time -import asyncio from typing import TYPE_CHECKING +import anyio + if TYPE_CHECKING: from ._client import Maisa, AsyncMaisa @@ -39,4 +40,4 @@ def __init__(self, client: AsyncMaisa) -> None: self._get_api_list = client.get_api_list async def _sleep(self, seconds: float) -> None: - await asyncio.sleep(seconds) + await anyio.sleep(seconds) diff --git a/src/maisa/_streaming.py b/src/maisa/_streaming.py index 502fceaf..0b48ef2c 100644 --- a/src/maisa/_streaming.py +++ b/src/maisa/_streaming.py @@ -5,7 +5,7 @@ import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast -from typing_extensions import Self, TypeGuard, override, get_origin +from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable import httpx @@ -23,6 +23,8 @@ class Stream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -33,7 +35,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() def __next__(self) -> _T: @@ -44,7 +46,10 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - yield from self._decoder.iter(self.response.iter_lines()) + if isinstance(self._decoder, SSEBytesDecoder): + yield from self._decoder.iter_bytes(self.response.iter_bytes()) + else: + yield from self._decoder.iter(self.response.iter_lines()) def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -84,6 +89,8 @@ class AsyncStream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -94,7 +101,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() async def __anext__(self) -> _T: @@ -105,8 +112,12 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - async for sse in self._decoder.aiter(self.response.aiter_lines()): - yield sse + if isinstance(self._decoder, SSEBytesDecoder): + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse + else: + async for sse in self._decoder.aiter(self.response.aiter_lines()): + yield sse async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) @@ -259,6 +270,17 @@ def decode(self, line: str) -> ServerSentEvent | None: return None +@runtime_checkable +class SSEBytesDecoder(Protocol): + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" origin = get_origin(typ) or typ diff --git a/src/maisa/_types.py b/src/maisa/_types.py index cad49dde..94a83c8a 100644 --- a/src/maisa/_types.py +++ b/src/maisa/_types.py @@ -40,8 +40,10 @@ ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] ProxiesTypes = Union[str, Proxy, ProxiesDict] if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] FileContent = Union[IO[bytes], bytes, PathLike[str]] else: + Base64FileInput = Union[IO[bytes], PathLike] FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. FileTypes = Union[ # file (or bytes) diff --git a/src/maisa/_utils/__init__.py b/src/maisa/_utils/__init__.py index b5790a87..56978941 100644 --- a/src/maisa/_utils/__init__.py +++ b/src/maisa/_utils/__init__.py @@ -44,5 +44,7 @@ from ._transform import ( PropertyInfo as PropertyInfo, transform as transform, + async_transform as async_transform, maybe_transform as maybe_transform, + async_maybe_transform as async_maybe_transform, ) diff --git a/src/maisa/_utils/_proxy.py b/src/maisa/_utils/_proxy.py index 6f05efcd..b9c12dc3 100644 --- a/src/maisa/_utils/_proxy.py +++ b/src/maisa/_utils/_proxy.py @@ -45,7 +45,7 @@ def __dir__(self) -> Iterable[str]: @property # type: ignore @override - def __class__(self) -> type: + def __class__(self) -> type: # pyright: ignore proxied = self.__get_proxied__() if issubclass(type(proxied), LazyProxy): return type(proxied) diff --git a/src/maisa/_utils/_transform.py b/src/maisa/_utils/_transform.py index 2cb7726c..1bd1330c 100644 --- a/src/maisa/_utils/_transform.py +++ b/src/maisa/_utils/_transform.py @@ -1,9 +1,13 @@ from __future__ import annotations +import io +import base64 +import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime from typing_extensions import Literal, get_args, override, get_type_hints +import anyio import pydantic from ._utils import ( @@ -11,6 +15,7 @@ is_mapping, is_iterable, ) +from .._files import is_base64_file_input from ._typing import ( is_list_type, is_union_type, @@ -29,7 +34,7 @@ # TODO: ensure works correctly with forward references in all cases -PropertyFormat = Literal["iso8601", "custom"] +PropertyFormat = Literal["iso8601", "base64", "custom"] class PropertyInfo: @@ -180,11 +185,7 @@ def _transform_recursive( if isinstance(data, pydantic.BaseModel): return model_dump(data, exclude_unset=True) - return _transform_value(data, annotation) - - -def _transform_value(data: object, type_: type) -> object: - annotated_type = _get_annotated_type(type_) + annotated_type = _get_annotated_type(annotation) if annotated_type is None: return data @@ -205,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N if format_ == "custom" and format_template is not None: return data.strftime(format_template) + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + return data @@ -222,3 +239,141 @@ def _transform_typeddict( else: result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) return result + + +async def async_maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `async_transform()` that allows `None` to be passed. + + See `async_transform()` for more details. + """ + if data is None: + return None + return await async_transform(data, expected_type) + + +async def async_transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +async def _async_transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return await _async_transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return await _async_format_data(data, annotation.format, annotation.format_template) + + return data + + +async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +async def _async_transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + return result diff --git a/src/maisa/resources/__init__.py b/src/maisa/resources/__init__.py index 6666c4dd..cd1034e9 100644 --- a/src/maisa/resources/__init__.py +++ b/src/maisa/resources/__init__.py @@ -1,5 +1,13 @@ # File generated from our OpenAPI spec by Stainless. +from .kpu import ( + Kpu, + AsyncKpu, + KpuWithRawResponse, + AsyncKpuWithRawResponse, + KpuWithStreamingResponse, + AsyncKpuWithStreamingResponse, +) from .mainet import ( Mainet, AsyncMainet, @@ -46,6 +54,12 @@ "AsyncModelsWithRawResponse", "ModelsWithStreamingResponse", "AsyncModelsWithStreamingResponse", + "Kpu", + "AsyncKpu", + "KpuWithRawResponse", + "AsyncKpuWithRawResponse", + "KpuWithStreamingResponse", + "AsyncKpuWithStreamingResponse", "FileInterpreter", "AsyncFileInterpreter", "FileInterpreterWithRawResponse", diff --git a/src/maisa/resources/capabilities/capabilities.py b/src/maisa/resources/capabilities/capabilities.py index 313d73d1..84eb7f74 100644 --- a/src/maisa/resources/capabilities/capabilities.py +++ b/src/maisa/resources/capabilities/capabilities.py @@ -17,7 +17,10 @@ ) from ...types import capability_compare_params, capability_extract_params, capability_summarize_params from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -265,7 +268,7 @@ async def compare( """ return await self._post( "/v1/capabilities/compare", - body=maybe_transform( + body=await async_maybe_transform( { "text1": text1, "text2": text2, @@ -317,7 +320,7 @@ async def extract( """ return await self._post( "/v1/capabilities/extract", - body=maybe_transform( + body=await async_maybe_transform( { "text": text, "variables": variables, @@ -373,7 +376,7 @@ async def summarize( """ return await self._post( "/v1/capabilities/summarize", - body=maybe_transform( + body=await async_maybe_transform( { "text": text, "format": format, diff --git a/src/maisa/resources/capabilities/media.py b/src/maisa/resources/capabilities/media.py index 662bbb8b..9a16c043 100644 --- a/src/maisa/resources/capabilities/media.py +++ b/src/maisa/resources/capabilities/media.py @@ -8,7 +8,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -414,7 +419,7 @@ async def compare( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/compare/media", - body=maybe_transform(body, media_compare_params.MediaCompareParams), + body=await async_maybe_transform(body, media_compare_params.MediaCompareParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -515,7 +520,7 @@ async def extract( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/extract/media", - body=maybe_transform(body, media_extract_params.MediaExtractParams), + body=await async_maybe_transform(body, media_extract_params.MediaExtractParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -579,7 +584,7 @@ async def summarize( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/capabilities/summarize/media", - body=maybe_transform(body, media_summarize_params.MediaSummarizeParams), + body=await async_maybe_transform(body, media_summarize_params.MediaSummarizeParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/maisa/resources/file_interpreter/from_audio.py b/src/maisa/resources/file_interpreter/from_audio.py index 04715f1b..9b93ef88 100644 --- a/src/maisa/resources/file_interpreter/from_audio.py +++ b/src/maisa/resources/file_interpreter/from_audio.py @@ -7,7 +7,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -115,7 +120,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-audio", - body=maybe_transform(body, from_audio_create_params.FromAudioCreateParams), + body=await async_maybe_transform(body, from_audio_create_params.FromAudioCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/maisa/resources/file_interpreter/from_docx.py b/src/maisa/resources/file_interpreter/from_docx.py index 8ee42a69..a2dd8103 100644 --- a/src/maisa/resources/file_interpreter/from_docx.py +++ b/src/maisa/resources/file_interpreter/from_docx.py @@ -7,7 +7,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -115,7 +120,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-docx", - body=maybe_transform(body, from_docx_create_params.FromDocxCreateParams), + body=await async_maybe_transform(body, from_docx_create_params.FromDocxCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/maisa/resources/file_interpreter/from_html.py b/src/maisa/resources/file_interpreter/from_html.py index a653bd62..c01b2522 100644 --- a/src/maisa/resources/file_interpreter/from_html.py +++ b/src/maisa/resources/file_interpreter/from_html.py @@ -7,7 +7,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -115,7 +120,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-html", - body=maybe_transform(body, from_html_create_params.FromHTMLCreateParams), + body=await async_maybe_transform(body, from_html_create_params.FromHTMLCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/maisa/resources/file_interpreter/from_image.py b/src/maisa/resources/file_interpreter/from_image.py index 2a20441e..b640ee22 100644 --- a/src/maisa/resources/file_interpreter/from_image.py +++ b/src/maisa/resources/file_interpreter/from_image.py @@ -7,7 +7,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -115,7 +120,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-image", - body=maybe_transform(body, from_image_create_params.FromImageCreateParams), + body=await async_maybe_transform(body, from_image_create_params.FromImageCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/maisa/resources/file_interpreter/from_pdf.py b/src/maisa/resources/file_interpreter/from_pdf.py index f8ac92c2..6a050e64 100644 --- a/src/maisa/resources/file_interpreter/from_pdf.py +++ b/src/maisa/resources/file_interpreter/from_pdf.py @@ -7,7 +7,12 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -121,14 +126,14 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/v1/file-interpreter/from-pdf", - body=maybe_transform(body, from_pdf_create_params.FromPdfCreateParams), + body=await async_maybe_transform(body, from_pdf_create_params.FromPdfCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform({"max_pages": max_pages}, from_pdf_create_params.FromPdfCreateParams), + query=await async_maybe_transform({"max_pages": max_pages}, from_pdf_create_params.FromPdfCreateParams), ), cast_to=object, ) diff --git a/src/maisa/resources/kpu.py b/src/maisa/resources/kpu.py new file mode 100644 index 00000000..3743abfc --- /dev/null +++ b/src/maisa/resources/kpu.py @@ -0,0 +1,223 @@ +# File generated from our OpenAPI spec by Stainless. + +from __future__ import annotations + +from typing import List, Mapping, cast + +import httpx + +from ..types import KpuRunResponse, kpu_run_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes +from .._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import ( + make_request_options, +) + +__all__ = ["Kpu", "AsyncKpu"] + + +class Kpu(SyncAPIResource): + @cached_property + def with_raw_response(self) -> KpuWithRawResponse: + return KpuWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> KpuWithStreamingResponse: + return KpuWithStreamingResponse(self) + + def run( + self, + *, + query: str, + explain_steps: bool | NotGiven = NOT_GIVEN, + retries: int | NotGiven = NOT_GIVEN, + file: List[FileTypes] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> KpuRunResponse: + """ + Executes the KPU in sync, sending the response when the KPU execution is done. + + Args: + query: User text with the query or request to be commanded to the KPU. + + explain_steps: If true, the KPU will explain in natural language the steps of each step of each + intent. Enabling this feature can slow down the KPU execution, and increase the + usage metric. + + retries: Number of retries in case of failure. Retries are sequential, and each failed + intent yields a learning for the next intent. This feature is experimental. + + file: Files to be used in the KPU execution. Files can be of any type. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + body = deepcopy_minimal( + { + "query": query, + "file": file, + } + ) + files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]]) + if files: + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + return self._post( + "/v1/kpu/run", + body=maybe_transform(body, kpu_run_params.KpuRunParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "explain_steps": explain_steps, + "retries": retries, + }, + kpu_run_params.KpuRunParams, + ), + ), + cast_to=KpuRunResponse, + ) + + +class AsyncKpu(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncKpuWithRawResponse: + return AsyncKpuWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncKpuWithStreamingResponse: + return AsyncKpuWithStreamingResponse(self) + + async def run( + self, + *, + query: str, + explain_steps: bool | NotGiven = NOT_GIVEN, + retries: int | NotGiven = NOT_GIVEN, + file: List[FileTypes] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> KpuRunResponse: + """ + Executes the KPU in sync, sending the response when the KPU execution is done. + + Args: + query: User text with the query or request to be commanded to the KPU. + + explain_steps: If true, the KPU will explain in natural language the steps of each step of each + intent. Enabling this feature can slow down the KPU execution, and increase the + usage metric. + + retries: Number of retries in case of failure. Retries are sequential, and each failed + intent yields a learning for the next intent. This feature is experimental. + + file: Files to be used in the KPU execution. Files can be of any type. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + body = deepcopy_minimal( + { + "query": query, + "file": file, + } + ) + files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]]) + if files: + # It should be noted that the actual Content-Type header that will be + # sent to the server will contain a `boundary` parameter, e.g. + # multipart/form-data; boundary=---abc-- + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + return await self._post( + "/v1/kpu/run", + body=await async_maybe_transform(body, kpu_run_params.KpuRunParams), + files=files, + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + { + "explain_steps": explain_steps, + "retries": retries, + }, + kpu_run_params.KpuRunParams, + ), + ), + cast_to=KpuRunResponse, + ) + + +class KpuWithRawResponse: + def __init__(self, kpu: Kpu) -> None: + self._kpu = kpu + + self.run = to_raw_response_wrapper( + kpu.run, + ) + + +class AsyncKpuWithRawResponse: + def __init__(self, kpu: AsyncKpu) -> None: + self._kpu = kpu + + self.run = async_to_raw_response_wrapper( + kpu.run, + ) + + +class KpuWithStreamingResponse: + def __init__(self, kpu: Kpu) -> None: + self._kpu = kpu + + self.run = to_streamed_response_wrapper( + kpu.run, + ) + + +class AsyncKpuWithStreamingResponse: + def __init__(self, kpu: AsyncKpu) -> None: + self._kpu = kpu + + self.run = async_to_streamed_response_wrapper( + kpu.run, + ) diff --git a/src/maisa/resources/mainet/search.py b/src/maisa/resources/mainet/search.py index 6c77b389..6b376d54 100644 --- a/src/maisa/resources/mainet/search.py +++ b/src/maisa/resources/mainet/search.py @@ -5,7 +5,10 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -102,7 +105,7 @@ async def create( """ return await self._post( "/v1/mainet/search", - body=maybe_transform({"text": text}, search_create_params.SearchCreateParams), + body=await async_maybe_transform({"text": text}, search_create_params.SearchCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/maisa/resources/models/embeddings.py b/src/maisa/resources/models/embeddings.py index 9d3e5575..eb3dfced 100644 --- a/src/maisa/resources/models/embeddings.py +++ b/src/maisa/resources/models/embeddings.py @@ -7,7 +7,10 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -104,7 +107,7 @@ async def create( """ return await self._post( "/v1/models/embeddings", - body=maybe_transform({"texts": texts}, embedding_create_params.EmbeddingCreateParams), + body=await async_maybe_transform({"texts": texts}, embedding_create_params.EmbeddingCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/maisa/resources/models/rerank.py b/src/maisa/resources/models/rerank.py index bcf6042c..5807a961 100644 --- a/src/maisa/resources/models/rerank.py +++ b/src/maisa/resources/models/rerank.py @@ -7,7 +7,10 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -116,7 +119,7 @@ async def create( """ return await self._post( "/v1/models/rerank", - body=maybe_transform( + body=await async_maybe_transform( { "sentences": sentences, "source_sentence": source_sentence, diff --git a/src/maisa/types/__init__.py b/src/maisa/types/__init__.py index 94785c30..de43413c 100644 --- a/src/maisa/types/__init__.py +++ b/src/maisa/types/__init__.py @@ -3,6 +3,8 @@ from __future__ import annotations from .shared import TextSummary as TextSummary, TextExtractor as TextExtractor, TextComparator as TextComparator +from .kpu_run_params import KpuRunParams as KpuRunParams +from .kpu_run_response import KpuRunResponse as KpuRunResponse from .capability_compare_params import CapabilityCompareParams as CapabilityCompareParams from .capability_extract_params import CapabilityExtractParams as CapabilityExtractParams from .capability_summarize_params import CapabilitySummarizeParams as CapabilitySummarizeParams diff --git a/src/maisa/types/kpu_run_params.py b/src/maisa/types/kpu_run_params.py new file mode 100644 index 00000000..bad448ed --- /dev/null +++ b/src/maisa/types/kpu_run_params.py @@ -0,0 +1,32 @@ +# File generated from our OpenAPI spec by Stainless. + +from __future__ import annotations + +from typing import List +from typing_extensions import Required, TypedDict + +from .._types import FileTypes + +__all__ = ["KpuRunParams"] + + +class KpuRunParams(TypedDict, total=False): + query: Required[str] + """User text with the query or request to be commanded to the KPU.""" + + explain_steps: bool + """ + If true, the KPU will explain in natural language the steps of each step of each + intent. Enabling this feature can slow down the KPU execution, and increase the + usage metric. + """ + + retries: int + """Number of retries in case of failure. + + Retries are sequential, and each failed intent yields a learning for the next + intent. This feature is experimental. + """ + + file: List[FileTypes] + """Files to be used in the KPU execution. Files can be of any type.""" diff --git a/src/maisa/types/kpu_run_response.py b/src/maisa/types/kpu_run_response.py new file mode 100644 index 00000000..efacc7d6 --- /dev/null +++ b/src/maisa/types/kpu_run_response.py @@ -0,0 +1,41 @@ +# File generated from our OpenAPI spec by Stainless. + +from typing import Dict, List, Optional + +from .._models import BaseModel + +__all__ = ["KpuRunResponse", "Intent"] + + +class Intent(BaseModel): + explain_steps: List[str] + """Array of the steps of the intent explained in natural language. + + This field is an ampty array if `explain_steps` param is set to `false` + """ + + intent: int + """Intent number, starting from 0.""" + + result: str + """The result of the intent.""" + + solved: bool + """Whether the intent was interpreted as solved by the KPU or not.""" + + downloadable_files: Optional[Dict[str, str]] = None + """Key-value of the files generated by the KPU.""" + + +class KpuRunResponse(BaseModel): + intents: List[Intent] + """Array of the intents executed by the KPU.""" + + result: str + """The result of the KPU execution. + + The result may be invalid if none of the intents were successful. + """ + + downloadable_files: Optional[Dict[str, str]] = None + """Key-value of the files generated by the KPU.""" diff --git a/tests/api_resources/test_kpu.py b/tests/api_resources/test_kpu.py new file mode 100644 index 00000000..93eb736a --- /dev/null +++ b/tests/api_resources/test_kpu.py @@ -0,0 +1,104 @@ +# File generated from our OpenAPI spec by Stainless. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from maisa import Maisa, AsyncMaisa +from maisa.types import KpuRunResponse +from tests.utils import assert_matches_type + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestKpu: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_run(self, client: Maisa) -> None: + kpu = client.kpu.run( + query="string", + ) + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + def test_method_run_with_all_params(self, client: Maisa) -> None: + kpu = client.kpu.run( + query="string", + explain_steps=True, + retries=1, + file=[b"raw file contents", b"raw file contents", b"raw file contents"], + ) + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + def test_raw_response_run(self, client: Maisa) -> None: + response = client.kpu.with_raw_response.run( + query="string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + kpu = response.parse() + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + def test_streaming_response_run(self, client: Maisa) -> None: + with client.kpu.with_streaming_response.run( + query="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + kpu = response.parse() + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncKpu: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_run(self, async_client: AsyncMaisa) -> None: + kpu = await async_client.kpu.run( + query="string", + ) + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + async def test_method_run_with_all_params(self, async_client: AsyncMaisa) -> None: + kpu = await async_client.kpu.run( + query="string", + explain_steps=True, + retries=1, + file=[b"raw file contents", b"raw file contents", b"raw file contents"], + ) + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + async def test_raw_response_run(self, async_client: AsyncMaisa) -> None: + response = await async_client.kpu.with_raw_response.run( + query="string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + kpu = await response.parse() + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + @parametrize + async def test_streaming_response_run(self, async_client: AsyncMaisa) -> None: + async with async_client.kpu.with_streaming_response.run( + query="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + kpu = await response.parse() + assert_matches_type(KpuRunResponse, kpu, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/sample_file.txt b/tests/sample_file.txt new file mode 100644 index 00000000..af5626b4 --- /dev/null +++ b/tests/sample_file.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/tests/test_client.py b/tests/test_client.py index 5381aa73..050dd28c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -82,7 +82,7 @@ def test_copy_default_options(self) -> None: # options that have a default are overridden correctly copied = self.client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 3 + assert self.client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 @@ -291,6 +291,16 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + async def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + async with httpx.AsyncClient() as http_client: + Maisa( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + def test_default_headers_option(self) -> None: client = Maisa( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} @@ -675,12 +685,12 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/v1/models/embeddings").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): self.client.post( - "/v1/models/embeddings", - body=cast(object, dict(texts=["string"])), + "/v1/capabilities/summarize", + body=cast(object, dict(text="Example long text...")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -690,12 +700,12 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/v1/models/embeddings").mock(return_value=httpx.Response(500)) + respx_mock.post("/v1/capabilities/summarize").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): self.client.post( - "/v1/models/embeddings", - body=cast(object, dict(texts=["string"])), + "/v1/capabilities/summarize", + body=cast(object, dict(text="Example long text...")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -740,7 +750,7 @@ def test_copy_default_options(self) -> None: # options that have a default are overridden correctly copied = self.client.copy(max_retries=7) assert copied.max_retries == 7 - assert self.client.max_retries == 3 + assert self.client.max_retries == 2 copied2 = copied.copy(max_retries=6) assert copied2.max_retries == 6 @@ -951,6 +961,16 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + with httpx.Client() as http_client: + AsyncMaisa( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + def test_default_headers_option(self) -> None: client = AsyncMaisa( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} @@ -1345,12 +1365,12 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/v1/models/embeddings").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/v1/capabilities/summarize").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): await self.client.post( - "/v1/models/embeddings", - body=cast(object, dict(texts=["string"])), + "/v1/capabilities/summarize", + body=cast(object, dict(text="Example long text...")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1360,12 +1380,12 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/v1/models/embeddings").mock(return_value=httpx.Response(500)) + respx_mock.post("/v1/capabilities/summarize").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): await self.client.post( - "/v1/models/embeddings", - body=cast(object, dict(texts=["string"])), + "/v1/capabilities/summarize", + body=cast(object, dict(text="Example long text...")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) diff --git a/tests/test_transform.py b/tests/test_transform.py index 5f58e6dc..19442b4f 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,22 +1,50 @@ from __future__ import annotations -from typing import Any, List, Union, Iterable, Optional, cast +import io +import pathlib +from typing import Any, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict import pytest -from maisa._utils import PropertyInfo, transform, parse_datetime +from maisa._types import Base64FileInput +from maisa._utils import ( + PropertyInfo, + transform as _transform, + parse_datetime, + async_transform as _async_transform, +) from maisa._compat import PYDANTIC_V2 from maisa._models import BaseModel +_T = TypeVar("_T") + +SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt") + + +async def transform( + data: _T, + expected_type: object, + use_async: bool, +) -> _T: + if use_async: + return await _async_transform(data, expected_type=expected_type) + + return _transform(data, expected_type=expected_type) + + +parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + class Foo1(TypedDict): foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] -def test_top_level_alias() -> None: - assert transform({"foo_bar": "hello"}, expected_type=Foo1) == {"fooBar": "hello"} +@parametrize +@pytest.mark.asyncio +async def test_top_level_alias(use_async: bool) -> None: + assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"} class Foo2(TypedDict): @@ -32,9 +60,11 @@ class Baz2(TypedDict): my_baz: Annotated[str, PropertyInfo(alias="myBaz")] -def test_recursive_typeddict() -> None: - assert transform({"bar": {"this_thing": 1}}, Foo2) == {"bar": {"this__thing": 1}} - assert transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2) == {"bar": {"Baz": {"myBaz": "foo"}}} +@parametrize +@pytest.mark.asyncio +async def test_recursive_typeddict(use_async: bool) -> None: + assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}} + assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}} class Foo3(TypedDict): @@ -45,8 +75,10 @@ class Bar3(TypedDict): my_field: Annotated[str, PropertyInfo(alias="myField")] -def test_list_of_typeddict() -> None: - result = transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, expected_type=Foo3) +@parametrize +@pytest.mark.asyncio +async def test_list_of_typeddict(use_async: bool) -> None: + result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async) assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]} @@ -62,10 +94,14 @@ class Baz4(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_union_of_typeddict() -> None: - assert transform({"foo": {"foo_bar": "bar"}}, Foo4) == {"foo": {"fooBar": "bar"}} - assert transform({"foo": {"foo_baz": "baz"}}, Foo4) == {"foo": {"fooBaz": "baz"}} - assert transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4) == {"foo": {"fooBaz": "baz", "fooBar": "bar"}} +@parametrize +@pytest.mark.asyncio +async def test_union_of_typeddict(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}} + assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}} + assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == { + "foo": {"fooBaz": "baz", "fooBar": "bar"} + } class Foo5(TypedDict): @@ -80,9 +116,11 @@ class Baz5(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_union_of_list() -> None: - assert transform({"foo": {"foo_bar": "bar"}}, Foo5) == {"FOO": {"fooBar": "bar"}} - assert transform( +@parametrize +@pytest.mark.asyncio +async def test_union_of_list(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}} + assert await transform( { "foo": [ {"foo_baz": "baz"}, @@ -90,6 +128,7 @@ def test_union_of_list() -> None: ] }, Foo5, + use_async, ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]} @@ -97,8 +136,10 @@ class Foo6(TypedDict): bar: Annotated[str, PropertyInfo(alias="Bar")] -def test_includes_unknown_keys() -> None: - assert transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6) == { +@parametrize +@pytest.mark.asyncio +async def test_includes_unknown_keys(use_async: bool) -> None: + assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == { "Bar": "bar", "baz_": {"FOO": 1}, } @@ -113,9 +154,11 @@ class Bar7(TypedDict): foo: str -def test_ignores_invalid_input() -> None: - assert transform({"bar": ""}, Foo7) == {"bAr": ""} - assert transform({"foo": ""}, Foo7) == {"foo": ""} +@parametrize +@pytest.mark.asyncio +async def test_ignores_invalid_input(use_async: bool) -> None: + assert await transform({"bar": ""}, Foo7, use_async) == {"bAr": ""} + assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""} class DatetimeDict(TypedDict, total=False): @@ -134,52 +177,66 @@ class DateDict(TypedDict, total=False): foo: Annotated[date, PropertyInfo(format="iso8601")] -def test_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] dt = dt.replace(tzinfo=None) - assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] - assert transform({"foo": None}, DateDict) == {"foo": None} # type: ignore[comparison-overlap] - assert transform({"foo": date.fromisoformat("2023-02-23")}, DateDict) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] -def test_optional_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_optional_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"bar": dt}, DatetimeDict) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] - assert transform({"bar": None}, DatetimeDict) == {"bar": None} + assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None} -def test_required_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_required_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"required": dt}, DatetimeDict) == {"required": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"required": dt}, DatetimeDict, use_async) == { + "required": "2023-02-23T14:16:36.337692+00:00" + } # type: ignore[comparison-overlap] - assert transform({"required": None}, DatetimeDict) == {"required": None} + assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None} -def test_union_datetime() -> None: +@parametrize +@pytest.mark.asyncio +async def test_union_datetime(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"union": dt}, DatetimeDict) == { # type: ignore[comparison-overlap] + assert await transform({"union": dt}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] "union": "2023-02-23T14:16:36.337692+00:00" } - assert transform({"union": "foo"}, DatetimeDict) == {"union": "foo"} + assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"} -def test_nested_list_iso6801_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_nested_list_iso6801_format(use_async: bool) -> None: dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") dt2 = parse_datetime("2022-01-15T06:34:23Z") - assert transform({"list_": [dt1, dt2]}, DatetimeDict) == { # type: ignore[comparison-overlap] + assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"] } -def test_datetime_custom_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_datetime_custom_format(use_async: bool) -> None: dt = parse_datetime("2022-01-15T06:34:23Z") - result = transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")]) + result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async) assert result == "06" # type: ignore[comparison-overlap] @@ -187,47 +244,59 @@ class DateDictWithRequiredAlias(TypedDict, total=False): required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]] -def test_datetime_with_alias() -> None: - assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap] - assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == { - "prop": "2023-02-23" - } # type: ignore[comparison-overlap] +@parametrize +@pytest.mark.asyncio +async def test_datetime_with_alias(use_async: bool) -> None: + assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None} # type: ignore[comparison-overlap] + assert await transform( + {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async + ) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap] class MyModel(BaseModel): foo: str -def test_pydantic_model_to_dictionary() -> None: - assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"} - assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_model_to_dictionary(use_async: bool) -> None: + assert await transform(MyModel(foo="hi!"), Any, use_async) == {"foo": "hi!"} + assert await transform(MyModel.construct(foo="hi!"), Any, use_async) == {"foo": "hi!"} -def test_pydantic_empty_model() -> None: - assert transform(MyModel.construct(), Any) == {} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_empty_model(use_async: bool) -> None: + assert await transform(MyModel.construct(), Any, use_async) == {} -def test_pydantic_unknown_field() -> None: - assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_unknown_field(use_async: bool) -> None: + assert await transform(MyModel.construct(my_untyped_field=True), Any, use_async) == {"my_untyped_field": True} -def test_pydantic_mismatched_types() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_types(use_async: bool) -> None: model = MyModel.construct(foo=True) if PYDANTIC_V2: with pytest.warns(UserWarning): - params = transform(model, Any) + params = await transform(model, Any, use_async) else: - params = transform(model, Any) + params = await transform(model, Any, use_async) assert params == {"foo": True} -def test_pydantic_mismatched_object_type() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_object_type(use_async: bool) -> None: model = MyModel.construct(foo=MyModel.construct(hello="world")) if PYDANTIC_V2: with pytest.warns(UserWarning): - params = transform(model, Any) + params = await transform(model, Any, use_async) else: - params = transform(model, Any) + params = await transform(model, Any, use_async) assert params == {"foo": {"hello": "world"}} @@ -235,10 +304,12 @@ class ModelNestedObjects(BaseModel): nested: MyModel -def test_pydantic_nested_objects() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_nested_objects(use_async: bool) -> None: model = ModelNestedObjects.construct(nested={"foo": "stainless"}) assert isinstance(model.nested, MyModel) - assert transform(model, Any) == {"nested": {"foo": "stainless"}} + assert await transform(model, Any, use_async) == {"nested": {"foo": "stainless"}} class ModelWithDefaultField(BaseModel): @@ -247,24 +318,26 @@ class ModelWithDefaultField(BaseModel): with_str_default: str = "foo" -def test_pydantic_default_field() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_default_field(use_async: bool) -> None: # should be excluded when defaults are used model = ModelWithDefaultField.construct() assert model.with_none_default is None assert model.with_str_default == "foo" - assert transform(model, Any) == {} + assert await transform(model, Any, use_async) == {} # should be included when the default value is explicitly given model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") assert model.with_none_default is None assert model.with_str_default == "foo" - assert transform(model, Any) == {"with_none_default": None, "with_str_default": "foo"} + assert await transform(model, Any, use_async) == {"with_none_default": None, "with_str_default": "foo"} # should be included when a non-default value is explicitly given model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") assert model.with_none_default == "bar" assert model.with_str_default == "baz" - assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"} + assert await transform(model, Any, use_async) == {"with_none_default": "bar", "with_str_default": "baz"} class TypedDictIterableUnion(TypedDict): @@ -279,21 +352,57 @@ class Baz8(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_iterable_of_dictionaries() -> None: - assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]} - assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]} +@parametrize +@pytest.mark.asyncio +async def test_iterable_of_dictionaries(use_async: bool) -> None: + assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "bar"}] + } + assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == { + "FOO": [{"fooBaz": "bar"}] + } def my_iter() -> Iterable[Baz8]: yield {"foo_baz": "hello"} yield {"foo_baz": "world"} - assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]} + assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}] + } class TypedDictIterableUnionStr(TypedDict): foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] -def test_iterable_union_str() -> None: - assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"} - assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}] +@parametrize +@pytest.mark.asyncio +async def test_iterable_union_str(use_async: bool) -> None: + assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"} + assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [ + {"fooBaz": "bar"} + ] + + +class TypedDictBase64Input(TypedDict): + foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")] + + +@parametrize +@pytest.mark.asyncio +async def test_base64_file_input(use_async: bool) -> None: + # strings are left as-is + assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"} + + # pathlib.Path is automatically converted to base64 + assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQo=" + } # type: ignore[comparison-overlap] + + # io instances are automatically converted to base64 + assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap]