Skip to content

Commit

Permalink
Fix decorator typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Apr 17, 2024
1 parent 1ca45d3 commit 8b8300c
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 82 deletions.
6 changes: 3 additions & 3 deletions src/coaster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections.abc import Sequence
from threading import Lock
from typing import Any, Generic, NoReturn, TypeVar, cast
from typing import Any, NoReturn, TypeVar, cast

from flask import Flask, current_app, g
from flask.globals import request_ctx
Expand Down Expand Up @@ -269,13 +269,13 @@ def init_app(app: Flask) -> None:
_CurrentAuthType_co = TypeVar('_CurrentAuthType_co', bound=CurrentAuth, covariant=True)


class GetCurrentAuth(Generic[_CurrentAuthType_co]):
class GetCurrentAuth:
"""Helper for :attr:`current_auth` proxy to use a :class:`CurrentAuth` subclass."""

def __init__(self, cls: type[_CurrentAuthType_co]) -> None:
self.cls = cls

def __call__(self) -> _CurrentAuthType_co:
def __call__(self) -> CurrentAuth:
"""Provide :attr:`current_auth` for the request context."""
# 1. Do we have a request context?
if request_ctx:
Expand Down
63 changes: 58 additions & 5 deletions src/coaster/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from __future__ import annotations

from typing import Any, Callable, Protocol, TypeVar, overload
from typing_extensions import ParamSpec, Self
from collections.abc import Awaitable
from typing import Any, Callable, Optional, Protocol, TypeVar, Union, overload
from typing_extensions import ParamSpec, Self, TypeAlias

WrappedFunc = TypeVar('WrappedFunc', bound=Callable)
ReturnDecorator = Callable[[WrappedFunc], WrappedFunc]
ReturnDecorator: TypeAlias = Callable[[WrappedFunc], WrappedFunc]

