Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/mistralai/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic_core import from_json

from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset
from functools import lru_cache


def serialize_decimal(as_str: bool):
Expand Down Expand Up @@ -141,14 +142,8 @@ def unmarshal_json(raw, typ: Any) -> Any:


def unmarshal(val, typ: Any) -> Any:
unmarshaller = create_model(
"Unmarshaller",
body=(typ, ...),
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
)

unmarshaller = _get_unmarshaller(typ)
m = unmarshaller(body=val)

# pyright: ignore[reportAttributeAccessIssue]
return m.body # type: ignore

Expand Down Expand Up @@ -178,7 +173,7 @@ def is_nullable(field):
if origin is Nullable or origin is OptionalNullable:
return True

if not origin is Union or type(None) not in get_args(field):
if origin is not Union or type(None) not in get_args(field):
return False

for arg in get_args(field):
Expand Down Expand Up @@ -247,3 +242,12 @@ def _get_typing_objects_by_name_of(name: str) -> Tuple[Any, ...]:
f"Neither typing nor typing_extensions has an object called {name!r}"
)
return result


@lru_cache(maxsize=64)
def _get_unmarshaller(typ: Any):
return create_model(
"Unmarshaller",
body=(typ, ...),
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
)