Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/mistralai/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset

_DECIMAL_UNSET_TYPES = (Decimal, Unset)

_STR_INT_FLOAT_TYPES = (str, int, float)


def serialize_decimal(as_str: bool):
def serialize(d):
Expand All @@ -35,13 +39,21 @@ def validate_decimal(d):
if d is None:
return None

if isinstance(d, (Decimal, Unset)):
# Compare type directly for Decimal (bypasses isinstance overhead, but leaves Unset for legacy behavior)
if type(d) is Decimal or type(d) is Unset:
return d

# Only check for most common types first (fast-path), then do the expensive isinstance otherwise
if type(d) in _STR_INT_FLOAT_TYPES:
return Decimal(str(d))

if isinstance(d, _DECIMAL_UNSET_TYPES):
return d

if not isinstance(d, (str, int, float)):
raise ValueError("Expected string, int or float")
if isinstance(d, _STR_INT_FLOAT_TYPES):
return Decimal(str(d))

return Decimal(str(d))
raise ValueError("Expected string, int or float")


def serialize_float(as_str: bool):
Expand Down Expand Up @@ -178,7 +190,7 @@ def is_nullable(field):
if origin is Nullable or origin is OptionalNullable:
return True

if not origin is Union or type(None) not in get_args(field):
if origin is not Union or type(None) not in get_args(field):
return False

for arg in get_args(field):
Expand Down