Skip to content

Commit

Permalink
Fix decorator typing (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
jace committed Apr 17, 2024
1 parent a294be6 commit 4b8884e
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 144 deletions.
6 changes: 3 additions & 3 deletions src/coaster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from collections.abc import Sequence
from threading import Lock
from typing import Any, Generic, NoReturn, TypeVar, cast
from typing import Any, NoReturn, TypeVar, cast

from flask import Flask, current_app, g
from flask.globals import request_ctx
Expand Down Expand Up @@ -269,13 +269,13 @@ def init_app(app: Flask) -> None:
_CurrentAuthType_co = TypeVar('_CurrentAuthType_co', bound=CurrentAuth, covariant=True)


class GetCurrentAuth(Generic[_CurrentAuthType_co]):
class GetCurrentAuth:
"""Helper for :attr:`current_auth` proxy to use a :class:`CurrentAuth` subclass."""

def __init__(self, cls: type[_CurrentAuthType_co]) -> None:
self.cls = cls

def __call__(self) -> _CurrentAuthType_co:
def __call__(self) -> CurrentAuth:
"""Provide :attr:`current_auth` for the request context."""
# 1. Do we have a request context?
if request_ctx:
Expand Down
2 changes: 1 addition & 1 deletion src/coaster/sqlalchemy/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def text(self, value: Optional[str]) -> None:
)
self.changed()

def __json__(self) -> dict[str, Optional[str]]:
def __json__(self) -> Any:
"""Return JSON-compatible rendering of composite."""
return {'text': self._text, 'html': self._html}

Expand Down
2 changes: 1 addition & 1 deletion src/coaster/sqlalchemy/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def __setitem__(self, key: str, value: str) -> None:
def __iter__(self) -> Iterator[str]:
yield from self._all_read

def __json__(self) -> dict[str, Any]:
def __json__(self) -> Any:
if self._datasets is None and self._obj.__json_datasets__:
# This proxy was created without specifying datasets, so we create a new
# proxy using the object's default JSON datasets, then convert it to a dict
Expand Down
25 changes: 7 additions & 18 deletions src/coaster/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,17 @@
from __future__ import annotations

from typing import Any, Callable, Protocol, TypeVar
from typing_extensions import Concatenate, ParamSpec
from typing_extensions import ParamSpec, TypeAlias

WrappedFunc = TypeVar('WrappedFunc', bound=Callable)
#: Return type for decorator factories
ReturnDecorator = Callable[[WrappedFunc], WrappedFunc]
ReturnDecorator: TypeAlias = Callable[[WrappedFunc], WrappedFunc]

#: Recurring use ParamSpec
_P = ParamSpec('_P')
#: Recurring use type spec
_T = TypeVar('_T')
_R_co = TypeVar('_R_co', covariant=True)


class MethodProtocol(Protocol[_P, _T]):
"""
Protocol that matches a method without also matching against a type constructor.
class Method(Protocol[_P, _R_co]):
"""Protocol for an instance method."""

Replace ``Callable[Concatenate[Any, P], R]`` with ``MethodProtocol[Concatenate[Any,
P], R]``. This is needed because the typeshed defines ``type.__call__``, so any type
will also validate as a callable. Mypy special-cases callable protocols as not
matching ``type.__call__`` in https://github.com/python/mypy/pull/14121.
"""

