Skip to content

Commit

Permalink
chore(internal): loosen type var restrictions (#1049)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Jan 5, 2024
1 parent 131f6bc commit e00876b
Show file tree
Hide file tree
Showing 24 changed files with 67 additions and 188 deletions.
41 changes: 19 additions & 22 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
Body,
Omit,
Query,
ModelT,
Headers,
Timeout,
NotGiven,
Expand All @@ -61,7 +60,6 @@
HttpxSendArgs,
AsyncTransport,
RequestOptions,
UnknownResponse,
ModelBuilderProtocol,
BinaryResponseContent,
)
Expand Down Expand Up @@ -142,7 +140,7 @@ def __init__(
self.params = params


class BasePage(GenericModel, Generic[ModelT]):
class BasePage(GenericModel, Generic[_T]):
"""
Defines the core interface for pagination.
Expand All @@ -155,7 +153,7 @@ class BasePage(GenericModel, Generic[ModelT]):
"""

_options: FinalRequestOptions = PrivateAttr()
_model: Type[ModelT] = PrivateAttr()
_model: Type[_T] = PrivateAttr()

def has_next_page(self) -> bool:
items = self._get_page_items()
Expand All @@ -166,7 +164,7 @@ def has_next_page(self) -> bool:
def next_page_info(self) -> Optional[PageInfo]:
...

def _get_page_items(self) -> Iterable[ModelT]: # type: ignore[empty-body]
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...

def _params_from_url(self, url: URL) -> httpx.QueryParams:
Expand All @@ -191,13 +189,13 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
raise ValueError("Unexpected PageInfo state")


class BaseSyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
client: SyncAPIClient,
model: Type[ModelT],
model: Type[_T],
options: FinalRequestOptions,
) -> None:
self._model = model
Expand All @@ -212,7 +210,7 @@ def _set_private_attributes(
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
def __iter__(self) -> Iterator[ModelT]: # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand All @@ -237,13 +235,13 @@ def get_next_page(self: SyncPageT) -> SyncPageT:
return self._client._request_api_list(self._model, page=self.__class__, options=options)


class AsyncPaginator(Generic[ModelT, AsyncPageT]):
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
self,
client: AsyncAPIClient,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
model: Type[ModelT],
model: Type[_T],
) -> None:
self._model = model
self._client = client
Expand All @@ -266,7 +264,7 @@ def _parser(resp: AsyncPageT) -> AsyncPageT:

