diff --git a/src/mistralai/utils/serializers.py b/src/mistralai/utils/serializers.py index 378a14c..449a6c6 100644 --- a/src/mistralai/utils/serializers.py +++ b/src/mistralai/utils/serializers.py @@ -157,20 +157,31 @@ def marshal_json(val, typ): if is_nullable(typ) and val is None: return "null" - marshaller = create_model( - "Marshaller", - body=(typ, ...), - __config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True), - ) + # Cache created model classes for (typ,) to avoid expensive recreation + if not hasattr(marshal_json, "_marshaller_cache"): + marshal_json._marshaller_cache = {} + marshaller_cache = marshal_json._marshaller_cache + + if typ in marshaller_cache: + marshaller = marshaller_cache[typ] + else: + marshaller = create_model( + "Marshaller", + body=(typ, ...), + __config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True), + ) + marshaller_cache[typ] = marshaller m = marshaller(body=val) + # This produces a dict with a single key "body" d = m.model_dump(by_alias=True, mode="json", exclude_none=True) - if len(d) == 0: + if not d: return "" - return json.dumps(d[next(iter(d))], separators=(",", ":")) + # Direct access instead of next(iter(d)) + return json.dumps(d["body"], separators=(",", ":")) def is_nullable(field): @@ -178,11 +189,13 @@ def is_nullable(field): if origin is Nullable or origin is OptionalNullable: return True - if not origin is Union or type(None) not in get_args(field): + if origin is not Union or type(None) not in get_args(field): return False + # Only call get_origin(arg) once per arg for arg in get_args(field): - if get_origin(arg) is Nullable or get_origin(arg) is OptionalNullable: + arg_origin = get_origin(arg) + if arg_origin is Nullable or arg_origin is OptionalNullable: return True return False