# Using ``def __call__`` seems to break Mypy, so we use this hack
# https://github.com/python/typing/discussions/1312#discussioncomment-4416217
__call__: Callable[Concatenate[Any, _P], _T]
# pylint: disable=no-self-argument
def __call__(__self, self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ...
118 changes: 61 additions & 57 deletions src/coaster/views/classview.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

from ..auth import add_auth_attribute, current_auth
from ..sqlalchemy import Query, UrlForMixin
from ..typing import MethodProtocol, ReturnDecorator, WrappedFunc
from ..typing import Method
from ..utils import InspectableSet
from .misc import ensure_sync

Expand All @@ -82,7 +82,7 @@
# --- Types and protocols --------------------------------------------------------------

#: Type for URL rules in classviews
RouteRuleOptions = dict[str, Any]
RouteRuleOptions: TypeAlias = dict[str, Any]
ClassViewSubtype = TypeVar('ClassViewSubtype', bound='ClassView')
ClassViewType: TypeAlias = type[ClassViewSubtype]
ModelType = TypeVar('ModelType', default=Any)
Expand All @@ -106,15 +106,13 @@ def __call__(self, __decorated: ClassViewType) -> ClassViewType: ...
def __call__(self, __decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

@overload
def __call__(
self, __decorated: MethodProtocol[Concatenate[Any, _P], _R_co]
) -> ViewMethod[_P, _R_co]: ...
def __call__(self, __decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

def __call__( # skipcq: PTC-W0049
self,
__decorated: Union[
ClassViewType,
MethodProtocol[Concatenate[Any, _P], _R_co],
Method[_P, _R_co],
ViewMethod[_P, _R_co],
],
) -> Union[ClassViewType, ViewMethod[_P, _R_co]]: ...
Expand All @@ -127,15 +125,11 @@ class ViewDataDecoratorProtocol(Protocol):
def __call__(self, __decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

@overload
def __call__(
self, __decorated: MethodProtocol[Concatenate[Any, _P], _R_co]
) -> ViewMethod[_P, _R_co]: ...
def __call__(self, __decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

def __call__( # skipcq: PTC-W0049
self,
__decorated: Union[
MethodProtocol[Concatenate[Any, _P], _R_co], ViewMethod[_P, _R_co]
],
__decorated: Union[Method[_P, _R_co], ViewMethod[_P, _R_co]],
) -> ViewMethod[_P, _R_co]: ...


Expand Down Expand Up @@ -201,14 +195,12 @@ def decorator(decorated: ClassViewType) -> ClassViewType: ...
def decorator(decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

@overload
def decorator(
decorated: MethodProtocol[Concatenate[Any, _P], _R_co]
) -> ViewMethod[_P, _R_co]: ...
def decorator(decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

def decorator(
decorated: Union[
ClassViewType,
MethodProtocol[Concatenate[Any, _P], _R_co],
Method[_P, _R_co],
ViewMethod[_P, _R_co],
]
) -> Union[ClassViewType, ViewMethod[_P, _R_co]]:
Expand Down Expand Up @@ -244,14 +236,10 @@ def viewdata(**kwargs: Any) -> ViewDataDecoratorProtocol:
def decorator(decorated: ViewMethod[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

@overload
def decorator(
decorated: MethodProtocol[Concatenate[Any, _P], _R_co]
) -> ViewMethod[_P, _R_co]: ...
def decorator(decorated: Method[_P, _R_co]) -> ViewMethod[_P, _R_co]: ...

def decorator(
decorated: Union[
ViewMethod[_P, _R_co], MethodProtocol[Concatenate[Any, _P], _R_co]
]
decorated: Union[ViewMethod[_P, _R_co], Method[_P, _R_co]]
) -> ViewMethod[_P, _R_co]:
return ViewMethod(decorated, data=kwargs)

Expand Down Expand Up @@ -348,9 +336,7 @@ def __repr__(self) -> str:

def replace(
self,
__f: Union[
ViewMethod[_P2, _R2_co], MethodProtocol[Concatenate[Any, _P2], _R2_co]
],
__f: Union[ViewMethod[_P2, _R2_co], Method[_P2, _R2_co]],
) -> ViewMethod[_P2, _R2_co]:
"""
Replace a view method in a subclass while keeping its URL routes.
Expand All @@ -368,7 +354,9 @@ class MyModelView(CrudView, ModelView[MyModel]):
def delete(self):
super().delete() # Call into base class's implementation if needed
"""
# Get the class, telling static type checkers to ignore generic type binding...
cls = cast(type[ViewMethod], self.__class__)
# ...then bind to the replacement generic types and use it
r: ViewMethod[_P2, _R2_co] = cls(__f, data=self.data)
r.routes = self.routes
return r
Expand Down Expand Up @@ -580,7 +568,7 @@ def __repr__(self) -> str:
return f'<ViewMethodBind {self.__qualname__}>'

def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co:
"""Treat this like a call to the original method and not to the view."""
# Treat this like a call to the original method and not to the view.
# As per the __decorators__ spec, we call .__func__, not .decorated_func
return self._view_method.__func__(self.__self__, *args, **kwargs)

Expand Down Expand Up @@ -1062,10 +1050,12 @@ def after_loader( # pylint: disable=useless-return
ModelViewType = TypeVar('ModelViewType', bound=ModelView)


def requires_roles(roles: set[str]) -> ReturnDecorator:
def requires_roles(
roles: set[str],
) -> Callable[[Callable[_P, _R_co]], Callable[_P, _R_co]]:
"""Decorate to require specific roles in a :class:`ModelView` view."""

def decorator(f: WrappedFunc) -> WrappedFunc:
def decorator(f: Callable[_P, _R_co]) -> Callable[_P, _R_co]:
def is_available_here(context: ModelViewType) -> bool:
return context.obj.roles_for(
current_auth.actor, current_auth.anchors
Expand All @@ -1083,20 +1073,25 @@ def validate(context: ModelViewType) -> None:
if not is_available_here(context):
abort(403)

@wraps(f)
def wrapper(self: ModelViewType, *args: Any, **kwargs: Any) -> Any:
validate(self)
return f(self, *args, **kwargs)
if iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(self: ModelViewType, *args: Any, **kwargs: Any) -> Any:
validate(self)
return await f(self, *args, **kwargs)
@wraps(f)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
validate(args[0]) # type: ignore[type-var]
return await f(*args, **kwargs)

# Fix type hint for the return type
wrapper = cast(Callable[_P, _R_co], async_wrapper)
else:

use_wrapper = async_wrapper if iscoroutinefunction(f) else wrapper
use_wrapper.requires_roles = roles # type: ignore[attr-defined]
use_wrapper.is_available = is_available # type: ignore[attr-defined]
return cast(WrappedFunc, use_wrapper)
@wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R_co:
validate(args[0]) # type: ignore[type-var]
return f(*args, **kwargs)

wrapper.requires_roles = roles # type: ignore[attr-defined]
wrapper.is_available = is_available # type: ignore[attr-defined]
return wrapper

return decorator

Expand Down Expand Up @@ -1207,7 +1202,9 @@ def blueprint_postprocess(state: BlueprintSetupState) -> None:
)


def url_change_check(f: WrappedFunc) -> WrappedFunc:
def url_change_check(
f: Callable[_P, _R_co]
) -> Callable[_P, Union[_R_co, BaseResponse]]:
"""
Decorate view method in a :class:`ModelView` to check for a change in URL.
Expand Down Expand Up @@ -1239,7 +1236,7 @@ def view(self):
(``#target_id``) is not available to the server and will be lost.
"""

def validate(context: ModelView) -> Optional[ResponseReturnValue]:
def validate(context: ModelView) -> Optional[BaseResponse]:
if request.method == 'GET' and getattr(context, 'obj', None) is not None:
correct_url = furl(context.obj.url_for(f.__name__, _external=True))
stripped_url = correct_url.copy().remove(
Expand All @@ -1257,21 +1254,28 @@ def validate(context: ModelView) -> Optional[ResponseReturnValue]:
)
return None

@wraps(f)
def wrapper(self: ModelView, *args, **kwargs) -> Any:
retval = validate(self)
if retval is not None:
return retval
return f(self, *args, **kwargs)

@wraps(f)
async def async_wrapper(self: ModelView, *args, **kwargs) -> Any:
retval = validate(self)
if retval is not None:
return retval
return await f(self, *args, **kwargs)

return cast(WrappedFunc, async_wrapper if iscoroutinefunction(f) else wrapper)
if iscoroutinefunction(f):

@wraps(f)
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
retval = validate(args[0]) # type: ignore[arg-type]
if retval is not None:
return retval
return await f(*args, **kwargs)

# Fix return type hint
wrapper = cast(Callable[_P, Union[_R_co, BaseResponse]], async_wrapper)

else:

@wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Union[_R_co, BaseResponse]:
retval = validate(args[0]) # type: ignore[arg-type]
if retval is not None:
return retval
return f(*args, **kwargs)

return wrapper


class UrlChangeCheck:
Expand Down
Loading

0 comments on commit 4b8884e

Please sign in to comment.