diff --git a/src/mistralai/utils/serializers.py b/src/mistralai/utils/serializers.py index 378a14c..ed9dfed 100644 --- a/src/mistralai/utils/serializers.py +++ b/src/mistralai/utils/serializers.py @@ -13,6 +13,7 @@ from pydantic_core import from_json from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset +from functools import lru_cache def serialize_decimal(as_str: bool): @@ -141,14 +142,8 @@ def unmarshal_json(raw, typ: Any) -> Any: def unmarshal(val, typ: Any) -> Any: - unmarshaller = create_model( - "Unmarshaller", - body=(typ, ...), - __config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True), - ) - + unmarshaller = _get_unmarshaller(typ) m = unmarshaller(body=val) - # pyright: ignore[reportAttributeAccessIssue] return m.body # type: ignore @@ -178,7 +173,7 @@ def is_nullable(field): if origin is Nullable or origin is OptionalNullable: return True - if not origin is Union or type(None) not in get_args(field): + if origin is not Union or type(None) not in get_args(field): return False for arg in get_args(field): @@ -247,3 +242,12 @@ def _get_typing_objects_by_name_of(name: str) -> Tuple[Any, ...]: f"Neither typing nor typing_extensions has an object called {name!r}" ) return result + + +@lru_cache(maxsize=64) +def _get_unmarshaller(typ: Any): + return create_model( + "Unmarshaller", + body=(typ, ...), + __config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True), + )