Skip to content

Commit

Permalink
Improve responders typing
Browse files Browse the repository at this point in the history
  • Loading branch information
copalco committed Nov 13, 2023
1 parent b2b94d8 commit 53df7c0
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 93 deletions.
158 changes: 102 additions & 56 deletions falcon/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,64 @@
from inspect import getmembers
from inspect import iscoroutinefunction
import re
import typing as t
import typing

from falcon.constants import COMBINED_METHODS
from falcon.util.misc import get_argnames
from falcon.util.sync import _wrap_non_coroutine_unsafe

if t.TYPE_CHECKING: # pragma: no cover
if typing.TYPE_CHECKING: # pragma: no cover
import falcon as wsgi
from falcon import asgi

ResponderParams = typing.ParamSpec('ResponderParams')

class SyncResponder(typing.Protocol[ResponderParams]):
def __call__(
self,
responder: SyncResponderOrResource,
req: wsgi.Request,
resp: wsgi.Response,
*args: ResponderParams.args,
**kwargs: ResponderParams.kwargs,
) -> None:
...

class AsyncResponder(typing.Protocol):
async def __call__(
self,
responder: AsyncResponderOrResource,
req: asgi.Request,
resp: asgi.Response,
*args: ResponderParams.args,
**kwargs: ResponderParams.kwargs,
) -> None:
...

Responder = typing.Union[SyncResponder, AsyncResponder]
Resource = object
SyncResponderOrResource = typing.Union[SyncResponder, Resource]
AsyncResponderOrResource = typing.Union[AsyncResponder, Resource]
ResponderOrResource = typing.Union[Responder, Resource]
SynchronousAction = typing.Callable[..., typing.Any]
AsynchronousAction = typing.Callable[..., typing.Awaitable[typing.Any]]
Action = typing.Union[SynchronousAction, AsynchronousAction]
else:
Resource = object
SynchronousAction = typing.Callable[..., typing.Any]
AsynchronousAction = typing.Callable[..., typing.Awaitable[typing.Any]]
SyncResponder = typing.Callable
AsyncResponder = typing.Awaitable
Responder = typing.Union[SyncResponder, AsyncResponder]

_DECORABLE_METHOD_NAME = re.compile(
r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS))
)

Resource = object
Responder = t.Callable
ResponderOrResource = t.Union[Responder, Resource]
Action = t.Callable


def before(
action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[ResponderOrResource], ResponderOrResource]:
action: Action, *args: typing.Any, is_async: bool = False, **kwargs: typing.Any
) -> typing.Callable[[ResponderOrResource], ResponderOrResource]:
"""Execute the given action function *before* the responder.
The `params` argument that is passed to the hook
Expand Down Expand Up @@ -93,29 +128,28 @@ def do_something(req, resp, resource, params):

def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource:
if isinstance(responder_or_resource, type):
resource = responder_or_resource

for responder_name, responder in getmembers(resource, callable):
for responder_name, responder in getmembers(
responder_or_resource, callable
):
if _DECORABLE_METHOD_NAME.match(responder_name):
# This pattern is necessary to capture the current value of
# responder in the do_before_all closure; otherwise, they
# will capture the same responder variable that is shared
# between iterations of the for loop, above.
responder = t.cast(Responder, responder)

def let(responder: Responder = responder) -> None:
def let(responder: typing.Callable = responder) -> None:
do_before_all = _wrap_with_before(
responder, action, args, kwargs, is_async
)

setattr(resource, responder_name, do_before_all)
setattr(responder_or_resource, responder_name, do_before_all)

let()

return resource
return responder_or_resource

else:
responder = t.cast(Responder, responder_or_resource)
responder = typing.cast(Responder, responder_or_resource)
do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async)

return do_before_one
Expand All @@ -124,8 +158,8 @@ def let(responder: Responder = responder) -> None:


def after(
action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any
) -> t.Callable[[ResponderOrResource], ResponderOrResource]:
action: Action, *args: typing.Any, is_async: bool = False, **kwargs: typing.Any
) -> typing.Callable[[ResponderOrResource], ResponderOrResource]:
"""Execute the given action function *after* the responder.
Args:
Expand Down Expand Up @@ -160,25 +194,24 @@ def after(

def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource:
if isinstance(responder_or_resource, type):
resource = t.cast(Resource, responder_or_resource)

for responder_name, responder in getmembers(resource, callable):
for responder_name, responder in getmembers(
responder_or_resource, callable
):
if _DECORABLE_METHOD_NAME.match(responder_name):
responder = t.cast(Responder, responder)

def let(responder: Responder = responder) -> None:
def let(responder: Responder | typing.Callable = responder) -> None:
do_after_all = _wrap_with_after(
responder, action, args, kwargs, is_async
)

setattr(resource, responder_name, do_after_all)
setattr(responder_or_resource, responder_name, do_after_all)

let()

return resource
return responder_or_resource

else:
responder = t.cast(Responder, responder_or_resource)
responder = typing.cast(Responder, responder_or_resource)
do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async)

