Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions src/deepgram/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import inspect
import typing
from functools import lru_cache

import pydantic
import typing_extensions
Expand Down Expand Up @@ -59,6 +60,10 @@ def convert_and_respect_annotation_metadata(
inner_type = annotation

clean_type = _remove_annotations(inner_type)

# Locally cache getting origin for the cleaned type
clean_type_origin = _get_origin_cached(clean_type)

# Pydantic models
if (
inspect.isclass(clean_type)
Expand All @@ -67,17 +72,14 @@ def convert_and_respect_annotation_metadata(
):
return _convert_mapping(object_, clean_type, direction)
# TypedDicts
if typing_extensions.is_typeddict(clean_type) and isinstance(object_, typing.Mapping):
if _is_typeddict_cached(clean_type) and isinstance(object_, typing.Mapping):
return _convert_mapping(object_, clean_type, direction)

if (
typing_extensions.get_origin(clean_type) == typing.Dict
or typing_extensions.get_origin(clean_type) == dict
or clean_type == typing.Dict
) and isinstance(object_, typing.Dict):
key_type = typing_extensions.get_args(clean_type)[0]
value_type = typing_extensions.get_args(clean_type)[1]

# Dict
if (clean_type_origin == typing.Dict or clean_type_origin == dict or clean_type == typing.Dict) and isinstance(
object_, typing.Dict
):
key_type, value_type = _get_args_cached(clean_type)
return {
key: convert_and_respect_annotation_metadata(
object_=value,
Expand All @@ -90,53 +92,46 @@ def convert_and_respect_annotation_metadata(

# If you're iterating on a string, do not bother to coerce it to a sequence.
if not isinstance(object_, str):
if (
typing_extensions.get_origin(clean_type) == typing.Set
or typing_extensions.get_origin(clean_type) == set
or clean_type == typing.Set
) and isinstance(object_, typing.Set):
inner_type = typing_extensions.get_args(clean_type)[0]
# Set
if (clean_type_origin == typing.Set or clean_type_origin == set or clean_type == typing.Set) and isinstance(
object_, typing.Set
):
(inner_container_type,) = _get_args_cached(clean_type)
return {
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type,
inner_type=inner_container_type,
direction=direction,
)
for item in object_
}
# List/Sequence
elif (
(
typing_extensions.get_origin(clean_type) == typing.List
or typing_extensions.get_origin(clean_type) == list
or clean_type == typing.List
)
(clean_type_origin == typing.List or clean_type_origin == list or clean_type == typing.List)
and isinstance(object_, typing.List)
) or (
(
typing_extensions.get_origin(clean_type) == typing.Sequence
or typing_extensions.get_origin(clean_type) == collections.abc.Sequence
clean_type_origin == typing.Sequence
or clean_type_origin == collections.abc.Sequence
or clean_type == typing.Sequence
)
and isinstance(object_, typing.Sequence)
):
inner_type = typing_extensions.get_args(clean_type)[0]
(inner_container_type,) = _get_args_cached(clean_type)
return [
convert_and_respect_annotation_metadata(
object_=item,
annotation=annotation,
inner_type=inner_type,
inner_type=inner_container_type,
direction=direction,
)
for item in object_
]

if typing_extensions.get_origin(clean_type) == typing.Union:
# We should be able to ~relatively~ safely try to convert keys against all
# member types in the union, the edge case here is if one member aliases a field
# of the same name to a different name from another member
# Or if another member aliases a field of the same name that another member does not.
for member in typing_extensions.get_args(clean_type):
# Union
if clean_type_origin == typing.Union:
for member in _get_args_cached(clean_type):
object_ = convert_and_respect_annotation_metadata(
object_=object_,
annotation=annotation,
Expand Down Expand Up @@ -274,3 +269,21 @@ def _alias_key(
if direction == "read":
return aliases_to_field_names.get(key, key)
return _get_alias_from_type(type_=type_) or key


# Cached/getter function for get_origin
@lru_cache(maxsize=128)
def _get_origin_cached(type_: typing.Any) -> typing.Any:
return typing_extensions.get_origin(type_)


# Cached/getter function for get_args
@lru_cache(maxsize=128)
def _get_args_cached(type_: typing.Any) -> tuple:
return typing_extensions.get_args(type_)


# Cached/getter function for is_typeddict
@lru_cache(maxsize=128)
def _is_typeddict_cached(type_: typing.Any) -> bool:
return typing_extensions.is_typeddict(type_)