Skip to content

Commit

Permalink
Expose mock directly (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
DominicOram committed May 28, 2024
1 parent 39dbff6 commit ee65c72
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 51 deletions.
5 changes: 5 additions & 0 deletions docs/how-to/write-tests-for-devices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ In addition this example also utilizes helper functions like ``assert_reading``
:pyobject: test_sensor_reading_shows_value


Given that the mock signal holds a ``unittest.mock.Mock`` object you can retrieve this object and assert that the device has been set correctly using ``get_mock_put``. You are also free to use any other behaviour that ``unittest.mock.Mock`` provides, such as in this example which sets the parent of the mock to allow ordering across signals to be asserted:

.. literalinclude:: ../../tests/epics/demo/test_demo.py
:pyobject: test_retrieve_mock_and_assert

There are several other test utility functions:

Use ``callback_on_mock_put``, for hooking in logic when a mock value changes (e.g. because someone puts to it). This can be called directly, or used as a context, with the callbacks ending after exit.
Expand Down
8 changes: 3 additions & 5 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
walk_rw_signals,
)
from .flyer import HardwareTriggeredFlyable, TriggerLogic
from .mock_signal_backend import (
MockSignalBackend,
)
from .mock_signal_backend import MockSignalBackend
from .mock_signal_utils import (
assert_mock_put_called_with,
callback_on_mock_put,
get_mock_put,
mock_puts_blocked,
reset_mock_put_calls,
set_mock_put_proceeds,
Expand Down Expand Up @@ -70,7 +68,7 @@
)