return do_after_one
Expand All @@ -194,8 +227,8 @@ def let(responder: Responder = responder) -> None:
def _wrap_with_after(
responder: Responder,
action: Action,
action_args: t.Any,
action_kwargs: t.Any,
action_args: typing.Any,
action_kwargs: typing.Any,
is_async: bool,
) -> Responder:
"""Execute the given action function after a responder method.
Expand All @@ -214,57 +247,62 @@ def _wrap_with_after(

responder_argnames = get_argnames(responder)
extra_argnames = responder_argnames[2:] # Skip req, resp
do_after_responder: Responder

if is_async or iscoroutinefunction(responder):
# NOTE(kgriffs): I manually verified that the implicit "else" branch
# is actually covered, but coverage isn't tracking it for
# some reason.
if not is_async: # pragma: nocover
async_action = _wrap_non_coroutine_unsafe(action)
async_action = typing.cast(
AsynchronousAction, _wrap_non_coroutine_unsafe(action)
)
else:
async_action = action
async_action = typing.cast(AsynchronousAction, action)
async_responder = typing.cast(AsyncResponder, responder)

@wraps(responder)
async def do_after(
self: ResponderOrResource,
self: AsyncResponderOrResource,
req: asgi.Request,
resp: asgi.Response,
*args: t.Any,
**kwargs: t.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
if args:
_merge_responder_args(args, kwargs, extra_argnames)

await responder(self, req, resp, **kwargs)
assert async_action
await async_responder(self, req, resp, **kwargs)
await async_action(req, resp, self, *action_args, **action_kwargs)

Check warning on line 275 in falcon/hooks.py

View check run for this annotation

Codecov / codecov/patch

falcon/hooks.py#L274-L275

Added lines #L274 - L275 were not covered by tests

do_after_responder = typing.cast(AsyncResponder, do_after)
else:
responder = typing.cast(SyncResponder, responder)

@wraps(responder)
def do_after(
self: ResponderOrResource,
self: SyncResponderOrResource,
req: wsgi.Request,
resp: wsgi.Response,
*args: t.Any,
**kwargs: t.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
if args:
_merge_responder_args(args, kwargs, extra_argnames)

responder(self, req, resp, **kwargs)
action(req, resp, self, *action_args, **action_kwargs)

return do_after
do_after_responder = typing.cast(SyncResponder, do_after)
return do_after_responder


def _wrap_with_before(
responder: Responder,
action: Action,
action_args: t.Tuple[t.Any, ...],
action_kwargs: t.Dict[str, t.Any],
action_args: typing.Tuple[typing.Any, ...],
action_kwargs: typing.Dict[str, typing.Any],
is_async: bool,
) -> t.Union[t.Callable[..., t.Awaitable[None]], t.Callable[..., None]]:
) -> Responder:
"""Execute the given action function before a responder method.
Args:
Expand All @@ -281,52 +319,60 @@ def _wrap_with_before(

responder_argnames = get_argnames(responder)
extra_argnames = responder_argnames[2:] # Skip req, resp
do_before_responder: Responder

if is_async or iscoroutinefunction(responder):
# NOTE(kgriffs): I manually verified that the implicit "else" branch
# is actually covered, but coverage isn't tracking it for
# some reason.
if not is_async: # pragma: nocover
async_action = _wrap_non_coroutine_unsafe(action)
async_action = typing.cast(
AsynchronousAction, _wrap_non_coroutine_unsafe(action)
)
else:
async_action = action
async_action = typing.cast(AsynchronousAction, action)

Check warning on line 333 in falcon/hooks.py

View check run for this annotation

Codecov / codecov/patch

falcon/hooks.py#L333

Added line #L333 was not covered by tests
async_responder = typing.cast(AsyncResponder, responder)

@wraps(responder)
async def do_before(
self: ResponderOrResource,
self: AsyncResponderOrResource,
req: asgi.Request,
resp: asgi.Response,
*args: t.Any,
**kwargs: t.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
if args:
_merge_responder_args(args, kwargs, extra_argnames)

assert async_action
await async_action(req, resp, self, kwargs, *action_args, **action_kwargs)
await responder(self, req, resp, **kwargs)
await async_responder(self, req, resp, **kwargs)

Check warning on line 348 in falcon/hooks.py

View check run for this annotation

Codecov / codecov/patch

falcon/hooks.py#L347-L348

Added lines #L347 - L348 were not covered by tests

do_before_responder = typing.cast(AsyncResponder, do_before)
else:
responder = typing.cast(SyncResponder, responder)

@wraps(responder)
def do_before(
self: ResponderOrResource,
self: SyncResponderOrResource,
req: wsgi.Request,
resp: wsgi.Response,
*args: t.Any,
**kwargs: t.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> None:
if args:
_merge_responder_args(args, kwargs, extra_argnames)

action(req, resp, self, kwargs, *action_args, **action_kwargs)
responder(self, req, resp, **kwargs)

return do_before
do_before_responder = typing.cast(SyncResponder, do_before)
return do_before_responder


def _merge_responder_args(
args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], argnames: t.List[str]
args: typing.Tuple[typing.Any, ...],
kwargs: typing.Dict[str, typing.Any],
argnames: typing.List[str],
) -> None:
"""Merge responder args into kwargs.
Expand Down

0 comments on commit 53df7c0

Please sign in to comment.