diff --git a/setup.py b/setup.py index 3b8d32e16..eac568110 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ def get_long_description(): include_package_data=True, install_requires=[ "anyio>=3.0.0,<4", - "typing_extensions; python_version < '3.8'", + "typing_extensions; python_version < '3.10'", "contextlib2 >= 21.6.0; python_version < '3.7'", ], extras_require={ diff --git a/starlette/background.py b/starlette/background.py index 1160baeed..14a4e9e1a 100644 --- a/starlette/background.py +++ b/starlette/background.py @@ -1,12 +1,20 @@ import asyncio +import sys import typing +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + from starlette.concurrency import run_in_threadpool +P = ParamSpec("P") + class BackgroundTask: def __init__( - self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs ) -> None: self.func = func self.args = args @@ -25,7 +33,7 @@ def __init__(self, tasks: typing.Sequence[BackgroundTask] = None): self.tasks = list(tasks) if tasks else [] def add_task( - self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + self, func: typing.Callable[P, typing.Any], *args: P.args, **kwargs: P.kwargs ) -> None: task = BackgroundTask(func, *args, **kwargs) self.tasks.append(task) diff --git a/starlette/concurrency.py b/starlette/concurrency.py index e89d1e047..78602077a 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,9 +1,14 @@ import functools +import sys import typing -from typing import Any, AsyncGenerator, Iterator import anyio +if sys.version_info >= (3, 10): # pragma: no cover + from typing import ParamSpec +else: # pragma: no cover + from typing_extensions import ParamSpec + try: import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover @@ -11,6 +16,7 @@ T = typing.TypeVar("T") +P = ParamSpec("P") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: @@ -25,14 +31,14 @@ async def run(func: typing.Callable[[], typing.Coroutine]) -> None: async def run_in_threadpool( - func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any + func: typing.Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> T: if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) context = contextvars.copy_context() - func = context.run - args = (child,) + func = context.run # type: ignore[assignment] + args = (child,) # type: ignore[assignment] elif kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) @@ -43,7 +49,7 @@ class _StopIteration(Exception): pass -def _next(iterator: Iterator) -> Any: +def _next(iterator: typing.Iterator[T]) -> T: # We can't raise `StopIteration` from within the threadpool iterator # and catch it outside that context, so we coerce them into a different # exception type. @@ -53,7 +59,9 @@ def _next(iterator: Iterator) -> Any: raise _StopIteration -async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: +async def iterate_in_threadpool( + iterator: typing.Iterator[T], +) -> typing.AsyncIterator[T]: while True: try: yield await anyio.to_thread.run_sync(_next, iterator) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index e0b7be8de..e27e4fe49 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -29,7 +29,9 @@ async def dispatch(self) -> None: else request.method.lower() ) - handler = getattr(self, handler_name, self.method_not_allowed) + handler: typing.Callable[[Request], typing.Any] = getattr( + self, handler_name, self.method_not_allowed + ) is_async = asyncio.iscoroutinefunction(handler) if is_async: response = await handler(request)