_P = ParamSpec('_P')
_T = TypeVar('_T')
Expand All @@ -29,11 +30,63 @@ def __call__(
class Method(Protocol[_P, _R_co]):
"""Protocol for an instance method."""

__name__: str

# pylint: disable=no-self-argument
def __call__(__self, self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...

@overload
def __get__(self, obj: None, cls: type[_T]) -> Self: ...
def __get__(self, obj: None, cls: type[_T_contra]) -> Self: ...

@overload
def __get__(
self, obj: _T_contra, cls: type[_T_contra]
) -> BoundMethod[_T_contra, _P, _R_co]: ...

def __get__(
self, obj: Optional[_T_contra], cls: type[_T_contra]
) -> Union[Self, BoundMethod[_T_contra, _P, _R_co]]: ...


class BoundAsyncMethod(Protocol[_T_contra, _P, _R_co]):
"""Protocol for a bound instance method. See :class:`Method` for use."""

# pylint: disable=no-self-argument
def __call__(
__self, self: _T_contra, *args: _P.args, **kwargs: _P.kwargs
) -> Awaitable[_R_co]: ...


class AsyncMethod(Protocol[_P, _R_co]):
"""Protocol for an instance method."""

# pylint: disable=no-self-argument
def __call__(
__self, self: Any, *args: _P.args, **kwargs: _P.kwargs
) -> Awaitable[_R_co]: ...

@overload
def __get__(self, obj: None, cls: type[_T_contra]) -> Self: ...

@overload
def __get__(
self, obj: _T_contra, cls: type[_T_contra]
) -> BoundAsyncMethod[_T_contra, _P, _R_co]: ...

def __get__(
self, obj: Optional[_T_contra], cls: type[_T_contra]
) -> Union[Self, BoundAsyncMethod[_T_contra, _P, _R_co]]: ...


class MethodDecorator(Protocol):
"""Protocol for a transparent method decorator (no change in signature)."""

@overload
def __get__(self, obj: _T, cls: type[_T]) -> BoundMethod[_T, _P, _R_co]: ...
def __call__(self, __f: AsyncMethod[_P, _T]) -> AsyncMethod[_P, _T]: ...

@overload
def __call__(self, __f: Method[_P, _T]) -> Method[_P, _T]: ...

def __call__(
self, __f: Union[Method[_P, _T], AsyncMethod[_P, _T]]
) -> Union[Method[_P, _T], AsyncMethod[_P, _T]]: ...
84 changes: 51 additions & 33 deletions src/coaster/views/classview.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

from ..auth import add_auth_attribute, current_auth
from ..sqlalchemy import Query, UrlForMixin
from ..typing import Method, ReturnDecorator, WrappedFunc
from ..typing import Method
from ..utils import InspectableSet
from .misc import ensure_sync

Expand All @@ -82,7 +82,7 @@
# --- Types and protocols --------------------------------------------------------------

#: Type for URL rules in classviews
RouteRuleOptions = dict[str, Any]
RouteRuleOptions: TypeAlias = dict[str, Any]
ClassViewSubtype = TypeVar('ClassViewSubtype', bound='ClassView')
ClassViewType: TypeAlias = type[ClassViewSubtype]
ModelType = TypeVar('ModelType', default=Any)
Expand Down Expand Up @@ -354,7 +354,9 @@ class MyModelView(CrudView, ModelView[MyModel]):
def delete(self):
super().delete() # Call into base class's implementation if needed
"""
# Get the class, telling static type checkers to ignore generic type binding...
cls = cast(type[ViewMethod], self.__class__)
# ...then bind to the replacement generic types and use it
r: ViewMethod[_P2, _R2_co] = cls(__f, data=self.data)
r.routes = self.routes
return r
Expand Down Expand Up @@ -1048,10 +1050,12 @@ def after_loader( # pylint: disable=useless-return
ModelViewType = TypeVar('ModelViewType', bound=ModelView)


def requires_roles(roles: set[str]) -> ReturnDecorator:
def requires_roles(
roles: set[str],
) -> Callable[[Callable[_P, _R_co]], Callable[_P, _R_co]]:
"""Decorate to require specific roles in a :class:`ModelView` view."""

def decorator(f: WrappedFunc) -> WrappedFunc:
def decorator(f: Callable[_P, _R_co]) -> Callable[_P, _R_co]:
def is_available_here(context: ModelViewType) -> bool:
return context.obj.roles_for(
current_auth.actor, current_auth.anchors
Expand All @@ -1069,20 +1073,25 @@ def validate(context: ModelViewType) -> None:
if not is_available_here(context):
abort(403)

@wraps(f)
def wrapper(self: ModelViewType, *args: Any, **kwargs: Any) -> Any:
validate(self)
return f(self, *args, **kwargs)
if iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(self: ModelViewType, *args: Any, **kwargs: Any) -> Any:
validate(self)
return await f(self, *args, **kwargs)
@wraps(f)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
validate(args[0]) # type: ignore[type-var]
return await f(*args, **kwargs)

# Fix type hint for the return type
wrapper = cast(Callable[_P, _R_co], async_wrapper)
else:

use_wrapper = async_wrapper if iscoroutinefunction(f) else wrapper
use_wrapper.requires_roles = roles # type: ignore[attr-defined]
use_wrapper.is_available = is_available # type: ignore[attr-defined]
return cast(WrappedFunc, use_wrapper)
@wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R_co:
validate(args[0]) # type: ignore[type-var]
return f(*args, **kwargs)

wrapper.requires_roles = roles # type: ignore[attr-defined]
wrapper.is_available = is_available # type: ignore[attr-defined]
return wrapper

return decorator

Expand Down Expand Up @@ -1193,7 +1202,9 @@ def blueprint_postprocess(state: BlueprintSetupState) -> None:
)


def url_change_check(f: WrappedFunc) -> WrappedFunc:
def url_change_check(
f: Callable[_P, _R_co]
) -> Callable[_P, Union[_R_co, BaseResponse]]:
"""
Decorate view method in a :class:`ModelView` to check for a change in URL.
Expand Down Expand Up @@ -1225,7 +1236,7 @@ def view(self):
(``#target_id``) is not available to the server and will be lost.
"""

def validate(context: ModelView) -> Optional[ResponseReturnValue]:
def validate(context: ModelView) -> Optional[BaseResponse]:
if request.method == 'GET' and getattr(context, 'obj', None) is not None:
correct_url = furl(context.obj.url_for(f.__name__, _external=True))
stripped_url = correct_url.copy().remove(
Expand All @@ -1243,21 +1254,28 @@ def validate(context: ModelView) -> Optional[ResponseReturnValue]:
)
return None

@wraps(f)
def wrapper(self: ModelView, *args, **kwargs) -> Any:
retval = validate(self)
if retval is not None:
return retval
return f(self, *args, **kwargs)

@wraps(f)
async def async_wrapper(self: ModelView, *args, **kwargs) -> Any:
retval = validate(self)
if retval is not None:
return retval
return await f(self, *args, **kwargs)

return cast(WrappedFunc, async_wrapper if iscoroutinefunction(f) else wrapper)
if iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
retval = validate(args[0]) # type: ignore[arg-type]
if retval is not None:
return retval
return await f(*args, **kwargs)

# Fix return type hint
wrapper = cast(Callable[_P, Union[_R_co, BaseResponse]], async_wrapper)

else:

@wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Union[_R_co, BaseResponse]:
retval = validate(args[0]) # type: ignore[arg-type]
if retval is not None:
return retval
return f(*args, **kwargs)

return wrapper


class UrlChangeCheck:
Expand Down
Loading

0 comments on commit 8b8300c

Please sign in to comment.