From 0effff4b3e8119ac8d6503688a6f1a4e5a20855e Mon Sep 17 00:00:00 2001 From: appel_c Date: Thu, 23 Apr 2026 12:59:05 +0200 Subject: [PATCH 01/13] feat: Add beamline state machine, and AggregatedState --- bec_lib/bec_lib/bl_state_machine.py | 132 ++++++++++++++ bec_lib/bec_lib/bl_states.py | 270 +++++++++++++++++++++++++++- bec_lib/bec_lib/client.py | 3 + state_config.yaml | 25 +++ 4 files changed, 429 insertions(+), 1 deletion(-) create mode 100644 bec_lib/bec_lib/bl_state_machine.py create mode 100644 state_config.yaml diff --git a/bec_lib/bec_lib/bl_state_machine.py b/bec_lib/bec_lib/bl_state_machine.py new file mode 100644 index 000000000..d808b7a33 --- /dev/null +++ b/bec_lib/bec_lib/bl_state_machine.py @@ -0,0 +1,132 @@ +""" +Module for managing aggregated beamline states based on configuration files. + +Example of the YAML configuration file: +``` yaml +alignment: + devices: + samx: + readback: + value: 0 + abs_tol: 0.1 + measurement: + devices: + samx: + readback: + value: 19 + abs_tol: 0.1 + velocity: + value: 5 + abs_tol: 0.1 + samy: + readback: + value: 0 + abs_tol: 0.1 + test: + devices: + samy: + readback: + value: 0 + abs_tol: 0.1 +``` + +""" + +from __future__ import annotations + +import yaml + +from bec_lib.bl_state_manager import BeamlineStateManager +from bec_lib.bl_states import AggregatedStateConfig + + +class BeamlineStateMachine: + + def __init__(self, manager: BeamlineStateManager) -> None: + self._manager = manager + self._configs: dict[str, AggregatedStateConfig] = {} + + def load_from_config( + self, name: str, config_path: str | None, config_dict: dict | None = None + ) -> None: + """ + Load an aggregated state configuration from a YAML file or a dictionary. If None or both are provided, + and error will be raised. + + Args: + name (str): The name of the aggregated state to create. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. + + Example of the YAML configuration file: + ``` yaml + alignment: + devices: + samx: + readback: + value: 0 + abs_tol: 0.1 + measurement: + devices: + samx: + readback: + value: 19 + abs_tol: 0.1 + velocity: + value: 5 + abs_tol: 0.1 + samy: + readback: + value: 0 + abs_tol: 0.1 + test: + devices: + samy: + readback: + value: 0 + abs_tol: 0.1 + ``` + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager.add(config) + + def update_config( + self, + name: str, + config_path: str | None, + config_dict: dict | AggregatedStateConfig | None = None, + ) -> None: + """ + Update an existing aggregated state configuration from a YAML file or a dictionary. + If None or both are provided, and error will be raised. + It will update the state based on the configuration and update it in the state_manager. + + Args: + name (str): The name of the aggregated state to update. + config_path (str | None): The path to the YAML configuration file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will + be used instead of loading from a file. + """ + self._check_inputs(config_path=config_path, config_dict=config_dict) + # pylint: disable=protected-access + if name not in self._manager._states: + raise ValueError(f"Configuration for name {name} not found.") + if config_path: + with open(config_path, "r", encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + # Load the new state + config = AggregatedStateConfig(name=name, states=config_dict) + self._manager.update(config) + + def _check_inputs( + self, config_path: str | None, config_dict: dict | AggregatedStateConfig | None + ) -> None: + if (config_path is None and config_dict is None) or ( + config_path is not None and config_dict is not None + ): + raise ValueError("Either config_path or config_dict must be provided, but not both.") diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 8d32959c2..b567578cb 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -1,11 +1,15 @@ +"""Module defining beamline states and their evaluation logic.""" + from __future__ import annotations import functools import keyword import traceback from abc import ABC, abstractmethod -from typing import Callable, ClassVar, Generic, Type, TypeVar, cast +from dataclasses import dataclass +from typing import Any, Callable, ClassVar, Generic, Literal, Type, TypeVar, cast +import yaml from pydantic import BaseModel, field_validator, model_validator from bec_lib import messages @@ -121,6 +125,28 @@ class DeviceWithinLimitsStateConfig(DeviceStateConfig): tolerance: float = 0.1 +class SignalConfig(BaseModel): + """Target value for a signal inside a named machine state.""" + + value: float | int | str | bool + abs_tol: float = 0.0 + + +class SubDeviceStates(BaseModel): + + devices: dict[str, dict[str, SignalConfig]] + + +class AggregatedStateConfig(BeamlineStateConfig): + """ + Configuration for a state machine driven by multiple device signals. + """ + + state_type: ClassVar[str] = "AggregatedState" + + states: dict[str, SubDeviceStates] + + C = TypeVar("C", bound=BeamlineStateConfig) D = TypeVar("D", bound=DeviceStateConfig) @@ -322,6 +348,248 @@ def _update_device_state(self, msg_obj: MessageObject) -> messages.BeamlineState return self.evaluate(msg) +SignalSource = TypeVar("SignalSource", bound=Literal["readback", "configuration", "limits"]) + + +@dataclass(frozen=True) +class _ResolvedStateSignal: + label: str + device_name: str + signal_name: str + expected_value: float | int | str | bool + abs_tolerance: float | int + source: SignalSource + + +class AggregatedState(BeamlineState[AggregatedStateConfig]): + """Beamline state that infers the current named state from multiple device signals.""" + + CONFIG_CLASS = AggregatedStateConfig + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Mapping from signal updates to affected state labels, used for efficient evaluation when a signal update is received + self._signal_info_to_labels: dict[tuple[str, SignalSource, str], set[str]] = {} + # Mapping from state labels to the list of signal requirements that define that state + self._requirements_for_label: dict[str, list[_ResolvedStateSignal]] = {} + # Set of subscriptions to signal updates + self._subscriptions: set[tuple[str, SignalSource]] = set() + # Cache of the latest signal values + self._signal_value_cache: dict[tuple[str, SignalSource, str], Any] = {} + # List of currently active state labels + self._current_labels: list[str] = [] + + @staticmethod + def _endpoint(device: str, source: SignalSource): + """Static method to get the appropriate message endpoint based on the signal source.""" + if source == "readback": + return MessageEndpoints.device_readback(device) + if source == "configuration": + return MessageEndpoints.device_read_configuration(device) + if source == "limits": + return MessageEndpoints.device_limits(device) + raise ValueError( + f"Invalid signal source '{source}', please use 'readback', 'configuration', or 'limits'." + ) + + def _get_devices(self): + if self.device_manager is None: + # pylint: disable=import-outside-toplevel + from bec_lib.client import BECClient + + bec = BECClient() + return bec.device_manager.devices + return self.device_manager.devices + + def _get_signal_source(self, signal_info: dict[str, Any]) -> SignalSource: + kind_str = str(signal_info.get("kind_str", "")).lower() + if "hinted" in kind_str or "normal" in kind_str: + return "readback" + if "config" in kind_str: + return "configuration" + raise ValueError( + f"{self._error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" + ) + + def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, SignalSource]: + devices = self._get_devices() + try: + device_obj: DeviceBase = devices[device_name] + except KeyError: + raise ValueError(f"{self._error_prefix} Device '{device_name}' not found.") from None + + # Special handling for limits, as they are not regular signals. + if signal_name in ["low_limit", "low_limit_travel"]: + return "low", "limits" + if signal_name in ["high_limit", "high_limit_travel"]: + return "high", "limits" + + signal_info = None + if "." in signal_name: + try: + signal_obj = devices[signal_name] + except AttributeError: + raise ValueError( + f"{self._error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) from None + if signal_obj.parent != device_obj: + raise ValueError( + f"{self._error_prefix} Signal '{signal_name}' does not belong to device '{device_name}'." + ) + signal_component = ".".join(signal_name.split(".")[1:]) + signal_info = device_obj.root._info["signals"].get(signal_component) + else: + signal_info = device_obj.root._info["signals"].get(signal_name) + if signal_info is None: + for candidate in device_obj.root._info["signals"].values(): + if candidate.get("obj_name") == signal_name: + signal_info = candidate + break + + if signal_info is None: + raise ValueError( + f"{self._error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + ) + + obj_name = signal_info.get("obj_name") + signal_source = self._get_signal_source(signal_info) + return obj_name, signal_source + + def _build_rules(self) -> None: + self._signal_info_to_labels.clear() + self._requirements_for_label.clear() + self._subscriptions.clear() + for label, device_configs in self.config.states.items(): + state_requirements: list[_ResolvedStateSignal] = [] + for device_name, signal_configs in device_configs.devices.items(): + for signal_name, target in signal_configs.items(): + resolved_signal_name, source = self._resolve_signal(device_name, signal_name) + state_requirements.append( + _ResolvedStateSignal( + label=label, + device_name=device_name, + signal_name=resolved_signal_name, + expected_value=target.value, + abs_tolerance=target.abs_tol, + source=source, + ) + ) + self._subscriptions.add((device_name, source)) + self._signal_info_to_labels.setdefault( + (device_name, source, resolved_signal_name), set() + ).add(label) + self._requirements_for_label[label] = state_requirements + + def start(self) -> None: + if self.started: + return + + if self.connector is None: + raise RuntimeError("Redis connector is not set.") + + try: + self._build_rules() + affected_labels = self._fill_cache() + except Exception as exc: + self._handle_state_exception(exc) + + msg = self.evaluate(affected_labels=affected_labels) + if msg is not None: + self._emit_state(msg) + for device, source in self._subscriptions: + self.connector.register( + self._endpoint(device, source), + cb=self._update_aggregated_state, + device=device, + source=source, + ) + super().start() + + def _fill_cache(self) -> set[str]: + affected_labels: set[str] = set() + for device, source in self._subscriptions: + endpoint = self._endpoint(device, source) + msg = self.connector.get(endpoint) + if msg is not None: + affected_labels.update(self._cache_message(device, source, msg)) + return affected_labels + + def _cache_message( + self, device: str, source: SignalSource, msg: messages.DeviceMessage + ) -> set[str]: + affected_labels: set[str] = set() + for signal_name, signal_data in msg.signals.items(): + key = (device, source, signal_name) + labels = self._signal_info_to_labels.get(key) + if labels is None: # signal not relevant for any state + continue + self._signal_value_cache[key] = signal_data.get("value") + affected_labels.update(labels) + return affected_labels + + def stop(self) -> None: + if not self.started: + return + if self.connector is not None: + for device, source in self._subscriptions: + self.connector.unregister( + self._endpoint(device, source), cb=self._update_aggregated_state + ) + super().stop() + + def _update_aggregated_state( + self, msg_obj: MessageObject, device: str, source: SignalSource, **_kwargs + ) -> None: + try: + msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage + affected_labels = self._cache_message(device, source, msg) + if affected_labels: + msg = self.evaluate(affected_labels=affected_labels) + if msg is not None: + self._emit_state(msg) + except Exception as exc: + self._handle_state_exception(exc) + + def evaluate( + self, affected_labels: set[str] | None = None + ) -> messages.BeamlineStateMessage | None: + if affected_labels is None: + return None + # We need to always extend the affected labels with the current labels, + # as the signal that updated might be not relevant for the currently active state, + # but the state should still be checked for validity. + affected_labels.update(self._current_labels) + matching_labels = [label for label in affected_labels if self._label_matches(label)] + if matching_labels: + self._current_labels = matching_labels + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="valid", label="|".join(matching_labels) + ) + return state_msg + + self._current_labels = [] + state_msg = messages.BeamlineStateMessage( + name=self.config.name, status="invalid", label="No matching state" + ) + return state_msg + + def _label_matches(self, label: str) -> bool: + requirements = self._requirements_for_label.get(label, []) + return bool(requirements) and all( + self._requirement_matches(requirement) for requirement in requirements + ) + + def _requirement_matches(self, requirement: _ResolvedStateSignal) -> bool: + key = (requirement.device_name, requirement.source, requirement.signal_name) + value = self._signal_value_cache.get(key, None) + if value is None: + return False + try: + return abs(value - requirement.expected_value) <= requirement.abs_tolerance + except TypeError: + return value == requirement.expected_value + + class ShutterState(DeviceBeamlineState[DeviceStateConfig]): """ A state that checks if the shutter is open. diff --git a/bec_lib/bec_lib/client.py b/bec_lib/bec_lib/client.py index 689f41f1b..e3cbc952b 100644 --- a/bec_lib/bec_lib/client.py +++ b/bec_lib/bec_lib/client.py @@ -20,6 +20,7 @@ from bec_lib.alarm_handler import AlarmHandler, Alarms from bec_lib.bec_service import BECService +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import BeamlineStateManager from bec_lib.callback_handler import CallbackHandler, EventType from bec_lib.config_helper import ConfigHelperUser @@ -162,6 +163,7 @@ def __init__( self._username = "" self._system_user = "" self.beamline_states = None + self.state_machine = None self.messaging: MessagingContainer = None # type: ignore def __new__(cls, *args, forced=False, **kwargs): @@ -241,6 +243,7 @@ def _start_services(self): self.device_monitor = DeviceMonitorPlugin(self.connector) self._update_username() self.beamline_states = BeamlineStateManager(client=self) + self.state_machine = BeamlineStateMachine(manager=self.beamline_states) def alarms(self, severity=Alarms.WARNING): """get the next alarm with at least the specified severity""" diff --git a/state_config.yaml b/state_config.yaml new file mode 100644 index 000000000..c151cffd6 --- /dev/null +++ b/state_config.yaml @@ -0,0 +1,25 @@ +alignment: # AggregatedStateConfig -> can have different labels and for each label, different devices + devices: + samx: + readback: + value: 0 + abs_tol: 0.1 + velocity: + value: 0 + abs_tol: 0.1 +measurement: + devices: + samx: + readback: + value: 19 + abs_tol: 0.1 + samy: + readback: + value: 0 + abs_tol: 0.1 +test: + devices: + samy: + readback: + value: 0 + abs_tol: 0.1 \ No newline at end of file From aa155be31c63755c59dd2a381268b317d8090074 Mon Sep 17 00:00:00 2001 From: appel_c Date: Thu, 23 Apr 2026 17:30:51 +0200 Subject: [PATCH 02/13] test: add tests for AggregatedState --- bec_lib/bec_lib/bl_states.py | 3 +- bec_lib/tests/test_beamline_states.py | 215 ++++++++++++++++++++++++++ 2 files changed, 217 insertions(+), 1 deletion(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index b567578cb..5e60e8dfc 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -488,12 +488,13 @@ def start(self) -> None: raise RuntimeError("Redis connector is not set.") try: + msg = None self._build_rules() affected_labels = self._fill_cache() + msg = self.evaluate(affected_labels=affected_labels) except Exception as exc: self._handle_state_exception(exc) - msg = self.evaluate(affected_labels=affected_labels) if msg is not None: self._emit_state(msg) for device, source in self._subscriptions: diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index e6a936e8e..08913cfc2 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -196,6 +196,221 @@ def test_device_within_limits_state(self, connected_connector, dm_with_devices): assert state.evaluate(invalid).status == "invalid" assert state.evaluate(missing).status == "invalid" + @pytest.fixture(scope="function") + def aggregated_state_config(self): + return bl_states.AggregatedStateConfig( + name="alignment", + states={ + "alignment": { + "devices": { + "samx": { + "readback": {"value": 0, "abs_tol": 0.1}, + "velocity": {"value": 5, "abs_tol": 0.1}, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + }, + "samy": {"readback": {"value": 0, "abs_tol": 0.1}}, + } + }, + "measurement": { + "devices": { + "samx": { + "readback": {"value": 19, "abs_tol": 0.1}, + "velocity": {"value": 5, "abs_tol": 0.1}, + "low_limit_travel": {"value": -20, "abs_tol": 0.1}, + "high_limit_travel": {"value": 20, "abs_tol": 0.1}, + }, + "samy": {"readback": {"value": 2, "abs_tol": 0.1}}, + } + }, + "test": {"devices": {"samy": {"readback": {"value": 0, "abs_tol": 0.1}}}}, + }, + ) + + def test_aggregated_state_init( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + # We should now have subscriptions on samx limits, readback and read_configuration, and samy readback + info = [ + MessageEndpoints.device_readback("samx"), + MessageEndpoints.device_read_configuration("samx"), + MessageEndpoints.device_limits("samx"), + MessageEndpoints.device_readback("samy"), + ] + for endpoint in info: + assert endpoint.endpoint in state.connector._topics_cb + + def test_aggregated_state_evaluation( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + with ( + mock.patch.object(state, "evaluate", return_value=None) as evaluate, + mock.patch.object(state, "_emit_state") as emit_state, + ): + + msg_with_2_states = messages.DeviceMessage( + signals={"samx": {"value": 5.0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ) + msg_obj = MessageObject( + value=msg_with_2_states, topic=MessageEndpoints.device_readback("samx").endpoint + ) + state._update_aggregated_state(msg_obj, device="samx", source="readback") + evaluate.assert_called_once_with(affected_labels=set(["alignment", "measurement"])) + emit_state.assert_not_called() # As evaluate is mocked to return None, _emit_state should not be called + + def test_aggregated_state_evaluate( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + state._cache_message( + "samx", + "configuration", + messages.DeviceMessage( + signals={"samx_velocity": {"value": 5, "timestamp": 1.0}}, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "samx", + "limits", + messages.DeviceMessage( + signals={ + "low": {"value": -20, "timestamp": 1.0}, + "high": {"value": 20, "timestamp": 1.0}, + }, + metadata={"stream": "baseline"}, + ), + ) + state._cache_message( + "samy", + "readback", + messages.DeviceMessage( + signals={"samy": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "valid" + assert msg.label == "alignment" + assert state._current_labels == ["alignment"] + + state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx": {"value": 3, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment"}) + + assert msg.status == "invalid" + assert msg.label == "No matching state" + assert state._current_labels == [] + + def test_aggregated_state_exception_handling( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + msg = messages.DeviceMessage( + signals={"samx": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + ) + msg_obj = MessageObject(value=msg, topic=MessageEndpoints.device_readback("samx").endpoint) + + with ( + mock.patch.object( + state, "evaluate", side_effect=RuntimeError("broken state") + ) as evaluate, + mock.patch.object(connected_connector, "raise_alarm") as raise_alarm, + ): + state._update_aggregated_state(msg_obj, device="samx", source="readback") + + evaluate.assert_called_once_with(affected_labels={"alignment", "measurement"}) + raise_alarm.assert_called_once() + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + assert out[-1]["data"].status == "unknown" + assert out[-1]["data"].label == "broken state" + assert state.raised_warning is True + + def test_aggregated_state_transitions_between_labels( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + def update(device, source, signals): + msg = messages.DeviceMessage(signals=signals, metadata={"stream": "primary"}) + msg_obj = MessageObject(value=msg, topic=state._endpoint(device, source).endpoint) + state._update_aggregated_state(msg_obj, device=device, source=source) + out = connected_connector.xread( + MessageEndpoints.beamline_state("alignment"), from_start=True + ) + return out[-1]["data"] + + msg = update("samx", "configuration", {"samx_velocity": {"value": 5, "timestamp": 1.0}}) + assert msg.status == "invalid" + + update( + "samx", + "limits", + {"low": {"value": -20, "timestamp": 1.0}, "high": {"value": 20, "timestamp": 1.0}}, + ) + update("samx", "readback", {"samx": {"value": 0, "timestamp": 1.0}}) + msg = update("samy", "readback", {"samy": {"value": 0, "timestamp": 1.0}}) + assert msg.status == "valid" + assert set(msg.label.split("|")) == {"alignment", "test"} + + msg = update("samx", "readback", {"samx": {"value": 19, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "test" + + msg = update("samy", "readback", {"samy": {"value": 2, "timestamp": 2.0}}) + assert msg.status == "valid" + assert msg.label == "measurement" + class TestBeamlineStateManager: def test_manager_registers_for_state_updates(self, connected_connector): From 13036b994a5809f416a61c89bcfb286fe9bc76eb Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 24 Apr 2026 08:36:25 +0200 Subject: [PATCH 03/13] refactor: Allow signals to be parsed in config for aggregated states, fix and improve tests. --- bec_lib/bec_lib/bl_state_machine.py | 62 ++---- bec_lib/bec_lib/bl_states.py | 36 +++- bec_lib/tests/test_beamline_states.py | 203 ++++++++++++++++-- .../scans/state_transition_scan.py | 193 +++++++++++++++++ state_config.yaml | 25 --- 5 files changed, 422 insertions(+), 97 deletions(-) create mode 100644 bec_server/bec_server/scan_server/scans/state_transition_scan.py delete mode 100644 state_config.yaml diff --git a/bec_lib/bec_lib/bl_state_machine.py b/bec_lib/bec_lib/bl_state_machine.py index d808b7a33..0b5601ee8 100644 --- a/bec_lib/bec_lib/bl_state_machine.py +++ b/bec_lib/bec_lib/bl_state_machine.py @@ -47,85 +47,47 @@ def __init__(self, manager: BeamlineStateManager) -> None: self._configs: dict[str, AggregatedStateConfig] = {} def load_from_config( - self, name: str, config_path: str | None, config_dict: dict | None = None + self, name: str, config_path: str | None = None, config_dict: dict | None = None ) -> None: """ - Load an aggregated state configuration from a YAML file or a dictionary. If None or both are provided, - and error will be raised. + Load a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. Args: - name (str): The name of the aggregated state to create. + name (str): The name of the aggregated state to load. config_path (str | None): The path to the YAML configuration file. config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. - - Example of the YAML configuration file: - ``` yaml - alignment: - devices: - samx: - readback: - value: 0 - abs_tol: 0.1 - measurement: - devices: - samx: - readback: - value: 19 - abs_tol: 0.1 - velocity: - value: 5 - abs_tol: 0.1 - samy: - readback: - value: 0 - abs_tol: 0.1 - test: - devices: - samy: - readback: - value: 0 - abs_tol: 0.1 - ``` """ self._check_inputs(config_path=config_path, config_dict=config_dict) if config_path: with open(config_path, "r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) - config = AggregatedStateConfig(name=name, states=config_dict) self._manager.add(config) def update_config( - self, - name: str, - config_path: str | None, - config_dict: dict | AggregatedStateConfig | None = None, + self, name: str, config_path: str | None = None, config_dict: dict | None = None ) -> None: """ - Update an existing aggregated state configuration from a YAML file or a dictionary. - If None or both are provided, and error will be raised. - It will update the state based on the configuration and update it in the state_manager. + Update a state configuration from a YAML file or a dictionary. If None or both are provided, + an error will be raised. Config must be states for an AggregatedStateConfig or a dictionary/YAML file that + can be parsed into one. Please check AggregatedStateConfig state field for the expected format of the configuration. Args: name (str): The name of the aggregated state to update. config_path (str | None): The path to the YAML configuration file. - config_dict (dict | None): A dictionary containing the configuration. If provided, this will - be used instead of loading from a file. + config_dict (dict | None): A dictionary containing the configuration. If provided, this will be used instead of loading from a file. """ self._check_inputs(config_path=config_path, config_dict=config_dict) - # pylint: disable=protected-access - if name not in self._manager._states: - raise ValueError(f"Configuration for name {name} not found.") if config_path: with open(config_path, "r", encoding="utf-8") as f: config_dict = yaml.safe_load(f) # Load the new state config = AggregatedStateConfig(name=name, states=config_dict) - self._manager.update(config) + self._manager._update_state(config) - def _check_inputs( - self, config_path: str | None, config_dict: dict | AggregatedStateConfig | None - ) -> None: + def _check_inputs(self, config_path: str | None, config_dict: dict | None) -> None: if (config_path is None and config_dict is None) or ( config_path is not None and config_dict is not None ): diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 5e60e8dfc..299e7de81 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -132,7 +132,7 @@ class SignalConfig(BaseModel): abs_tol: float = 0.0 -class SubDeviceStates(BaseModel): +class SubDeviceStateConfig(BaseModel): devices: dict[str, dict[str, SignalConfig]] @@ -144,7 +144,7 @@ class AggregatedStateConfig(BeamlineStateConfig): state_type: ClassVar[str] = "AggregatedState" - states: dict[str, SubDeviceStates] + states: dict[str, SubDeviceStateConfig] C = TypeVar("C", bound=BeamlineStateConfig) @@ -425,7 +425,11 @@ def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, Sign return "high", "limits" signal_info = None - if "." in signal_name: + # This case is relevant if we are looking at a Signal directly + if device_name == signal_name and len(device_obj.root._info["signals"]) == 0: + signal_info = {"obj_name": signal_name, "kind_str": "hinted"} + # Case where we have a signal specified as a dotted name, e.g. + elif "." in signal_name: try: signal_obj = devices[signal_name] except AttributeError: @@ -438,6 +442,7 @@ def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, Sign ) signal_component = ".".join(signal_name.split(".")[1:]) signal_info = device_obj.root._info["signals"].get(signal_component) + # Case where the signal is specified as the signal else: signal_info = device_obj.root._info["signals"].get(signal_name) if signal_info is None: @@ -582,13 +587,28 @@ def _label_matches(self, label: str) -> bool: def _requirement_matches(self, requirement: _ResolvedStateSignal) -> bool: key = (requirement.device_name, requirement.source, requirement.signal_name) - value = self._signal_value_cache.get(key, None) - if value is None: + cached_value = self._signal_value_cache.get(key, None) + if cached_value is None: return False + try: - return abs(value - requirement.expected_value) <= requirement.abs_tolerance - except TypeError: - return value == requirement.expected_value + # Cast to float to make sure comparison with abs works as expected. + value = float(cached_value) + expected_value = float(requirement.expected_value) + return abs(value - expected_value) <= requirement.abs_tolerance + # Catch TypeError and ValueError in case the value is not a number or cannot be cast to float, + # in that case we fall back to exact equality. + except (TypeError, ValueError): + try: + result = cached_value == requirement.expected_value + except (TypeError, ValueError): + return False + # In case this comparison runs on comparing two arrays. + # We do not consider this comparsion as valid currently. + try: + return bool(result) + except (TypeError, ValueError): + return False class ShutterState(DeviceBeamlineState[DeviceStateConfig]): diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index 08913cfc2..6f671ec43 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -3,10 +3,13 @@ import inspect from unittest import mock +import numpy as np import pytest +import yaml from pydantic import BaseModel from bec_lib import bl_states, messages +from bec_lib.bl_state_machine import BeamlineStateMachine from bec_lib.bl_state_manager import ( BeamlineStateClientBase, BeamlineStateManager, @@ -198,6 +201,7 @@ def test_device_within_limits_state(self, connected_connector, dm_with_devices): @pytest.fixture(scope="function") def aggregated_state_config(self): + """Fixture for an test aggregated state configuration.""" return bl_states.AggregatedStateConfig( name="alignment", states={ @@ -209,7 +213,7 @@ def aggregated_state_config(self): "low_limit": {"value": -20, "abs_tol": 0.1}, "high_limit": {"value": 20, "abs_tol": 0.1}, }, - "samy": {"readback": {"value": 0, "abs_tol": 0.1}}, + "bpm4i": {"bpm4i": {"value": 0, "abs_tol": 0.1}}, } }, "measurement": { @@ -220,16 +224,24 @@ def aggregated_state_config(self): "low_limit_travel": {"value": -20, "abs_tol": 0.1}, "high_limit_travel": {"value": 20, "abs_tol": 0.1}, }, - "samy": {"readback": {"value": 2, "abs_tol": 0.1}}, + "bpm4i": {"bpm4i": {"value": 2, "abs_tol": 0.1}}, } }, - "test": {"devices": {"samy": {"readback": {"value": 0, "abs_tol": 0.1}}}}, + "test": {"devices": {"bpm4i": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}}, + "string_state": {"devices": {"bpm3i": {"bpm3i": {"value": "ok"}}}}, }, ) - def test_aggregated_state_init( + def test_aggregated_state_init_and_start( self, connected_connector, dm_with_devices, aggregated_state_config ): + """ + Test the initialization of the AggregatedState. + + Based on the provided configuration, we expect certain callbacks to be registered with the + Redis connector. This test checks this which essentially checks the proper functionality + of the 'start' method. + """ state = bl_states.AggregatedState( name=aggregated_state_config.name, @@ -238,12 +250,13 @@ def test_aggregated_state_init( device_manager=dm_with_devices, ) state.start() - # We should now have subscriptions on samx limits, readback and read_configuration, and samy readback + # We should now have subscriptions on samx limits, readback and read_configuration, and bpm4i & bpm4i info = [ MessageEndpoints.device_readback("samx"), MessageEndpoints.device_read_configuration("samx"), MessageEndpoints.device_limits("samx"), - MessageEndpoints.device_readback("samy"), + MessageEndpoints.device_readback("bpm4i"), + MessageEndpoints.device_readback("bpm3i"), ] for endpoint in info: assert endpoint.endpoint in state.connector._topics_cb @@ -251,6 +264,10 @@ def test_aggregated_state_init( def test_aggregated_state_evaluation( self, connected_connector, dm_with_devices, aggregated_state_config ): + """ + Test the evaluation of the AggregatedState when receiving message updates. This should trigger a state evaluation for + the affected labels and the current state, and if the state changes, a new state should be published. + """ state = bl_states.AggregatedState( name=aggregated_state_config.name, config=aggregated_state_config, @@ -263,9 +280,10 @@ def test_aggregated_state_evaluation( mock.patch.object(state, "evaluate", return_value=None) as evaluate, mock.patch.object(state, "_emit_state") as emit_state, ): - + # Test triggering evaluation for multiple labels + # samx affects alignment and measurement, so both should be evaluated. msg_with_2_states = messages.DeviceMessage( - signals={"samx": {"value": 5.0, "timestamp": 1.0}}, metadata={"stream": "primary"} + signals={"samx": {"value": 5.0, "timestamp": 1.0}} ) msg_obj = MessageObject( value=msg_with_2_states, topic=MessageEndpoints.device_readback("samx").endpoint @@ -277,6 +295,11 @@ def test_aggregated_state_evaluation( def test_aggregated_state_evaluate( self, connected_connector, dm_with_devices, aggregated_state_config ): + """ + Test the evaluate method. + We manually cache the relevant messages and then call evaluate with the affected label. + We then check if the output message has the expected status and label, and if the current labels are updated correctly. + """ state = bl_states.AggregatedState( name=aggregated_state_config.name, config=aggregated_state_config, @@ -284,6 +307,8 @@ def test_aggregated_state_evaluate( device_manager=dm_with_devices, ) state._build_rules() + # Assume that we are currently in test + state._current_labels = ["test"] state._cache_message( "samx", "readback", @@ -311,18 +336,19 @@ def test_aggregated_state_evaluate( ), ) state._cache_message( - "samy", + "bpm4i", "readback", messages.DeviceMessage( - signals={"samy": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} + signals={"bpm4i": {"value": 0, "timestamp": 1.0}}, metadata={"stream": "primary"} ), ) msg = state.evaluate(affected_labels={"alignment"}) assert msg.status == "valid" - assert msg.label == "alignment" - assert state._current_labels == ["alignment"] + # The order of the labels is not guaranteed + assert msg.label in ["alignment|test", "test|alignment"] + assert set(state._current_labels) == set(["alignment", "test"]) state._cache_message( "samx", @@ -334,6 +360,20 @@ def test_aggregated_state_evaluate( msg = state.evaluate(affected_labels={"alignment"}) + assert msg.status == "valid" + assert msg.label == "test" + assert state._current_labels == ["test"] + + state._cache_message( + "bpm4i", + "readback", + messages.DeviceMessage( + signals={"bpm4i": {"value": 2, "timestamp": 2.0}}, metadata={"stream": "primary"} + ), + ) + + msg = state.evaluate(affected_labels={"alignment", "test", "measurement"}) + assert msg.status == "invalid" assert msg.label == "No matching state" assert state._current_labels == [] @@ -341,6 +381,11 @@ def test_aggregated_state_evaluate( def test_aggregated_state_exception_handling( self, connected_connector, dm_with_devices, aggregated_state_config ): + """ + Test that if an exception is raised during the evaluation of the state, this is properly handled and an alarm is raised. + We check that the evaluate method is called and that if it raises an exception, the raise_alarm method of the connector + is called, and a state with status "unknown" and label "broken state" is published. + """ state = bl_states.AggregatedState( name=aggregated_state_config.name, config=aggregated_state_config, @@ -373,6 +418,10 @@ def test_aggregated_state_exception_handling( def test_aggregated_state_transitions_between_labels( self, connected_connector, dm_with_devices, aggregated_state_config ): + """ + Test the transitions between different labels of the aggregated state. We simulate the messages that would trigger + the transitions and check that the output message has the expected status and label, and that the current labels are updated correctly. + """ state = bl_states.AggregatedState( name=aggregated_state_config.name, config=aggregated_state_config, @@ -399,7 +448,7 @@ def update(device, source, signals): {"low": {"value": -20, "timestamp": 1.0}, "high": {"value": 20, "timestamp": 1.0}}, ) update("samx", "readback", {"samx": {"value": 0, "timestamp": 1.0}}) - msg = update("samy", "readback", {"samy": {"value": 0, "timestamp": 1.0}}) + msg = update("bpm4i", "readback", {"bpm4i": {"value": 0, "timestamp": 1.0}}) assert msg.status == "valid" assert set(msg.label.split("|")) == {"alignment", "test"} @@ -407,10 +456,57 @@ def update(device, source, signals): assert msg.status == "valid" assert msg.label == "test" - msg = update("samy", "readback", {"samy": {"value": 2, "timestamp": 2.0}}) + msg = update("bpm4i", "readback", {"bpm4i": {"value": 2, "timestamp": 2.0}}) assert msg.status == "valid" assert msg.label == "measurement" + @pytest.mark.parametrize( + ("cached_value", "expected_value", "abs_tolerance", "matches"), + [ + (1.05, 1.0, 0.1, True), + (1.2, 1.0, 0.1, False), + (5, 5, 0.0, True), + (np.int64(5), 5, 0.0, True), + (np.float64(1.05), 1.0, 0.1, True), + ("ok", "ok", 0.0, True), + ("not-ok", "ok", 0.0, False), + ([1, 2], 1, 0.0, False), + (np.array([1.0, 2.0]), 1.0, 0.1, False), + (np.array([1.0, 2.0]), np.array([1.0, 2.0]), 0.0, False), + ], + ) + def test_aggregated_state_requirement_matches( + self, + connected_connector, + dm_with_devices, + aggregated_state_config, + cached_value, + expected_value, + abs_tolerance, + matches, + ): + """ + Test the evaluation of requirements in the aggregated state. We manually set the signal value + cache and then call the _requirement_matches method with a requirement, and check if the output is as expected. + """ + state = bl_states.AggregatedState( + name=aggregated_state_config.name, + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + requirement = bl_states._ResolvedStateSignal( + label="alignment", + device_name="bpm4i", + signal_name="bpm4i", + expected_value=expected_value, + abs_tolerance=abs_tolerance, + source="readback", + ) + state._signal_value_cache[("bpm4i", "readback", "bpm4i")] = cached_value + + assert state._requirement_matches(requirement) is matches + class TestBeamlineStateManager: def test_manager_registers_for_state_updates(self, connected_connector): @@ -536,3 +632,82 @@ def test_show_all_prints_table(self, state_manager, capsys): captured = capsys.readouterr() assert "shutter_open" in (captured.out + captured.err) + + +class TestStateMachine: + + @pytest.fixture() + def state_machine(self, state_manager): + state_machine = BeamlineStateMachine(manager=state_manager) + return state_machine + + @pytest.fixture() + def config_dict(self): + return { + "alignment": { + "devices": { + "samx": { + "readback": {"value": 0, "abs_tol": 0.1}, + "velocity": {"value": 5, "abs_tol": 0.1}, + } + } + } + } + + def test_load_from_config_with_dict( + self, state_machine: BeamlineStateMachine, tmp_path, config_dict + ): + """Test loading configuration from a dictionary or file.""" + + # Load valid configuration from dictionary + with mock.patch.object(state_machine._manager, "add") as manager_add: + state_machine.load_from_config( + name="alignment", config_path=None, config_dict=config_dict + ) + manager_add.assert_called_once_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + # Loading with both config_path and config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + # Loading with neither config_path nor config_dict should raise an error + with pytest.raises(ValueError): + state_machine.load_from_config(name="alignment", config_path=None, config_dict=None) + + # Loading from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.load_from_config(name="alignment", config_path=str(config_path)) + manager_add.assert_called_with( + bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + ) + + def test_update_config(self, state_machine: BeamlineStateMachine, config_dict, tmp_path): + """Test update method of state machine.""" + with mock.patch.object(state_machine._manager, "_update_state") as manager_update: + config = bl_states.AggregatedStateConfig(name="alignment", states=config_dict) + state_machine.update_config(name="alignment", config_dict=config_dict) + manager_update.assert_called_once_with(config) + + manager_update.reset_mock() + + # Invalid updates should raise an error + with pytest.raises(ValueError): + state_machine.update_config(name="alignment", config_dict=None) + manager_update.assert_not_called() + + with pytest.raises(ValueError): + state_machine.update_config( + name="alignment", config_path="path/to/config.yaml", config_dict=config_dict + ) + manager_update.assert_not_called() + manager_update.reset_mock() + # Updating from file should work. + config_path = tmp_path / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + state_machine.update_config(name="alignment", config_path=str(config_path)) + manager_update.assert_called_once_with(config) diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py new file mode 100644 index 000000000..f974d7f4d --- /dev/null +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -0,0 +1,193 @@ +""" +Updated move scan implementation for coordinated motor repositioning commands. + +Scan procedure: + - prepare_scan + - open_scan + - stage + - pre_scan + - scan_core + - at_each_point (optionally called by scan_core) + - post_scan + - unstage + - close_scan + - on_exception (called if any exception is raised during the scan) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Tuple + +from bec_lib.device import DeviceBase, Signal +from bec_lib.endpoints import MessageEndpoints +from bec_lib.logger import bec_logger +from bec_server.scan_server.scans.scan_modifier import scan_hook +from bec_server.scan_server.scans.scans_v4 import ScanBase, bundle_args + +if TYPE_CHECKING: + from bec_lib.bl_states import AggregatedStateConfig, SubDeviceStateConfig + from bec_lib.messages import AvailableBeamlineStatesMessage + +logger = bec_logger.logger + + +class StateTransitionScan(ScanBase): + + # Scan Type: Hardware triggered or software triggered? + # If the main trigger and readout logic is done within the at_each_point method in scan_core, choose SOFTWARE_TRIGGERED. + # If the main trigger and readout logic is implemented on a device that is simply kicked off in this scan, choose HARDWARE_TRIGGERED. + # This primarily serves as information for devices: The device may need to react differently if a software trigger is expected + # for every point. + scan_type = None + + # Scan name: This is the name of the scan, e.g. "line_scan". This is used for display purposes and to identify the scan type in user interfaces. + # Choose a descriptive name that does not conflict with existing scan names. + scan_name = "_v4_state_transition" + + # We set is_scan to False to separate this class from the other scans in the user interface + is_scan = False + + def __init__(self, *args, state_name: str, target_label: str, **kwargs): + """ + State transition scan that moves a motor in between two states. + The main purpose of this scan is to be used in conjunction with state + management in BEC, and transitioning the beamline in-between different aggregated states. + """ + super().__init__(**kwargs) + self.state_name = state_name + self.target_label = target_label + # Check if the state and the target label exists, if yes, fetch the configuration for the target state + self.config_for_label = self._fetch_config_for_label(state_name, target_label) + + # We need to sort the devices and signals in the config, and identify which of them are motor setpoint/readback pairs + # and which of them are just readouts and thereby can not be set within the transition. + self._settable_signals_with_setpoint: list[Tuple[Signal, Any]] = [] + + @scan_hook + def prepare_scan(self): + """ + Prepare the scan. This can include any steps that need to be executed + before the scan is opened, such as preparing the positions (if not done already) + or setting up the devices. + """ + for device_name, signal_configs in self.config_for_label.devices.items(): + dev_obj = self.device_manager.devices.get(device_name, None) + if dev_obj is None: + raise ValueError(f"Device {device_name} not found in device manager.") + if isinstance(dev_obj, Signal): + if dev_obj._info["write_access"] is False: + logger.info( + f"Signal {device_name} is read-only, skipping during state transition." + ) + continue # This is a read-only signal, we can transition it + + @scan_hook + def open_scan(self): + """ + Open the scan. + This step must call self.actions.open_scan() to ensure that a new scan is + opened. Make sure to prepare the scan metadata before, either in + prepare_scan() or in open_scan() itself and call self.update_scan_info(...) + to update the scan metadata if needed. + """ + + @scan_hook + def stage(self): + """ + Stage the devices for the upcoming scan. The stage logic is typically + implemented on the device itself (i.e. by the device's stage method). + However, if there are any additional steps that need to be executed before + staging the devices, they can be implemented here. + """ + + @scan_hook + def pre_scan(self): + """ + Pre-scan steps to be executed before the main scan logic. + This is typically the last chance to prepare the devices before the core scan + logic is executed. For example, this is a good place to initialize time-criticial + devices, e.g. devices that have a short timeout. + The pre-scan logic is typically implemented on the device itself. + """ + + @scan_hook + def scan_core(self): + """ + Core scan logic to be executed during the scan. + This is where the main scan logic should be implemented. + """ + current_positions = self.components.get_start_positions(self.motors) + target_positions = list(self.motor_args_bundles.values()) + target_positions = [pos[0] for pos in target_positions] + if self.relative: + target_positions = [ + target + current + for target, current in zip(target_positions, current_positions, strict=False) + ] + + self.actions.add_scan_report_instruction_readback( + devices=self.motors, + start=current_positions, + stop=target_positions, + request_id=self.scan_info.metadata["RID"], + ) + + self.components.move_and_wait(self.motors, target_positions) + + @scan_hook + def at_each_point(self): + """ + Logic to be executed at each point during the scan. This is called by the step_scan method at each point. + + Args: + motors (list[str | DeviceBase]): List of motor names or device instances being moved. + positions (np.ndarray): Current positions of the motors, shape (len(motors),). + last_positions (np.ndarray | None): Previous positions of the motors, shape (len(motors),) or None if this is the first point. + """ + + @scan_hook + def post_scan(self): + """ + Post-scan steps to be executed after the main scan logic. + """ + + @scan_hook + def unstage(self): + """Unstage the scan by executing post-scan steps.""" + + @scan_hook + def close_scan(self): + """Close the scan.""" + + @scan_hook + def on_exception(self, exception: Exception): + """ + Handle exceptions that occur during the scan. + This is a good place to implement any cleanup logic that needs to be executed in case of an exception, + such as returning the devices to a safe state or moving the motors back to their starting position. + """ + + ################# + ## Custom Methods + ################# + + def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDeviceStateConfig: + available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get( + MessageEndpoints.available_beamline_states() + ) + configs = [state for state in available_states_msg.states if state.name == state_name] + if len(configs) == 0: + raise ValueError(f"State {state_name} not found in available states.") + elif len(configs) > 1: # Should not be possible, but just in case + raise ValueError(f"Multiple states with name {state_name} found in available states.") + config: AggregatedStateConfig = configs[0] + if config.state_type != "AggregatedState": + raise ValueError( + f"State {state_name} is not an aggregated state. Transitions are only supported for aggregated states." + ) + available_labels = list(config.states.keys()) + if target_label not in available_labels: + raise ValueError( + f"Target label {target_label} not found in state {state_name}. Available labels: {available_labels}" + ) + return config.states[target_label] diff --git a/state_config.yaml b/state_config.yaml deleted file mode 100644 index c151cffd6..000000000 --- a/state_config.yaml +++ /dev/null @@ -1,25 +0,0 @@ -alignment: # AggregatedStateConfig -> can have different labels and for each label, different devices - devices: - samx: - readback: - value: 0 - abs_tol: 0.1 - velocity: - value: 0 - abs_tol: 0.1 -measurement: - devices: - samx: - readback: - value: 19 - abs_tol: 0.1 - samy: - readback: - value: 0 - abs_tol: 0.1 -test: - devices: - samy: - readback: - value: 0 - abs_tol: 0.1 \ No newline at end of file From 4c64c92f64e8a49d99c7e3933a177299020f97d3 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 24 Apr 2026 12:41:44 +0200 Subject: [PATCH 04/13] refactor: Improve config structure for aggregated state configs --- bec_lib/bec_lib/bl_states.py | 110 ++++++++++++++++++++++---- bec_lib/tests/test_beamline_states.py | 26 +++--- 2 files changed, 109 insertions(+), 27 deletions(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 299e7de81..9f524820e 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -11,6 +11,7 @@ import yaml from pydantic import BaseModel, field_validator, model_validator +from typeguard import typechecked from bec_lib import messages from bec_lib.alarm_handler import Alarms @@ -132,14 +133,47 @@ class SignalConfig(BaseModel): abs_tol: float = 0.0 +class DeviceConfig(BaseModel): + """Configuration for a device inside a named machine state.""" + + abs_tol: float = 0.0 + value: float | int | str | bool | None = None + low_limit: SignalConfig | None = None + high_limit: SignalConfig | None = None + signals: dict[str, SignalConfig] | None = None + + @model_validator(mode="after") + def validate_config(self) -> DeviceConfig: + """ + Validate that either value, low_limit, high_limit, or signals are provided. + """ + if ( + self.value is None + and self.low_limit is None + and self.high_limit is None + and self.signals is None + ): + raise ValueError( + "At least one of value, low_limit, high_limit, or signals must be provided." + ) + return self + + class SubDeviceStateConfig(BaseModel): + """ + Configuration for a sub-state with a specific label. + This is a device/signal mappping to either a DeviceConfig or SignalConfig. + """ - devices: dict[str, dict[str, SignalConfig]] + devices: dict[str, DeviceConfig | SignalConfig] + transition_metadata: dict[str, Any] | None = None class AggregatedStateConfig(BeamlineStateConfig): """ Configuration for a state machine driven by multiple device signals. + + Keys of the states dictionary are the labels of the different states. """ state_type: ClassVar[str] = "AggregatedState" @@ -411,6 +445,7 @@ def _get_signal_source(self, signal_info: dict[str, Any]) -> SignalSource: f"{self._error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" ) + @typechecked def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, SignalSource]: devices = self._get_devices() try: @@ -466,25 +501,70 @@ def _build_rules(self) -> None: self._subscriptions.clear() for label, device_configs in self.config.states.items(): state_requirements: list[_ResolvedStateSignal] = [] - for device_name, signal_configs in device_configs.devices.items(): - for signal_name, target in signal_configs.items(): - resolved_signal_name, source = self._resolve_signal(device_name, signal_name) + for device_name, config in device_configs.devices.items(): + if isinstance(config, SignalConfig): state_requirements.append( - _ResolvedStateSignal( - label=label, - device_name=device_name, - signal_name=resolved_signal_name, - expected_value=target.value, - abs_tolerance=target.abs_tol, - source=source, + self._build_requirement_for_signal( + device_name, device_name, config.value, config.abs_tol, label ) ) - self._subscriptions.add((device_name, source)) - self._signal_info_to_labels.setdefault( - (device_name, source, resolved_signal_name), set() - ).add(label) + elif isinstance(config, DeviceConfig): + # If a value is specified for the device, add it as a requirement + if config.value is not None: + state_requirements.append( + self._build_requirement_for_signal( + device_name, device_name, config.value, config.abs_tol, label + ) + ) + if config.low_limit is not None: + state_requirements.append( + self._build_requirement_for_signal( + device_name, + "low_limit", + config.low_limit.value, + config.low_limit.abs_tol, + label, + ) + ) + if config.high_limit is not None: + state_requirements.append( + self._build_requirement_for_signal( + device_name, + "high_limit", + config.high_limit.value, + config.high_limit.abs_tol, + label, + ) + ) + for signal_name, signal_config in (config.signals or {}).items(): + state_requirements.append( + self._build_requirement_for_signal( + device_name, + signal_name, + signal_config.value, + signal_config.abs_tol, + label, + ) + ) self._requirements_for_label[label] = state_requirements + def _build_requirement_for_signal( + self, device_name: str, signal_name: str, value: Any, abs_tol: float, label: str + ) -> _ResolvedStateSignal: + resolved_signal_name, source = self._resolve_signal(device_name, signal_name) + self._subscriptions.add((device_name, source)) + self._signal_info_to_labels.setdefault( + (device_name, source, resolved_signal_name), set() + ).add(label) + return _ResolvedStateSignal( + label=label, + device_name=device_name, + signal_name=resolved_signal_name, + expected_value=value, + abs_tolerance=abs_tol, + source=source, + ) + def start(self) -> None: if self.started: return diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index 6f671ec43..0606bba30 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -208,27 +208,28 @@ def aggregated_state_config(self): "alignment": { "devices": { "samx": { - "readback": {"value": 0, "abs_tol": 0.1}, - "velocity": {"value": 5, "abs_tol": 0.1}, + "value": 0, + "abs_tol": 0.1, "low_limit": {"value": -20, "abs_tol": 0.1}, "high_limit": {"value": 20, "abs_tol": 0.1}, }, - "bpm4i": {"bpm4i": {"value": 0, "abs_tol": 0.1}}, + "bpm4i": {"value": 0, "abs_tol": 0.1}, } }, "measurement": { "devices": { "samx": { - "readback": {"value": 19, "abs_tol": 0.1}, - "velocity": {"value": 5, "abs_tol": 0.1}, - "low_limit_travel": {"value": -20, "abs_tol": 0.1}, - "high_limit_travel": {"value": 20, "abs_tol": 0.1}, + "value": 19, + "abs_tol": 0.1, + "low_limit": {"value": -20, "abs_tol": 0.1}, + "high_limit": {"value": 20, "abs_tol": 0.1}, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, }, - "bpm4i": {"bpm4i": {"value": 2, "abs_tol": 0.1}}, + "bpm4i": {"value": 2, "abs_tol": 0.1}, } }, - "test": {"devices": {"bpm4i": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}}, - "string_state": {"devices": {"bpm3i": {"bpm3i": {"value": "ok"}}}}, + "test": {"devices": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}, + "string_state": {"devices": {"bpm3i": {"value": "ok"}}}, }, ) @@ -647,8 +648,9 @@ def config_dict(self): "alignment": { "devices": { "samx": { - "readback": {"value": 0, "abs_tol": 0.1}, - "velocity": {"value": 5, "abs_tol": 0.1}, + "value": 0, + "abs_tol": 0.1, + "signals": {"velocity": {"value": 5, "abs_tol": 0.1}}, } } } From 154857c308e87d74c23f3a2e7815271dc0b63988 Mon Sep 17 00:00:00 2001 From: appel_c Date: Fri, 24 Apr 2026 14:32:08 +0200 Subject: [PATCH 05/13] feat: add state transition scan --- bec_lib/bec_lib/bl_states.py | 190 +++++++++++------- bec_lib/tests/test_beamline_states.py | 2 +- .../scans/state_transition_scan.py | 122 ++++++++--- 3 files changed, 218 insertions(+), 96 deletions(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index 9f524820e..d8689fdb5 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -11,7 +11,6 @@ import yaml from pydantic import BaseModel, field_validator, model_validator -from typeguard import typechecked from bec_lib import messages from bec_lib.alarm_handler import Alarms @@ -386,7 +385,7 @@ def _update_device_state(self, msg_obj: MessageObject) -> messages.BeamlineState @dataclass(frozen=True) -class _ResolvedStateSignal: +class ResolvedStateSignal: label: str device_name: str signal_name: str @@ -405,7 +404,7 @@ def __init__(self, *args, **kwargs) -> None: # Mapping from signal updates to affected state labels, used for efficient evaluation when a signal update is received self._signal_info_to_labels: dict[tuple[str, SignalSource, str], set[str]] = {} # Mapping from state labels to the list of signal requirements that define that state - self._requirements_for_label: dict[str, list[_ResolvedStateSignal]] = {} + self._requirements_for_label: dict[str, list[ResolvedStateSignal]] = {} # Set of subscriptions to signal updates self._subscriptions: set[tuple[str, SignalSource]] = set() # Cache of the latest signal values @@ -426,32 +425,39 @@ def _endpoint(device: str, source: SignalSource): f"Invalid signal source '{source}', please use 'readback', 'configuration', or 'limits'." ) - def _get_devices(self): + def _get_device_manager(self): if self.device_manager is None: # pylint: disable=import-outside-toplevel from bec_lib.client import BECClient bec = BECClient() - return bec.device_manager.devices - return self.device_manager.devices + return bec.device_manager + return self.device_manager - def _get_signal_source(self, signal_info: dict[str, Any]) -> SignalSource: + @staticmethod + def _get_signal_source(signal_info: dict[str, Any], error_prefix: str) -> SignalSource: kind_str = str(signal_info.get("kind_str", "")).lower() if "hinted" in kind_str or "normal" in kind_str: return "readback" if "config" in kind_str: return "configuration" raise ValueError( - f"{self._error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" + f"{error_prefix} Unsupported kind: '{kind_str}' for signal : \n {yaml.dump(signal_info, indent=4)}" ) - @typechecked - def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, SignalSource]: - devices = self._get_devices() + @staticmethod + def _resolve_signal( + device_name: str, signal_name: str, device_manager: DeviceManagerBase, error_prefix: str + ) -> tuple[str, SignalSource]: + devices = device_manager.devices try: + if not isinstance(device_name, str): + raise ValueError( + f"{error_prefix} Device name must be a string, got {type(device_name)}" + ) device_obj: DeviceBase = devices[device_name] except KeyError: - raise ValueError(f"{self._error_prefix} Device '{device_name}' not found.") from None + raise ValueError(f"{error_prefix} Device '{device_name}' not found.") from None # Special handling for limits, as they are not regular signals. if signal_name in ["low_limit", "low_limit_travel"]: @@ -469,11 +475,11 @@ def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, Sign signal_obj = devices[signal_name] except AttributeError: raise ValueError( - f"{self._error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." ) from None if signal_obj.parent != device_obj: raise ValueError( - f"{self._error_prefix} Signal '{signal_name}' does not belong to device '{device_name}'." + f"{error_prefix} Signal '{signal_name}' does not belong to device '{device_name}'." ) signal_component = ".".join(signal_name.split(".")[1:]) signal_info = device_obj.root._info["signals"].get(signal_component) @@ -488,75 +494,119 @@ def _resolve_signal(self, device_name: str, signal_name: str) -> tuple[str, Sign if signal_info is None: raise ValueError( - f"{self._error_prefix} Signal '{signal_name}' not found for device '{device_name}'." + f"{error_prefix} Signal '{signal_name}' not found for device '{device_name}'." ) obj_name = signal_info.get("obj_name") - signal_source = self._get_signal_source(signal_info) + signal_source = AggregatedState._get_signal_source(signal_info, error_prefix) return obj_name, signal_source - def _build_rules(self) -> None: - self._signal_info_to_labels.clear() - self._requirements_for_label.clear() - self._subscriptions.clear() - for label, device_configs in self.config.states.items(): - state_requirements: list[_ResolvedStateSignal] = [] - for device_name, config in device_configs.devices.items(): - if isinstance(config, SignalConfig): + @staticmethod + def get_state_requirements( + label: str, + state_config: SubDeviceStateConfig, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> list[ResolvedStateSignal]: + state_requirements: list[ResolvedStateSignal] = [] + for device_name, config in state_config.devices.items(): + if isinstance(config, SignalConfig): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, + ) + ) + elif isinstance(config, DeviceConfig): + # If a value is specified for the device, add it as a requirement + if config.value is not None: state_requirements.append( - self._build_requirement_for_signal( - device_name, device_name, config.value, config.abs_tol, label + AggregatedState._build_requirement_for_signal( + device_name, + device_name, + config.value, + config.abs_tol, + label, + device_manager, + error_prefix, ) ) - elif isinstance(config, DeviceConfig): - # If a value is specified for the device, add it as a requirement - if config.value is not None: - state_requirements.append( - self._build_requirement_for_signal( - device_name, device_name, config.value, config.abs_tol, label - ) - ) - if config.low_limit is not None: - state_requirements.append( - self._build_requirement_for_signal( - device_name, - "low_limit", - config.low_limit.value, - config.low_limit.abs_tol, - label, - ) + if config.low_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "low_limit", + config.low_limit.value, + config.low_limit.abs_tol, + label, + device_manager, + error_prefix, ) - if config.high_limit is not None: - state_requirements.append( - self._build_requirement_for_signal( - device_name, - "high_limit", - config.high_limit.value, - config.high_limit.abs_tol, - label, - ) + ) + if config.high_limit is not None: + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + "high_limit", + config.high_limit.value, + config.high_limit.abs_tol, + label, + device_manager, + error_prefix, ) - for signal_name, signal_config in (config.signals or {}).items(): - state_requirements.append( - self._build_requirement_for_signal( - device_name, - signal_name, - signal_config.value, - signal_config.abs_tol, - label, - ) + ) + for signal_name, signal_config in (config.signals or {}).items(): + state_requirements.append( + AggregatedState._build_requirement_for_signal( + device_name, + signal_name, + signal_config.value, + signal_config.abs_tol, + label, + device_manager, + error_prefix, ) + ) + return state_requirements + + def _build_rules(self) -> None: + self._signal_info_to_labels.clear() + self._requirements_for_label.clear() + self._subscriptions.clear() + for label, device_configs in self.config.states.items(): + state_requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + label, device_configs, self._get_device_manager(), self._error_prefix + ) + for requirement in state_requirements: + device_name = requirement.device_name + signal_name = requirement.signal_name + source = requirement.source + self._subscriptions.add((device_name, source)) + self._signal_info_to_labels.setdefault( + (device_name, source, signal_name), set() + ).add(label) self._requirements_for_label[label] = state_requirements + @staticmethod def _build_requirement_for_signal( - self, device_name: str, signal_name: str, value: Any, abs_tol: float, label: str - ) -> _ResolvedStateSignal: - resolved_signal_name, source = self._resolve_signal(device_name, signal_name) - self._subscriptions.add((device_name, source)) - self._signal_info_to_labels.setdefault( - (device_name, source, resolved_signal_name), set() - ).add(label) - return _ResolvedStateSignal( + device_name: str, + signal_name: str, + value: Any, + abs_tol: float, + label: str, + device_manager: DeviceManagerBase, + error_prefix: str, + ) -> ResolvedStateSignal: + resolved_signal_name, source = AggregatedState._resolve_signal( + device_name, signal_name, device_manager, error_prefix + ) + + return ResolvedStateSignal( label=label, device_name=device_name, signal_name=resolved_signal_name, @@ -665,7 +715,7 @@ def _label_matches(self, label: str) -> bool: self._requirement_matches(requirement) for requirement in requirements ) - def _requirement_matches(self, requirement: _ResolvedStateSignal) -> bool: + def _requirement_matches(self, requirement: ResolvedStateSignal) -> bool: key = (requirement.device_name, requirement.source, requirement.signal_name) cached_value = self._signal_value_cache.get(key, None) if cached_value is None: diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index 0606bba30..29ef4ac77 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -496,7 +496,7 @@ def test_aggregated_state_requirement_matches( redis_connector=connected_connector, device_manager=dm_with_devices, ) - requirement = bl_states._ResolvedStateSignal( + requirement = bl_states.ResolvedStateSignal( label="alignment", device_name="bpm4i", signal_name="bpm4i", diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py index f974d7f4d..a9f8ff973 100644 --- a/bec_server/bec_server/scan_server/scans/state_transition_scan.py +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -18,19 +18,37 @@ from typing import TYPE_CHECKING, Any, Tuple -from bec_lib.device import DeviceBase, Signal +from bec_lib.alarm_handler import AlarmBase, Alarms +from bec_lib.bl_states import AggregatedState, SubDeviceStateConfig +from bec_lib.device import DeviceBase, Positioner, Signal from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger +from bec_lib.messages import AlarmMessage, ErrorInfo from bec_server.scan_server.scans.scan_modifier import scan_hook -from bec_server.scan_server.scans.scans_v4 import ScanBase, bundle_args +from bec_server.scan_server.scans.scans_v4 import ScanBase if TYPE_CHECKING: - from bec_lib.bl_states import AggregatedStateConfig, SubDeviceStateConfig + from bec_lib.bl_states import AggregatedStateConfig, ResolvedStateSignal from bec_lib.messages import AvailableBeamlineStatesMessage logger = bec_logger.logger +class StateTransitionScanError(AlarmBase): + """Exception raised when an RPC call fails.""" + + def __init__(self, exc_type: str, message: str, compact_message: str) -> None: + alarm = AlarmMessage( + severity=Alarms.MAJOR, + info=ErrorInfo( + exception_type=exc_type, + error_message=message, + compact_error_message=compact_message, + ), + ) + super().__init__(alarm, Alarms.MAJOR, handled=False) + + class StateTransitionScan(ScanBase): # Scan Type: Hardware triggered or software triggered? @@ -61,8 +79,11 @@ def __init__(self, *args, state_name: str, target_label: str, **kwargs): # We need to sort the devices and signals in the config, and identify which of them are motor setpoint/readback pairs # and which of them are just readouts and thereby can not be set within the transition. - self._settable_signals_with_setpoint: list[Tuple[Signal, Any]] = [] + self._signals_to_set: list[Tuple[Signal, Any]] = [] + self._limits_to_set: dict[str, Tuple[Positioner, float, float]] = {} + self._devices_to_set: list[Tuple[Positioner, float]] = [] + # pylint: disable=protected-access @scan_hook def prepare_scan(self): """ @@ -70,16 +91,65 @@ def prepare_scan(self): before the scan is opened, such as preparing the positions (if not done already) or setting up the devices. """ - for device_name, signal_configs in self.config_for_label.devices.items(): - dev_obj = self.device_manager.devices.get(device_name, None) + requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( + self.target_label, self.config_for_label, self.device_manager, "StateTransitionScan" + ) + for req in requirements: + dev_obj: DeviceBase = self.device_manager.devices.get(req.device_name) + # Device not found if dev_obj is None: - raise ValueError(f"Device {device_name} not found in device manager.") + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {req.device_name} not found in device manager.", + compact_message=f"Device {req.device_name} not found.", + ) + # First we handle Signals logic if isinstance(dev_obj, Signal): - if dev_obj._info["write_access"] is False: - logger.info( - f"Signal {device_name} is read-only, skipping during state transition." + self._signals_to_set.append((dev_obj, req.expected_value)) + continue + # Positioner and Device logic. Devices must implement .set for this to work, otherwise we can not set them and we raise an error + if isinstance(dev_obj, DeviceBase): + # Handle motor-specific logic here + # First we handle logic for motions of the motor. Device_name and signal_name will be equivalent here + if req.signal_name == req.device_name: + self._devices_to_set.append((dev_obj, req.expected_value)) + continue + if req.signal_name in ["low", "high"]: + if req.device_name not in self._limits_to_set: + self._limits_to_set[req.device_name] = ( + dev_obj, + dev_obj.low_limit, + dev_obj.high_limit, + ) + if req.signal_name == "low_limit": + self._limits_to_set[req.device_name] = ( + dev_obj, + req.expected_value, + self._limits_to_set[req.device_name][2], + ) + else: + self._limits_to_set[req.device_name] = ( + dev_obj, + self._limits_to_set[req.device_name][1], + req.expected_value, + ) + continue + signal_obj = self._get_signal_object(dev_obj, req.signal_name) + if signal_obj is None: + raise StateTransitionScanError( + exc_type="SignalNotFound", + message=f"Signal {req.signal_name} for device {req.device_name} not found in device manager.", + compact_message=f"Signal {req.signal_name} for device {req.device_name} not found.", ) - continue # This is a read-only signal, we can transition it + self._signals_to_set.append((signal_obj, req.expected_value)) + continue + + self.update_scan_info(scan_report_devices=[dev for dev, _ in self._devices_to_set]) + + def _get_signal_object(self, device_obj: DeviceBase, signal_name: str) -> Signal: + for component_name, info in device_obj._info["signals"].items(): + if info["obj_name"] == signal_name: + return getattr(device_obj, component_name) @scan_hook def open_scan(self): @@ -116,23 +186,23 @@ def scan_core(self): Core scan logic to be executed during the scan. This is where the main scan logic should be implemented. """ - current_positions = self.components.get_start_positions(self.motors) - target_positions = list(self.motor_args_bundles.values()) - target_positions = [pos[0] for pos in target_positions] - if self.relative: - target_positions = [ - target + current - for target, current in zip(target_positions, current_positions, strict=False) - ] + motors = [element[0] for element in self._devices_to_set] + target_positions = [element[1] for element in self._devices_to_set] + current_positions = self.components.get_start_positions(motors) self.actions.add_scan_report_instruction_readback( - devices=self.motors, + devices=motors, start=current_positions, stop=target_positions, request_id=self.scan_info.metadata["RID"], ) - self.components.move_and_wait(self.motors, target_positions) + self.components.move_and_wait(motors, target_positions) + # After the move is completed, we set the limits and signals. + for dev_name, (dev_obj, low_limit, high_limit) in self._limits_to_set.items(): + dev_obj.limits = [low_limit, high_limit] + for signal_obj, target_value in self._signals_to_set: + signal_obj.set(target_value).wait() @scan_hook def at_each_point(self): @@ -172,10 +242,12 @@ def on_exception(self, exception: Exception): ################# def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDeviceStateConfig: - available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get( + available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get_last( MessageEndpoints.available_beamline_states() ) - configs = [state for state in available_states_msg.states if state.name == state_name] + configs = [ + state for state in available_states_msg["data"].states if state.name == state_name + ] if len(configs) == 0: raise ValueError(f"State {state_name} not found in available states.") elif len(configs) > 1: # Should not be possible, but just in case @@ -185,9 +257,9 @@ def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDevi raise ValueError( f"State {state_name} is not an aggregated state. Transitions are only supported for aggregated states." ) - available_labels = list(config.states.keys()) + available_labels = list(config.parameters["states"].keys()) if target_label not in available_labels: raise ValueError( f"Target label {target_label} not found in state {state_name}. Available labels: {available_labels}" ) - return config.states[target_label] + return SubDeviceStateConfig.model_validate(config.parameters["states"][target_label]) From e759d790e39051914f502977bc8201c50842b4b2 Mon Sep 17 00:00:00 2001 From: appel_c Date: Mon, 4 May 2026 15:08:29 +0200 Subject: [PATCH 06/13] tests: fix mocked device info for tests --- bec_lib/bec_lib/tests/utils.py | 14 +------------- .../scan_server/scans/state_transition_scan.py | 2 +- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index a69190030..afb2de4e6 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -440,19 +440,7 @@ def get_device_info_mock(device_name, device_class) -> messages.DeviceInfoMessag if device_base_class == "positioner": signals = positioner_info_mock(device_name)["device_info"]["signals"] elif device_base_class == "signal": - signals = { - device_name: { - "metadata": { - "connected": True, - "read_access": True, - "write_access": False, - "timestamp": 0, - "status": None, - "severity": None, - "precision": None, - } - } - } + signals = {} else: signals = {} dev_info = { diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py index a9f8ff973..fac356270 100644 --- a/bec_server/bec_server/scan_server/scans/state_transition_scan.py +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -25,7 +25,7 @@ from bec_lib.logger import bec_logger from bec_lib.messages import AlarmMessage, ErrorInfo from bec_server.scan_server.scans.scan_modifier import scan_hook -from bec_server.scan_server.scans.scans_v4 import ScanBase +from bec_server.scan_server.scans.scan_base import ScanBase if TYPE_CHECKING: from bec_lib.bl_states import AggregatedStateConfig, ResolvedStateSignal From 92e0f5924702d1e7faa77152ace1c4465caeae3b Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 12 May 2026 15:12:56 +0200 Subject: [PATCH 07/13] fix: remove deprecated option suggestion-mode from pylintrc --- .pylintrc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.pylintrc b/.pylintrc index c6f676501..806d3bc3d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -54,10 +54,6 @@ persistent=yes # the version used to run pylint. py-version=3.11 -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no From f948d2813e588761d4713139ca5a9d05dfb3974b Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 12 May 2026 15:13:12 +0200 Subject: [PATCH 08/13] fix: address PR comments --- bec_lib/bec_lib/bl_state_machine.py | 62 +++--- bec_lib/bec_lib/bl_states.py | 81 +++++++- bec_lib/bec_lib/tests/utils.py | 6 +- bec_lib/tests/test_beamline_states.py | 184 ++++++++++++++++++ .../scans/state_transition_scan.py | 12 +- 5 files changed, 308 insertions(+), 37 deletions(-) diff --git a/bec_lib/bec_lib/bl_state_machine.py b/bec_lib/bec_lib/bl_state_machine.py index 0b5601ee8..57fc834d3 100644 --- a/bec_lib/bec_lib/bl_state_machine.py +++ b/bec_lib/bec_lib/bl_state_machine.py @@ -3,31 +3,45 @@ Example of the YAML configuration file: ``` yaml -alignment: - devices: - samx: - readback: - value: 0 - abs_tol: 0.1 - measurement: - devices: - samx: - readback: - value: 19 - abs_tol: 0.1 +alignment: # AggregatedStateConfig -> can have different labels and for each label, different devices + transition_metadata: # optional field for metadata for each label + field: value + devices: + samx: + value: 0 + abs_tol: 0.1 + low_limit: + value: -20 + abs_tol: 0.1 + high_limit: + value: 20 + abs_tol: 0.1 + signals: velocity: - value: 5 - abs_tol: 0.1 - samy: - readback: - value: 0 - abs_tol: 0.1 - test: - devices: - samy: - readback: - value: 0 - abs_tol: 0.1 + value: 5 + abs_tol: 0.1 + bpm4i: + value: 100 + abs_tol: 10 +measurement: + devices: + samx: + value: 19 + abs_tol: 0.1 + signals: + velocity: + value: 20 + samy: + value: 0 + abs_tol: 0.1 +test: + devices: + samy: + value: 0 + abs_tol: 0.1 + bpm4i: + value: 100 + abs_tol: 10 ``` """ diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index d8689fdb5..fd7491ab8 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -381,7 +381,7 @@ def _update_device_state(self, msg_obj: MessageObject) -> messages.BeamlineState return self.evaluate(msg) -SignalSource = TypeVar("SignalSource", bound=Literal["readback", "configuration", "limits"]) +SignalSource = Literal["readback", "configuration", "limits"] @dataclass(frozen=True) @@ -426,6 +426,7 @@ def _endpoint(device: str, source: SignalSource): ) def _get_device_manager(self): + """Get the device manager.""" if self.device_manager is None: # pylint: disable=import-outside-toplevel from bec_lib.client import BECClient @@ -436,6 +437,16 @@ def _get_device_manager(self): @staticmethod def _get_signal_source(signal_info: dict[str, Any], error_prefix: str) -> SignalSource: + """ + Determine the signal source (readback, configuration, or limits) based on the signal information. + + Args: + signal_info (dict[str, Any]): The signal information dictionary containing at least the "kind_str" key. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + SignalSource: A string literal indicating the signal source, one of "readback", "configuration", or "limits". + """ kind_str = str(signal_info.get("kind_str", "")).lower() if "hinted" in kind_str or "normal" in kind_str: return "readback" @@ -449,6 +460,18 @@ def _get_signal_source(signal_info: dict[str, Any], error_prefix: str) -> Signal def _resolve_signal( device_name: str, signal_name: str, device_manager: DeviceManagerBase, error_prefix: str ) -> tuple[str, SignalSource]: + """ + Resolve the signal information for a given device and signal name. + + Args: + device_name (str): The name of the device. + signal_name (str): The name of the signal. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + tuple[str, SignalSource]: A tuple containing the object name and the signal source. + """ devices = device_manager.devices try: if not isinstance(device_name, str): @@ -508,6 +531,18 @@ def get_state_requirements( device_manager: DeviceManagerBase, error_prefix: str, ) -> list[ResolvedStateSignal]: + """ + Get the state requirements for a given label and state configuration. + + Args: + label (str): The label for the state. + state_config (SubDeviceStateConfig): The state configuration. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + list[ResolvedStateSignal]: A list of resolved state signals. + """ state_requirements: list[ResolvedStateSignal] = [] for device_name, config in state_config.devices.items(): if isinstance(config, SignalConfig): @@ -575,6 +610,7 @@ def get_state_requirements( return state_requirements def _build_rules(self) -> None: + """Build the internal rules and mappings for state evaluation based on the configuration.""" self._signal_info_to_labels.clear() self._requirements_for_label.clear() self._subscriptions.clear() @@ -602,6 +638,21 @@ def _build_requirement_for_signal( device_manager: DeviceManagerBase, error_prefix: str, ) -> ResolvedStateSignal: + """ + Build a ResolvedStateSignal for a given device, signal, and expected value. + + Args: + device_name (str): The name of the device. + signal_name (str): The name of the signal. + value (Any): The expected value for the signal. + abs_tol (float): The absolute tolerance for comparing the signal value. + label (str): The label of the state that this requirement belongs to. + device_manager (DeviceManagerBase): The device manager instance. + error_prefix (str): A prefix to use in error messages for better context. + + Returns: + ResolvedStateSignal: The resolved state signal requirement. + """ resolved_signal_name, source = AggregatedState._resolve_signal( device_name, signal_name, device_manager, error_prefix ) @@ -621,9 +672,8 @@ def start(self) -> None: if self.connector is None: raise RuntimeError("Redis connector is not set.") - + msg = None try: - msg = None self._build_rules() affected_labels = self._fill_cache() msg = self.evaluate(affected_labels=affected_labels) @@ -642,6 +692,7 @@ def start(self) -> None: super().start() def _fill_cache(self) -> set[str]: + """Fill the signal value cache with the current values and return the set of affected state labels.""" affected_labels: set[str] = set() for device, source in self._subscriptions: endpoint = self._endpoint(device, source) @@ -653,6 +704,7 @@ def _fill_cache(self) -> set[str]: def _cache_message( self, device: str, source: SignalSource, msg: messages.DeviceMessage ) -> set[str]: + """Cache the signal values from a device message and return the set of affected state labels.""" affected_labels: set[str] = set() for signal_name, signal_data in msg.signals.items(): key = (device, source, signal_name) @@ -664,6 +716,7 @@ def _cache_message( return affected_labels def stop(self) -> None: + """Stop the state manager and unregister all subscriptions.""" if not self.started: return if self.connector is not None: @@ -676,19 +729,31 @@ def stop(self) -> None: def _update_aggregated_state( self, msg_obj: MessageObject, device: str, source: SignalSource, **_kwargs ) -> None: + """Update the aggregated state based on a new device message.""" try: msg: messages.DeviceMessage = msg_obj.value # type: ignore ; we know it's a DeviceMessage affected_labels = self._cache_message(device, source, msg) if affected_labels: - msg = self.evaluate(affected_labels=affected_labels) - if msg is not None: - self._emit_state(msg) + state_msg = self.evaluate(affected_labels=affected_labels) + if state_msg is not None: + self._emit_state(state_msg) except Exception as exc: self._handle_state_exception(exc) def evaluate( self, affected_labels: set[str] | None = None ) -> messages.BeamlineStateMessage | None: + """ + Evaluate the current state based on the cached signal values and return a BeamlineStateMessage. + + Args: + affected_labels (set[str] | None): The set of state labels that are affected by + the latest signal update. If None, all states will be evaluated. + + Returns: + messages.BeamlineStateMessage | None: The resulting state message after evaluation, or None + if no state could be evaluated. + """ if affected_labels is None: return None # We need to always extend the affected labels with the current labels, @@ -697,7 +762,7 @@ def evaluate( affected_labels.update(self._current_labels) matching_labels = [label for label in affected_labels if self._label_matches(label)] if matching_labels: - self._current_labels = matching_labels + self._current_labels = sorted(matching_labels) state_msg = messages.BeamlineStateMessage( name=self.config.name, status="valid", label="|".join(matching_labels) ) @@ -710,12 +775,14 @@ def evaluate( return state_msg def _label_matches(self, label: str) -> bool: + """Check if the given label matches the current signal values based on the defined requirements.""" requirements = self._requirements_for_label.get(label, []) return bool(requirements) and all( self._requirement_matches(requirement) for requirement in requirements ) def _requirement_matches(self, requirement: ResolvedStateSignal) -> bool: + """Check if the given requirement matches the current signal values.""" key = (requirement.device_name, requirement.source, requirement.signal_name) cached_value = self._signal_value_cache.get(key, None) if cached_value is None: diff --git a/bec_lib/bec_lib/tests/utils.py b/bec_lib/bec_lib/tests/utils.py index afb2de4e6..f193e5c01 100644 --- a/bec_lib/bec_lib/tests/utils.py +++ b/bec_lib/bec_lib/tests/utils.py @@ -430,8 +430,10 @@ def get_device_info_mock(device_name, device_class) -> messages.DeviceInfoMessag return messages.DeviceInfoMessage( device="rt_controller", info=positioner_info_mock_with_user_access(device_name) ) - elif device_name == "samx": - return messages.DeviceInfoMessage(device="samx", info=positioner_info_mock(device_name)) + elif device_name in ["samx", "samy"]: + return messages.DeviceInfoMessage( + device=device_name, info=positioner_info_mock(device_name) + ) elif device_name == "dyn_signals": return DYN_SIGNALS_MSG elif device_name == "eiger": diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index 29ef4ac77..ea9e8e44c 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -508,6 +508,190 @@ def test_aggregated_state_requirement_matches( assert state._requirement_matches(requirement) is matches + def test_device_config_requires_at_least_one_target(self): + with pytest.raises(ValueError, match="At least one of value"): + bl_states.DeviceConfig() + + def test_aggregated_state_endpoint_rejects_unknown_source(self): + with pytest.raises(ValueError, match="Invalid signal source"): + bl_states.AggregatedState._endpoint("samx", "unknown") + + def test_aggregated_state_get_device_manager_falls_back_to_client(self): + state = bl_states.AggregatedState( + name="alignment", states={"label": {"devices": {"samx": {"value": 0}}}} + ) + client = mock.MagicMock() + + with mock.patch("bec_lib.client.BECClient", return_value=client): + assert state._get_device_manager() is client.device_manager + + def test_aggregated_state_get_signal_source_rejects_unsupported_kind(self): + with pytest.raises(ValueError, match="Unsupported kind"): + bl_states.AggregatedState._get_signal_source( + {"kind_str": "omitted", "obj_name": "samx_unused"}, "test" + ) + + def test_aggregated_state_resolve_signal_edge_cases(self, dm_with_devices): + assert bl_states.AggregatedState._resolve_signal( + "samx", "low_limit_travel", dm_with_devices, "test" + ) == ("low", "limits") + assert bl_states.AggregatedState._resolve_signal( + "samx", "high_limit_travel", dm_with_devices, "test" + ) == ("high", "limits") + assert bl_states.AggregatedState._resolve_signal( + "samx", "samx_velocity", dm_with_devices, "test" + ) == ("samx_velocity", "configuration") + + with pytest.raises(ValueError, match="Device 'missing' not found"): + bl_states.AggregatedState._resolve_signal( + "missing", "missing", dm_with_devices, "test" + ) + with pytest.raises(ValueError, match="Device name must be a string"): + bl_states.AggregatedState._resolve_signal(1, "samx", dm_with_devices, "test") + with pytest.raises(ValueError, match="Signal 'missing_signal' not found"): + bl_states.AggregatedState._resolve_signal( + "samx", "missing_signal", dm_with_devices, "test" + ) + with pytest.raises(ValueError, match="Unsupported kind"): + bl_states.AggregatedState._resolve_signal("samx", "unused", dm_with_devices, "test") + + def test_aggregated_state_resolve_dotted_signal_edge_cases(self, dm_with_devices): + assert bl_states.AggregatedState._resolve_signal( + "samx", "samx.velocity", dm_with_devices, "test" + ) == ("samx_velocity", "configuration") + + with pytest.raises(ValueError, match="does not belong"): + bl_states.AggregatedState._resolve_signal( + "samx", "samy.velocity", dm_with_devices, "test" + ) + + devices = mock.MagicMock() + devices.__getitem__.side_effect = [dm_with_devices.devices["samx"], AttributeError] + manager = mock.MagicMock(devices=devices) + with pytest.raises(ValueError, match="Signal 'samx.missing' not found"): + bl_states.AggregatedState._resolve_signal("samx", "samx.missing", manager, "test") + + def test_aggregated_state_start_edge_cases( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.started = True + with mock.patch.object(state, "_build_rules") as build_rules: + state.start() + build_rules.assert_not_called() + + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=None, + device_manager=dm_with_devices, + ) + with pytest.raises(RuntimeError, match="Redis connector is not set"): + state.start() + + def test_aggregated_state_start_handles_rule_build_error( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + with ( + mock.patch.object(state, "_build_rules", side_effect=RuntimeError("bad rules")), + mock.patch.object(state, "_handle_state_exception") as handle_exception, + ): + state.start() + + handle_exception.assert_called_once() + assert state.started is True + + def test_aggregated_state_fill_cache_uses_existing_messages( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + connected_connector.set_and_publish( + MessageEndpoints.device_readback("samx"), + messages.DeviceMessage(signals={"samx": {"value": 0, "timestamp": 1.0}}), + ) + + affected_labels = state._fill_cache() + + assert affected_labels == {"alignment", "measurement"} + assert state._signal_value_cache[("samx", "readback", "samx")] == 0 + + def test_aggregated_state_cache_ignores_irrelevant_signals( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state._build_rules() + + affected_labels = state._cache_message( + "samx", + "readback", + messages.DeviceMessage( + signals={"samx_unused": {"value": 1, "timestamp": 1.0}}, + metadata={"stream": "primary"}, + ), + ) + + assert affected_labels == set() + assert ("samx", "readback", "samx_unused") not in state._signal_value_cache + + def test_aggregated_state_stop_unregisters_subscriptions( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + state.start() + + with mock.patch.object(connected_connector, "unregister") as unregister: + state.stop() + + assert unregister.call_count == len(state._subscriptions) + assert state.started is False + + def test_aggregated_state_stop_is_noop_before_start( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + with mock.patch.object(connected_connector, "unregister") as unregister: + state.stop() + + unregister.assert_not_called() + + def test_aggregated_state_evaluate_without_affected_labels( + self, connected_connector, dm_with_devices, aggregated_state_config + ): + state = bl_states.AggregatedState( + config=aggregated_state_config, + redis_connector=connected_connector, + device_manager=dm_with_devices, + ) + + assert state.evaluate() is None + class TestBeamlineStateManager: def test_manager_registers_for_state_updates(self, connected_connector): diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py index fac356270..f5c2687c6 100644 --- a/bec_server/bec_server/scan_server/scans/state_transition_scan.py +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -24,8 +24,8 @@ from bec_lib.endpoints import MessageEndpoints from bec_lib.logger import bec_logger from bec_lib.messages import AlarmMessage, ErrorInfo -from bec_server.scan_server.scans.scan_modifier import scan_hook from bec_server.scan_server.scans.scan_base import ScanBase +from bec_server.scan_server.scans.scan_modifier import scan_hook if TYPE_CHECKING: from bec_lib.bl_states import AggregatedStateConfig, ResolvedStateSignal @@ -114,20 +114,20 @@ def prepare_scan(self): if req.signal_name == req.device_name: self._devices_to_set.append((dev_obj, req.expected_value)) continue - if req.signal_name in ["low", "high"]: + if req.source == "limits": if req.device_name not in self._limits_to_set: self._limits_to_set[req.device_name] = ( dev_obj, dev_obj.low_limit, dev_obj.high_limit, ) - if req.signal_name == "low_limit": + if req.signal_name == "low": self._limits_to_set[req.device_name] = ( dev_obj, req.expected_value, self._limits_to_set[req.device_name][2], ) - else: + elif req.signal_name == "high": self._limits_to_set[req.device_name] = ( dev_obj, self._limits_to_set[req.device_name][1], @@ -245,6 +245,10 @@ def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDevi available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get_last( MessageEndpoints.available_beamline_states() ) + if available_states_msg is None: + raise ValueError( + "No available beamline states found in Redis. Cannot fetch configuration for state transition scan." + ) configs = [ state for state in available_states_msg["data"].states if state.name == state_name ] From 442d1e2cdf55d93cd37aae5f0e8c64af593b6165 Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 11:07:02 +0200 Subject: [PATCH 09/13] refactor(aggregated-state): allow user parameters to be passed as value --- bec_lib/bec_lib/bl_states.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index fd7491ab8..e97c23d4c 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -788,16 +788,29 @@ def _requirement_matches(self, requirement: ResolvedStateSignal) -> bool: if cached_value is None: return False + expected_value = requirement.expected_value + # If expected value is a user parameter, fetch the lates value from the device manager + if isinstance(expected_value, str) and expected_value.startswith("user_parameter:"): + # In this case, we fetch the latest user_parameter value from the device manager + dev_obj = self._get_device_manager().devices.get(requirement.device_name, None) + if dev_obj is None: + return False + expected_value = dev_obj.user_parameter.get( + expected_value.split("user_parameter:")[1], None + ) + if expected_value is None: + return False + try: # Cast to float to make sure comparison with abs works as expected. value = float(cached_value) - expected_value = float(requirement.expected_value) - return abs(value - expected_value) <= requirement.abs_tolerance + comparison_value = float(expected_value) + return abs(value - comparison_value) <= requirement.abs_tolerance # Catch TypeError and ValueError in case the value is not a number or cannot be cast to float, # in that case we fall back to exact equality. except (TypeError, ValueError): try: - result = cached_value == requirement.expected_value + result = cached_value == expected_value except (TypeError, ValueError): return False # In case this comparison runs on comparing two arrays. From e8eb8ab87716e0c9106510e7f2765dbc0580d31d Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 11:09:01 +0200 Subject: [PATCH 10/13] refactor: fix tests and cleanup --- bec_lib/tests/test_beamline_states.py | 14 +- .../devices/device_serializer.py | 2 + .../scans/state_transition_scan.py | 221 ++++++++++++------ .../scan_server/tests/scan_fixtures.py | 26 ++- .../scans_v4/test_state_transition_scan.py | 163 +++++++++++++ 5 files changed, 352 insertions(+), 74 deletions(-) create mode 100644 bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index ea9e8e44c..a79591bdf 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -230,6 +230,7 @@ def aggregated_state_config(self): }, "test": {"devices": {"bpm4i": {"value": 0, "abs_tol": 0.1}}}, "string_state": {"devices": {"bpm3i": {"value": "ok"}}}, + "state_with_user_param": {"devices": {"samx": {"value": "user_parameter:test"}}}, }, ) @@ -350,6 +351,11 @@ def test_aggregated_state_evaluate( # The order of the labels is not guaranteed assert msg.label in ["alignment|test", "test|alignment"] assert set(state._current_labels) == set(["alignment", "test"]) + dm_with_devices.devices["samx"].user_parameter["test"] = 0 + msg = state.evaluate(affected_labels={"alignment", "state_with_user_param"}) + assert msg.status == "valid" + assert set(msg.label.split("|")) == set(["alignment", "state_with_user_param", "test"]) + assert set(state._current_labels) == set(["alignment", "state_with_user_param", "test"]) state._cache_message( "samx", @@ -543,9 +549,7 @@ def test_aggregated_state_resolve_signal_edge_cases(self, dm_with_devices): ) == ("samx_velocity", "configuration") with pytest.raises(ValueError, match="Device 'missing' not found"): - bl_states.AggregatedState._resolve_signal( - "missing", "missing", dm_with_devices, "test" - ) + bl_states.AggregatedState._resolve_signal("missing", "missing", dm_with_devices, "test") with pytest.raises(ValueError, match="Device name must be a string"): bl_states.AggregatedState._resolve_signal(1, "samx", dm_with_devices, "test") with pytest.raises(ValueError, match="Signal 'missing_signal' not found"): @@ -585,9 +589,7 @@ def test_aggregated_state_start_edge_cases( build_rules.assert_not_called() state = bl_states.AggregatedState( - config=aggregated_state_config, - redis_connector=None, - device_manager=dm_with_devices, + config=aggregated_state_config, redis_connector=None, device_manager=dm_with_devices ) with pytest.raises(RuntimeError, match="Redis connector is not set"): state.start() diff --git a/bec_server/bec_server/device_server/devices/device_serializer.py b/bec_server/bec_server/device_server/devices/device_serializer.py index 63bdc1e4e..1caf254b7 100644 --- a/bec_server/bec_server/device_server/devices/device_serializer.py +++ b/bec_server/bec_server/device_server/devices/device_serializer.py @@ -202,6 +202,8 @@ def get_device_info( "kind_str": signal_obj.kind.name, "doc": doc, "describe": signal_obj.describe().get(signal_obj.name, {}), + "read_access": getattr(signal_obj, "read_access", None), + "write_access": getattr(signal_obj, "write_access", None), # pylint: disable=protected-access "metadata": signal_obj._metadata, "labels": sorted(signal_obj._ophyd_labels_), diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py index f5c2687c6..923dd1922 100644 --- a/bec_server/bec_server/scan_server/scans/state_transition_scan.py +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -74,8 +74,6 @@ def __init__(self, *args, state_name: str, target_label: str, **kwargs): super().__init__(**kwargs) self.state_name = state_name self.target_label = target_label - # Check if the state and the target label exists, if yes, fetch the configuration for the target state - self.config_for_label = self._fetch_config_for_label(state_name, target_label) # We need to sort the devices and signals in the config, and identify which of them are motor setpoint/readback pairs # and which of them are just readouts and thereby can not be set within the transition. @@ -91,66 +89,17 @@ def prepare_scan(self): before the scan is opened, such as preparing the positions (if not done already) or setting up the devices. """ + # Check if the state and the target label exists, if yes, fetch the configuration for the target state + self.config_for_label = self._fetch_config_for_label(self.state_name, self.target_label) requirements: list[ResolvedStateSignal] = AggregatedState.get_state_requirements( self.target_label, self.config_for_label, self.device_manager, "StateTransitionScan" ) - for req in requirements: - dev_obj: DeviceBase = self.device_manager.devices.get(req.device_name) - # Device not found - if dev_obj is None: - raise StateTransitionScanError( - exc_type="DeviceNotFound", - message=f"Device {req.device_name} not found in device manager.", - compact_message=f"Device {req.device_name} not found.", - ) - # First we handle Signals logic - if isinstance(dev_obj, Signal): - self._signals_to_set.append((dev_obj, req.expected_value)) - continue - # Positioner and Device logic. Devices must implement .set for this to work, otherwise we can not set them and we raise an error - if isinstance(dev_obj, DeviceBase): - # Handle motor-specific logic here - # First we handle logic for motions of the motor. Device_name and signal_name will be equivalent here - if req.signal_name == req.device_name: - self._devices_to_set.append((dev_obj, req.expected_value)) - continue - if req.source == "limits": - if req.device_name not in self._limits_to_set: - self._limits_to_set[req.device_name] = ( - dev_obj, - dev_obj.low_limit, - dev_obj.high_limit, - ) - if req.signal_name == "low": - self._limits_to_set[req.device_name] = ( - dev_obj, - req.expected_value, - self._limits_to_set[req.device_name][2], - ) - elif req.signal_name == "high": - self._limits_to_set[req.device_name] = ( - dev_obj, - self._limits_to_set[req.device_name][1], - req.expected_value, - ) - continue - signal_obj = self._get_signal_object(dev_obj, req.signal_name) - if signal_obj is None: - raise StateTransitionScanError( - exc_type="SignalNotFound", - message=f"Signal {req.signal_name} for device {req.device_name} not found in device manager.", - compact_message=f"Signal {req.signal_name} for device {req.device_name} not found.", - ) - self._signals_to_set.append((signal_obj, req.expected_value)) - continue + self._signals_to_set, self._limits_to_set, self._devices_to_set = ( + self._fetch_devices_signals_and_limits_to_set(requirements) + ) self.update_scan_info(scan_report_devices=[dev for dev, _ in self._devices_to_set]) - def _get_signal_object(self, device_obj: DeviceBase, signal_name: str) -> Signal: - for component_name, info in device_obj._info["signals"].items(): - if info["obj_name"] == signal_name: - return getattr(device_obj, component_name) - @scan_hook def open_scan(self): """ @@ -160,6 +109,7 @@ def open_scan(self): prepare_scan() or in open_scan() itself and call self.update_scan_info(...) to update the scan metadata if needed. """ + self.actions.open_scan() @scan_hook def stage(self): @@ -169,6 +119,7 @@ def stage(self): However, if there are any additional steps that need to be executed before staging the devices, they can be implemented here. """ + self.actions.stage_all_devices() @scan_hook def pre_scan(self): @@ -179,6 +130,7 @@ def pre_scan(self): devices, e.g. devices that have a short timeout. The pre-scan logic is typically implemented on the device itself. """ + self.actions.pre_scan_all_devices() @scan_hook def scan_core(self): @@ -186,23 +138,22 @@ def scan_core(self): Core scan logic to be executed during the scan. This is where the main scan logic should be implemented. """ + # Set the signals first... because otherwise there can be an issue with the live updates if + # TODO we set the scan_report_instruction_readback for the motors and one of the signal is also a motor. + self._set_signals() + # Motors motors = [element[0] for element in self._devices_to_set] target_positions = [element[1] for element in self._devices_to_set] current_positions = self.components.get_start_positions(motors) - + # TODO Check how this can be managed in view of the live updates. If we move the signal section further down, + # We get issues with the DeviceProgressBar live updates, and in this ordering, we have an issue that multiple + # Live displays seem to be triggered. This has to be investigated with care. self.actions.add_scan_report_instruction_readback( - devices=motors, - start=current_positions, - stop=target_positions, - request_id=self.scan_info.metadata["RID"], + devices=motors, start=current_positions, stop=target_positions ) - self.components.move_and_wait(motors, target_positions) - # After the move is completed, we set the limits and signals. - for dev_name, (dev_obj, low_limit, high_limit) in self._limits_to_set.items(): - dev_obj.limits = [low_limit, high_limit] - for signal_obj, target_value in self._signals_to_set: - signal_obj.set(target_value).wait() + # Limits + self._set_limits() @scan_hook def at_each_point(self): @@ -220,14 +171,20 @@ def post_scan(self): """ Post-scan steps to be executed after the main scan logic. """ + self.actions.complete_all_devices() @scan_hook def unstage(self): """Unstage the scan by executing post-scan steps.""" + self.actions.unstage_all_devices() @scan_hook def close_scan(self): """Close the scan.""" + if self._baseline_readout_status is not None: + self._baseline_readout_status.wait() + self.actions.close_scan() + self.actions.check_for_unchecked_statuses() @scan_hook def on_exception(self, exception: Exception): @@ -241,6 +198,136 @@ def on_exception(self, exception: Exception): ## Custom Methods ################# + def _set_signals(self): + """Method to set signals in the transition.""" + for signal_obj, target_value in self._signals_to_set: + # Check if signal is writable before setting it, if not skip. + if self._check_if_signal_has_write_access(signal_obj): + signal_obj.set(target_value).wait() + + def _set_limits(self): + """Method to set limits for devices in the transition.""" + for dev_name, (dev_obj, low_limit, high_limit) in self._limits_to_set.items(): + dev_obj.limits = [low_limit, high_limit] + + def _check_if_signal_has_write_access(self, signal_obj: Signal) -> bool: + """ + Check if a signal has write access based on its signal information. The issue here is that signals of a + device follow a slightly different pattern. Therefore, we have to first check "_info" for signals + and if that is empty, check '_signal_info' for sub-signals of devices. + + Args: + signal_obj (Signal): Signal object to check. + Returns: + bool: True if the signal has write access, False otherwise. + """ + write_access = signal_obj._info.get("write_access", None) + if write_access is None: + write_access = signal_obj._signal_info.get("write_access", False) + return write_access + + def _fetch_devices_signals_and_limits_to_set( + self, requirements: list[ResolvedStateSignal] + ) -> Tuple[dict, dict, dict]: + """ + This method fetches the device signals, limits and devices to set based on a list of state requirements. + It returns a tuple containing three elements: + - signals_to_set (list[Tuple[Signal, Any]]): List of signals to set with their target values. + - limits_to_set (dict[str, Tuple[Positioner, float, float]]): Dictionary of devices and their limits to set. + - devices_to_set (list[Tuple[Positioner, float]]): List of devices to set with their target positions. + + Args: + requirements (list[ResolvedStateSignal]): List of state requirements to fetch the device signals and limits for. + + Returns: + Tuple containing: + - signals_to_set (list[Tuple[Signal, Any]]): List of signals to set with their target values. + - limits_to_set (dict[str, Tuple[Positioner, float, float]]): Dictionary of devices and their limits to set. + - devices_to_set (list[Tuple[Positioner, float]]): List of devices to set with their target positions. + """ + _signals_to_set: list[Tuple[Signal, Any]] = [] + _limits_to_set: dict[str, Tuple[Positioner, float, float]] = {} + _devices_to_set: list[Tuple[Positioner, float]] = [] + for req in requirements: + dev_obj: DeviceBase = self.device_manager.devices.get(req.device_name) + # Device not found + if dev_obj is None: + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {req.device_name} not found in device manager.", + compact_message=f"Device {req.device_name} not found.", + ) + expected_value = self._get_expected_value(req) + # First we handle Signals logic + if isinstance(dev_obj, Signal): + _signals_to_set.append((dev_obj, expected_value)) + continue + # Positioner and Device logic. Devices must implement .set for this to work, otherwise we can not set them and we raise an error + if isinstance(dev_obj, DeviceBase): + # Handle motor-specific logic here + # First we handle logic for motions of the motor. Device_name and signal_name will be equivalent here + if req.signal_name == req.device_name: + _devices_to_set.append((dev_obj, expected_value)) + continue + if req.source == "limits": + if req.device_name not in _limits_to_set: + _limits_to_set[req.device_name] = ( + dev_obj, + dev_obj.low_limit, + dev_obj.high_limit, + ) + if req.signal_name == "low": + _limits_to_set[req.device_name] = ( + dev_obj, + expected_value, + _limits_to_set[req.device_name][2], + ) + elif req.signal_name == "high": + _limits_to_set[req.device_name] = ( + dev_obj, + _limits_to_set[req.device_name][1], + expected_value, + ) + continue + signal_obj = self._get_signal_object(dev_obj, req.signal_name) + if signal_obj is None: + raise StateTransitionScanError( + exc_type="SignalNotFound", + message=f"Signal {req.signal_name} for device {req.device_name} not found in device manager.", + compact_message=f"Signal {req.signal_name} for device {req.device_name} not found.", + ) + _signals_to_set.append((signal_obj, expected_value)) + continue + # Return the collected signals, limits and devices to set + return _signals_to_set, _limits_to_set, _devices_to_set + + def _get_expected_value(self, requirement: ResolvedStateSignal) -> Any: + expected_value = requirement.expected_value + # If expected value is a user parameter, fetch the lates value from the device manager + if isinstance(expected_value, str) and expected_value.startswith("user_parameter:"): + dev_obj = self.device_manager.devices.get(requirement.device_name, None) + if dev_obj is None: + raise StateTransitionScanError( + exc_type="DeviceNotFound", + message=f"Device {requirement.device_name} not found in device manager.", + compact_message=f"Device {requirement.device_name} not found.", + ) + expected_value = dev_obj.user_parameter.get( + expected_value.split("user_parameter:")[1], None + ) + if expected_value is None: + raise StateTransitionScanError( + exc_type="UserParameterNotFound", + message=f"User parameter {expected_value.split('user_parameter:')[1]} for device {requirement.device_name} not found in device manager.", + compact_message=f"User parameter {expected_value.split('user_parameter:')[1]} for device {requirement.device_name} not found.", + ) + return expected_value + + def _get_signal_object(self, device_obj: DeviceBase, signal_name: str) -> Signal: + for component_name, info in device_obj._info["signals"].items(): + if info["obj_name"] == signal_name: + return getattr(device_obj, component_name) + def _fetch_config_for_label(self, state_name: str, target_label: str) -> SubDeviceStateConfig: available_states_msg: AvailableBeamlineStatesMessage = self.redis_connector.get_last( MessageEndpoints.available_beamline_states() diff --git a/bec_server/bec_server/scan_server/tests/scan_fixtures.py b/bec_server/bec_server/scan_server/tests/scan_fixtures.py index e7f3b23e9..2c51f6aea 100644 --- a/bec_server/bec_server/scan_server/tests/scan_fixtures.py +++ b/bec_server/bec_server/scan_server/tests/scan_fixtures.py @@ -123,6 +123,18 @@ def full_name(self): def limits(self): return self._limits + @limits.setter + def limits(self, value): + self._limits = tuple(value) + + @property + def low_limit(self): + return self._limits[0] + + @property + def high_limit(self): + return self._limits[1] + @property def enabled(self): return self._enabled @@ -226,6 +238,18 @@ def full_name(self): def limits(self): return self._limits + @limits.setter + def limits(self, value): + self._limits = tuple(value) + + @property + def low_limit(self): + return self._limits[0] + + @property + def high_limit(self): + return self._limits[1] + @property def enabled(self): return self._enabled @@ -360,6 +384,7 @@ def v4_scan_assembler(readout_priority: ReadoutPriorityContainer, device_manager def _assemble_scan(scan_type, *scan_args, **scan_kwargs): scan_id = scan_kwargs.pop("scan_id", "scan-id-test") + connector = scan_kwargs.pop("connector", None) or ConnectorMock("") try: scan_cls = scan_classes[scan_type] @@ -367,7 +392,6 @@ def _assemble_scan(scan_type, *scan_args, **scan_kwargs): available = ", ".join(sorted(scan_classes)) raise KeyError(f"Unknown scan type '{scan_type}'. Available: {available}") from exc - connector = ConnectorMock("") instruction_handler = InstructionHandler(connector) device_names = sorted( set(_infer_v4_device_names(scan_cls, scan_args, scan_kwargs)) diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py new file mode 100644 index 000000000..67127deeb --- /dev/null +++ b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py @@ -0,0 +1,163 @@ +from unittest import mock + +import pytest +from ophyd_devices.sim.sim_positioner import SimPositioner + +from bec_lib import messages +from bec_lib.device import Positioner, _PermissiveDeviceModel +from bec_lib.endpoints import MessageEndpoints +from bec_server.device_server.devices.device_serializer import get_device_info +from bec_server.scan_server.tests.scan_hook_tests import ( + assert_close_scan_waits_for_baseline_and_closes, + assert_pre_scan_called, + assert_prepare_scan_reads_baseline_devices, + assert_scan_open_called, + assert_stage_all_devices_called, + assert_unstage_all_devices_called, + run_scan_tests, +) + +ACQUIRE_DEFAULT_HOOK_TESTS = [ + ("open_scan", [assert_scan_open_called]), + ("stage", [assert_stage_all_devices_called]), + ("pre_scan", [assert_pre_scan_called]), + ("unstage", [assert_unstage_all_devices_called]), + ("close_scan", [assert_close_scan_waits_for_baseline_and_closes]), +] + + +@pytest.fixture +def state_transition_connector(connected_connector): + connected_connector.xadd( + MessageEndpoints.available_beamline_states(), + { + "data": messages.AvailableBeamlineStatesMessage( + states=[ + messages.BeamlineStateConfig( + name="test", + title="Test state", + state_type="AggregatedState", + parameters={ + "states": { + "alignment": { + "devices": { + "samx": { + "value": 1.5, + "low_limit": {"value": -2}, + "high_limit": {"value": 2}, + "signals": {"velocity": {"value": 0.5}}, + }, + "samy": { + "value": 0.5, + "low_limit": {"value": -1}, + "high_limit": {"value": 1}, + }, + } + } + } + }, + ) + ] + ) + }, + ) + return connected_connector + + +@pytest.fixture +def simulated_samx(device_manager): + # dev_obj = SimPositioner(name="samx") + name = "samx" + dev = SimPositioner(name=name) + config = _PermissiveDeviceModel( + enabled=True, + deviceClass="ophyd_devices.sim.sim_positioner.SimPositioner", + readoutPriority="baseline", + ) + info = get_device_info(dev, connect=True) + dev_man_obj = Positioner( + name=name, info=info, config=config, class_name=config.deviceClass, parent=device_manager + ) + return dev_man_obj + + +@pytest.mark.timeout(20) +@pytest.mark.parametrize(("hook_name", "hook_tests"), ACQUIRE_DEFAULT_HOOK_TESTS) +def test_state_transition_default_hooks( + v4_scan_assembler, state_transition_connector, nth_done_status_mock, hook_name, hook_tests +): + """Test default hooks open_scan, stage, pre_scan, unstage, and close_scan for the StateTransitionScan.""" + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + + run_scan_tests(scan, [(hook_name, hook_tests)], nth_done_status_mock=nth_done_status_mock) + + +@pytest.mark.timeout(20) +def test_state_transition_prepare_scan( + v4_scan_assembler, state_transition_connector, device_manager, simulated_samx +): + """Test prepare scan hook for the StateTransitionScan.""" + device_manager.add_device(simulated_samx, replace=True) + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + + scan.prepare_scan() + + devices_to_set = {(device.name, value) for device, value in scan._devices_to_set} + limits_to_set = { + device_name: (device.name, low_limit, high_limit) + for device_name, (device, low_limit, high_limit) in scan._limits_to_set.items() + } + signals_to_set = {(signal.full_name, value) for signal, value in scan._signals_to_set} + + assert devices_to_set == {("samx", 1.5), ("samy", 0.5)} + assert limits_to_set == {"samx": ("samx", -2, 2), "samy": ("samy", -1, 1)} + assert signals_to_set == {("samx_velocity", 0.5)} + + +@pytest.mark.timeout(20) +def test_state_transition_scan_core( + v4_scan_assembler, state_transition_connector, device_manager, simulated_samx +): + device_manager.add_device(simulated_samx, replace=True) + scan = v4_scan_assembler( + "_v4_state_transition", + state_name="test", + target_label="alignment", + connector=state_transition_connector, + ) + scan.prepare_scan() + signal_by_name = {signal.full_name: signal for signal, _ in scan._signals_to_set} + velocity_set_status = mock.MagicMock() + signal_by_name["samx_velocity"].set = mock.MagicMock(return_value=velocity_set_status) + scan.components.get_start_positions = mock.MagicMock(return_value=[0, 0]) + with ( + mock.patch.object( + scan.components, "get_start_positions", return_value=[0, 0] + ) as mock_get_start_positions, + mock.patch.object(scan.components, "move_and_wait") as mock_move_and_wait, + mock.patch.object( + scan.actions, "add_scan_report_instruction_readback" + ) as mock_add_scan_report_instruction_readback, + mock.patch.object(scan, "_set_limits") as mock_set_limits, + ): + scan.scan_core() + mock_get_start_positions.assert_called_once() + mock_add_scan_report_instruction_readback.assert_called_once_with( + devices=[scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], + start=[0, 0], + stop=[1.5, 0.5], + ) + mock_move_and_wait.assert_called_once_with( + [scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], [1.5, 0.5] + ) + mock_set_limits.assert_called_once() From 83bf44cdb28bc5d52a5209548a2c8d9ae3ce4c22 Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 14:55:43 +0200 Subject: [PATCH 11/13] fix: deactivate live updates for now --- .../scan_server/scans/state_transition_scan.py | 6 +++--- .../scans_v4/test_state_transition_scan.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bec_server/bec_server/scan_server/scans/state_transition_scan.py b/bec_server/bec_server/scan_server/scans/state_transition_scan.py index 923dd1922..08af25fa9 100644 --- a/bec_server/bec_server/scan_server/scans/state_transition_scan.py +++ b/bec_server/bec_server/scan_server/scans/state_transition_scan.py @@ -148,9 +148,9 @@ def scan_core(self): # TODO Check how this can be managed in view of the live updates. If we move the signal section further down, # We get issues with the DeviceProgressBar live updates, and in this ordering, we have an issue that multiple # Live displays seem to be triggered. This has to be investigated with care. - self.actions.add_scan_report_instruction_readback( - devices=motors, start=current_positions, stop=target_positions - ) + # self.actions.add_scan_report_instruction_readback( + # devices=motors, start=current_positions, stop=target_positions + # ) self.components.move_and_wait(motors, target_positions) # Limits self._set_limits() diff --git a/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py index 67127deeb..39bdb40e4 100644 --- a/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py +++ b/bec_server/tests/tests_scan_server/scans_v4/test_state_transition_scan.py @@ -152,11 +152,11 @@ def test_state_transition_scan_core( ): scan.scan_core() mock_get_start_positions.assert_called_once() - mock_add_scan_report_instruction_readback.assert_called_once_with( - devices=[scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], - start=[0, 0], - stop=[1.5, 0.5], - ) + # mock_add_scan_report_instruction_readback.assert_called_once_with( + # devices=[scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], + # start=[0, 0], + # stop=[1.5, 0.5], + # ) mock_move_and_wait.assert_called_once_with( [scan.device_manager.devices["samx"], scan.device_manager.devices["samy"]], [1.5, 0.5] ) From 2a99e7a6bca4ac7941b2fa66b799f64b93db79cd Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 14:57:53 +0200 Subject: [PATCH 12/13] refactor(bl_states): adjust ClassVar[str] typehint to be str|None --- bec_lib/bec_lib/bl_states.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bec_lib/bec_lib/bl_states.py b/bec_lib/bec_lib/bl_states.py index e97c23d4c..d7b61d54f 100644 --- a/bec_lib/bec_lib/bl_states.py +++ b/bec_lib/bec_lib/bl_states.py @@ -7,7 +7,7 @@ import traceback from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, ClassVar, Generic, Literal, Type, TypeVar, cast +from typing import Any, Callable, Generic, Literal, Type, TypeVar, cast import yaml from pydantic import BaseModel, field_validator, model_validator @@ -58,7 +58,7 @@ class BeamlineStateConfig(BaseModel): Base Configuration for a beamline state. """ - state_type: ClassVar[str] = "BeamlineState" + state_type: str | None = "BeamlineState" name: str title: str | None = None @@ -85,7 +85,7 @@ class DeviceStateConfig(BeamlineStateConfig): Configuration for a device-based beamline state. """ - state_type: ClassVar[str] = "DeviceBeamlineState" + state_type: str | None = "DeviceBeamlineState" device: DeviceBase | str signal: DeviceBase | str | None = None @@ -118,7 +118,7 @@ class DeviceWithinLimitsStateConfig(DeviceStateConfig): Configuration for a device within limits beamline state. """ - state_type: ClassVar[str] = "DeviceWithinLimitsState" + state_type: str | None = "DeviceWithinLimitsState" low_limit: float | None = None high_limit: float | None = None @@ -175,7 +175,7 @@ class AggregatedStateConfig(BeamlineStateConfig): Keys of the states dictionary are the labels of the different states. """ - state_type: ClassVar[str] = "AggregatedState" + state_type: str | None = "AggregatedState" states: dict[str, SubDeviceStateConfig] From 95b9f4f737fb2120f9ab1bcccf1efe1277a2933c Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 19 May 2026 15:19:13 +0200 Subject: [PATCH 13/13] test(aggregated-state): fix tests after adjusted config. --- bec_lib/tests/test_beamline_states.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bec_lib/tests/test_beamline_states.py b/bec_lib/tests/test_beamline_states.py index a79591bdf..2a930bc8a 100644 --- a/bec_lib/tests/test_beamline_states.py +++ b/bec_lib/tests/test_beamline_states.py @@ -291,7 +291,9 @@ def test_aggregated_state_evaluation( value=msg_with_2_states, topic=MessageEndpoints.device_readback("samx").endpoint ) state._update_aggregated_state(msg_obj, device="samx", source="readback") - evaluate.assert_called_once_with(affected_labels=set(["alignment", "measurement"])) + evaluate.assert_called_once_with( + affected_labels=set(["state_with_user_param", "alignment", "measurement"]) + ) emit_state.assert_not_called() # As evaluate is mocked to return None, _emit_state should not be called def test_aggregated_state_evaluate( @@ -413,7 +415,9 @@ def test_aggregated_state_exception_handling( ): state._update_aggregated_state(msg_obj, device="samx", source="readback") - evaluate.assert_called_once_with(affected_labels={"alignment", "measurement"}) + evaluate.assert_called_once_with( + affected_labels={"state_with_user_param", "alignment", "measurement"} + ) raise_alarm.assert_called_once() out = connected_connector.xread( MessageEndpoints.beamline_state("alignment"), from_start=True @@ -628,7 +632,7 @@ def test_aggregated_state_fill_cache_uses_existing_messages( affected_labels = state._fill_cache() - assert affected_labels == {"alignment", "measurement"} + assert affected_labels == {"alignment", "measurement", "state_with_user_param"} assert state._signal_value_cache[("samx", "readback", "samx")] == 0 def test_aggregated_state_cache_ignores_irrelevant_signals(