Skip to content

Commit

Permalink
Add IsolatedDbusTestCase.assertDbusSignalEmits
Browse files Browse the repository at this point in the history
When used in the `async with` block it will assert that the signal
gets emitted at least once. The with block returns the DbusSignalRecorder
object which can be used to test the data emitted by signal.

API is not final and is not documented for now. It should be
finalized by the time 0.12.0 version releases.
  • Loading branch information
igo95862 committed Feb 4, 2024
1 parent 25661bd commit 6010f26
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 34 deletions.
20 changes: 15 additions & 5 deletions src/sdbus/dbus_proxy_async_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Type

from .dbus_proxy_async_interface_base import DbusInterfaceBaseAsync
from .sd_bus_internals import SdBus, SdBusMessage
from .sd_bus_internals import SdBus, SdBusMessage, SdBusSlot


T = TypeVar('T')
Expand Down Expand Up @@ -116,14 +116,24 @@ def __init__(

self.__doc__ = dbus_signal.__doc__

async def catch(self) -> AsyncIterator[T]:
message_queue: Queue[SdBusMessage] = Queue()

match_slot = await self.proxy_meta.attached_bus.match_signal_async(
async def _register_match_slot(
self,
bus: SdBus,
callback: Callable[[SdBusMessage], Any],
) -> SdBusSlot:
return await bus.match_signal_async(
self.proxy_meta.service_name,
self.proxy_meta.object_path,
self.dbus_signal.interface_name,
self.dbus_signal.signal_name,
callback,
)

async def catch(self) -> AsyncIterator[T]:
message_queue: Queue[SdBusMessage] = Queue()

match_slot = await self._register_match_slot(
self.proxy_meta.attached_bus,
message_queue.put_nowait,
)

Expand Down
148 changes: 146 additions & 2 deletions src/sdbus/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
from __future__ import annotations

from asyncio import Event, TimeoutError, wait_for
from os import environ, kill
from pathlib import Path
from signal import SIGTERM
Expand All @@ -27,11 +28,33 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING
from unittest import IsolatedAsyncioTestCase
from weakref import ref as weak_ref

from sdbus import sd_bus_open_user, set_default_bus
from .dbus_common_funcs import set_default_bus
from .dbus_proxy_async_signal import (
DbusSignalAsyncLocalBind,
DbusSignalAsyncProxyBind,
)
from .sd_bus_internals import SdBusMessage, sd_bus_open_user

if TYPE_CHECKING:
from typing import ClassVar
from typing import (
Any,
AsyncContextManager,
ClassVar,
List,
Optional,
TypeVar,
Union,
)

from .dbus_proxy_async_signal import (
DbusSignalAsync,
DbusSignalAsyncBaseBind,
)
from .sd_bus_internals import SdBus, SdBusSlot

T = TypeVar('T')


dbus_config = '''
Expand All @@ -49,6 +72,114 @@
'''


class DbusSignalRecorderBase:
def __init__(
self,
testcase: IsolatedDbusTestCase,
timeout: Union[int, float],
):
self._testcase = testcase
self._timeout = timeout
self._captured_data: List[Any] = []
self._ready_event = Event()
self._callback_method = self._callback

async def start(self) -> None:
raise NotImplementedError

async def stop(self) -> None:
raise NotImplementedError

async def __aenter__(self) -> DbusSignalRecorderBase:
raise NotImplementedError

async def __aexit__(
self,
exc_type: Any,
exc_value: Any,
traceback: Any,
) -> None:
if exc_type is not None:
return

try:
await wait_for(self._ready_event.wait(), timeout=self._timeout)
except TimeoutError:
raise AssertionError("D-Bus signal not captured.") from None

def _callback(self, data: Any) -> None:
if isinstance(data, SdBusMessage):
data = data.get_contents()

self._captured_data.append(data)
self._ready_event.set()

def assert_emitted_once_with(self, data: Any) -> None:
captured_signals_num = len(self._captured_data)
if captured_signals_num != 1:
raise AssertionError(
f"Expected one captured signal got {captured_signals_num}"
)

self._testcase.assertEqual(self._captured_data[0], data)


class DbusSignalRecorderRemote(DbusSignalRecorderBase):
def __init__(
self,
testcase: IsolatedDbusTestCase,
timeout: Union[int, float],
bus: SdBus,
remote_signal: DbusSignalAsyncProxyBind[Any],
):
super().__init__(testcase, timeout)
self._bus = bus
self._match_slot: Optional[SdBusSlot] = None
self._remote_signal = remote_signal

async def __aenter__(self) -> DbusSignalRecorderBase:
self._match_slot = await self._remote_signal._register_match_slot(
self._bus,
self._callback_method,
)

return self

async def __aexit__(
self,
exc_type: Any,
exc_value: Any,
traceback: Any,
) -> None:
try:
await super().__aexit__(exc_type, exc_value, traceback)
finally:
if self._match_slot is not None:
self._match_slot.close()


class DbusSignalRecorderLocal(DbusSignalRecorderBase):
def __init__(
self,
testcase: IsolatedDbusTestCase,
timeout: Union[int, float],
local_signal: DbusSignalAsyncLocalBind[Any],
):
super().__init__(testcase, timeout)
self._local_signal_ref: weak_ref[DbusSignalAsync[Any]] = (
weak_ref(local_signal.dbus_signal)
)

async def __aenter__(self) -> DbusSignalRecorderBase:
local_signal = self._local_signal_ref()

if local_signal is None:
raise RuntimeError

local_signal.local_callbacks.add(self._callback_method)
return self


class IsolatedDbusTestCase(IsolatedAsyncioTestCase):
dbus_executable_name: ClassVar[str] = 'dbus-daemon'

Expand Down Expand Up @@ -95,3 +226,16 @@ def tearDown(self) -> None:
environ.pop('DBUS_SESSION_BUS_ADDRESS')
if self.old_session_bus_address is not None:
environ['DBUS_SESSION_BUS_ADDRESS'] = self.old_session_bus_address

def assertDbusSignalEmits(
self,
signal: DbusSignalAsyncBaseBind[Any],
timeout: Union[int, float] = 1,
) -> AsyncContextManager[DbusSignalRecorderBase]:

if isinstance(signal, DbusSignalAsyncLocalBind):
return DbusSignalRecorderLocal(self, timeout, signal)
elif isinstance(signal, DbusSignalAsyncProxyBind):
return DbusSignalRecorderRemote(self, timeout, self.bus, signal)
else:
raise TypeError("Unknown or unsupported signal class.")
60 changes: 33 additions & 27 deletions test/test_sdbus_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
from __future__ import annotations

from asyncio import Event, get_running_loop, sleep, wait, wait_for
from asyncio import Event, get_running_loop, sleep, wait_for
from asyncio.subprocess import create_subprocess_exec
from typing import TYPE_CHECKING, cast
from unittest import SkipTest
Expand Down Expand Up @@ -51,8 +51,7 @@
)

if TYPE_CHECKING:
from asyncio import Task
from typing import Any, Tuple
from typing import Tuple

from sdbus.dbus_proxy_async_interfaces import (
DBUS_PROPERTIES_CHANGED_TYPING,
Expand Down Expand Up @@ -430,21 +429,30 @@ async def test_properties(self) -> None:
async def test_signal(self) -> None:
test_object, test_object_connection = initialize_object()

loop = get_running_loop()

test_tuple = ('sgfsretg', 'asd')

aiter_dbus: Any = test_object_connection.test_signal.__aiter__()
anext_dbus: Task[Any] = loop.create_task(aiter_dbus.__anext__())
aiter_local: Any = test_object.test_signal.__aiter__()
anext_local: Task[Any] = loop.create_task(aiter_local.__anext__())

loop.call_later(0.1, test_object.test_signal.emit, test_tuple)
async with (
self.assertDbusSignalEmits(
test_object.test_signal
) as local_signals_record,
self.assertDbusSignalEmits(
test_object_connection.test_signal
) as remote_signals_record
):
test_object.test_signal.emit(test_tuple)

await wait((anext_dbus, anext_local), timeout=1)
async with (
self.assertDbusSignalEmits(
test_object.test_signal
) as local_signals_record,
self.assertDbusSignalEmits(
test_object_connection.test_signal
) as remote_signals_record
):
test_object.test_signal.emit(test_tuple)

self.assertEqual(test_tuple, anext_dbus.result())
self.assertEqual(test_tuple, anext_local.result())
local_signals_record.assert_emitted_once_with(test_tuple)
remote_signals_record.assert_emitted_once_with(test_tuple)

async def test_signal_catch_anywhere(self) -> None:
test_object, test_object_connection = initialize_object()
Expand Down Expand Up @@ -731,20 +739,18 @@ async def test_properties_get_all_dict(self) -> None:
async def test_empty_signal(self) -> None:
test_object, test_object_connection = initialize_object()

loop = get_running_loop()

aiter_dbus: Any = test_object_connection.empty_signal.__aiter__()
anext_dbus: Task[Any] = loop.create_task(aiter_dbus.__anext__())
aiter_local: Any = test_object.empty_signal.__aiter__()
anext_local: Task[Any] = loop.create_task(aiter_local.__anext__())

loop.call_later(0.1, test_object.empty_signal.emit, None)

await wait((anext_dbus, anext_local), timeout=1)

self.assertIsNone(anext_dbus.result())
async with (
self.assertDbusSignalEmits(
test_object.empty_signal
) as local_signals_record,
self.assertDbusSignalEmits(
test_object_connection.empty_signal
) as remote_signals_record
):
test_object.empty_signal.emit(None)

self.assertIsNone(anext_local.result())
local_signals_record.assert_emitted_once_with(None)
remote_signals_record.assert_emitted_once_with(None)

async def test_properties_changed(self) -> None:
test_object, test_object_connection = initialize_object()
Expand Down

0 comments on commit 6010f26

Please sign in to comment.