Skip to content

Commit

Permalink
Add representers for all channel access types (#66)
Browse files Browse the repository at this point in the history
* Add representers for all channel access types

* Black format

* Ignore type hints

---------

Co-authored-by: Abigail Emery <abigail.emery@diamond.ac.uk>
  • Loading branch information
rosesyrett and abbiemery committed Nov 13, 2023
1 parent f4213a4 commit 2570efd
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
31 changes: 30 additions & 1 deletion src/ophyd_async/core/device_save_loader.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
from enum import Enum
from typing import Any, Dict, Generator, List, Optional, Sequence
from functools import partial
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union

import numpy as np
import numpy.typing as npt
import yaml
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from epicscorelibs.ca.dbr import ca_array, ca_float, ca_int, ca_str

from .device import Device
from .signal import SignalRW

CaType = Union[ca_float, ca_int, ca_str, ca_array]


def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.Node:
return dumper.represent_sequence(
"tag:yaml.org,2002:seq", array.tolist(), flow_style=True
)


def ca_dbr_representer(dumper: yaml.Dumper, value: CaType) -> yaml.Node:
# if it's an array, just call ndarray_representer...
represent_array = partial(ndarray_representer, dumper)

representers: Dict[CaType, Callable[[CaType], yaml.Node]] = {
ca_float: dumper.represent_float,
ca_int: dumper.represent_int,
ca_str: dumper.represent_str,
ca_array: represent_array,
}
return representers[type(value)](value)


class OphydDumper(yaml.Dumper):
def represent_data(self, data: Any) -> Any:
if isinstance(data, Enum):
Expand Down Expand Up @@ -59,6 +76,11 @@ def get_signal_values(
key: signal for key, signal in signals.items() if key not in ignore
}
selected_values = yield Msg("locate", *selected_signals.values())

# TODO: investigate wrong type hints
if isinstance(selected_values, dict):
selected_values = [selected_values] # type: ignore

assert selected_values is not None, "No signalRW's were able to be located"
named_values = {
key: value["setpoint"] for key, value in zip(selected_signals, selected_values)
Expand Down Expand Up @@ -128,7 +150,14 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:
:func:`ophyd_async.core.get_signal_values`
:func:`ophyd_async.core.load_from_yaml`
"""

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)

yaml.add_representer(ca_float, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_int, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_str, ca_dbr_representer, Dumper=yaml.Dumper)
yaml.add_representer(ca_array, ca_dbr_representer, Dumper=yaml.Dumper)

with open(save_path, "w") as file:
yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)

Expand Down
14 changes: 5 additions & 9 deletions tests/core/test_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ async def test_enum_yaml_formatting(tmp_path):
assert saved_enums == enums


async def test_save_device(device, tmp_path):
RE = RunEngine()

async def test_save_device(RE: RunEngine, device, tmp_path):
# Populate fake device with PV's...
await device.child1.sig1.set("test_string")
# Test tables PVs
Expand Down Expand Up @@ -127,9 +125,8 @@ def save_my_device():
assert yaml_content["parent_sig1"] is None


async def test_yaml_formatting(device, tmp_path):
async def test_yaml_formatting(RE: RunEngine, device, tmp_path):
file_path = path.join(tmp_path, "test_file.yaml")
RE = RunEngine()
await device.child1.sig1.set("test_string")
table_pv = {"VAL1": np.array([1, 2, 3, 4, 5]), "VAL2": np.array([6, 7, 8, 9, 10])}
await device.child2.sig1.set(table_pv)
Expand All @@ -145,9 +142,9 @@ async def test_yaml_formatting(device, tmp_path):
assert file.read() == expected


async def test_load_from_yaml(device, tmp_path):
async def test_load_from_yaml(RE: RunEngine, device, tmp_path):
file_path = path.join(tmp_path, "test_file.yaml")
RE = RunEngine()

array = np.array([1, 1, 1, 1, 1])
await device.child1.sig1.set("initial_string")
await device.child2.sig1.set(array)
Expand All @@ -158,9 +155,8 @@ async def test_load_from_yaml(device, tmp_path):
assert np.array_equal(values[1]["child2.sig1"], array)


async def test_set_signal_values_restores_value(device, tmp_path):
async def test_set_signal_values_restores_value(RE: RunEngine, device, tmp_path):
file_path = path.join(tmp_path, "test_file.yaml")
RE = RunEngine()

await device.child1.sig1.set("initial_string")
await device.child2.sig1.set(np.array([1, 1, 1, 1, 1]))
Expand Down
21 changes: 20 additions & 1 deletion tests/epics/test_signals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import os
import random
import re
import string
Expand All @@ -17,14 +18,27 @@
from aioca import purge_channel_caches
from bluesky.protocols import Reading

from ophyd_async.core import NotConnected, SignalBackend, T, get_dtype
from ophyd_async.core import (
NotConnected,
SignalBackend,
T,
get_dtype,
load_from_yaml,
save_to_yaml,
)
from ophyd_async.epics.signal._epics_transport import EpicsTransport
from ophyd_async.epics.signal.signal import _make_backend

RECORDS = str(Path(__file__).parent / "test_records.db")
PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12))


@pytest.fixture
def _ensure_removed():
yield
os.remove("test.yaml")


@dataclass
class IOC:
process: subprocess.Popen
Expand Down Expand Up @@ -183,6 +197,7 @@ def waveform_d(value):
],
)
async def test_backend_get_put_monitor(
_ensure_removed: None,
ioc: IOC,
datatype: Type[T],
suffix: str,
Expand Down Expand Up @@ -210,6 +225,10 @@ async def test_backend_get_put_monitor(
ioc, suffix, descriptor(put_value), put_value, initial_value, datatype=None
)

save_to_yaml([{"test": put_value}], "test.yaml")
loaded = load_from_yaml("test.yaml")
assert np.all(loaded[0]["test"] == put_value)


async def test_bool_conversion_of_enum(ioc: IOC) -> None:
await assert_monitor_then_put(
Expand Down

0 comments on commit 2570efd

Please sign in to comment.