diff --git a/.coveragerc b/.coveragerc index 2c5041d7e..bce877f6a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,9 +7,10 @@ parallel = True [report] show_missing = True -exclude_lines = +# https://coverage.readthedocs.io/en/latest/excluding.html#advanced-exclusion +exclude_also = if TYPE_CHECKING: - if not TYPE_CHECKING: pragma: nocover - pragma: no cover pragma: no py39,py310 cover + @overload + class .*\bProtocol\): diff --git a/falcon/hooks.py b/falcon/hooks.py index 5ca50aefb..fb377e6c8 100644 --- a/falcon/hooks.py +++ b/falcon/hooks.py @@ -20,29 +20,98 @@ from inspect import getmembers from inspect import iscoroutinefunction import re -import typing as t +from typing import ( + Any, + Awaitable, + Callable, + cast, + Dict, + List, + Protocol, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) from falcon.constants import COMBINED_METHODS from falcon.util.misc import get_argnames from falcon.util.sync import _wrap_non_coroutine_unsafe -if t.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: import falcon as wsgi from falcon import asgi + from falcon.typing import AsyncResponderMethod + from falcon.typing import Resource + from falcon.typing import Responder + from falcon.typing import SyncResponderMethod + + +# TODO: if is_async is removed these protocol would no longer be needed, since +# ParamSpec could be used together with Concatenate to use a simple Callable +# to type the before and after functions. This approach was prototyped in +# https://github.com/falconry/falcon/pull/2234 +class SyncBeforeFn(Protocol): + def __call__( + self, + req: wsgi.Request, + resp: wsgi.Response, + resource: Resource, + params: Dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> None: ... + + +class AsyncBeforeFn(Protocol): + def __call__( + self, + req: asgi.Request, + resp: asgi.Response, + resource: Resource, + params: Dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Awaitable[None]: ... + + +BeforeFn = Union[SyncBeforeFn, AsyncBeforeFn] + + +class SyncAfterFn(Protocol): + def __call__( + self, + req: wsgi.Request, + resp: wsgi.Response, + resource: Resource, + *args: Any, + **kwargs: Any, + ) -> None: ... + + +class AsyncAfterFn(Protocol): + def __call__( + self, + req: asgi.Request, + resp: asgi.Response, + resource: Resource, + *args: Any, + **kwargs: Any, + ) -> Awaitable[None]: ... + + +AfterFn = Union[SyncAfterFn, AsyncAfterFn] +_R = TypeVar('_R', bound=Union['Responder', 'Resource']) + _DECORABLE_METHOD_NAME = re.compile( r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS)) ) -Resource = object -Responder = t.Callable -ResponderOrResource = t.Union[Responder, Resource] -Action = t.Callable - def before( - action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[ResponderOrResource], ResponderOrResource]: + action: BeforeFn, *args: Any, is_async: bool = False, **kwargs: Any +) -> Callable[[_R], _R]: """Execute the given action function *before* the responder. The `params` argument that is passed to the hook @@ -92,41 +161,33 @@ def do_something(req, resp, resource, params): *action*. """ - def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource: + def _before(responder_or_resource: _R) -> _R: if isinstance(responder_or_resource, type): - resource = responder_or_resource - - for responder_name, responder in getmembers(resource, callable): + for responder_name, responder in getmembers( + responder_or_resource, callable + ): if _DECORABLE_METHOD_NAME.match(responder_name): - # This pattern is necessary to capture the current value of - # responder in the do_before_all closure; otherwise, they - # will capture the same responder variable that is shared - # between iterations of the for loop, above. - responder = t.cast(Responder, responder) - - def let(responder: Responder = responder) -> None: - do_before_all = _wrap_with_before( - responder, action, args, kwargs, is_async - ) + responder = cast('Responder', responder) + do_before_all = _wrap_with_before( + responder, action, args, kwargs, is_async + ) - setattr(resource, responder_name, do_before_all) + setattr(responder_or_resource, responder_name, do_before_all) - let() - - return resource + return cast(_R, responder_or_resource) else: - responder = t.cast(Responder, responder_or_resource) + responder = cast('Responder', responder_or_resource) do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async) - return do_before_one + return cast(_R, do_before_one) return _before def after( - action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[ResponderOrResource], ResponderOrResource]: + action: AfterFn, *args: Any, is_async: bool = False, **kwargs: Any +) -> Callable[[_R], _R]: """Execute the given action function *after* the responder. Args: @@ -159,30 +220,26 @@ def after( *action*. """ - def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource: + def _after(responder_or_resource: _R) -> _R: if isinstance(responder_or_resource, type): - resource = t.cast(Resource, responder_or_resource) - - for responder_name, responder in getmembers(resource, callable): + for responder_name, responder in getmembers( + responder_or_resource, callable + ): if _DECORABLE_METHOD_NAME.match(responder_name): - responder = t.cast(Responder, responder) - - def let(responder: Responder = responder) -> None: - do_after_all = _wrap_with_after( - responder, action, args, kwargs, is_async - ) + responder = cast('Responder', responder) + do_after_all = _wrap_with_after( + responder, action, args, kwargs, is_async + ) - setattr(resource, responder_name, do_after_all) + setattr(responder_or_resource, responder_name, do_after_all) - let() - - return resource + return cast(_R, responder_or_resource) else: - responder = t.cast(Responder, responder_or_resource) + responder = cast('Responder', responder_or_resource) do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async) - return do_after_one + return cast(_R, do_after_one) return _after @@ -194,9 +251,9 @@ def let(responder: Responder = responder) -> None: def _wrap_with_after( responder: Responder, - action: Action, - action_args: t.Any, - action_kwargs: t.Any, + action: AfterFn, + action_args: Any, + action_kwargs: Any, is_async: bool, ) -> Responder: """Execute the given action function after a responder method. @@ -215,57 +272,62 @@ def _wrap_with_after( responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp + do_after_responder: Responder if is_async or iscoroutinefunction(responder): # NOTE(kgriffs): I manually verified that the implicit "else" branch # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - async_action = _wrap_non_coroutine_unsafe(action) + async_action = cast('AsyncAfterFn', _wrap_non_coroutine_unsafe(action)) else: - async_action = action + async_action = cast('AsyncAfterFn', action) + async_responder = cast('AsyncResponderMethod', responder) - @wraps(responder) + @wraps(async_responder) async def do_after( - self: ResponderOrResource, + self: Resource, req: asgi.Request, resp: asgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: Any, + **kwargs: Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - await responder(self, req, resp, **kwargs) - assert async_action + await async_responder(self, req, resp, **kwargs) await async_action(req, resp, self, *action_args, **action_kwargs) + do_after_responder = cast('AsyncResponderMethod', do_after) else: + sync_action = cast('SyncAfterFn', action) + sync_responder = cast('SyncResponderMethod', responder) - @wraps(responder) + @wraps(sync_responder) def do_after( - self: ResponderOrResource, + self: Resource, req: wsgi.Request, resp: wsgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: Any, + **kwargs: Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - responder(self, req, resp, **kwargs) - action(req, resp, self, *action_args, **action_kwargs) + sync_responder(self, req, resp, **kwargs) + sync_action(req, resp, self, *action_args, **action_kwargs) - return do_after + do_after_responder = cast('SyncResponderMethod', do_after) + return do_after_responder def _wrap_with_before( responder: Responder, - action: Action, - action_args: t.Tuple[t.Any, ...], - action_kwargs: t.Dict[str, t.Any], + action: BeforeFn, + action_args: Tuple[Any, ...], + action_kwargs: Dict[str, Any], is_async: bool, -) -> t.Union[t.Callable[..., t.Awaitable[None]], t.Callable[..., None]]: +) -> Responder: """Execute the given action function before a responder method. Args: @@ -282,52 +344,57 @@ def _wrap_with_before( responder_argnames = get_argnames(responder) extra_argnames = responder_argnames[2:] # Skip req, resp + do_before_responder: Responder if is_async or iscoroutinefunction(responder): # NOTE(kgriffs): I manually verified that the implicit "else" branch # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - async_action = _wrap_non_coroutine_unsafe(action) + async_action = cast('AsyncBeforeFn', _wrap_non_coroutine_unsafe(action)) else: - async_action = action + async_action = cast('AsyncBeforeFn', action) + async_responder = cast('AsyncResponderMethod', responder) - @wraps(responder) + @wraps(async_responder) async def do_before( - self: ResponderOrResource, + self: Resource, req: asgi.Request, resp: asgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: Any, + **kwargs: Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - assert async_action await async_action(req, resp, self, kwargs, *action_args, **action_kwargs) - await responder(self, req, resp, **kwargs) + await async_responder(self, req, resp, **kwargs) + do_before_responder = cast('AsyncResponderMethod', do_before) else: + sync_action = cast('SyncBeforeFn', action) + sync_responder = cast('SyncResponderMethod', responder) - @wraps(responder) + @wraps(sync_responder) def do_before( - self: ResponderOrResource, + self: Resource, req: wsgi.Request, resp: wsgi.Response, - *args: t.Any, - **kwargs: t.Any, + *args: Any, + **kwargs: Any, ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - action(req, resp, self, kwargs, *action_args, **action_kwargs) - responder(self, req, resp, **kwargs) + sync_action(req, resp, self, kwargs, *action_args, **action_kwargs) + sync_responder(self, req, resp, **kwargs) - return do_before + do_before_responder = cast('SyncResponderMethod', do_before) + return do_before_responder def _merge_responder_args( - args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], argnames: t.List[str] + args: Tuple[Any, ...], kwargs: Dict[str, Any], argnames: List[str] ) -> None: """Merge responder args into kwargs. diff --git a/falcon/testing/resource.py b/falcon/testing/resource.py index c20854a3e..14e0854c3 100644 --- a/falcon/testing/resource.py +++ b/falcon/testing/resource.py @@ -23,12 +23,26 @@ resource = testing.SimpleTestResource() """ +from __future__ import annotations + from json import dumps as json_dumps +import typing import falcon +if typing.TYPE_CHECKING: # pragma: no cover + from falcon import app as wsgi + from falcon.asgi import app as asgi + from falcon.typing import HeaderList + from falcon.typing import Resource + -def capture_responder_args(req, resp, resource, params): +def capture_responder_args( + req: wsgi.Request, + resp: wsgi.Response, + resource: object, + params: typing.Mapping[str, str], +) -> None: """Before hook for capturing responder arguments. Adds the following attributes to the hooked responder's resource @@ -49,41 +63,53 @@ def capture_responder_args(req, resp, resource, params): * `capture-req-media` """ - resource.captured_req = req - resource.captured_resp = resp - resource.captured_kwargs = params + simple_resource = typing.cast(SimpleTestResource, resource) + simple_resource.captured_req = req + simple_resource.captured_resp = resp + simple_resource.captured_kwargs = params - resource.captured_req_media = None - resource.captured_req_body = None + simple_resource.captured_req_media = None + simple_resource.captured_req_body = None num_bytes = req.get_header('capture-req-body-bytes') if num_bytes: - resource.captured_req_body = req.stream.read(int(num_bytes)) + simple_resource.captured_req_body = req.stream.read(int(num_bytes)) elif req.get_header('capture-req-media'): - resource.captured_req_media = req.get_media() + simple_resource.captured_req_media = req.get_media() -async def capture_responder_args_async(req, resp, resource, params): +async def capture_responder_args_async( + req: asgi.Request, + resp: asgi.Response, + resource: Resource, + params: typing.Mapping[str, str], +) -> None: """Before hook for capturing responder arguments. An asynchronous version of :meth:`~falcon.testing.capture_responder_args`. """ - resource.captured_req = req - resource.captured_resp = resp - resource.captured_kwargs = params + simple_resource = typing.cast(SimpleTestResource, resource) + simple_resource.captured_req = req + simple_resource.captured_resp = resp + simple_resource.captured_kwargs = params - resource.captured_req_media = None - resource.captured_req_body = None + simple_resource.captured_req_media = None + simple_resource.captured_req_body = None num_bytes = req.get_header('capture-req-body-bytes') if num_bytes: - resource.captured_req_body = await req.stream.read(int(num_bytes)) + simple_resource.captured_req_body = await req.stream.read(int(num_bytes)) elif req.get_header('capture-req-media'): - resource.captured_req_media = await req.get_media() + simple_resource.captured_req_media = await req.get_media() -def set_resp_defaults(req, resp, resource, params): +def set_resp_defaults( + req: wsgi.Request, + resp: wsgi.Response, + resource: Resource, + params: typing.Mapping[str, str], +) -> None: """Before hook for setting default response properties. This hook simply sets the the response body, status, @@ -92,18 +118,23 @@ def set_resp_defaults(req, resp, resource, params): that are assumed to be defined on the resource object. """ + simple_resource = typing.cast(SimpleTestResource, resource) + if simple_resource._default_status is not None: + resp.status = simple_resource._default_status - if resource._default_status is not None: - resp.status = resource._default_status - - if resource._default_body is not None: - resp.text = resource._default_body + if simple_resource._default_body is not None: + resp.text = simple_resource._default_body - if resource._default_headers is not None: - resp.set_headers(resource._default_headers) + if simple_resource._default_headers is not None: + resp.set_headers(simple_resource._default_headers) -async def set_resp_defaults_async(req, resp, resource, params): +async def set_resp_defaults_async( + req: asgi.Request, + resp: asgi.Response, + resource: Resource, + params: typing.Mapping[str, str], +) -> None: """Wrap :meth:`~falcon.testing.set_resp_defaults` in a coroutine.""" set_resp_defaults(req, resp, resource, params) @@ -145,7 +176,13 @@ class SimpleTestResource: responder methods. """ - def __init__(self, status=None, body=None, json=None, headers=None): + def __init__( + self, + status: typing.Optional[str] = None, + body: typing.Optional[str] = None, + json: typing.Optional[dict[str, str]] = None, + headers: typing.Optional[HeaderList] = None, + ): self._default_status = status self._default_headers = headers @@ -154,14 +191,22 @@ def __init__(self, status=None, body=None, json=None, headers=None): msg = 'Either json or body may be specified, but not both' raise ValueError(msg) - self._default_body = json_dumps(json, ensure_ascii=False) + self._default_body: typing.Optional[str] = json_dumps( + json, ensure_ascii=False + ) else: self._default_body = body - self.captured_req = None - self.captured_resp = None - self.captured_kwargs = None + self.captured_req: typing.Optional[typing.Union[wsgi.Request, asgi.Request]] = ( + None + ) + self.captured_resp: typing.Optional[ + typing.Union[wsgi.Response, asgi.Response] + ] = None + self.captured_kwargs: typing.Optional[typing.Any] = None + self.captured_req_media: typing.Optional[typing.Any] = None + self.captured_req_body: typing.Optional[str] = None @property def called(self): @@ -169,12 +214,16 @@ def called(self): @falcon.before(capture_responder_args) @falcon.before(set_resp_defaults) - def on_get(self, req, resp, **kwargs): + def on_get( + self, req: wsgi.Request, resp: wsgi.Response, **kwargs: typing.Any + ) -> None: pass @falcon.before(capture_responder_args) @falcon.before(set_resp_defaults) - def on_post(self, req, resp, **kwargs): + def on_post( + self, req: wsgi.Request, resp: wsgi.Response, **kwargs: typing.Any + ) -> None: pass @@ -218,10 +267,14 @@ class SimpleTestResourceAsync(SimpleTestResource): @falcon.before(capture_responder_args_async) @falcon.before(set_resp_defaults_async) - async def on_get(self, req, resp, **kwargs): + async def on_get( # type: ignore[override] + self, req: asgi.Request, resp: asgi.Response, **kwargs: typing.Any + ) -> None: pass @falcon.before(capture_responder_args_async) @falcon.before(set_resp_defaults_async) - async def on_post(self, req, resp, **kwargs): + async def on_post( # type: ignore[override] + self, req: asgi.Request, resp: asgi.Response, **kwargs: typing.Any + ) -> None: pass diff --git a/falcon/typing.py b/falcon/typing.py index 4049d9ab8..6bb5315fc 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -22,12 +22,14 @@ Dict, List, Pattern, + Protocol, Tuple, TYPE_CHECKING, Union, ) if TYPE_CHECKING: + from falcon import asgi from falcon.request import Request from falcon.response import Response @@ -62,3 +64,30 @@ Headers = Dict[str, str] HeaderList = Union[Headers, List[Tuple[str, str]]] ResponseStatus = Union[http.HTTPStatus, str, int] + +Resource = object + + +class SyncResponderMethod(Protocol): + def __call__( + self, + resource: Resource, + req: Request, + resp: Response, + *args: Any, + **kwargs: Any, + ) -> None: ... + + +class AsyncResponderMethod(Protocol): + async def __call__( + self, + resource: Resource, + req: asgi.Request, + resp: asgi.Response, + *args: Any, + **kwargs: Any, + ) -> None: ... + + +Responder = Union[SyncResponderMethod, AsyncResponderMethod] diff --git a/falcon/util/uri.py b/falcon/util/uri.py index f9a772785..a2a324f02 100644 --- a/falcon/util/uri.py +++ b/falcon/util/uri.py @@ -554,7 +554,7 @@ def unquote_string(quoted: str) -> str: # TODO(vytas): Restructure this in favour of a cleaner way to hoist the pure # Cython functions into this module. -if not TYPE_CHECKING: +if not TYPE_CHECKING: # pragma: nocover if _cy_uri is not None: decode = _cy_uri.decode # NOQA parse_query_string = _cy_uri.parse_query_string # NOQA diff --git a/tests/test_after_hooks.py b/tests/test_after_hooks.py index 442788373..f6e53769f 100644 --- a/tests/test_after_hooks.py +++ b/tests/test_after_hooks.py @@ -1,10 +1,13 @@ import functools import json +import typing import pytest import falcon +from falcon import app as wsgi from falcon import testing +from falcon.typing import Resource # -------------------------------------------------------------------- # Fixtures @@ -340,8 +343,9 @@ class ResourceAwareGameHook: VALUES = ('rock', 'scissors', 'paper') @classmethod - def __call__(cls, req, resp, resource): + def __call__(cls, req: wsgi.Request, resp: wsgi.Response, resource: Resource): assert resource + resource = typing.cast(HandGame, resource) assert resource.seed in cls.VALUES assert resp.text == 'Responder called.'