diff --git a/src/cohere/core/http_client.py b/src/cohere/core/http_client.py index e4173f990..eced438be 100644 --- a/src/cohere/core/http_client.py +++ b/src/cohere/core/http_client.py @@ -10,13 +10,14 @@ from random import random import httpx +from httpx._types import RequestFiles + from .file import File, convert_file_dict_to_httpx_tuples from .force_multipart import FORCE_MULTIPART from .jsonable_encoder import jsonable_encoder from .query_encoder import encode_query from .remove_none_from_dict import remove_none_from_dict from .request_options import RequestOptions -from httpx._types import RequestFiles INITIAL_RETRY_DELAY_SECONDS = 0.5 MAX_RETRY_DELAY_SECONDS = 10 diff --git a/src/cohere/core/jsonable_encoder.py b/src/cohere/core/jsonable_encoder.py index afee3662d..6ea8176ad 100644 --- a/src/cohere/core/jsonable_encoder.py +++ b/src/cohere/core/jsonable_encoder.py @@ -17,18 +17,20 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union import pydantic + from .datetime_utils import serialize_datetime -from .pydantic_utilities import ( - IS_PYDANTIC_V2, - encode_by_type, - to_jsonable_with_fallback, -) +from .pydantic_utilities import (IS_PYDANTIC_V2, encode_by_type, + to_jsonable_with_fallback) SetIntStr = Set[Union[int, str]] DictIntStrAny = Dict[Union[int, str], Any] def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any: + # Short-circuit for most common primitive types + if isinstance(obj, (str, int, float, type(None))): + return obj + custom_encoder = custom_encoder or {} if custom_encoder: if type(obj) in custom_encoder: @@ -59,8 +61,6 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any] return obj.value if isinstance(obj, PurePath): return str(obj) - if isinstance(obj, (str, int, float, type(None))): - return obj if isinstance(obj, dt.datetime): return serialize_datetime(obj) if isinstance(obj, dt.date): diff --git a/src/cohere/core/unchecked_base_model.py b/src/cohere/core/unchecked_base_model.py index 2c2d92a7b..f2a3fd1f5 100644 --- a/src/cohere/core/unchecked_base_model.py +++ b/src/cohere/core/unchecked_base_model.py @@ -7,21 +7,14 @@ import pydantic import typing_extensions -from .pydantic_utilities import ( - IS_PYDANTIC_V2, - ModelField, - UniversalBaseModel, - get_args, - get_origin, - is_literal_type, - is_union, - parse_date, - parse_datetime, - parse_obj_as, -) -from .serialization import get_field_to_alias_mapping from pydantic_core import PydanticUndefined +from .pydantic_utilities import (IS_PYDANTIC_V2, ModelField, + UniversalBaseModel, get_args, get_origin, + is_literal_type, is_union, parse_date, + parse_datetime, parse_obj_as) +from .serialization import get_field_to_alias_mapping + class UnionMetadata: discriminant: str @@ -179,6 +172,14 @@ def construct_type(*, type_: typing.Type[typing.Any], object_: typing.Any) -> ty return None base_type = get_origin(type_) or type_ + + # Early fast return for direct primitive types + if base_type in (str, int, float, bool, type(None)): + try: + return base_type(object_) + except Exception: + return object_ + is_annotated = base_type == typing_extensions.Annotated maybe_annotation_members = get_args(type_) is_annotated_union = is_annotated and is_union(get_origin(maybe_annotation_members[0]))