diff --git a/databases/core.py b/databases/core.py index 0e27227c..efa59471 100644 --- a/databases/core.py +++ b/databases/core.py @@ -315,6 +315,9 @@ def _build_query( return query +_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) + + class Transaction: def __init__( self, @@ -347,13 +350,13 @@ async def __aexit__( else: await self.commit() - def __await__(self) -> typing.Generator: + def __await__(self) -> typing.Generator[None, None, "Transaction"]: """ Called if using the low-level `transaction = await database.transaction()` """ return self.start().__await__() - def __call__(self, func: typing.Callable) -> typing.Callable: + def __call__(self, func: _CallableType) -> _CallableType: """ Called if using `@database.transaction()` as a decorator. """ @@ -363,7 +366,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async with self: return await func(*args, **kwargs) - return wrapper + return wrapper # type: ignore async def start(self) -> "Transaction": self._connection = self._connection_callable()