Skip to content

Commit

Permalink
Use ParamSpec for DbusMethodAsync instead of masking it
Browse files Browse the repository at this point in the history
Originally the DbusMethodAsync was masked under the original
function meaning type checker treated it as it was the original
function and not a DbusMethodAsync.

However, it is planned to add new methods to the DbusMethodAsync
so masking it is no longer an option.

ParamSpec is only available since Python 3.10 so use the
`typing_extensions` import hidden under TYPE_CHECKING if statement.
  • Loading branch information
igo95862 committed Feb 25, 2024
1 parent 38485ff commit 72e9d52
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 27 deletions.
87 changes: 65 additions & 22 deletions src/sdbus/dbus_proxy_async_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from contextvars import ContextVar, copy_context
from inspect import iscoroutinefunction
from types import FunctionType
from typing import TYPE_CHECKING, cast, overload
from typing import TYPE_CHECKING, Generic, TypeVar, cast, overload
from weakref import ref as weak_ref

from .dbus_common_elements import (
Expand All @@ -36,14 +36,27 @@
from .sd_bus_internals import DbusNoReplyFlag

if TYPE_CHECKING:
from typing import Any, Callable, Optional, Sequence, Type, TypeVar, Union
from typing import (
Any,
Callable,
Coroutine,
Optional,
Sequence,
Type,
Union,
)

from typing_extensions import Concatenate, ParamSpec

from .dbus_proxy_async_interface_base import DbusInterfaceBaseAsync
from .sd_bus_internals import SdBusMessage

T = TypeVar('T')
TDBI = TypeVar("TDBI", bound=DbusInterfaceBaseAsync)
P = ParamSpec("P")
else:
T = None
P = TypeVar("P")

TR = TypeVar("TR")

CURRENT_MESSAGE: ContextVar[SdBusMessage] = ContextVar('CURRENT_MESSAGE')

Expand All @@ -52,29 +65,33 @@ def get_current_message() -> SdBusMessage:
return CURRENT_MESSAGE.get()


class DbusMethodAsync(DbusMethodCommon, DbusSomethingAsync):
class DbusMethodAsync(
DbusMethodCommon,
DbusSomethingAsync,
Generic[P, TR],
):

@overload
def __get__(
self,
obj: None,
obj_class: Type[DbusInterfaceBaseAsync],
) -> DbusMethodAsync:
) -> DbusMethodAsync[P, TR]:
...

@overload
def __get__(
self,
obj: DbusInterfaceBaseAsync,
obj_class: Type[DbusInterfaceBaseAsync],
) -> Callable[..., Any]:
) -> DbusMethodAsyncBaseBind[P, TR]:
...

def __get__(
self,
obj: Optional[DbusInterfaceBaseAsync],
obj_class: Optional[Type[DbusInterfaceBaseAsync]] = None,
) -> Union[Callable[..., Any], DbusMethodAsync]:
) -> Union[DbusMethodAsyncBaseBind[P, TR], DbusMethodAsync[P, TR]]:
if obj is not None:
dbus_meta = obj._dbus
if isinstance(dbus_meta, DbusRemoteObjectMeta):
Expand All @@ -85,16 +102,22 @@ def __get__(
return self


class DbusMethodAsyncBaseBind(DbusBindedAsync):
class DbusMethodAsyncBaseBind(
DbusBindedAsync,
Generic[P, TR],
):

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(
*args: P.args,
**kwargs: P.kwargs,
) -> Coroutine[Any, Any, TR]:
raise NotImplementedError


class DbusMethodAsyncProxyBind(DbusMethodAsyncBaseBind):
class DbusMethodAsyncProxyBind(DbusMethodAsyncBaseBind[P, TR]):
def __init__(
self,
dbus_method: DbusMethodAsync,
dbus_method: DbusMethodAsync[P, TR],
proxy_meta: DbusRemoteObjectMeta,
):
self.dbus_method = dbus_method
Expand All @@ -111,7 +134,7 @@ async def _dbus_async_call(self, call_message: SdBusMessage) -> Any:
async def _no_reply() -> None:
return None

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
bus = self.proxy_meta.attached_bus
dbus_method = self.dbus_method

Expand Down Expand Up @@ -145,18 +168,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._dbus_async_call(new_call_message)


class DbusMethodAsyncLocalBind(DbusMethodAsyncBaseBind):
class DbusMethodAsyncLocalBind(DbusMethodAsyncBaseBind[P, TR]):
def __init__(
self,
dbus_method: DbusMethodAsync,
dbus_method: DbusMethodAsync[P, TR],
local_object: DbusInterfaceBaseAsync,
):
self.dbus_method = dbus_method
self.local_object_ref = weak_ref(local_object)

self.__doc__ = dbus_method.__doc__

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Any:
local_object = self.local_object_ref()
if local_object is None:
raise RuntimeError("Local object no longer exists!")
Expand Down Expand Up @@ -239,14 +262,24 @@ def dbus_method_async(
result_args_names: Optional[Sequence[str]] = None,
input_args_names: Optional[Sequence[str]] = None,
method_name: Optional[str] = None,
) -> Callable[[T], T]:
) -> Callable[
[Callable[
Concatenate[TDBI, P],
Coroutine[Any, Any, TR]]],
DbusMethodAsync[P, TR],
]:

assert not isinstance(input_signature, FunctionType), (
"Passed function to decorator directly. "
"Did you forget () round brackets?"
)

def dbus_method_decorator(original_method: T) -> T:
def dbus_method_decorator(
original_method: Callable[
Concatenate[TDBI, P],
Coroutine[Any, Any, TR]
],
) -> DbusMethodAsync[P, TR]:
assert isinstance(original_method, FunctionType)
assert iscoroutinefunction(original_method), (
"Expected coroutine function. ",
Expand All @@ -262,15 +295,25 @@ def dbus_method_decorator(original_method: T) -> T:
flags=flags,
)

return cast(T, new_wrapper)
return new_wrapper

return dbus_method_decorator


def dbus_method_async_override() -> Callable[[T], T]:
def dbus_method_async_override(
) -> Callable[
[Callable[
Concatenate[TDBI, P],
Coroutine[Any, Any, TR]]],
DbusMethodAsync[P, TR],
]:

def new_decorator(
new_function: T) -> T:
return cast(T, DbusOverload(new_function))
new_function: Callable[
Concatenate[TDBI, P],
Coroutine[Any, Any, TR]
],
) -> DbusMethodAsync[P, TR]:
return cast(DbusMethodAsync[P, TR], DbusOverload(new_function))

return new_decorator
2 changes: 1 addition & 1 deletion test/test_sdbus_async_bad_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_bad_subclass(self) -> None:
with self.assertRaises(TypeError):

class TestInheritence(TestInterface):
async def test_int(self) -> int:
async def test_int(self) -> int: # type: ignore[override]
return 2

with self.assertRaises(TypeError):
Expand Down
7 changes: 3 additions & 4 deletions test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ async def check_async_element_class_access_typing() -> None:

test_list: List[str] = []

# TODO: Fix dbus async method typing
# test_list.append(
# TestTypingAsync.get_str_list_method.method_name
# )
test_list.append(
TestTypingAsync.get_str_list_method.method_name
)
test_list.append(
TestTypingAsync.str_list_property.property_name
)
Expand Down

1 comment on commit 72e9d52

@igo95862
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will have to be reverted because Jedi can't handle ParamSpec yet: davidhalter/jedi#1812

Maybe if there was somekind of JEDI_CHECKING variable I could mask the D-Bus method with original method.

Please sign in to comment.