From 98d4e59afcf2d65d4e660d91eb9462240ef5cd63 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:50:54 +0000 Subject: [PATCH] chore(internal): support serialising iterable types (#1127) --- src/openai/_utils/__init__.py | 2 ++ src/openai/_utils/_transform.py | 9 ++++++++- src/openai/_utils/_typing.py | 9 ++++++++- src/openai/_utils/_utils.py | 4 ++++ tests/test_transform.py | 34 ++++++++++++++++++++++++++++++++- 5 files changed, 55 insertions(+), 3 deletions(-) diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index 0fb811a945..b5790a879f 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -9,6 +9,7 @@ is_mapping as is_mapping, is_tuple_t as is_tuple_t, parse_date as parse_date, + is_iterable as is_iterable, is_sequence as is_sequence, coerce_float as coerce_float, is_mapping_t as is_mapping_t, @@ -33,6 +34,7 @@ is_list_type as is_list_type, is_union_type as is_union_type, extract_type_arg as extract_type_arg, + is_iterable_type as is_iterable_type, is_required_type as is_required_type, is_annotated_type as is_annotated_type, strip_annotated_type as strip_annotated_type, diff --git a/src/openai/_utils/_transform.py b/src/openai/_utils/_transform.py index 3a1c14969b..2cb7726c73 100644 --- a/src/openai/_utils/_transform.py +++ b/src/openai/_utils/_transform.py @@ -9,11 +9,13 @@ from ._utils import ( is_list, is_mapping, + is_iterable, ) from ._typing import ( is_list_type, is_union_type, extract_type_arg, + is_iterable_type, is_required_type, is_annotated_type, strip_annotated_type, @@ -157,7 +159,12 @@ def _transform_recursive( if is_typeddict(stripped_type) and is_mapping(data): return _transform_typeddict(data, stripped_type) - if is_list_type(stripped_type) and is_list(data): + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): inner_type = extract_type_arg(stripped_type, 0) return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] diff --git a/src/openai/_utils/_typing.py b/src/openai/_utils/_typing.py index c1d1ebb9a4..c036991f04 100644 --- a/src/openai/_utils/_typing.py +++ b/src/openai/_utils/_typing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, TypeVar, cast +from typing import Any, TypeVar, Iterable, cast +from collections import abc as _c_abc from typing_extensions import Required, Annotated, get_args, get_origin from .._types import InheritsGeneric @@ -15,6 +16,12 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_iterable_type(typ: type) -> bool: + """If the given type is `typing.Iterable[T]`""" + origin = get_origin(typ) or typ + return origin == Iterable or origin == _c_abc.Iterable + + def is_union_type(typ: type) -> bool: return _is_union(get_origin(typ)) diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index 1c5c21a8ea..93c95517a9 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -164,6 +164,10 @@ def is_list(obj: object) -> TypeGuard[list[object]]: return isinstance(obj, list) +def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: + return isinstance(obj, Iterable) + + def deepcopy_minimal(item: _T) -> _T: """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: diff --git a/tests/test_transform.py b/tests/test_transform.py index c4dffb3bb0..6ed67d49a7 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Union, Optional +from typing import Any, List, Union, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict @@ -265,3 +265,35 @@ def test_pydantic_default_field() -> None: assert model.with_none_default == "bar" assert model.with_str_default == "baz" assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"} + + +class TypedDictIterableUnion(TypedDict): + foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +class Bar8(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz8(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +def test_iterable_of_dictionaries() -> None: + assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]} + assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]} + + def my_iter() -> Iterable[Baz8]: + yield {"foo_baz": "hello"} + yield {"foo_baz": "world"} + + assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]} + + +class TypedDictIterableUnionStr(TypedDict): + foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +def test_iterable_union_str() -> None: + assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"} + assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]