diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index c6b2570e..b88cdcbc 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -14,7 +14,7 @@ jobs:
steps:
- - uses: actions/checkout@v3
+ - uses: actions/checkout@v4
- name: Install Rye
run: |
diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml
index ed786f39..81fea1d1 100644
--- a/.github/workflows/publish-pypi.yml
+++ b/.github/workflows/publish-pypi.yml
@@ -1,6 +1,6 @@
# This workflow is triggered when a GitHub release is created.
# It can also be run manually to re-publish to PyPI in case it failed for some reason.
-# You can run this workflow by navigating to https://www.github.com/clibrain/python-sdk/actions/workflows/publish-pypi.yml
+# You can run this workflow by navigating to https://www.github.com/maisaai/python-sdk/actions/workflows/publish-pypi.yml
name: Publish PyPI
on:
workflow_dispatch:
diff --git a/.github/workflows/release-doctor.yml b/.github/workflows/release-doctor.yml
index 75712fd0..7a5fc2f2 100644
--- a/.github/workflows/release-doctor.yml
+++ b/.github/workflows/release-doctor.yml
@@ -7,7 +7,7 @@ jobs:
release_doctor:
name: release doctor
runs-on: ubuntu-latest
- if: github.repository == 'clibrain/python-sdk' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next')
+ if: github.repository == 'maisaai/python-sdk' && (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || startsWith(github.head_ref, 'release-please') || github.head_ref == 'next')
steps:
- uses: actions/checkout@v3
diff --git a/.stats.yml b/.stats.yml
index dd473053..c2549479 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1 +1 @@
-configured_endpoints: 14
+configured_endpoints: 15
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 5ee857ad..54364e97 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -59,7 +59,7 @@ If you’d like to use the repository from source, you can either install from g
To install via git:
```bash
-pip install git+ssh://git@github.com:clibrain/python-sdk.git
+pip install git+ssh://git@github.com/maisaai/python-sdk.git
```
Alternatively, you can build from source and install the wheel file:
@@ -82,7 +82,7 @@ pip install ./path-to-wheel-file.whl
## Running tests
-Most tests will require you to [setup a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests.
+Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests.
```bash
# you will need npm installed
@@ -117,7 +117,7 @@ the changes aren't made through the automated pipeline, you may want to make rel
### Publish with a GitHub workflow
-You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/clibrain/python-sdk/actions/workflows/publish-pypi.yml). This will require a setup organization or repository secret to be set up.
+You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/maisaai/python-sdk/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up.
### Publish manually
diff --git a/README.md b/README.md
index 7bdb10c9..78f19d95 100644
--- a/README.md
+++ b/README.md
@@ -8,11 +8,12 @@ and offers both synchronous and asynchronous clients powered by [httpx](https://
## Documentation
-The REST API documentation can be found [on maisa.ai](https://maisa.ai/). The full API of this library can be found in [api.md](api.md).
+The REST API documentation can be found [on docs.maisa.ai](https://docs.maisa.ai/). The full API of this library can be found in [api.md](api.md).
## Installation
```sh
+# install from PyPI
pip install --pre maisa
```
@@ -29,10 +30,10 @@ client = Maisa(
api_key=os.environ.get("MAISA_API_KEY"),
)
-embeddings = client.models.embeddings.create(
- texts=["string"],
+text_summary = client.capabilities.summarize(
+ text="Example long text...",
)
-print(embeddings.embeddings)
+print(text_summary.summary)
```
While you can provide an `api_key` keyword argument,
@@ -56,10 +57,10 @@ client = AsyncMaisa(
async def main() -> None:
- embeddings = await client.models.embeddings.create(
- texts=["string"],
+ text_summary = await client.capabilities.summarize(
+ text="Example long text...",
)
- print(embeddings.embeddings)
+ print(text_summary.summary)
asyncio.run(main())
@@ -92,8 +93,8 @@ from maisa import Maisa
client = Maisa()
try:
- client.models.embeddings.create(
- texts=["string"],
+ client.capabilities.summarize(
+ text="Example long text...",
)
except maisa.APIConnectionError as e:
print("The server could not be reached")
@@ -121,7 +122,7 @@ Error codes are as followed:
### Retries
-Certain errors are automatically retried 3 times by default, with a short exponential backoff.
+Certain errors are automatically retried 2 times by default, with a short exponential backoff.
Connection errors (for example, due to a network connectivity problem), 408 Request Timeout, 409 Conflict,
429 Rate Limit, and >=500 Internal errors are all retried by default.
@@ -137,8 +138,8 @@ client = Maisa(
)
# Or, configure per-request:
-client.with_options(max_retries=5).models.embeddings.create(
- texts=["string"],
+client.with_options(max_retries=5).capabilities.summarize(
+ text="Example long text...",
)
```
@@ -162,8 +163,8 @@ client = Maisa(
)
# Override per-request:
-client.with_options(timeout=5 * 1000).models.embeddings.create(
- texts=["string"],
+client.with_options(timeout=5 * 1000).capabilities.summarize(
+ text="Example long text...",
)
```
@@ -203,18 +204,18 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to
from maisa import Maisa
client = Maisa()
-response = client.models.embeddings.with_raw_response.create(
- texts=["string"],
+response = client.capabilities.with_raw_response.summarize(
+ text="Example long text...",
)
print(response.headers.get('X-My-Header'))
-embedding = response.parse() # get the object that `models.embeddings.create()` would have returned
-print(embedding.embeddings)
+capability = response.parse() # get the object that `capabilities.summarize()` would have returned
+print(capability.summary)
```
-These methods return an [`APIResponse`](https://github.com/clibrain/python-sdk/tree/main/src/maisa/_response.py) object.
+These methods return an [`APIResponse`](https://github.com/maisaai/python-sdk/tree/main/src/maisa/_response.py) object.
-The async client returns an [`AsyncAPIResponse`](https://github.com/clibrain/python-sdk/tree/main/src/maisa/_response.py) with the same structure, the only difference being `await`able methods for reading the response content.
+The async client returns an [`AsyncAPIResponse`](https://github.com/maisaai/python-sdk/tree/main/src/maisa/_response.py) with the same structure, the only difference being `await`able methods for reading the response content.
#### `.with_streaming_response`
@@ -223,8 +224,8 @@ The above interface eagerly reads the full response body when you make the reque
To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods.
```python
-with client.models.embeddings.with_streaming_response.create(
- texts=["string"],
+with client.capabilities.with_streaming_response.summarize(
+ text="Example long text...",
) as response:
print(response.headers.get("X-My-Header"))
@@ -270,7 +271,7 @@ This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) con
We take backwards-compatibility seriously and work hard to ensure you can rely on a smooth upgrade experience.
-We are keen for your feedback; please open an [issue](https://www.github.com/clibrain/python-sdk/issues) with questions, bugs, or suggestions.
+We are keen for your feedback; please open an [issue](https://www.github.com/maisaai/python-sdk/issues) with questions, bugs, or suggestions.
## Requirements
diff --git a/api.md b/api.md
index 42c1031b..39b91ab5 100644
--- a/api.md
+++ b/api.md
@@ -46,6 +46,18 @@ Methods:
- client.models.rerank.create(\*\*params) -> Rerank
+# Kpu
+
+Types:
+
+```python
+from maisa.types import KpuRunResponse
+```
+
+Methods:
+
+- client.kpu.run(\*\*params) -> KpuRunResponse
+
# FileInterpreter
## FromPdf
diff --git a/pyproject.toml b/pyproject.toml
index beccaebc..fe875f86 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,8 +39,8 @@ classifiers = [
[project.urls]
-Homepage = "https://github.com/clibrain/python-sdk"
-Repository = "https://github.com/clibrain/python-sdk"
+Homepage = "https://github.com/maisaai/python-sdk"
+Repository = "https://github.com/maisaai/python-sdk"
diff --git a/requirements-dev.lock b/requirements-dev.lock
index d26251a3..204fcb08 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -63,7 +63,7 @@ pydantic==2.4.2
# via maisa
pydantic-core==2.10.1
# via pydantic
-pyright==1.1.332
+pyright==1.1.351
pytest==7.1.1
# via pytest-asyncio
pytest-asyncio==0.21.1
diff --git a/src/maisa/_base_client.py b/src/maisa/_base_client.py
index 0b5ece2c..2b3a1f98 100644
--- a/src/maisa/_base_client.py
+++ b/src/maisa/_base_client.py
@@ -79,7 +79,7 @@
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
)
-from ._streaming import Stream, AsyncStream
+from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
APIStatusError,
APITimeoutError,
@@ -430,6 +430,9 @@ def _prepare_url(self, url: str) -> URL:
return merge_url
+ def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
+ return SSEDecoder()
+
def _build_request(
self,
options: FinalRequestOptions,
@@ -776,6 +779,11 @@ def __init__(
else:
timeout = DEFAULT_TIMEOUT
+ if http_client is not None and not isinstance(http_client, httpx.Client): # pyright: ignore[reportUnnecessaryIsInstance]
+ raise TypeError(
+ f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}"
+ )
+
super().__init__(
version=version,
limits=limits,
@@ -1305,6 +1313,11 @@ def __init__(
else:
timeout = DEFAULT_TIMEOUT
+ if http_client is not None and not isinstance(http_client, httpx.AsyncClient): # pyright: ignore[reportUnnecessaryIsInstance]
+ raise TypeError(
+ f"Invalid `http_client` argument; Expected an instance of `httpx.AsyncClient` but got {type(http_client)}"
+ )
+
super().__init__(
version=version,
base_url=base_url,
diff --git a/src/maisa/_client.py b/src/maisa/_client.py
index b6f06497..11d68726 100644
--- a/src/maisa/_client.py
+++ b/src/maisa/_client.py
@@ -48,6 +48,7 @@
class Maisa(SyncAPIClient):
capabilities: resources.Capabilities
models: resources.Models
+ kpu: resources.Kpu
file_interpreter: resources.FileInterpreter
mainet: resources.Mainet
with_raw_response: MaisaWithRawResponse
@@ -107,6 +108,7 @@ def __init__(
self.capabilities = resources.Capabilities(self)
self.models = resources.Models(self)
+ self.kpu = resources.Kpu(self)
self.file_interpreter = resources.FileInterpreter(self)
self.mainet = resources.Mainet(self)
self.with_raw_response = MaisaWithRawResponse(self)
@@ -220,6 +222,7 @@ def _make_status_error(
class AsyncMaisa(AsyncAPIClient):
capabilities: resources.AsyncCapabilities
models: resources.AsyncModels
+ kpu: resources.AsyncKpu
file_interpreter: resources.AsyncFileInterpreter
mainet: resources.AsyncMainet
with_raw_response: AsyncMaisaWithRawResponse
@@ -279,6 +282,7 @@ def __init__(
self.capabilities = resources.AsyncCapabilities(self)
self.models = resources.AsyncModels(self)
+ self.kpu = resources.AsyncKpu(self)
self.file_interpreter = resources.AsyncFileInterpreter(self)
self.mainet = resources.AsyncMainet(self)
self.with_raw_response = AsyncMaisaWithRawResponse(self)
@@ -393,6 +397,7 @@ class MaisaWithRawResponse:
def __init__(self, client: Maisa) -> None:
self.capabilities = resources.CapabilitiesWithRawResponse(client.capabilities)
self.models = resources.ModelsWithRawResponse(client.models)
+ self.kpu = resources.KpuWithRawResponse(client.kpu)
self.file_interpreter = resources.FileInterpreterWithRawResponse(client.file_interpreter)
self.mainet = resources.MainetWithRawResponse(client.mainet)
@@ -401,6 +406,7 @@ class AsyncMaisaWithRawResponse:
def __init__(self, client: AsyncMaisa) -> None:
self.capabilities = resources.AsyncCapabilitiesWithRawResponse(client.capabilities)
self.models = resources.AsyncModelsWithRawResponse(client.models)
+ self.kpu = resources.AsyncKpuWithRawResponse(client.kpu)
self.file_interpreter = resources.AsyncFileInterpreterWithRawResponse(client.file_interpreter)
self.mainet = resources.AsyncMainetWithRawResponse(client.mainet)
@@ -409,6 +415,7 @@ class MaisaWithStreamedResponse:
def __init__(self, client: Maisa) -> None:
self.capabilities = resources.CapabilitiesWithStreamingResponse(client.capabilities)
self.models = resources.ModelsWithStreamingResponse(client.models)
+ self.kpu = resources.KpuWithStreamingResponse(client.kpu)
self.file_interpreter = resources.FileInterpreterWithStreamingResponse(client.file_interpreter)
self.mainet = resources.MainetWithStreamingResponse(client.mainet)
@@ -417,6 +424,7 @@ class AsyncMaisaWithStreamedResponse:
def __init__(self, client: AsyncMaisa) -> None:
self.capabilities = resources.AsyncCapabilitiesWithStreamingResponse(client.capabilities)
self.models = resources.AsyncModelsWithStreamingResponse(client.models)
+ self.kpu = resources.AsyncKpuWithStreamingResponse(client.kpu)
self.file_interpreter = resources.AsyncFileInterpreterWithStreamingResponse(client.file_interpreter)
self.mainet = resources.AsyncMainetWithStreamingResponse(client.mainet)
diff --git a/src/maisa/_constants.py b/src/maisa/_constants.py
index 036f4f77..bf15141a 100644
--- a/src/maisa/_constants.py
+++ b/src/maisa/_constants.py
@@ -7,7 +7,7 @@
# default timeout is 1 minute
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0)
-DEFAULT_MAX_RETRIES = 3
+DEFAULT_MAX_RETRIES = 2
DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)
INITIAL_RETRY_DELAY = 0.5
diff --git a/src/maisa/_files.py b/src/maisa/_files.py
index b6e8af8b..0d2022ae 100644
--- a/src/maisa/_files.py
+++ b/src/maisa/_files.py
@@ -13,12 +13,17 @@
FileContent,
RequestFiles,
HttpxFileTypes,
+ Base64FileInput,
HttpxFileContent,
HttpxRequestFiles,
)
from ._utils import is_tuple_t, is_mapping_t, is_sequence_t
+def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]:
+ return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
+
+
def is_file_content(obj: object) -> TypeGuard[FileContent]:
return (
isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike)
diff --git a/src/maisa/_models.py b/src/maisa/_models.py
index 48d5624f..81089149 100644
--- a/src/maisa/_models.py
+++ b/src/maisa/_models.py
@@ -283,7 +283,7 @@ def construct_type(*, value: object, type_: type) -> object:
if is_union(origin):
try:
- return validate_type(type_=type_, value=value)
+ return validate_type(type_=cast("type[object]", type_), value=value)
except Exception:
pass
diff --git a/src/maisa/_resource.py b/src/maisa/_resource.py
index 647a61a7..ee940c3d 100644
--- a/src/maisa/_resource.py
+++ b/src/maisa/_resource.py
@@ -3,9 +3,10 @@
from __future__ import annotations
import time
-import asyncio
from typing import TYPE_CHECKING
+import anyio
+
if TYPE_CHECKING:
from ._client import Maisa, AsyncMaisa
@@ -39,4 +40,4 @@ def __init__(self, client: AsyncMaisa) -> None:
self._get_api_list = client.get_api_list
async def _sleep(self, seconds: float) -> None:
- await asyncio.sleep(seconds)
+ await anyio.sleep(seconds)
diff --git a/src/maisa/_streaming.py b/src/maisa/_streaming.py
index 502fceaf..0b48ef2c 100644
--- a/src/maisa/_streaming.py
+++ b/src/maisa/_streaming.py
@@ -5,7 +5,7 @@
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
-from typing_extensions import Self, TypeGuard, override, get_origin
+from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable
import httpx
@@ -23,6 +23,8 @@ class Stream(Generic[_T]):
response: httpx.Response
+ _decoder: SSEDecoder | SSEBytesDecoder
+
def __init__(
self,
*,
@@ -33,7 +35,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
- self._decoder = SSEDecoder()
+ self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
def __next__(self) -> _T:
@@ -44,7 +46,10 @@ def __iter__(self) -> Iterator[_T]:
yield item
def _iter_events(self) -> Iterator[ServerSentEvent]:
- yield from self._decoder.iter(self.response.iter_lines())
+ if isinstance(self._decoder, SSEBytesDecoder):
+ yield from self._decoder.iter_bytes(self.response.iter_bytes())
+ else:
+ yield from self._decoder.iter(self.response.iter_lines())
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -84,6 +89,8 @@ class AsyncStream(Generic[_T]):
response: httpx.Response
+ _decoder: SSEDecoder | SSEBytesDecoder
+
def __init__(
self,
*,
@@ -94,7 +101,7 @@ def __init__(
self.response = response
self._cast_to = cast_to
self._client = client
- self._decoder = SSEDecoder()
+ self._decoder = client._make_sse_decoder()
self._iterator = self.__stream__()
async def __anext__(self) -> _T:
@@ -105,8 +112,12 @@ async def __aiter__(self) -> AsyncIterator[_T]:
yield item
async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
- async for sse in self._decoder.aiter(self.response.aiter_lines()):
- yield sse
+ if isinstance(self._decoder, SSEBytesDecoder):
+ async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()):
+ yield sse
+ else:
+ async for sse in self._decoder.aiter(self.response.aiter_lines()):
+ yield sse
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
@@ -259,6 +270,17 @@ def decode(self, line: str) -> ServerSentEvent | None:
return None
+@runtime_checkable
+class SSEBytesDecoder(Protocol):
+ def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]:
+ """Given an iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ ...
+
+ def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]:
+ """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered"""
+ ...
+
+
def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]:
"""TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`"""
origin = get_origin(typ) or typ
diff --git a/src/maisa/_types.py b/src/maisa/_types.py
index cad49dde..94a83c8a 100644
--- a/src/maisa/_types.py
+++ b/src/maisa/_types.py
@@ -40,8 +40,10 @@
ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]]
ProxiesTypes = Union[str, Proxy, ProxiesDict]
if TYPE_CHECKING:
+ Base64FileInput = Union[IO[bytes], PathLike[str]]
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
+ Base64FileInput = Union[IO[bytes], PathLike]
FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8.
FileTypes = Union[
# file (or bytes)
diff --git a/src/maisa/_utils/__init__.py b/src/maisa/_utils/__init__.py
index b5790a87..56978941 100644
--- a/src/maisa/_utils/__init__.py
+++ b/src/maisa/_utils/__init__.py
@@ -44,5 +44,7 @@
from ._transform import (
PropertyInfo as PropertyInfo,
transform as transform,
+ async_transform as async_transform,
maybe_transform as maybe_transform,
+ async_maybe_transform as async_maybe_transform,
)
diff --git a/src/maisa/_utils/_proxy.py b/src/maisa/_utils/_proxy.py
index 6f05efcd..b9c12dc3 100644
--- a/src/maisa/_utils/_proxy.py
+++ b/src/maisa/_utils/_proxy.py
@@ -45,7 +45,7 @@ def __dir__(self) -> Iterable[str]:
@property # type: ignore
@override
- def __class__(self) -> type:
+ def __class__(self) -> type: # pyright: ignore
proxied = self.__get_proxied__()
if issubclass(type(proxied), LazyProxy):
return type(proxied)
diff --git a/src/maisa/_utils/_transform.py b/src/maisa/_utils/_transform.py
index 2cb7726c..1bd1330c 100644
--- a/src/maisa/_utils/_transform.py
+++ b/src/maisa/_utils/_transform.py
@@ -1,9 +1,13 @@
from __future__ import annotations
+import io
+import base64
+import pathlib
from typing import Any, Mapping, TypeVar, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, override, get_type_hints
+import anyio
import pydantic
from ._utils import (
@@ -11,6 +15,7 @@
is_mapping,
is_iterable,
)
+from .._files import is_base64_file_input
from ._typing import (
is_list_type,
is_union_type,
@@ -29,7 +34,7 @@
# TODO: ensure works correctly with forward references in all cases
-PropertyFormat = Literal["iso8601", "custom"]
+PropertyFormat = Literal["iso8601", "base64", "custom"]
class PropertyInfo:
@@ -180,11 +185,7 @@ def _transform_recursive(
if isinstance(data, pydantic.BaseModel):
return model_dump(data, exclude_unset=True)
- return _transform_value(data, annotation)
-
-
-def _transform_value(data: object, type_: type) -> object:
- annotated_type = _get_annotated_type(type_)
+ annotated_type = _get_annotated_type(annotation)
if annotated_type is None:
return data
@@ -205,6 +206,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N
if format_ == "custom" and format_template is not None:
return data.strftime(format_template)
+ if format_ == "base64" and is_base64_file_input(data):
+ binary: str | bytes | None = None
+
+ if isinstance(data, pathlib.Path):
+ binary = data.read_bytes()
+ elif isinstance(data, io.IOBase):
+ binary = data.read()
+
+ if isinstance(binary, str): # type: ignore[unreachable]
+ binary = binary.encode()
+
+ if not isinstance(binary, bytes):
+ raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+ return base64.b64encode(binary).decode("ascii")
+
return data
@@ -222,3 +239,141 @@ def _transform_typeddict(
else:
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
return result
+
+
+async def async_maybe_transform(
+ data: object,
+ expected_type: object,
+) -> Any | None:
+ """Wrapper over `async_transform()` that allows `None` to be passed.
+
+ See `async_transform()` for more details.
+ """
+ if data is None:
+ return None
+ return await async_transform(data, expected_type)
+
+
+async def async_transform(
+ data: _T,
+ expected_type: object,
+) -> _T:
+ """Transform dictionaries based off of type information from the given type, for example:
+
+ ```py
+ class Params(TypedDict, total=False):
+ card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
+
+
+ transformed = transform({"card_id": ""}, Params)
+ # {'cardID': ''}
+ ```
+
+ Any keys / data that does not have type information given will be included as is.
+
+ It should be noted that the transformations that this function does are not represented in the type system.
+ """
+ transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
+ return cast(_T, transformed)
+
+
+async def _async_transform_recursive(
+ data: object,
+ *,
+ annotation: type,
+ inner_type: type | None = None,
+) -> object:
+ """Transform the given data against the expected type.
+
+ Args:
+ annotation: The direct type annotation given to the particular piece of data.
+ This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
+
+ inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
+ is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
+ the list can be transformed using the metadata from the container type.
+
+ Defaults to the same value as the `annotation` argument.
+ """
+ if inner_type is None:
+ inner_type = annotation
+
+ stripped_type = strip_annotated_type(inner_type)
+ if is_typeddict(stripped_type) and is_mapping(data):
+ return await _async_transform_typeddict(data, stripped_type)
+
+ if (
+ # List[T]
+ (is_list_type(stripped_type) and is_list(data))
+ # Iterable[T]
+ or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
+ ):
+ inner_type = extract_type_arg(stripped_type, 0)
+ return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
+
+ if is_union_type(stripped_type):
+ # For union types we run the transformation against all subtypes to ensure that everything is transformed.
+ #
+ # TODO: there may be edge cases where the same normalized field name will transform to two different names
+ # in different subtypes.
+ for subtype in get_args(stripped_type):
+ data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
+ return data
+
+ if isinstance(data, pydantic.BaseModel):
+ return model_dump(data, exclude_unset=True)
+
+ annotated_type = _get_annotated_type(annotation)
+ if annotated_type is None:
+ return data
+
+ # ignore the first argument as it is the actual type
+ annotations = get_args(annotated_type)[1:]
+ for annotation in annotations:
+ if isinstance(annotation, PropertyInfo) and annotation.format is not None:
+ return await _async_format_data(data, annotation.format, annotation.format_template)
+
+ return data
+
+
+async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
+ if isinstance(data, (date, datetime)):
+ if format_ == "iso8601":
+ return data.isoformat()
+
+ if format_ == "custom" and format_template is not None:
+ return data.strftime(format_template)
+
+ if format_ == "base64" and is_base64_file_input(data):
+ binary: str | bytes | None = None
+
+ if isinstance(data, pathlib.Path):
+ binary = await anyio.Path(data).read_bytes()
+ elif isinstance(data, io.IOBase):
+ binary = data.read()
+
+ if isinstance(binary, str): # type: ignore[unreachable]
+ binary = binary.encode()
+
+ if not isinstance(binary, bytes):
+ raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
+
+ return base64.b64encode(binary).decode("ascii")
+
+ return data
+
+
+async def _async_transform_typeddict(
+ data: Mapping[str, object],
+ expected_type: type,
+) -> Mapping[str, object]:
+ result: dict[str, object] = {}
+ annotations = get_type_hints(expected_type, include_extras=True)
+ for key, value in data.items():
+ type_ = annotations.get(key)
+ if type_ is None:
+ # we do not have a type annotation for this field, leave it as is
+ result[key] = value
+ else:
+ result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
+ return result
diff --git a/src/maisa/resources/__init__.py b/src/maisa/resources/__init__.py
index 6666c4dd..cd1034e9 100644
--- a/src/maisa/resources/__init__.py
+++ b/src/maisa/resources/__init__.py
@@ -1,5 +1,13 @@
# File generated from our OpenAPI spec by Stainless.
+from .kpu import (
+ Kpu,
+ AsyncKpu,
+ KpuWithRawResponse,
+ AsyncKpuWithRawResponse,
+ KpuWithStreamingResponse,
+ AsyncKpuWithStreamingResponse,
+)
from .mainet import (
Mainet,
AsyncMainet,
@@ -46,6 +54,12 @@
"AsyncModelsWithRawResponse",
"ModelsWithStreamingResponse",
"AsyncModelsWithStreamingResponse",
+ "Kpu",
+ "AsyncKpu",
+ "KpuWithRawResponse",
+ "AsyncKpuWithRawResponse",
+ "KpuWithStreamingResponse",
+ "AsyncKpuWithStreamingResponse",
"FileInterpreter",
"AsyncFileInterpreter",
"FileInterpreterWithRawResponse",
diff --git a/src/maisa/resources/capabilities/capabilities.py b/src/maisa/resources/capabilities/capabilities.py
index 313d73d1..84eb7f74 100644
--- a/src/maisa/resources/capabilities/capabilities.py
+++ b/src/maisa/resources/capabilities/capabilities.py
@@ -17,7 +17,10 @@
)
from ...types import capability_compare_params, capability_extract_params, capability_summarize_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import maybe_transform
+from ..._utils import (
+ maybe_transform,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -265,7 +268,7 @@ async def compare(
"""
return await self._post(
"/v1/capabilities/compare",
- body=maybe_transform(
+ body=await async_maybe_transform(
{
"text1": text1,
"text2": text2,
@@ -317,7 +320,7 @@ async def extract(
"""
return await self._post(
"/v1/capabilities/extract",
- body=maybe_transform(
+ body=await async_maybe_transform(
{
"text": text,
"variables": variables,
@@ -373,7 +376,7 @@ async def summarize(
"""
return await self._post(
"/v1/capabilities/summarize",
- body=maybe_transform(
+ body=await async_maybe_transform(
{
"text": text,
"format": format,
diff --git a/src/maisa/resources/capabilities/media.py b/src/maisa/resources/capabilities/media.py
index 662bbb8b..9a16c043 100644
--- a/src/maisa/resources/capabilities/media.py
+++ b/src/maisa/resources/capabilities/media.py
@@ -8,7 +8,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -414,7 +419,7 @@ async def compare(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/capabilities/compare/media",
- body=maybe_transform(body, media_compare_params.MediaCompareParams),
+ body=await async_maybe_transform(body, media_compare_params.MediaCompareParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -515,7 +520,7 @@ async def extract(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/capabilities/extract/media",
- body=maybe_transform(body, media_extract_params.MediaExtractParams),
+ body=await async_maybe_transform(body, media_extract_params.MediaExtractParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -579,7 +584,7 @@ async def summarize(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/capabilities/summarize/media",
- body=maybe_transform(body, media_summarize_params.MediaSummarizeParams),
+ body=await async_maybe_transform(body, media_summarize_params.MediaSummarizeParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/maisa/resources/file_interpreter/from_audio.py b/src/maisa/resources/file_interpreter/from_audio.py
index 04715f1b..9b93ef88 100644
--- a/src/maisa/resources/file_interpreter/from_audio.py
+++ b/src/maisa/resources/file_interpreter/from_audio.py
@@ -7,7 +7,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -115,7 +120,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/file-interpreter/from-audio",
- body=maybe_transform(body, from_audio_create_params.FromAudioCreateParams),
+ body=await async_maybe_transform(body, from_audio_create_params.FromAudioCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/maisa/resources/file_interpreter/from_docx.py b/src/maisa/resources/file_interpreter/from_docx.py
index 8ee42a69..a2dd8103 100644
--- a/src/maisa/resources/file_interpreter/from_docx.py
+++ b/src/maisa/resources/file_interpreter/from_docx.py
@@ -7,7 +7,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -115,7 +120,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/file-interpreter/from-docx",
- body=maybe_transform(body, from_docx_create_params.FromDocxCreateParams),
+ body=await async_maybe_transform(body, from_docx_create_params.FromDocxCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/maisa/resources/file_interpreter/from_html.py b/src/maisa/resources/file_interpreter/from_html.py
index a653bd62..c01b2522 100644
--- a/src/maisa/resources/file_interpreter/from_html.py
+++ b/src/maisa/resources/file_interpreter/from_html.py
@@ -7,7 +7,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -115,7 +120,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/file-interpreter/from-html",
- body=maybe_transform(body, from_html_create_params.FromHTMLCreateParams),
+ body=await async_maybe_transform(body, from_html_create_params.FromHTMLCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/maisa/resources/file_interpreter/from_image.py b/src/maisa/resources/file_interpreter/from_image.py
index 2a20441e..b640ee22 100644
--- a/src/maisa/resources/file_interpreter/from_image.py
+++ b/src/maisa/resources/file_interpreter/from_image.py
@@ -7,7 +7,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -115,7 +120,7 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/file-interpreter/from-image",
- body=maybe_transform(body, from_image_create_params.FromImageCreateParams),
+ body=await async_maybe_transform(body, from_image_create_params.FromImageCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
diff --git a/src/maisa/resources/file_interpreter/from_pdf.py b/src/maisa/resources/file_interpreter/from_pdf.py
index f8ac92c2..6a050e64 100644
--- a/src/maisa/resources/file_interpreter/from_pdf.py
+++ b/src/maisa/resources/file_interpreter/from_pdf.py
@@ -7,7 +7,12 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
-from ..._utils import extract_files, maybe_transform, deepcopy_minimal
+from ..._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -121,14 +126,14 @@ async def create(
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self._post(
"/v1/file-interpreter/from-pdf",
- body=maybe_transform(body, from_pdf_create_params.FromPdfCreateParams),
+ body=await async_maybe_transform(body, from_pdf_create_params.FromPdfCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
- query=maybe_transform({"max_pages": max_pages}, from_pdf_create_params.FromPdfCreateParams),
+ query=await async_maybe_transform({"max_pages": max_pages}, from_pdf_create_params.FromPdfCreateParams),
),
cast_to=object,
)
diff --git a/src/maisa/resources/kpu.py b/src/maisa/resources/kpu.py
new file mode 100644
index 00000000..3743abfc
--- /dev/null
+++ b/src/maisa/resources/kpu.py
@@ -0,0 +1,223 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import List, Mapping, cast
+
+import httpx
+
+from ..types import KpuRunResponse, kpu_run_params
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
+from .._utils import (
+ extract_files,
+ maybe_transform,
+ deepcopy_minimal,
+ async_maybe_transform,
+)
+from .._compat import cached_property
+from .._resource import SyncAPIResource, AsyncAPIResource
+from .._response import (
+ to_raw_response_wrapper,
+ to_streamed_response_wrapper,
+ async_to_raw_response_wrapper,
+ async_to_streamed_response_wrapper,
+)
+from .._base_client import (
+ make_request_options,
+)
+
+__all__ = ["Kpu", "AsyncKpu"]
+
+
+class Kpu(SyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> KpuWithRawResponse:
+ return KpuWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> KpuWithStreamingResponse:
+ return KpuWithStreamingResponse(self)
+
+ def run(
+ self,
+ *,
+ query: str,
+ explain_steps: bool | NotGiven = NOT_GIVEN,
+ retries: int | NotGiven = NOT_GIVEN,
+ file: List[FileTypes] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> KpuRunResponse:
+ """
+ Executes the KPU in sync, sending the response when the KPU execution is done.
+
+ Args:
+ query: User text with the query or request to be commanded to the KPU.
+
+ explain_steps: If true, the KPU will explain in natural language the steps of each step of each
+ intent. Enabling this feature can slow down the KPU execution, and increase the
+ usage metric.
+
+ retries: Number of retries in case of failure. Retries are sequential, and each failed
+ intent yields a learning for the next intent. This feature is experimental.
+
+ file: Files to be used in the KPU execution. Files can be of any type.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "query": query,
+ "file": file,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+ return self._post(
+ "/v1/kpu/run",
+ body=maybe_transform(body, kpu_run_params.KpuRunParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=maybe_transform(
+ {
+ "explain_steps": explain_steps,
+ "retries": retries,
+ },
+ kpu_run_params.KpuRunParams,
+ ),
+ ),
+ cast_to=KpuRunResponse,
+ )
+
+
+class AsyncKpu(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncKpuWithRawResponse:
+ return AsyncKpuWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncKpuWithStreamingResponse:
+ return AsyncKpuWithStreamingResponse(self)
+
+ async def run(
+ self,
+ *,
+ query: str,
+ explain_steps: bool | NotGiven = NOT_GIVEN,
+ retries: int | NotGiven = NOT_GIVEN,
+ file: List[FileTypes] | NotGiven = NOT_GIVEN,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> KpuRunResponse:
+ """
+ Executes the KPU in sync, sending the response when the KPU execution is done.
+
+ Args:
+ query: User text with the query or request to be commanded to the KPU.
+
+ explain_steps: If true, the KPU will explain in natural language the steps of each step of each
+ intent. Enabling this feature can slow down the KPU execution, and increase the
+ usage metric.
+
+ retries: Number of retries in case of failure. Retries are sequential, and each failed
+ intent yields a learning for the next intent. This feature is experimental.
+
+ file: Files to be used in the KPU execution. Files can be of any type.
+
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ body = deepcopy_minimal(
+ {
+ "query": query,
+ "file": file,
+ }
+ )
+ files = extract_files(cast(Mapping[str, object], body), paths=[["file", ""]])
+ if files:
+ # It should be noted that the actual Content-Type header that will be
+ # sent to the server will contain a `boundary` parameter, e.g.
+ # multipart/form-data; boundary=---abc--
+ extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
+ return await self._post(
+ "/v1/kpu/run",
+ body=await async_maybe_transform(body, kpu_run_params.KpuRunParams),
+ files=files,
+ options=make_request_options(
+ extra_headers=extra_headers,
+ extra_query=extra_query,
+ extra_body=extra_body,
+ timeout=timeout,
+ query=await async_maybe_transform(
+ {
+ "explain_steps": explain_steps,
+ "retries": retries,
+ },
+ kpu_run_params.KpuRunParams,
+ ),
+ ),
+ cast_to=KpuRunResponse,
+ )
+
+
+class KpuWithRawResponse:
+ def __init__(self, kpu: Kpu) -> None:
+ self._kpu = kpu
+
+ self.run = to_raw_response_wrapper(
+ kpu.run,
+ )
+
+
+class AsyncKpuWithRawResponse:
+ def __init__(self, kpu: AsyncKpu) -> None:
+ self._kpu = kpu
+
+ self.run = async_to_raw_response_wrapper(
+ kpu.run,
+ )
+
+
+class KpuWithStreamingResponse:
+ def __init__(self, kpu: Kpu) -> None:
+ self._kpu = kpu
+
+ self.run = to_streamed_response_wrapper(
+ kpu.run,
+ )
+
+
+class AsyncKpuWithStreamingResponse:
+ def __init__(self, kpu: AsyncKpu) -> None:
+ self._kpu = kpu
+
+ self.run = async_to_streamed_response_wrapper(
+ kpu.run,
+ )
diff --git a/src/maisa/resources/mainet/search.py b/src/maisa/resources/mainet/search.py
index 6c77b389..6b376d54 100644
--- a/src/maisa/resources/mainet/search.py
+++ b/src/maisa/resources/mainet/search.py
@@ -5,7 +5,10 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import maybe_transform
+from ..._utils import (
+ maybe_transform,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -102,7 +105,7 @@ async def create(
"""
return await self._post(
"/v1/mainet/search",
- body=maybe_transform({"text": text}, search_create_params.SearchCreateParams),
+ body=await async_maybe_transform({"text": text}, search_create_params.SearchCreateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
diff --git a/src/maisa/resources/models/embeddings.py b/src/maisa/resources/models/embeddings.py
index 9d3e5575..eb3dfced 100644
--- a/src/maisa/resources/models/embeddings.py
+++ b/src/maisa/resources/models/embeddings.py
@@ -7,7 +7,10 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import maybe_transform
+from ..._utils import (
+ maybe_transform,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -104,7 +107,7 @@ async def create(
"""
return await self._post(
"/v1/models/embeddings",
- body=maybe_transform({"texts": texts}, embedding_create_params.EmbeddingCreateParams),
+ body=await async_maybe_transform({"texts": texts}, embedding_create_params.EmbeddingCreateParams),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
diff --git a/src/maisa/resources/models/rerank.py b/src/maisa/resources/models/rerank.py
index bcf6042c..5807a961 100644
--- a/src/maisa/resources/models/rerank.py
+++ b/src/maisa/resources/models/rerank.py
@@ -7,7 +7,10 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import maybe_transform
+from ..._utils import (
+ maybe_transform,
+ async_maybe_transform,
+)
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -116,7 +119,7 @@ async def create(
"""
return await self._post(
"/v1/models/rerank",
- body=maybe_transform(
+ body=await async_maybe_transform(
{
"sentences": sentences,
"source_sentence": source_sentence,
diff --git a/src/maisa/types/__init__.py b/src/maisa/types/__init__.py
index 94785c30..de43413c 100644
--- a/src/maisa/types/__init__.py
+++ b/src/maisa/types/__init__.py
@@ -3,6 +3,8 @@
from __future__ import annotations
from .shared import TextSummary as TextSummary, TextExtractor as TextExtractor, TextComparator as TextComparator
+from .kpu_run_params import KpuRunParams as KpuRunParams
+from .kpu_run_response import KpuRunResponse as KpuRunResponse
from .capability_compare_params import CapabilityCompareParams as CapabilityCompareParams
from .capability_extract_params import CapabilityExtractParams as CapabilityExtractParams
from .capability_summarize_params import CapabilitySummarizeParams as CapabilitySummarizeParams
diff --git a/src/maisa/types/kpu_run_params.py b/src/maisa/types/kpu_run_params.py
new file mode 100644
index 00000000..bad448ed
--- /dev/null
+++ b/src/maisa/types/kpu_run_params.py
@@ -0,0 +1,32 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+from typing import List
+from typing_extensions import Required, TypedDict
+
+from .._types import FileTypes
+
+__all__ = ["KpuRunParams"]
+
+
+class KpuRunParams(TypedDict, total=False):
+ query: Required[str]
+ """User text with the query or request to be commanded to the KPU."""
+
+ explain_steps: bool
+ """
+ If true, the KPU will explain in natural language the steps of each step of each
+ intent. Enabling this feature can slow down the KPU execution, and increase the
+ usage metric.
+ """
+
+ retries: int
+ """Number of retries in case of failure.
+
+ Retries are sequential, and each failed intent yields a learning for the next
+ intent. This feature is experimental.
+ """
+
+ file: List[FileTypes]
+ """Files to be used in the KPU execution. Files can be of any type."""
diff --git a/src/maisa/types/kpu_run_response.py b/src/maisa/types/kpu_run_response.py
new file mode 100644
index 00000000..efacc7d6
--- /dev/null
+++ b/src/maisa/types/kpu_run_response.py
@@ -0,0 +1,41 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from typing import Dict, List, Optional
+
+from .._models import BaseModel
+
+__all__ = ["KpuRunResponse", "Intent"]
+
+
+class Intent(BaseModel):
+ explain_steps: List[str]
+ """Array of the steps of the intent explained in natural language.
+
+ This field is an ampty array if `explain_steps` param is set to `false`
+ """
+
+ intent: int
+ """Intent number, starting from 0."""
+
+ result: str
+ """The result of the intent."""
+
+ solved: bool
+ """Whether the intent was interpreted as solved by the KPU or not."""
+
+ downloadable_files: Optional[Dict[str, str]] = None
+ """Key-value of the files generated by the KPU."""
+
+
+class KpuRunResponse(BaseModel):
+ intents: List[Intent]
+ """Array of the intents executed by the KPU."""
+
+ result: str
+ """The result of the KPU execution.
+
+ The result may be invalid if none of the intents were successful.
+ """
+
+ downloadable_files: Optional[Dict[str, str]] = None
+ """Key-value of the files generated by the KPU."""
diff --git a/tests/api_resources/test_kpu.py b/tests/api_resources/test_kpu.py
new file mode 100644
index 00000000..93eb736a
--- /dev/null
+++ b/tests/api_resources/test_kpu.py
@@ -0,0 +1,104 @@
+# File generated from our OpenAPI spec by Stainless.
+
+from __future__ import annotations
+
+import os
+from typing import Any, cast
+
+import pytest
+
+from maisa import Maisa, AsyncMaisa
+from maisa.types import KpuRunResponse
+from tests.utils import assert_matches_type
+
+base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
+
+
+class TestKpu:
+ parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ def test_method_run(self, client: Maisa) -> None:
+ kpu = client.kpu.run(
+ query="string",
+ )
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ def test_method_run_with_all_params(self, client: Maisa) -> None:
+ kpu = client.kpu.run(
+ query="string",
+ explain_steps=True,
+ retries=1,
+ file=[b"raw file contents", b"raw file contents", b"raw file contents"],
+ )
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ def test_raw_response_run(self, client: Maisa) -> None:
+ response = client.kpu.with_raw_response.run(
+ query="string",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ kpu = response.parse()
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ def test_streaming_response_run(self, client: Maisa) -> None:
+ with client.kpu.with_streaming_response.run(
+ query="string",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ kpu = response.parse()
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+
+class TestAsyncKpu:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
+ @parametrize
+ async def test_method_run(self, async_client: AsyncMaisa) -> None:
+ kpu = await async_client.kpu.run(
+ query="string",
+ )
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ async def test_method_run_with_all_params(self, async_client: AsyncMaisa) -> None:
+ kpu = await async_client.kpu.run(
+ query="string",
+ explain_steps=True,
+ retries=1,
+ file=[b"raw file contents", b"raw file contents", b"raw file contents"],
+ )
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ async def test_raw_response_run(self, async_client: AsyncMaisa) -> None:
+ response = await async_client.kpu.with_raw_response.run(
+ query="string",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ kpu = await response.parse()
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ @parametrize
+ async def test_streaming_response_run(self, async_client: AsyncMaisa) -> None:
+ async with async_client.kpu.with_streaming_response.run(
+ query="string",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ kpu = await response.parse()
+ assert_matches_type(KpuRunResponse, kpu, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
diff --git a/tests/sample_file.txt b/tests/sample_file.txt
new file mode 100644
index 00000000..af5626b4
--- /dev/null
+++ b/tests/sample_file.txt
@@ -0,0 +1 @@
+Hello, world!
diff --git a/tests/test_client.py b/tests/test_client.py
index 5381aa73..050dd28c 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -82,7 +82,7 @@ def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 3
+ assert self.client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
@@ -291,6 +291,16 @@ def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ async def test_invalid_http_client(self) -> None:
+ with pytest.raises(TypeError, match="Invalid `http_client` arg"):
+ async with httpx.AsyncClient() as http_client:
+ Maisa(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=cast(Any, http_client),
+ )
+
def test_default_headers_option(self) -> None:
client = Maisa(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
@@ -675,12 +685,12 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
@mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
- respx_mock.post("/v1/models/embeddings").mock(side_effect=httpx.TimeoutException("Test timeout error"))
+ respx_mock.post("/v1/capabilities/summarize").mock(side_effect=httpx.TimeoutException("Test timeout error"))
with pytest.raises(APITimeoutError):
self.client.post(
- "/v1/models/embeddings",
- body=cast(object, dict(texts=["string"])),
+ "/v1/capabilities/summarize",
+ body=cast(object, dict(text="Example long text...")),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
@@ -690,12 +700,12 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No
@mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
- respx_mock.post("/v1/models/embeddings").mock(return_value=httpx.Response(500))
+ respx_mock.post("/v1/capabilities/summarize").mock(return_value=httpx.Response(500))
with pytest.raises(APIStatusError):
self.client.post(
- "/v1/models/embeddings",
- body=cast(object, dict(texts=["string"])),
+ "/v1/capabilities/summarize",
+ body=cast(object, dict(text="Example long text...")),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
@@ -740,7 +750,7 @@ def test_copy_default_options(self) -> None:
# options that have a default are overridden correctly
copied = self.client.copy(max_retries=7)
assert copied.max_retries == 7
- assert self.client.max_retries == 3
+ assert self.client.max_retries == 2
copied2 = copied.copy(max_retries=6)
assert copied2.max_retries == 6
@@ -951,6 +961,16 @@ async def test_http_client_timeout_option(self) -> None:
timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
assert timeout == DEFAULT_TIMEOUT # our default
+ def test_invalid_http_client(self) -> None:
+ with pytest.raises(TypeError, match="Invalid `http_client` arg"):
+ with httpx.Client() as http_client:
+ AsyncMaisa(
+ base_url=base_url,
+ api_key=api_key,
+ _strict_response_validation=True,
+ http_client=cast(Any, http_client),
+ )
+
def test_default_headers_option(self) -> None:
client = AsyncMaisa(
base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
@@ -1345,12 +1365,12 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
@mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
- respx_mock.post("/v1/models/embeddings").mock(side_effect=httpx.TimeoutException("Test timeout error"))
+ respx_mock.post("/v1/capabilities/summarize").mock(side_effect=httpx.TimeoutException("Test timeout error"))
with pytest.raises(APITimeoutError):
await self.client.post(
- "/v1/models/embeddings",
- body=cast(object, dict(texts=["string"])),
+ "/v1/capabilities/summarize",
+ body=cast(object, dict(text="Example long text...")),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
@@ -1360,12 +1380,12 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter)
@mock.patch("maisa._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
- respx_mock.post("/v1/models/embeddings").mock(return_value=httpx.Response(500))
+ respx_mock.post("/v1/capabilities/summarize").mock(return_value=httpx.Response(500))
with pytest.raises(APIStatusError):
await self.client.post(
- "/v1/models/embeddings",
- body=cast(object, dict(texts=["string"])),
+ "/v1/capabilities/summarize",
+ body=cast(object, dict(text="Example long text...")),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)
diff --git a/tests/test_transform.py b/tests/test_transform.py
index 5f58e6dc..19442b4f 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -1,22 +1,50 @@
from __future__ import annotations
-from typing import Any, List, Union, Iterable, Optional, cast
+import io
+import pathlib
+from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
from datetime import date, datetime
from typing_extensions import Required, Annotated, TypedDict
import pytest
-from maisa._utils import PropertyInfo, transform, parse_datetime
+from maisa._types import Base64FileInput
+from maisa._utils import (
+ PropertyInfo,
+ transform as _transform,
+ parse_datetime,
+ async_transform as _async_transform,
+)
from maisa._compat import PYDANTIC_V2
from maisa._models import BaseModel
+_T = TypeVar("_T")
+
+SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt")
+
+
+async def transform(
+ data: _T,
+ expected_type: object,
+ use_async: bool,
+) -> _T:
+ if use_async:
+ return await _async_transform(data, expected_type=expected_type)
+
+ return _transform(data, expected_type=expected_type)
+
+
+parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"])
+
class Foo1(TypedDict):
foo_bar: Annotated[str, PropertyInfo(alias="fooBar")]
-def test_top_level_alias() -> None:
- assert transform({"foo_bar": "hello"}, expected_type=Foo1) == {"fooBar": "hello"}
+@parametrize
+@pytest.mark.asyncio
+async def test_top_level_alias(use_async: bool) -> None:
+ assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"}
class Foo2(TypedDict):
@@ -32,9 +60,11 @@ class Baz2(TypedDict):
my_baz: Annotated[str, PropertyInfo(alias="myBaz")]
-def test_recursive_typeddict() -> None:
- assert transform({"bar": {"this_thing": 1}}, Foo2) == {"bar": {"this__thing": 1}}
- assert transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2) == {"bar": {"Baz": {"myBaz": "foo"}}}
+@parametrize
+@pytest.mark.asyncio
+async def test_recursive_typeddict(use_async: bool) -> None:
+ assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}}
+ assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}}
class Foo3(TypedDict):
@@ -45,8 +75,10 @@ class Bar3(TypedDict):
my_field: Annotated[str, PropertyInfo(alias="myField")]
-def test_list_of_typeddict() -> None:
- result = transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, expected_type=Foo3)
+@parametrize
+@pytest.mark.asyncio
+async def test_list_of_typeddict(use_async: bool) -> None:
+ result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async)
assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]}
@@ -62,10 +94,14 @@ class Baz4(TypedDict):
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
-def test_union_of_typeddict() -> None:
- assert transform({"foo": {"foo_bar": "bar"}}, Foo4) == {"foo": {"fooBar": "bar"}}
- assert transform({"foo": {"foo_baz": "baz"}}, Foo4) == {"foo": {"fooBaz": "baz"}}
- assert transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4) == {"foo": {"fooBaz": "baz", "fooBar": "bar"}}
+@parametrize
+@pytest.mark.asyncio
+async def test_union_of_typeddict(use_async: bool) -> None:
+ assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}}
+ assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}}
+ assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == {
+ "foo": {"fooBaz": "baz", "fooBar": "bar"}
+ }
class Foo5(TypedDict):
@@ -80,9 +116,11 @@ class Baz5(TypedDict):
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
-def test_union_of_list() -> None:
- assert transform({"foo": {"foo_bar": "bar"}}, Foo5) == {"FOO": {"fooBar": "bar"}}
- assert transform(
+@parametrize
+@pytest.mark.asyncio
+async def test_union_of_list(use_async: bool) -> None:
+ assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}}
+ assert await transform(
{
"foo": [
{"foo_baz": "baz"},
@@ -90,6 +128,7 @@ def test_union_of_list() -> None:
]
},
Foo5,
+ use_async,
) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]}
@@ -97,8 +136,10 @@ class Foo6(TypedDict):
bar: Annotated[str, PropertyInfo(alias="Bar")]
-def test_includes_unknown_keys() -> None:
- assert transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6) == {
+@parametrize
+@pytest.mark.asyncio
+async def test_includes_unknown_keys(use_async: bool) -> None:
+ assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == {
"Bar": "bar",
"baz_": {"FOO": 1},
}
@@ -113,9 +154,11 @@ class Bar7(TypedDict):
foo: str
-def test_ignores_invalid_input() -> None:
- assert transform({"bar": ""}, Foo7) == {"bAr": ""}
- assert transform({"foo": ""}, Foo7) == {"foo": ""}
+@parametrize
+@pytest.mark.asyncio
+async def test_ignores_invalid_input(use_async: bool) -> None:
+ assert await transform({"bar": ""}, Foo7, use_async) == {"bAr": ""}
+ assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""}
class DatetimeDict(TypedDict, total=False):
@@ -134,52 +177,66 @@ class DateDict(TypedDict, total=False):
foo: Annotated[date, PropertyInfo(format="iso8601")]
-def test_iso8601_format() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
- assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+ assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
dt = dt.replace(tzinfo=None)
- assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
+ assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
- assert transform({"foo": None}, DateDict) == {"foo": None} # type: ignore[comparison-overlap]
- assert transform({"foo": date.fromisoformat("2023-02-23")}, DateDict) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
+ assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap]
+ assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
-def test_optional_iso8601_format() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_optional_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
- assert transform({"bar": dt}, DatetimeDict) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+ assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
- assert transform({"bar": None}, DatetimeDict) == {"bar": None}
+ assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None}
-def test_required_iso8601_format() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_required_iso8601_format(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
- assert transform({"required": dt}, DatetimeDict) == {"required": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
+ assert await transform({"required": dt}, DatetimeDict, use_async) == {
+ "required": "2023-02-23T14:16:36.337692+00:00"
+ } # type: ignore[comparison-overlap]
- assert transform({"required": None}, DatetimeDict) == {"required": None}
+ assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None}
-def test_union_datetime() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_union_datetime(use_async: bool) -> None:
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
- assert transform({"union": dt}, DatetimeDict) == { # type: ignore[comparison-overlap]
+ assert await transform({"union": dt}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap]
"union": "2023-02-23T14:16:36.337692+00:00"
}
- assert transform({"union": "foo"}, DatetimeDict) == {"union": "foo"}
+ assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"}
-def test_nested_list_iso6801_format() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_nested_list_iso6801_format(use_async: bool) -> None:
dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
dt2 = parse_datetime("2022-01-15T06:34:23Z")
- assert transform({"list_": [dt1, dt2]}, DatetimeDict) == { # type: ignore[comparison-overlap]
+ assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap]
"list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"]
}
-def test_datetime_custom_format() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_datetime_custom_format(use_async: bool) -> None:
dt = parse_datetime("2022-01-15T06:34:23Z")
- result = transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")])
+ result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async)
assert result == "06" # type: ignore[comparison-overlap]
@@ -187,47 +244,59 @@ class DateDictWithRequiredAlias(TypedDict, total=False):
required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]]
-def test_datetime_with_alias() -> None:
- assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap]
- assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == {
- "prop": "2023-02-23"
- } # type: ignore[comparison-overlap]
+@parametrize
+@pytest.mark.asyncio
+async def test_datetime_with_alias(use_async: bool) -> None:
+ assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None} # type: ignore[comparison-overlap]
+ assert await transform(
+ {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async
+ ) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap]
class MyModel(BaseModel):
foo: str
-def test_pydantic_model_to_dictionary() -> None:
- assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"}
- assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"}
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_model_to_dictionary(use_async: bool) -> None:
+ assert await transform(MyModel(foo="hi!"), Any, use_async) == {"foo": "hi!"}
+ assert await transform(MyModel.construct(foo="hi!"), Any, use_async) == {"foo": "hi!"}
-def test_pydantic_empty_model() -> None:
- assert transform(MyModel.construct(), Any) == {}
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_empty_model(use_async: bool) -> None:
+ assert await transform(MyModel.construct(), Any, use_async) == {}
-def test_pydantic_unknown_field() -> None:
- assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True}
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_unknown_field(use_async: bool) -> None:
+ assert await transform(MyModel.construct(my_untyped_field=True), Any, use_async) == {"my_untyped_field": True}
-def test_pydantic_mismatched_types() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_mismatched_types(use_async: bool) -> None:
model = MyModel.construct(foo=True)
if PYDANTIC_V2:
with pytest.warns(UserWarning):
- params = transform(model, Any)
+ params = await transform(model, Any, use_async)
else:
- params = transform(model, Any)
+ params = await transform(model, Any, use_async)
assert params == {"foo": True}
-def test_pydantic_mismatched_object_type() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_mismatched_object_type(use_async: bool) -> None:
model = MyModel.construct(foo=MyModel.construct(hello="world"))
if PYDANTIC_V2:
with pytest.warns(UserWarning):
- params = transform(model, Any)
+ params = await transform(model, Any, use_async)
else:
- params = transform(model, Any)
+ params = await transform(model, Any, use_async)
assert params == {"foo": {"hello": "world"}}
@@ -235,10 +304,12 @@ class ModelNestedObjects(BaseModel):
nested: MyModel
-def test_pydantic_nested_objects() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_nested_objects(use_async: bool) -> None:
model = ModelNestedObjects.construct(nested={"foo": "stainless"})
assert isinstance(model.nested, MyModel)
- assert transform(model, Any) == {"nested": {"foo": "stainless"}}
+ assert await transform(model, Any, use_async) == {"nested": {"foo": "stainless"}}
class ModelWithDefaultField(BaseModel):
@@ -247,24 +318,26 @@ class ModelWithDefaultField(BaseModel):
with_str_default: str = "foo"
-def test_pydantic_default_field() -> None:
+@parametrize
+@pytest.mark.asyncio
+async def test_pydantic_default_field(use_async: bool) -> None:
# should be excluded when defaults are used
model = ModelWithDefaultField.construct()
assert model.with_none_default is None
assert model.with_str_default == "foo"
- assert transform(model, Any) == {}
+ assert await transform(model, Any, use_async) == {}
# should be included when the default value is explicitly given
model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo")
assert model.with_none_default is None
assert model.with_str_default == "foo"
- assert transform(model, Any) == {"with_none_default": None, "with_str_default": "foo"}
+ assert await transform(model, Any, use_async) == {"with_none_default": None, "with_str_default": "foo"}
# should be included when a non-default value is explicitly given
model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz")
assert model.with_none_default == "bar"
assert model.with_str_default == "baz"
- assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"}
+ assert await transform(model, Any, use_async) == {"with_none_default": "bar", "with_str_default": "baz"}
class TypedDictIterableUnion(TypedDict):
@@ -279,21 +352,57 @@ class Baz8(TypedDict):
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
-def test_iterable_of_dictionaries() -> None:
- assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]}
- assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]}
+@parametrize
+@pytest.mark.asyncio
+async def test_iterable_of_dictionaries(use_async: bool) -> None:
+ assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == {
+ "FOO": [{"fooBaz": "bar"}]
+ }
+ assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == {
+ "FOO": [{"fooBaz": "bar"}]
+ }
def my_iter() -> Iterable[Baz8]:
yield {"foo_baz": "hello"}
yield {"foo_baz": "world"}
- assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]}
+ assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == {
+ "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]
+ }
class TypedDictIterableUnionStr(TypedDict):
foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
-def test_iterable_union_str() -> None:
- assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"}
- assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]
+@parametrize
+@pytest.mark.asyncio
+async def test_iterable_union_str(use_async: bool) -> None:
+ assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"}
+ assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [
+ {"fooBaz": "bar"}
+ ]
+
+
+class TypedDictBase64Input(TypedDict):
+ foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")]
+
+
+@parametrize
+@pytest.mark.asyncio
+async def test_base64_file_input(use_async: bool) -> None:
+ # strings are left as-is
+ assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"}
+
+ # pathlib.Path is automatically converted to base64
+ assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == {
+ "foo": "SGVsbG8sIHdvcmxkIQo="
+ } # type: ignore[comparison-overlap]
+
+ # io instances are automatically converted to base64
+ assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == {
+ "foo": "SGVsbG8sIHdvcmxkIQ=="
+ } # type: ignore[comparison-overlap]
+ assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == {
+ "foo": "SGVsbG8sIHdvcmxkIQ=="
+ } # type: ignore[comparison-overlap]