diff --git a/falcon/hooks.py b/falcon/hooks.py index 354f0d0e1..45c477d24 100644 --- a/falcon/hooks.py +++ b/falcon/hooks.py @@ -33,14 +33,15 @@ r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS)) ) -SynchronousResource = t.Callable[..., t.Any] -AsynchronousResource = t.Callable[..., t.Awaitable[t.Any]] -Resource = t.Union[SynchronousResource, AsynchronousResource] +Resource = object +Responder = t.Callable +ResponderOrResource = t.Union[Responder, Resource] +Action = t.Callable def before( - action: Resource, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[Resource], Resource]: + action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any +) -> t.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *before* the responder. The `params` argument that is passed to the hook @@ -90,7 +91,7 @@ def do_something(req, resp, resource, params): *action*. """ - def _before(responder_or_resource: Resource) -> Resource: + def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): resource = responder_or_resource @@ -100,7 +101,9 @@ def _before(responder_or_resource: Resource) -> Resource: # responder in the do_before_all closure; otherwise, they # will capture the same responder variable that is shared # between iterations of the for loop, above. - def let(responder: Resource = responder) -> None: + responder = t.cast(Responder, responder) + + def let(responder: Responder = responder) -> None: do_before_all = _wrap_with_before( responder, action, args, kwargs, is_async ) @@ -112,7 +115,7 @@ def let(responder: Resource = responder) -> None: return resource else: - responder = responder_or_resource + responder = t.cast(Responder, responder_or_resource) do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async) return do_before_one @@ -121,8 +124,8 @@ def let(responder: Resource = responder) -> None: def after( - action: Resource, *args: t.Any, is_async: bool = False, **kwargs: t.Any -) -> t.Callable[[Resource], Resource]: + action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any +) -> t.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *after* the responder. Args: @@ -155,14 +158,15 @@ def after( *action*. """ - def _after(responder_or_resource: Resource) -> Resource: + def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): - resource = responder_or_resource + resource = t.cast(Resource, responder_or_resource) for responder_name, responder in getmembers(resource, callable): if _DECORABLE_METHOD_NAME.match(responder_name): + responder = t.cast(Responder, responder) - def let(responder: Resource = responder) -> None: + def let(responder: Responder = responder) -> None: do_after_all = _wrap_with_after( responder, action, args, kwargs, is_async ) @@ -174,7 +178,7 @@ def let(responder: Resource = responder) -> None: return resource else: - responder = responder_or_resource + responder = t.cast(Responder, responder_or_resource) do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async) return do_after_one @@ -188,12 +192,12 @@ def let(responder: Resource = responder) -> None: def _wrap_with_after( - responder: Resource, - action: Resource, + responder: Responder, + action: Action, action_args: t.Any, action_kwargs: t.Any, is_async: bool, -) -> Resource: +) -> Responder: """Execute the given action function after a responder method. Args: @@ -222,7 +226,7 @@ def _wrap_with_after( @wraps(responder) async def do_after( - self: Resource, + self: ResponderOrResource, req: asgi.Request, resp: asgi.Response, *args: t.Any, @@ -239,7 +243,7 @@ async def do_after( @wraps(responder) def do_after( - self: Resource, + self: ResponderOrResource, req: wsgi.Request, resp: wsgi.Response, *args: t.Any, @@ -255,8 +259,8 @@ def do_after( def _wrap_with_before( - responder: Resource, - action: Resource, + responder: Responder, + action: Action, action_args: t.Tuple[t.Any, ...], action_kwargs: t.Dict[str, t.Any], is_async: bool, @@ -289,7 +293,7 @@ def _wrap_with_before( @wraps(responder) async def do_before( - self: Resource, + self: ResponderOrResource, req: asgi.Request, resp: asgi.Response, *args: t.Any, @@ -306,7 +310,7 @@ async def do_before( @wraps(responder) def do_before( - self: Resource, + self: ResponderOrResource, req: wsgi.Request, resp: wsgi.Response, *args: t.Any,