Skip to content

Commit

Permalink
Add a Watchable AsyncStatus and extend the wrap decorator (#176)
Browse files Browse the repository at this point in the history
* Re: issues (#117) (#45) 
* Adds a WatchableAsyncStatus which wraps an AsyncIterator
* Lets AsyncStatus.wrap and WatchableAsyncStatus.wrap decorate any function which
  returns the right type
* Updates motors, flyers etc. to match
* Tests the above

---------

Co-authored-by: Tom C (DLS) <101418278+coretl@users.noreply.github.com>
  • Loading branch information
dperl-dls and coretl committed May 17, 2024
1 parent 4c4da2c commit 1c0e20e
Show file tree
Hide file tree
Showing 15 changed files with 646 additions and 219 deletions.
3 changes: 2 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ShapeProvider,
StaticDirectoryProvider,
)
from .async_status import AsyncStatus
from .async_status import AsyncStatus, WatchableAsyncStatus
from .detector import (
DetectorControl,
DetectorTrigger,
Expand Down Expand Up @@ -96,6 +96,7 @@
"set_mock_value",
"wait_for_value",
"AsyncStatus",
"WatchableAsyncStatus",
"DirectoryInfo",
"DirectoryProvider",
"NameProvider",
Expand Down
131 changes: 96 additions & 35 deletions src/ophyd_async/core/async_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,42 @@

import asyncio
import functools
from typing import Awaitable, Callable, Coroutine, List, Optional, cast
import time
from dataclasses import asdict, replace
from typing import (
AsyncIterator,
Awaitable,
Callable,
Generic,
SupportsFloat,
Type,
TypeVar,
cast,
)

from bluesky.protocols import Status

from .utils import Callback, T
from ..protocols import Watcher
from .utils import Callback, P, T, WatcherUpdate

AS = TypeVar("AS", bound="AsyncStatus")
WAS = TypeVar("WAS", bound="WatchableAsyncStatus")

class AsyncStatus(Status):

class AsyncStatusBase(Status):
"""Convert asyncio awaitable to bluesky Status interface"""

def __init__(
self,
awaitable: Awaitable,
watchers: Optional[List[Callable]] = None,
):
def __init__(self, awaitable: Awaitable, timeout: SupportsFloat | None = None):
if isinstance(timeout, SupportsFloat):
timeout = float(timeout)
if isinstance(awaitable, asyncio.Task):
self.task = awaitable
else:
self.task = asyncio.create_task(awaitable) # type: ignore

self.task = asyncio.create_task(
asyncio.wait_for(awaitable, timeout=timeout)
)
self.task.add_done_callback(self._run_callbacks)

self._callbacks = cast(List[Callback[Status]], [])
self._watchers = watchers
self._callbacks: list[Callback[Status]] = []

def __await__(self):
return self.task.__await__()
Expand All @@ -41,15 +53,11 @@ def _run_callbacks(self, task: asyncio.Task):
for callback in self._callbacks:
callback(self)

# TODO: remove ignore and bump min version when bluesky v1.12.0 is released
def exception( # type: ignore
self, timeout: Optional[float] = 0.0
) -> Optional[BaseException]:
def exception(self, timeout: float | None = 0.0) -> BaseException | None:
if timeout != 0.0:
raise Exception(
raise ValueError(
"cannot honour any timeout other than 0 in an asynchronous function"
)

if self.task.done():
try:
return self.task.exception()
Expand All @@ -69,22 +77,6 @@ def success(self) -> bool:
and self.task.exception() is None
)

def watch(self, watcher: Callable):
"""Add watcher to the list of interested parties.
Arguments as per Bluesky :external+bluesky:meth:`watch` protocol.
"""
if self._watchers is not None:
self._watchers.append(watcher)

@classmethod
def wrap(cls, f: Callable[[T], Coroutine]) -> Callable[[T], "AsyncStatus"]:
@functools.wraps(f)
def wrap_f(self) -> AsyncStatus:
return AsyncStatus(f(self))

return wrap_f

def __repr__(self) -> str:
if self.done:
if e := self.exception():
Expand All @@ -96,3 +88,72 @@ def __repr__(self) -> str:
return f"<{type(self).__name__}, task: {self.task.get_coro()}, {status}>"

__str__ = __repr__


class AsyncStatus(AsyncStatusBase):
@classmethod
def wrap(cls: Type[AS], f: Callable[P, Awaitable]) -> Callable[P, AS]:
@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> AS:
# We can't type this more properly because Concatenate/ParamSpec doesn't
# yet support keywords
# https://peps.python.org/pep-0612/#concatenating-keyword-parameters
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)

# type is actually functools._Wrapped[P, Awaitable, P, AS]
# but functools._Wrapped is not necessarily available
return cast(Callable[P, AS], wrap_f)


class WatchableAsyncStatus(AsyncStatusBase, Generic[T]):
"""Convert AsyncIterator of WatcherUpdates to bluesky Status interface."""

def __init__(
self,
iterator: AsyncIterator[WatcherUpdate[T]],
timeout: SupportsFloat | None = None,
):
self._watchers: list[Watcher] = []
self._start = time.monotonic()
self._last_update: WatcherUpdate[T] | None = None
super().__init__(self._notify_watchers_from(iterator), timeout)

async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]):
async for update in iterator:
self._last_update = (
update
if update.time_elapsed is not None
else replace(update, time_elapsed=time.monotonic() - self._start)
)
for watcher in self._watchers:
self._update_watcher(watcher, self._last_update)

def _update_watcher(self, watcher: Watcher, update: WatcherUpdate[T]):
vals = asdict(
update, dict_factory=lambda d: {k: v for k, v in d if v is not None}
)
watcher(**vals)

