diff --git a/beanie/odm/actions.py b/beanie/odm/actions.py index 1b7eadaf..99d4aebd 100644 --- a/beanie/odm/actions.py +++ b/beanie/odm/actions.py @@ -11,11 +11,17 @@ Optional, Tuple, Type, + TypeVar, Union, ) +from typing_extensions import ParamSpec + if TYPE_CHECKING: - from beanie.odm.documents import Document + from beanie.odm.documents import AsyncDocMethod, DocType, Document + +P = ParamSpec("P") +R = TypeVar("R") class EventTypes(str, Enum): @@ -136,10 +142,14 @@ async def run_actions( await asyncio.gather(*coros) +# `Any` because there is arbitrary attribute assignment on this type +F = TypeVar("F", bound=Any) + + def register_action( - event_types: Tuple[Union[List[EventTypes], EventTypes]], + event_types: Tuple[Union[List[EventTypes], EventTypes], ...], action_direction: ActionDirections, -): +) -> Callable[[F], F]: """ Decorator. Base registration method. Used inside `before_event` and `after_event` @@ -154,7 +164,7 @@ def register_action( else: final_event_types.append(event_type) - def decorator(f): + def decorator(f: F) -> F: f.has_action = True f.event_types = final_event_types f.action_direction = action_direction @@ -163,7 +173,9 @@ def decorator(f): return decorator -def before_event(*args: Union[List[EventTypes], EventTypes]): +def before_event( + *args: Union[List[EventTypes], EventTypes] +) -> Callable[[F], F]: """ Decorator. It adds action, which should run before mentioned one or many events happen @@ -172,11 +184,13 @@ def before_event(*args: Union[List[EventTypes], EventTypes]): :return: None """ return register_action( - action_direction=ActionDirections.BEFORE, event_types=args # type: ignore + action_direction=ActionDirections.BEFORE, event_types=args ) -def after_event(*args: Union[List[EventTypes], EventTypes]): +def after_event( + *args: Union[List[EventTypes], EventTypes] +) -> Callable[[F], F]: """ Decorator. It adds action, which should run after mentioned one or many events happen @@ -186,11 +200,15 @@ def after_event(*args: Union[List[EventTypes], EventTypes]): """ return register_action( - action_direction=ActionDirections.AFTER, event_types=args # type: ignore + action_direction=ActionDirections.AFTER, event_types=args ) -def wrap_with_actions(event_type: EventTypes): +def wrap_with_actions( + event_type: EventTypes, +) -> Callable[ + ["AsyncDocMethod[DocType, P, R]"], "AsyncDocMethod[DocType, P, R]" +]: """ Helper function to wrap Document methods with before and after event listeners @@ -198,14 +216,16 @@ def wrap_with_actions(event_type: EventTypes): :return: None """ - def decorator(f: Callable): + def decorator( + f: "AsyncDocMethod[DocType, P, R]", + ) -> "AsyncDocMethod[DocType, P, R]": @wraps(f) async def wrapper( - self, - *args, + self: "Document", + *args: P.args, skip_actions: Optional[List[Union[ActionDirections, str]]] = None, - **kwargs, - ): + **kwargs: P.kwargs, + ) -> R: if skip_actions is None: skip_actions = [] @@ -216,7 +236,12 @@ async def wrapper( exclude=skip_actions, ) - result = await f(self, *args, skip_actions=skip_actions, **kwargs) + result = await f( + self, + *args, + skip_actions=skip_actions, # type: ignore[arg-type] + **kwargs, + ) await ActionRegistry.run_actions( self, diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 94d7faf6..c5e8c0b9 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -3,6 +3,8 @@ from enum import Enum from typing import ( Any, + Awaitable, + Callable, ClassVar, Dict, Iterable, @@ -32,6 +34,7 @@ DeleteResult, InsertManyResult, ) +from typing_extensions import Concatenate, ParamSpec, TypeAlias from beanie.exceptions import ( CollectionWasNotInitialized, @@ -104,6 +107,10 @@ from pydantic import model_validator DocType = TypeVar("DocType", bound="Document") +P = ParamSpec("P") +R = TypeVar("R") +SyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], R] +AsyncDocMethod: TypeAlias = Callable[Concatenate[DocType, P], Awaitable[R]] DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel) @@ -529,7 +536,7 @@ async def save( link_rule: WriteRules = WriteRules.DO_NOTHING, ignore_revision: bool = False, **kwargs, - ) -> None: + ) -> DocType: """ Update an existing model in the database or insert it if it does not yet exist. @@ -605,12 +612,12 @@ async def save( @wrap_with_actions(EventTypes.SAVE_CHANGES) @validate_self_before async def save_changes( - self, + self: DocType, ignore_revision: bool = False, session: Optional[ClientSession] = None, bulk_writer: Optional[BulkWriter] = None, skip_actions: Optional[List[Union[ActionDirections, str]]] = None, - ) -> None: + ) -> Optional[DocType]: """ Save changes. State management usage must be turned on @@ -632,7 +639,7 @@ async def save_changes( ) else: return await self.set( - changes, # type: ignore #TODO fix typing + changes, ignore_revision=ignore_revision, session=session, bulk_writer=bulk_writer, @@ -741,13 +748,13 @@ def update_all( ) def set( - self, + self: DocType, expression: Dict[Union[ExpressionField, str], Any], session: Optional[ClientSession] = None, bulk_writer: Optional[BulkWriter] = None, skip_sync: Optional[bool] = None, **kwargs, - ): + ) -> Awaitable[DocType]: """ Set values diff --git a/beanie/odm/utils/self_validation.py b/beanie/odm/utils/self_validation.py index c20e1fb5..28462b68 100644 --- a/beanie/odm/utils/self_validation.py +++ b/beanie/odm/utils/self_validation.py @@ -1,13 +1,20 @@ from functools import wraps -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, TypeVar + +from typing_extensions import ParamSpec if TYPE_CHECKING: - from beanie.odm.documents import DocType + from beanie.odm.documents import AsyncDocMethod, DocType + +P = ParamSpec("P") +R = TypeVar("R") -def validate_self_before(f: Callable): +def validate_self_before( + f: "AsyncDocMethod[DocType, P, R]", +) -> "AsyncDocMethod[DocType, P, R]": @wraps(f) - async def wrapper(self: "DocType", *args, **kwargs): + async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: await self.validate_self(*args, **kwargs) return await f(self, *args, **kwargs) diff --git a/beanie/odm/utils/state.py b/beanie/odm/utils/state.py index 0879b869..9eb5c250 100644 --- a/beanie/odm/utils/state.py +++ b/beanie/odm/utils/state.py @@ -1,11 +1,16 @@ import inspect from functools import wraps -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, TypeVar, Union, overload + +from typing_extensions import ParamSpec, TypeAlias from beanie.exceptions import StateManagementIsTurnedOff, StateNotSaved if TYPE_CHECKING: - from beanie.odm.documents import DocType + from beanie.odm.documents import AsyncDocMethod, DocType, SyncDocMethod + +P = ParamSpec("P") +R = TypeVar("R") def check_if_state_saved(self: "DocType"): @@ -17,7 +22,26 @@ def check_if_state_saved(self: "DocType"): raise StateNotSaved("No state was saved") -def saved_state_needed(f: Callable): +AnyDocMethod: TypeAlias = Union[ + "AsyncDocMethod[DocType, P, R]", "SyncDocMethod[DocType, P, R]" +] + + +@overload +def saved_state_needed( + f: "AsyncDocMethod[DocType, P, R]", +) -> "AsyncDocMethod[DocType, P, R]": + ... + + +@overload +def saved_state_needed( + f: "SyncDocMethod[DocType, P, R]", +) -> "SyncDocMethod[DocType, P, R]": + ... + + +def saved_state_needed(f: Callable) -> Callable: @wraps(f) def sync_wrapper(self: "DocType", *args, **kwargs): check_if_state_saved(self) @@ -44,7 +68,21 @@ def check_if_previous_state_saved(self: "DocType"): ) -def previous_saved_state_needed(f: Callable): +@overload +def previous_saved_state_needed( + f: "AsyncDocMethod[DocType, P, R]", +) -> "AsyncDocMethod[DocType, P, R]": + ... + + +@overload +def previous_saved_state_needed( + f: "SyncDocMethod[DocType, P, R]", +) -> "SyncDocMethod[DocType, P, R]": + ... + + +def previous_saved_state_needed(f: Callable) -> Callable: @wraps(f) def sync_wrapper(self: "DocType", *args, **kwargs): check_if_previous_state_saved(self) @@ -60,9 +98,11 @@ async def async_wrapper(self: "DocType", *args, **kwargs): return sync_wrapper -def save_state_after(f: Callable): +def save_state_after( + f: "AsyncDocMethod[DocType, P, R]", +) -> "AsyncDocMethod[DocType, P, R]": @wraps(f) - async def wrapper(self: "DocType", *args, **kwargs): + async def wrapper(self: "DocType", *args: P.args, **kwargs: P.kwargs) -> R: result = await f(self, *args, **kwargs) self._save_state() return result