Skip to content

Commit

Permalink
chore(internal): minor utils restructuring (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Dec 19, 2023
1 parent 6c3427d commit 5ba576a
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 66 deletions.
17 changes: 9 additions & 8 deletions src/openai/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import datetime
import functools
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast
from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin
from typing_extensions import Awaitable, ParamSpec, override, get_origin

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._utils import is_given
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
from ._exceptions import APIResponseValidationError
Expand Down Expand Up @@ -221,12 +221,13 @@ def __init__(self) -> None:


def _extract_stream_chunk_type(stream_cls: type) -> type:
args = get_args(stream_cls)
if not args:
raise TypeError(
f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}",
)
return cast(type, args[0])
from ._base_client import Stream, AsyncStream

return extract_type_var_from_base(
stream_cls,
index=0,
generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)),
)


def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]:
Expand Down
71 changes: 56 additions & 15 deletions src/openai/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,31 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator
from typing_extensions import override
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast
from typing_extensions import Self, override

import httpx

from ._types import ResponseT
from ._utils import is_mapping
from ._exceptions import APIError

if TYPE_CHECKING:
from ._client import OpenAI, AsyncOpenAI


class Stream(Generic[ResponseT]):
_T = TypeVar("_T")


class Stream(Generic[_T]):
"""Provides the core interface to iterate over a synchronous stream response."""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
cast_to: type[_T],
response: httpx.Response,
client: OpenAI,
) -> None:
Expand All @@ -33,18 +36,18 @@ def __init__(
self._decoder = SSEDecoder()
self._iterator = self.__stream__()

def __next__(self) -> ResponseT:
def __next__(self) -> _T:
return self._iterator.__next__()

def __iter__(self) -> Iterator[ResponseT]:
def __iter__(self) -> Iterator[_T]:
for item in self._iterator:
yield item

def _iter_events(self) -> Iterator[ServerSentEvent]:
yield from self._decoder.iter(self.response.iter_lines())

def __stream__(self) -> Iterator[ResponseT]:
cast_to = self._cast_to
def __stream__(self) -> Iterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
Expand All @@ -68,16 +71,35 @@ def __stream__(self) -> Iterator[ResponseT]:
for _sse in iterator:
...

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self.response.close()

class AsyncStream(Generic[ResponseT]):

class AsyncStream(Generic[_T]):
"""Provides the core interface to iterate over an asynchronous stream response."""

response: httpx.Response

def __init__(
self,
*,
cast_to: type[ResponseT],
cast_to: type[_T],
response: httpx.Response,
client: AsyncOpenAI,
) -> None:
Expand All @@ -87,19 +109,19 @@ def __init__(
self._decoder = SSEDecoder()
self._iterator = self.__stream__()

async def __anext__(self) -> ResponseT:
async def __anext__(self) -> _T:
return await self._iterator.__anext__()

async def __aiter__(self) -> AsyncIterator[ResponseT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for item in self._iterator:
yield item

async def _iter_events(self) -> AsyncIterator[ServerSentEvent]:
async for sse in self._decoder.aiter(self.response.aiter_lines()):
yield sse

async def __stream__(self) -> AsyncIterator[ResponseT]:
cast_to = self._cast_to
async def __stream__(self) -> AsyncIterator[_T]:
cast_to = cast(Any, self._cast_to)
response = self.response
process_data = self._client._process_response_data
iterator = self._iter_events()
Expand All @@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]:
async for _sse in iterator:
...

async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self.response.aclose()


class ServerSentEvent:
def __init__(
Expand Down
14 changes: 14 additions & 0 deletions src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,17 @@ def get(self, __key: str) -> str | None:
IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None"

PostParser = Callable[[Any], Any]


@runtime_checkable
class InheritsGeneric(Protocol):
"""Represents a type that has inherited from `Generic`
The `__orig_bases__` property can be used to determine the resolved
type variable for a given base class.
"""

__orig_bases__: tuple[_GenericAlias]


class _GenericAlias(Protocol):
__origin__: type[object]
15 changes: 9 additions & 6 deletions src/openai/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,32 @@
from ._utils import parse_date as parse_date
from ._utils import is_sequence as is_sequence
from ._utils import coerce_float as coerce_float
from ._utils import is_list_type as is_list_type
from ._utils import is_mapping_t as is_mapping_t
from ._utils import removeprefix as removeprefix
from ._utils import removesuffix as removesuffix
from ._utils import extract_files as extract_files
from ._utils import is_sequence_t as is_sequence_t
from ._utils import is_union_type as is_union_type
from ._utils import required_args as required_args
from ._utils import coerce_boolean as coerce_boolean
from ._utils import coerce_integer as coerce_integer
from ._utils import file_from_path as file_from_path
from ._utils import parse_datetime as parse_datetime
from ._utils import strip_not_given as strip_not_given
from ._utils import deepcopy_minimal as deepcopy_minimal
from ._utils import extract_type_arg as extract_type_arg
from ._utils import is_required_type as is_required_type
from ._utils import get_async_library as get_async_library
from ._utils import is_annotated_type as is_annotated_type
from ._utils import maybe_coerce_float as maybe_coerce_float
from ._utils import get_required_header as get_required_header
from ._utils import maybe_coerce_boolean as maybe_coerce_boolean
from ._utils import maybe_coerce_integer as maybe_coerce_integer
from ._utils import strip_annotated_type as strip_annotated_type
from ._typing import is_list_type as is_list_type
from ._typing import is_union_type as is_union_type
from ._typing import extract_type_arg as extract_type_arg
from ._typing import is_required_type as is_required_type
from ._typing import is_annotated_type as is_annotated_type
from ._typing import strip_annotated_type as strip_annotated_type
from ._typing import extract_type_var_from_base as extract_type_var_from_base
from ._streams import consume_sync_iterator as consume_sync_iterator
from ._streams import consume_async_iterator as consume_async_iterator
from ._transform import PropertyInfo as PropertyInfo
from ._transform import transform as transform
from ._transform import maybe_transform as maybe_transform
12 changes: 12 additions & 0 deletions src/openai/_utils/_streams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any
from typing_extensions import Iterator, AsyncIterator


def consume_sync_iterator(iterator: Iterator[Any]) -> None:
for _ in iterator:
...


async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None:
async for _ in iterator:
...
5 changes: 2 additions & 3 deletions src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import pydantic

from ._utils import (
is_list,
is_mapping,
from ._utils import is_list, is_mapping
from ._typing import (
is_list_type,
is_union_type,
extract_type_arg,
Expand Down
80 changes: 80 additions & 0 deletions src/openai/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import Any, cast
from typing_extensions import Required, Annotated, get_args, get_origin

from .._types import InheritsGeneric
from .._compat import is_union as _is_union


def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated


def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list


def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))


def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))

return typ


def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err


def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type:
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
This also handles the case where a concrete subclass is given, e.g.
```py
class MyResponse(Foo[bytes]):
...
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
# we're given the class directly
return extract_type_arg(typ, index)

# if a subclass is given
# ---
# this is needed as __orig_bases__ is not present in the typeshed stubs
# because it is intended to be for internal use only, however there does
# not seem to be a way to resolve generic TypeVars for inherited subclasses
# without using it.
if isinstance(cls, InheritsGeneric):
target_base_class: Any | None = None
for base in cls.__orig_bases__:
if base.__origin__ in generic_bases:
target_base_class = base
break

if target_base_class is None:
raise RuntimeError(
"Could not find the generic base class;\n"
"This should never happen;\n"
f"Does {cls} inherit from one of {generic_bases} ?"
)

return extract_type_arg(target_base_class, index)

raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
35 changes: 1 addition & 34 deletions src/openai/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
overload,
)
from pathlib import Path
from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin
from typing_extensions import TypeGuard

import sniffio

from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike
from .._compat import is_union as _is_union
from .._compat import parse_date as parse_date
from .._compat import parse_datetime as parse_datetime

Expand Down Expand Up @@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]:
return isinstance(obj, list)


def is_annotated_type(typ: type) -> bool:
return get_origin(typ) == Annotated


def is_list_type(typ: type) -> bool:
return (get_origin(typ) or typ) == list


def is_union_type(typ: type) -> bool:
return _is_union(get_origin(typ))


def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
return strip_annotated_type(cast(type, get_args(typ)[0]))

return typ


def extract_type_arg(typ: type, index: int) -> type:
args = get_args(typ)
try:
return cast(type, args[index])
except IndexError as err:
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err


def deepcopy_minimal(item: _T) -> _T:
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
Expand Down

0 comments on commit 5ba576a

Please sign in to comment.