def watch(self, watcher: Watcher):
self._watchers.append(watcher)
if self._last_update:
self._update_watcher(watcher, self._last_update)

@classmethod
def wrap(
cls: Type[WAS],
f: Callable[P, AsyncIterator[WatcherUpdate[T]]],
) -> Callable[P, WAS]:
"""Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes
'timeout' as an argument, this must be a float or None, and it
will be propagated to the status."""

@functools.wraps(f)
def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS:
timeout = kwargs.get("timeout")
assert isinstance(timeout, SupportsFloat) or timeout is None
return cls(f(*args, **kwargs), timeout=timeout)

return cast(Callable[P, WAS], wrap_f)
50 changes: 22 additions & 28 deletions src/ophyd_async/core/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@

from ophyd_async.protocols import AsyncConfigurable, AsyncReadable

from .async_status import AsyncStatus
from .async_status import AsyncStatus, WatchableAsyncStatus
from .device import Device
from .utils import DEFAULT_TIMEOUT, merge_gathered_dicts
from .utils import DEFAULT_TIMEOUT, WatcherUpdate, merge_gathered_dicts

T = TypeVar("T")

Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(
self._trigger_info: Optional[TriggerInfo] = None
# For kickoff
self._watchers: List[Callable] = []
self._fly_status: Optional[AsyncStatus] = None
self._fly_status: Optional[WatchableAsyncStatus] = None
self._fly_start: float

self._intial_frame: int
Expand Down Expand Up @@ -292,43 +292,37 @@ async def _prepare(self, value: T) -> None:
f"Detector {self.controller} needs at least {required}s deadtime, "
f"but trigger logic provides only {self._trigger_info.deadtime}s"
)

self._arm_status = await self.controller.arm(
num=self._trigger_info.num,
trigger=self._trigger_info.trigger,
exposure=self._trigger_info.livetime,
)

@AsyncStatus.wrap
async def kickoff(self) -> None:
self._fly_status = AsyncStatus(self._fly(), self._watchers)
self._fly_start = time.monotonic()

async def _fly(self) -> None:
await self._observe_writer_indicies(self._last_frame)

async def _observe_writer_indicies(self, end_observation: int):
@AsyncStatus.wrap
async def kickoff(self):
if not self._arm_status:
raise Exception("Detector not armed!")

@WatchableAsyncStatus.wrap
async def complete(self):
assert self._arm_status, "Prepare not run"
assert self._trigger_info
async for index in self.writer.observe_indices_written(
self._frame_writing_timeout
):
for watcher in self._watchers:
watcher(
name=self.name,
current=index,
initial=self._initial_frame,
target=end_observation,
unit="",
precision=0,
time_elapsed=time.monotonic() - self._fly_start,
)
if index >= end_observation:
yield WatcherUpdate(
name=self.name,
current=index,
initial=self._initial_frame,
target=self._trigger_info.num,
unit="",
precision=0,
time_elapsed=time.monotonic() - self._fly_start,
)
if index >= self._trigger_info.num:
break

@AsyncStatus.wrap
async def complete(self) -> AsyncStatus:
assert self._fly_status, "Kickoff not run"
return await self._fly_status

async def describe_collect(self) -> Dict[str, DataKey]:
return self._describe

Expand Down
21 changes: 18 additions & 3 deletions src/ophyd_async/core/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Location,
Movable,
Reading,
Status,
Subscribable,
)

Expand Down Expand Up @@ -390,14 +391,18 @@ def assert_emitted(docs: Mapping[str, list[dict]], **numbers: int):
)


async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, None]:
async def observe_value(
signal: SignalR[T], timeout=None, done_status: Status | None = None
) -> AsyncGenerator[T, None]:
"""Subscribe to the value of a signal so it can be iterated from.
Parameters
----------
signal:
Call subscribe_value on this at the start, and clear_sub on it at the
end
done_status:
If this status is complete, stop observing and make the iterator return.
Notes
-----
Expand All @@ -406,18 +411,28 @@ async def observe_value(signal: SignalR[T], timeout=None) -> AsyncGenerator[T, N
async for value in observe_value(sig):
do_something_with(value)
"""
q: asyncio.Queue[T] = asyncio.Queue()

class StatusIsDone: ...

q: asyncio.Queue[T | StatusIsDone] = asyncio.Queue()
if timeout is None:
get_value = q.get
else:

async def get_value():
return await asyncio.wait_for(q.get(), timeout)

if done_status is not None:
done_status.add_callback(lambda _: q.put_nowait(StatusIsDone()))

signal.subscribe_value(q.put_nowait)
try:
while True:
yield await get_value()
item = await get_value()
if not isinstance(item, StatusIsDone):
yield item
else:
break
finally:
signal.clear_sub(q.put_nowait)

Expand Down
19 changes: 19 additions & 0 deletions src/ophyd_async/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import asyncio
import logging
from dataclasses import dataclass
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
ParamSpec,
Type,
TypeVar,
Union,
Expand All @@ -18,6 +21,7 @@
from bluesky.protocols import Reading

T = TypeVar("T")
P = ParamSpec("P")
Callback = Callable[[T], None]

#: A function that will be called with the Reading and value when the
Expand Down Expand Up @@ -77,6 +81,21 @@ def __str__(self) -> str:
return self.format_error_string(indent="")


@dataclass(frozen=True)
class WatcherUpdate(Generic[T]):
"""A dataclass such that, when expanded, it provides the kwargs for a watcher"""

current: T
initial: T
target: T
name: str | None = None
unit: str | None = None
precision: float | None = None
fraction: float | None = None
time_elapsed: float | None = None
time_remaining: float | None = None


async def wait_for_connection(**coros: Awaitable[None]):
"""Call many underlying signals, accumulating exceptions and returning them
Expand Down
Loading

0 comments on commit 1c0e20e

Please sign in to comment.