__all__ = [
"assert_mock_put_called_with",
"get_mock_put",
"callback_on_mock_put",
"mock_puts_blocked",
"set_mock_values",
Expand Down
13 changes: 5 additions & 8 deletions src/ophyd_async/core/mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from functools import cached_property
from typing import Optional, Type
from typing import Callable, Optional, Type
from unittest.mock import Mock

from bluesky.protocols import Descriptor, Reading
Expand All @@ -10,7 +10,7 @@
from ophyd_async.core.utils import DEFAULT_TIMEOUT, ReadingValueCallback, T


class MockSignalBackend(SignalBackend):
class MockSignalBackend(SignalBackend[T]):
def __init__(
self,
datatype: Optional[Type[T]] = None,
Expand All @@ -31,11 +31,11 @@ def __init__(

if not isinstance(self.initial_backend, SoftSignalBackend):
# If the backend is a hard signal backend, or not provided,
# then we create a soft signal to mimick it
# then we create a soft signal to mimic it

self.soft_backend = SoftSignalBackend(datatype=datatype)
else:
self.soft_backend = initial_backend
self.soft_backend = self.initial_backend

def source(self, name: str) -> str:
if self.initial_backend:
Expand All @@ -47,7 +47,7 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:

@cached_property
def put_mock(self) -> Mock:
return Mock(name="put")
return Mock(name="put", spec=Callable)

@cached_property
def put_proceeds(self) -> asyncio.Event:
Expand All @@ -65,9 +65,6 @@ async def put(self, value: Optional[T], wait=True, timeout=None):
def set_value(self, value: T):
self.soft_backend.set_value(value)

async def get_descriptor(self, source: str) -> Descriptor:
return await self.soft_backend.get_descriptor(source)

async def get_reading(self) -> Reading:
return await self.soft_backend.get_reading()

Expand Down
20 changes: 10 additions & 10 deletions src/ophyd_async/core/mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Callable, Iterable, Iterator, List
from unittest.mock import ANY, Mock
from typing import Any, Callable, Iterable
from unittest.mock import Mock

from ophyd_async.core.signal import Signal
from ophyd_async.core.utils import T
Expand All @@ -22,7 +22,7 @@ def set_mock_value(signal: Signal[T], value: T):
backend.set_value(value)


def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):
def set_mock_put_proceeds(signal: Signal, proceeds: bool):
"""Allow or block a put with wait=True from proceeding"""
backend = _get_mock_signal_backend(signal)

Expand All @@ -33,17 +33,17 @@ def set_mock_put_proceeds(signal: Signal[T], proceeds: bool):


@asynccontextmanager
async def mock_puts_blocked(*signals: List[Signal]):
async def mock_puts_blocked(*signals: Signal):
for signal in signals:
set_mock_put_proceeds(signal, False)
yield
for signal in signals:
set_mock_put_proceeds(signal, True)


def assert_mock_put_called_with(signal: Signal, value: Any, wait=ANY, timeout=ANY):
backend = _get_mock_signal_backend(signal)
backend.put_mock.assert_called_with(value, wait=wait, timeout=timeout)
def get_mock_put(signal: Signal) -> Mock:
"""Get the mock associated with the put call on the signal."""
return _get_mock_signal_backend(signal).put_mock


def reset_mock_put_calls(signal: Signal):
Expand Down Expand Up @@ -79,15 +79,15 @@ def __next__(self):
return next_value

def __del__(self):
if self.require_all_consumed and self.index != len(self.values):
if self.require_all_consumed and self.index != len(list(self.values)):
raise AssertionError("Not all values have been consumed.")


def set_mock_values(
signal: Signal,
values: Iterable[Any],
require_all_consumed: bool = False,
) -> Iterator[Any]:
) -> _SetValuesIterator:
"""Iterator to set a signal to a sequence of values, optionally repeating the
sequence.
Expand Down Expand Up @@ -127,7 +127,7 @@ def _unset_side_effect_cm(put_mock: Mock):
put_mock.side_effect = None


def callback_on_mock_put(signal: Signal, callback: Callable[[T], None]):
def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
"""For setting a callback when a backend is put to.
Can either be used in a context, with the callback being
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/core/signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class SignalBackend(Generic[T]):

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
def source(name: str) -> str:
def source(self, name: str) -> str:
"""Return source of signal. Signals may pass a name to the backend, which can be
used or discarded."""

Expand Down
59 changes: 32 additions & 27 deletions tests/core/test_mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import asyncio
import re
from itertools import repeat
from unittest.mock import MagicMock, call
from unittest.mock import ANY, MagicMock, call

import pytest

from ophyd_async.core import MockSignalBackend, SignalRW
from ophyd_async.core.device import Device, DeviceCollector
from ophyd_async.core.mock_signal_utils import (
assert_mock_put_called_with,
callback_on_mock_put,
get_mock_put,
mock_puts_blocked,
reset_mock_put_calls,
set_mock_put_proceeds,
set_mock_value,
set_mock_values,
)
from ophyd_async.core.signal import (
SignalW,
soft_signal_r_and_setter,
soft_signal_rw,
)
from ophyd_async.core.signal import SignalW, soft_signal_r_and_setter, soft_signal_rw
from ophyd_async.core.soft_signal_backend import SoftSignalBackend
from ophyd_async.epics.signal.signal import epics_signal_r, epics_signal_rw

Expand All @@ -31,6 +27,7 @@ async def test_mock_signal_backend(connect_mock_mode):
# If mock is false it will be handled like a normal signal, otherwise it will
# initalize a new backend from the one in the line above
await mock_signal.connect(mock=connect_mock_mode)
assert isinstance(mock_signal._backend, MockSignalBackend)

assert await mock_signal._backend.get_value() == ""
await mock_signal._backend.put("test")
Expand Down Expand Up @@ -74,6 +71,8 @@ async def test_set_mock_put_proceeds():
mock_signal = SignalW(SoftSignalBackend(str))
await mock_signal.connect(mock=True)

assert isinstance(mock_signal._backend, MockSignalBackend)

assert mock_signal._backend.put_proceeds.is_set() is True

set_mock_put_proceeds(mock_signal, False)
Expand All @@ -95,6 +94,7 @@ async def test_set_mock_put_proceeds_timeout():
async def test_put_proceeds_timeout():
mock_signal = SignalW(SoftSignalBackend(str))
await mock_signal.connect(mock=True)
assert isinstance(mock_signal._backend, MockSignalBackend)

assert mock_signal._backend.put_proceeds.is_set() is True

Expand All @@ -112,14 +112,14 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
set_mock_value(signal, 10)
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
assert_mock_put_called_with(signal, 10)
get_mock_put(signal).assert_called_once_with(10)
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
async with mock_puts_blocked(signal, 10):
async with mock_puts_blocked(signal):
...
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
with callback_on_mock_put(signal, 10):
with callback_on_mock_put(signal, lambda x: _):
...
exc_msgs.append(str(exc.value))
with pytest.raises(AssertionError) as exc:
Expand All @@ -137,16 +137,13 @@ async def test_mock_utils_throw_error_if_backend_isnt_mock_signal_backend():
)


async def test_assert_mock_put_called_with():
async def test_get_mock_put():
mock_signal = epics_signal_rw(str, "READ_PV", "WRITE_PV", name="mock_name")
await mock_signal.connect(mock=True)
await mock_signal.set("test_value", wait=True, timeout=100)

# can leave out kwargs
assert_mock_put_called_with(mock_signal, "test_value")
assert_mock_put_called_with(mock_signal, "test_value", wait=True)
assert_mock_put_called_with(mock_signal, "test_value", timeout=100)
assert_mock_put_called_with(mock_signal, "test_value", wait=True, timeout=100)
mock = get_mock_put(mock_signal)
mock.assert_called_once_with("test_value", wait=True, timeout=100)