return await self._client.request(self._page_cls, self._options)

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
page = cast(
AsyncPageT,
Expand All @@ -276,20 +274,20 @@ async def __aiter__(self) -> AsyncIterator[ModelT]:
yield item


class BaseAsyncPage(BasePage[ModelT], Generic[ModelT]):
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()

def _set_private_attributes(
self,
model: Type[ModelT],
model: Type[_T],
client: AsyncAPIClient,
options: FinalRequestOptions,
) -> None:
self._model = model
self._client = client
self._options = options

async def __aiter__(self) -> AsyncIterator[ModelT]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
for item in page._get_page_items():
yield item
Expand Down Expand Up @@ -528,7 +526,7 @@ def _process_response_data(
if data is None:
return cast(ResponseT, None)

if cast_to is UnknownResponse:
if cast_to is object:
return cast(ResponseT, data)

try:
Expand Down Expand Up @@ -970,7 +968,7 @@ def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
options: FinalRequestOptions,
) -> SyncPageT:
Expand Down Expand Up @@ -1132,7 +1130,7 @@ def get_api_list(
self,
path: str,
*,
model: Type[ModelT],
model: Type[object],
page: Type[SyncPageT],
body: Body | None = None,
options: RequestOptions = {},
Expand Down Expand Up @@ -1434,10 +1432,10 @@ async def _retry_request(

def _request_api_list(
self,
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
options: FinalRequestOptions,
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
return AsyncPaginator(client=self, options=options, page_cls=page, model=model)

@overload
Expand Down Expand Up @@ -1584,13 +1582,12 @@ def get_api_list(
self,
path: str,
*,
# TODO: support paginating `str`
model: Type[ModelT],
model: Type[_T],
page: Type[AsyncPageT],
body: Body | None = None,
options: RequestOptions = {},
method: str = "get",
) -> AsyncPaginator[ModelT, AsyncPageT]:
) -> AsyncPaginator[_T, AsyncPageT]:
opts = FinalRequestOptions.construct(method=method, url=path, json_data=body, **options)
return self._request_api_list(model, page, opts)

Expand Down
4 changes: 2 additions & 2 deletions src/openai/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import httpx

from ._types import NoneType, UnknownResponse, BinaryResponseContent
from ._types import NoneType, BinaryResponseContent
from ._utils import is_given, extract_type_var_from_base
from ._models import BaseModel, is_basemodel
from ._constants import RAW_RESPONSE_HEADER
Expand Down Expand Up @@ -162,7 +162,7 @@ def _parse(self) -> R:
# `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then
# this function would become unsafe but a type checker would not report an error.
if (
cast_to is not UnknownResponse
cast_to is not object
and not origin is list
and not origin is dict
and not origin is Union
Expand Down
17 changes: 11 additions & 6 deletions src/openai/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,6 @@ class RequestOptions(TypedDict, total=False):
idempotency_key: str


# Sentinel class used when the response type is an object with an unknown schema
class UnknownResponse:
...


# Sentinel class used until PEP 0661 is accepted
class NotGiven:
"""
Expand Down Expand Up @@ -339,7 +334,17 @@ def get(self, __key: str) -> str | None:

ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
bound=Union[
object,
str,
None,
"BaseModel",
List[Any],
Dict[str, Any],
Response,
ModelBuilderProtocol,
BinaryResponseContent,
],
)

StrBytesIntFloat = Union[str, bytes, int, float]
Expand Down
29 changes: 15 additions & 14 deletions src/openai/pagination.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
# File generated from our OpenAPI spec by Stainless.

from typing import Any, List, Generic, Optional, cast
from typing import Any, List, Generic, TypeVar, Optional, cast
from typing_extensions import Protocol, override, runtime_checkable

from ._types import ModelT
from ._base_client import BasePage, PageInfo, BaseSyncPage, BaseAsyncPage

__all__ = ["SyncPage", "AsyncPage", "SyncCursorPage", "AsyncCursorPage"]

_T = TypeVar("_T")


@runtime_checkable
class CursorPageItem(Protocol):
id: Optional[str]


class SyncPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""

data: List[ModelT]
data: List[_T]
object: str

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -36,14 +37,14 @@ def next_page_info(self) -> None:
return None


class AsyncPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
class AsyncPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""

data: List[ModelT]
data: List[_T]
object: str

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -58,11 +59,11 @@ def next_page_info(self) -> None:
return None


class SyncCursorPage(BaseSyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
data: List[ModelT]
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand All @@ -82,11 +83,11 @@ def next_page_info(self) -> Optional[PageInfo]:
return PageInfo(params={"after": item.id})


class AsyncCursorPage(BaseAsyncPage[ModelT], BasePage[ModelT], Generic[ModelT]):
data: List[ModelT]
class AsyncCursorPage(BaseAsyncPage[_T], BasePage[_T], Generic[_T]):
data: List[_T]

@override
def _get_page_items(self) -> List[ModelT]:
def _get_page_items(self) -> List[_T]:
data = self.data
if not data:
return []
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/audio/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ..._utils import maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
9 changes: 1 addition & 8 deletions src/openai/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
FileTypes,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
9 changes: 1 addition & 8 deletions src/openai/resources/audio/translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@

import httpx

from ..._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
FileTypes,
)
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..._utils import extract_files, maybe_transform, deepcopy_minimal
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/assistants/assistants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
import httpx

from .files import Files, AsyncFiles, FilesWithRawResponse, AsyncFilesWithRawResponse
from ...._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/assistants/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import httpx

from ...._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ...._utils import maybe_transform
from ...._compat import cached_property
from ...._resource import SyncAPIResource, AsyncAPIResource
Expand Down
8 changes: 1 addition & 7 deletions src/openai/resources/beta/threads/messages/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

import httpx

from ....._types import (
NOT_GIVEN,
Body,
Query,
Headers,
NotGiven,
)
from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from ....._utils import maybe_transform
from ....._compat import cached_property
from ....._resource import SyncAPIResource, AsyncAPIResource
Expand Down

0 comments on commit e00876b

Please sign in to comment.