Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/cohere/core/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from random import random

import httpx
from httpx._types import RequestFiles

from .file import File, convert_file_dict_to_httpx_tuples
from .force_multipart import FORCE_MULTIPART
from .jsonable_encoder import jsonable_encoder
from .query_encoder import encode_query
from .remove_none_from_dict import remove_none_from_dict
from .request_options import RequestOptions
from httpx._types import RequestFiles

INITIAL_RETRY_DELAY_SECONDS = 0.5
MAX_RETRY_DELAY_SECONDS = 10
Expand Down
14 changes: 7 additions & 7 deletions src/cohere/core/jsonable_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
from typing import Any, Callable, Dict, List, Optional, Set, Union

import pydantic

from .datetime_utils import serialize_datetime
from .pydantic_utilities import (
IS_PYDANTIC_V2,
encode_by_type,
to_jsonable_with_fallback,
)
from .pydantic_utilities import (IS_PYDANTIC_V2, encode_by_type,
to_jsonable_with_fallback)

SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]


def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None) -> Any:
# Short-circuit for most common primitive types
if isinstance(obj, (str, int, float, type(None))):
return obj

custom_encoder = custom_encoder or {}
if custom_encoder:
if type(obj) in custom_encoder:
Expand Down Expand Up @@ -59,8 +61,6 @@ def jsonable_encoder(obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any]
return obj.value
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if isinstance(obj, dt.datetime):
return serialize_datetime(obj)
if isinstance(obj, dt.date):
Expand Down
27 changes: 14 additions & 13 deletions src/cohere/core/unchecked_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,14 @@

import pydantic
import typing_extensions
from .pydantic_utilities import (
IS_PYDANTIC_V2,
ModelField,
UniversalBaseModel,
get_args,
get_origin,
is_literal_type,
is_union,
parse_date,
parse_datetime,
parse_obj_as,
)
from .serialization import get_field_to_alias_mapping
from pydantic_core import PydanticUndefined

from .pydantic_utilities import (IS_PYDANTIC_V2, ModelField,
UniversalBaseModel, get_args, get_origin,
is_literal_type, is_union, parse_date,
parse_datetime, parse_obj_as)
from .serialization import get_field_to_alias_mapping


class UnionMetadata:
discriminant: str
Expand Down Expand Up @@ -179,6 +172,14 @@ def construct_type(*, type_: typing.Type[typing.Any], object_: typing.Any) -> ty
return None

base_type = get_origin(type_) or type_

# Early fast return for direct primitive types
if base_type in (str, int, float, bool, type(None)):
try:
return base_type(object_)
except Exception:
return object_

is_annotated = base_type == typing_extensions.Annotated
maybe_annotation_members = get_args(type_)
is_annotated_union = is_annotated and is_union(get_origin(maybe_annotation_members[0]))
Expand Down