diff --git a/src/mistralai/utils/serializers.py b/src/mistralai/utils/serializers.py index 378a14c..8468f9e 100644 --- a/src/mistralai/utils/serializers.py +++ b/src/mistralai/utils/serializers.py @@ -14,6 +14,10 @@ from ..types.basemodel import BaseModel, Nullable, OptionalNullable, Unset +_DECIMAL_UNSET_TYPES = (Decimal, Unset) + +_STR_INT_FLOAT_TYPES = (str, int, float) + def serialize_decimal(as_str: bool): def serialize(d): @@ -35,13 +39,21 @@ def validate_decimal(d): if d is None: return None - if isinstance(d, (Decimal, Unset)): + # Compare type directly for Decimal (bypasses isinstance overhead, but leaves Unset for legacy behavior) + if type(d) is Decimal or type(d) is Unset: + return d + + # Only check for most common types first (fast-path), then do the expensive isinstance otherwise + if type(d) in _STR_INT_FLOAT_TYPES: + return Decimal(str(d)) + + if isinstance(d, _DECIMAL_UNSET_TYPES): return d - if not isinstance(d, (str, int, float)): - raise ValueError("Expected string, int or float") + if isinstance(d, _STR_INT_FLOAT_TYPES): + return Decimal(str(d)) - return Decimal(str(d)) + raise ValueError("Expected string, int or float") def serialize_float(as_str: bool): @@ -178,7 +190,7 @@ def is_nullable(field): if origin is Nullable or origin is OptionalNullable: return True - if not origin is Union or type(None) not in get_args(field): + if origin is not Union or type(None) not in get_args(field): return False for arg in get_args(field):