Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/mistralai/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,32 +157,45 @@ def marshal_json(val, typ):
if is_nullable(typ) and val is None:
return "null"

marshaller = create_model(
"Marshaller",
body=(typ, ...),
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
)
# Cache created model classes for (typ,) to avoid expensive recreation
if not hasattr(marshal_json, "_marshaller_cache"):
marshal_json._marshaller_cache = {}
marshaller_cache = marshal_json._marshaller_cache

if typ in marshaller_cache:
marshaller = marshaller_cache[typ]
else:
marshaller = create_model(
"Marshaller",
body=(typ, ...),
__config__=ConfigDict(populate_by_name=True, arbitrary_types_allowed=True),
)
marshaller_cache[typ] = marshaller

m = marshaller(body=val)

# This produces a dict with a single key "body"
d = m.model_dump(by_alias=True, mode="json", exclude_none=True)

if len(d) == 0:
if not d:
return ""

return json.dumps(d[next(iter(d))], separators=(",", ":"))
# Direct access instead of next(iter(d))
return json.dumps(d["body"], separators=(",", ":"))


def is_nullable(field):
origin = get_origin(field)
if origin is Nullable or origin is OptionalNullable:
return True

if not origin is Union or type(None) not in get_args(field):
if origin is not Union or type(None) not in get_args(field):
return False

# Only call get_origin(arg) once per arg
for arg in get_args(field):
if get_origin(arg) is Nullable or get_origin(arg) is OptionalNullable:
arg_origin = get_origin(arg)
if arg_origin is Nullable or arg_origin is OptionalNullable:
return True

return False
Expand Down