From b49eb370573626abd5ddb5dc03228c503079be59 Mon Sep 17 00:00:00 2001 From: Mehdi ABAAKOUK Date: Thu, 9 Feb 2023 17:58:03 +0100 Subject: [PATCH] chore(typing): improve typing of WrappedFn (#390) This change improves the typing of WrappedFn. It makes explictly the two signatures of tenacity.retry() with overload. This avoids mypy thinking the return type is `` --- tenacity/__init__.py | 97 +++++++++++++++++++++++++++++--------------- tenacity/_asyncio.py | 24 +++++------ tox.ini | 2 +- 3 files changed, 76 insertions(+), 47 deletions(-) diff --git a/tenacity/__init__.py b/tenacity/__init__.py index 1f26ecdd..67312809 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import functools import sys import threading @@ -91,37 +92,8 @@ from .wait import WaitBaseT +WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any]) -_RetValT = t.TypeVar("_RetValT") - - -def retry(*dargs: t.Any, **dkw: t.Any) -> t.Union[WrappedFn, t.Callable[[WrappedFn], WrappedFn]]: # noqa - """Wrap a function with a new `Retrying` object. - - :param dargs: positional arguments passed to Retrying object - :param dkw: keyword arguments passed to the Retrying object - """ - # support both @retry and @retry() as valid syntax - if len(dargs) == 1 and callable(dargs[0]): - return retry()(dargs[0]) - else: - - def wrap(f: WrappedFn) -> WrappedFn: - if isinstance(f, retry_base): - warnings.warn( - f"Got retry_base instance ({f.__class__.__name__}) as callable argument, " - f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" - ) - if iscoroutinefunction(f): - r: "BaseRetrying" = AsyncRetrying(*dargs, **dkw) - elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f): - r = TornadoRetrying(*dargs, **dkw) - else: - r = Retrying(*dargs, **dkw) - - return r.wraps(f) - - return wrap class TryAgain(Exception): @@ -382,14 +354,24 @@ def __iter__(self) -> t.Generator[AttemptManager, None, None]: break @abstractmethod - def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT: + def __call__( + self, + fn: t.Callable[..., WrappedFnReturnT], + *args: t.Any, + **kwargs: t.Any, + ) -> WrappedFnReturnT: pass class Retrying(BaseRetrying): """Retrying controller.""" - def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT: + def __call__( + self, + fn: t.Callable[..., WrappedFnReturnT], + *args: t.Any, + **kwargs: t.Any, + ) -> WrappedFnReturnT: self.begin() retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) @@ -510,6 +492,57 @@ def __repr__(self) -> str: return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>" +@t.overload +def retry(func: WrappedFn) -> WrappedFn: + ... + + +@t.overload +def retry( + sleep: t.Callable[[t.Union[int, float]], None] = sleep, + stop: "StopBaseT" = stop_never, + wait: "WaitBaseT" = wait_none(), + retry: "RetryBaseT" = retry_if_exception_type(), + before: t.Callable[["RetryCallState"], None] = before_nothing, + after: t.Callable[["RetryCallState"], None] = after_nothing, + before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None, + reraise: bool = False, + retry_error_cls: t.Type["RetryError"] = RetryError, + retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None, +) -> t.Callable[[WrappedFn], WrappedFn]: + ... + + +def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any: + """Wrap a function with a new `Retrying` object. + + :param dargs: positional arguments passed to Retrying object + :param dkw: keyword arguments passed to the Retrying object + """ + # support both @retry and @retry() as valid syntax + if len(dargs) == 1 and callable(dargs[0]): + return retry()(dargs[0]) + else: + + def wrap(f: WrappedFn) -> WrappedFn: + if isinstance(f, retry_base): + warnings.warn( + f"Got retry_base instance ({f.__class__.__name__}) as callable argument, " + f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)" + ) + r: "BaseRetrying" + if iscoroutinefunction(f): + r = AsyncRetrying(*dargs, **dkw) + elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f): + r = TornadoRetrying(*dargs, **dkw) + else: + r = Retrying(*dargs, **dkw) + + return r.wraps(f) + + return wrap + + from tenacity._asyncio import AsyncRetrying # noqa:E402,I100 if tornado: diff --git a/tenacity/_asyncio.py b/tenacity/_asyncio.py index ab88d26b..9e10c072 100644 --- a/tenacity/_asyncio.py +++ b/tenacity/_asyncio.py @@ -17,7 +17,7 @@ import functools import sys -import typing +import typing as t from asyncio import sleep from tenacity import AttemptManager @@ -26,24 +26,20 @@ from tenacity import DoSleep from tenacity import RetryCallState - -WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable[..., typing.Any]) -_RetValT = typing.TypeVar("_RetValT") +WrappedFnReturnT = t.TypeVar("WrappedFnReturnT") +WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]]) class AsyncRetrying(BaseRetrying): - def __init__( - self, sleep: typing.Callable[[float], typing.Awaitable[typing.Any]] = sleep, **kwargs: typing.Any - ) -> None: + sleep: t.Callable[[float], t.Awaitable[t.Any]] + + def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None: super().__init__(**kwargs) self.sleep = sleep async def __call__( # type: ignore[override] - self, - fn: typing.Callable[..., typing.Awaitable[_RetValT]], - *args: typing.Any, - **kwargs: typing.Any, - ) -> _RetValT: + self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any + ) -> WrappedFnReturnT: self.begin() retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs) @@ -62,7 +58,7 @@ async def __call__( # type: ignore[override] else: return do # type: ignore[no-any-return] - def __iter__(self) -> typing.Generator[AttemptManager, None, None]: + def __iter__(self) -> t.Generator[AttemptManager, None, None]: raise TypeError("AsyncRetrying object is not iterable") def __aiter__(self) -> "AsyncRetrying": @@ -88,7 +84,7 @@ def wraps(self, fn: WrappedFn) -> WrappedFn: # Ensure wrapper is recognized as a coroutine function. @functools.wraps(fn) - async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: + async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any: return await fn(*args, **kwargs) # Preserve attributes diff --git a/tox.ini b/tox.ini index 0e4fda13..6a5c81c9 100644 --- a/tox.ini +++ b/tox.ini @@ -33,7 +33,7 @@ commands = [testenv:mypy] deps = - mypy + mypy>=1.0.0 commands = mypy tenacity