def err_text(text, wait, timeout):
return (
Expand All @@ -162,7 +159,7 @@ def err_text(text, wait, timeout):
("test_value", False, 0), # wait and timeout wrong
]:
with pytest.raises(AssertionError) as exc:
assert_mock_put_called_with(mock_signal, text, wait=wait, timeout=timeout)
mock.assert_called_once_with(text, wait=wait, timeout=timeout)
for err_substr in err_text(text, wait, timeout):
assert err_substr in str(exc.value)

Expand Down Expand Up @@ -216,10 +213,8 @@ async def test_callback_on_mock_put_no_ctx():
mock_signal = SignalRW(SoftSignalBackend(float))
await mock_signal.connect(mock=True)
calls = []
(
callback_on_mock_put(
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
),
callback_on_mock_put(
mock_signal, lambda *args, **kwargs: calls.append({**kwargs, "_args": args})
)
await mock_signal.set(10.0)
assert calls == [
Expand Down Expand Up @@ -249,16 +244,16 @@ def some_function_without_kwargs(arg):
async def test_set_mock_values(mock_signals):
signal1, signal2 = mock_signals

await signal2.get_value() == "first_value"
assert await signal2.get_value() == "first_value"
for value_set in set_mock_values(signal1, ["second_value", "third_value"]):
assert await signal1.get_value() == value_set

iterator = set_mock_values(signal2, ["second_value", "third_value"])
await signal2.get_value() == "first_value"
assert await signal2.get_value() == "first_value"
next(iterator)
await signal2.get_value() == "second_value"
assert await signal2.get_value() == "second_value"
next(iterator)
await signal2.get_value() == "third_value"
assert await signal2.get_value() == "third_value"


async def test_set_mock_values_exhausted_passes(mock_signals):
Expand Down Expand Up @@ -300,10 +295,10 @@ async def test_set_mock_values_exhausted_fails(mock_signals):
async def test_reset_mock_put_calls(mock_signals):
signal1, signal2 = mock_signals
await signal1.set("test_value", wait=True, timeout=1)
assert_mock_put_called_with(signal1, "test_value")
get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY)
reset_mock_put_calls(signal1)
with pytest.raises(AssertionError) as exc:
assert_mock_put_called_with(signal1, "test_value")
get_mock_put(signal1).assert_called_with("test_value", wait=ANY, timeout=ANY)
# Replacing spaces because they change between runners
# (e.g the github actions runner has more)
assert str(exc.value).replace(" ", "").replace("\n", "") == (
Expand Down Expand Up @@ -350,3 +345,13 @@ async def set(self):
assert await signal.get_value() == 0
backend_put(100)
assert await signal.get_value() == 100


async def test_when_put_mock_called_with_typo_then_fails_but_calling_directly_passes():
mock_signal = SignalRW(SoftSignalBackend(int))
await mock_signal.connect(mock=True)
assert isinstance(mock_signal._backend, MockSignalBackend)
mock = mock_signal._backend.put_mock
with pytest.raises(AttributeError):
mock.asssert_called_once() # Note typo here is deliberate!
mock()
24 changes: 24 additions & 0 deletions tests/epics/demo/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
assert_reading,
assert_value,
callback_on_mock_put,
get_mock_put,
set_mock_value,
)
from ophyd_async.epics import demo
Expand Down Expand Up @@ -176,6 +177,29 @@ async def test_sensor_reading_shows_value(mock_sensor: demo.Sensor):
)


async def test_retrieve_mock_and_assert(mock_mover: demo.Mover):
mover_setpoint_mock = get_mock_put(mock_mover.setpoint)
await mock_mover.setpoint.set(10)
mover_setpoint_mock.assert_called_once_with(10, wait=ANY, timeout=ANY)

# Assert that velocity is set before move
mover_velocity_mock = get_mock_put(mock_mover.velocity)

parent_mock = Mock()
parent_mock.attach_mock(mover_setpoint_mock, "setpoint")
parent_mock.attach_mock(mover_velocity_mock, "velocity")

await mock_mover.velocity.set(100)
await mock_mover.setpoint.set(67)

parent_mock.assert_has_calls(
[
call.velocity(100, wait=True, timeout=ANY),
call.setpoint(67, wait=True, timeout=ANY),
]
)


async def test_read_mover(mock_mover: demo.Mover):
await mock_mover.stage()
assert (await mock_mover.read())["mock_mover"]["value"] == 0.0
Expand Down

0 comments on commit ee65c72

Please sign in to comment.