Skip to content

Commit

Permalink
Refactor driver interface + warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterKraus committed Jul 15, 2024
1 parent 6f23f82 commit af20439
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 94 deletions.
4 changes: 2 additions & 2 deletions src/tomato/daemon/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ def action_queued_jobs(daemon, matched, req):

jpath = root / "jobdata.json"
jobargs = {
"pipeline": pip.dict(),
"pipeline": pip.model_dump(),
"payload": job.payload.model_dump(),
"devices": {dname: dev.dict() for dname, dev in daemon.devs.items()},
"devices": {dn: dev.model_dump() for dn, dev in daemon.devs.items()},
"job": dict(id=job.id, path=str(root)),
}

Expand Down
209 changes: 121 additions & 88 deletions src/tomato/driverinterface_1_0/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta, abstractmethod
from typing import TypeVar, Any, Literal
from typing import TypeVar, Any
from pydantic import BaseModel
from threading import Thread, currentThread
from queue import Queue
Expand All @@ -8,7 +8,8 @@
import logging
from functools import wraps
from xarray import Dataset

from collections import defaultdict
import time

logger = logging.getLogger(__name__)

Expand All @@ -26,20 +27,23 @@ def wrapper(self, **kwargs):
return wrapper


class ModelInterface(metaclass=ABCMeta):
version: Literal = "1.0"
T = TypeVar("T")


class Attr(BaseModel):
"""Class used to describe device attributes."""

type: T
rw: bool = False
status: bool = False

class Attr(BaseModel):
"""Class used to describe device attributes."""

type: TypeVar("T")
rw: bool = False
status: bool = False
class DriverInterface(metaclass=ABCMeta):
version: str = "1.0"

class DeviceInterface(metaclass=ABCMeta):
driver: object
data: list
status: dict
data: dict[str, list]
key: tuple
thread: Thread
task_list: Queue
Expand All @@ -49,38 +53,82 @@ def __init__(self, driver, key, **kwargs):
self.driver = driver
self.key = key
self.task_list = Queue()
self.thread = Thread(target=self._worker_wrapper, daemon=True)
self.data = []
self.status = {}
self.thread = Thread(target=self.task_runner, daemon=True)
self.data = defaultdict(list)
self.running = False

def run(self):
self.thread.do_run = True
self.thread.start()
self.running = True

def _worker_wrapper(self):
def task_runner(self):
thread = currentThread()
task = self.task_list.get()

self.task_runner(task, thread)
task: Task = self.task_list.get()
self.prepare_task(task)
t0 = time.perf_counter()
tD = t0
self.data = defaultdict(list)
while getattr(thread, "do_run"):
tN = time.perf_counter()
if tN - tD > task.sampling_interval:
self.do_task(task, t0=t0, tN=tN, tD=tD)
tD += task.sampling_interval
if tN - t0 > task.max_duration:
break
time.sleep(max(1e-2, task.sampling_interval / 10))

self.task_list.task_done()
self.running = False
self.thread = Thread(target=self._worker_wrapper, daemon=True)
self.thread = Thread(target=self.task_runner, daemon=True)

def prepare_task(self, task: Task, **kwargs: dict):
for k, v in task.technique_params.items():
self.set_attr(attr=k, val=v)

@abstractmethod
def do_task(self, task: Task, **kwargs: dict):
pass

def stop_task(self, **kwargs: dict):
setattr(self.thread, "do_run", False)

@abstractmethod
def set_attr(self, attr: str, val: Any, **kwargs: dict):
pass

@abstractmethod
def task_runner(task: Task, thread: Thread):
def get_attr(self, attr: str, **kwargs: dict) -> Any:
pass

def get_data(self, **kwargs: dict) -> dict[str, list]:
ret = self.data
self.data = defaultdict(list)
return ret

@abstractmethod
def attrs(**kwargs) -> dict:
pass

@abstractmethod
def tasks(**kwargs) -> set:
pass

def status(self, **kwargs) -> dict:
status = {}
for attr, props in self.attrs().items():
if props.status:
status[attr] = self.get_attr(attr)
return status

def CreateDeviceInterface(self, key, **kwargs):
"""Factory function which passes DriverInterface to the DeviceInterface"""
return self.DeviceInterface(self, key, **kwargs)

devmap: dict[tuple, DeviceInterface]
"""Map of registered devices, the tuple keys are components = (address, channel)"""

settings: dict[str, str]
settings: dict[str, Any]
"""A settings map to contain driver-specific settings such as `dllpath` for BioLogic"""

def __init__(self, settings=None):
Expand All @@ -94,7 +142,7 @@ def dev_register(self, address: str, channel: int, **kwargs: dict) -> None:
updating existing channels in :obj:`self.devmap`.
"""
key = (address, channel)
self.devmap[(address, channel)] = self.CreateDeviceInterface(key, **kwargs)
self.devmap[key] = self.CreateDeviceInterface(key, **kwargs)

def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None:
"""
Expand All @@ -105,44 +153,42 @@ def dev_teardown(self, address: str, channel: int, **kwargs: dict) -> None:
"""
pass

@abstractmethod
def attrs(self, address: str, channel: int, **kwargs) -> dict[str, Attr]:
"""
Function that returns all gettable and settable attributes, their rw status,
and whether they are to be returned in :func:`self.dev_status`. All attrs are
returned by :func:`self.dev_get_data`.
This is the "low level" control interface, intended for the device dashboard.
Example:
::
return dict(
delay = self.Attr(type=float, rw=True, status=False),
time = self.Attr(type=float, rw=True, status=False),
started = self.Attr(type=bool, rw=True, status=True),
val = self.Attr(type=int, rw=False, status=True),
)
"""
pass
@in_devmap
def attrs(self, address: str, channel: int, **kwargs) -> Reply | None:
key = (address, channel)
ret = self.devmap[key].attrs(**kwargs)
return Reply(
success=True,
msg=f"attrs of component {key} are: {ret}",
data=ret,
)

@in_devmap
def dev_set_attr(self, attr: str, val: Any, address: str, channel: int, **kwargs):
def dev_set_attr(
self, attr: str, val: Any, address: str, channel: int, **kwargs
) -> Reply | None:
key = (address, channel)
if attr in self.attrs():
params = self.attrs()[attr]
if params.rw and isinstance(val, params.type):
self.devmap[key].status[attr] = val
self.devmap[key].set_attr(attr=attr, val=val, **kwargs)
return Reply(
success=True,
msg=f"attr {attr!r} of component {key} set to {val}",
data=val,
)

@in_devmap
def dev_get_attr(self, attr: str, address: str, channel: int, **kwargs):
def dev_get_attr(
self, attr: str, address: str, channel: int, **kwargs
) -> Reply | None:
key = (address, channel)
if attr in self.attrs(address=address, channel=channel):
return self.devmap[key].status[attr]
ret = self.devmap[key].get_attr(attr=attr, **kwargs)
return Reply(
success=True,
msg=f"attr {attr!r} of component {key} is: {ret}",
data=ret,
)

@in_devmap
def dev_status(self, address: str, channel: int, **kwargs):
def dev_status(self, address: str, channel: int, **kwargs) -> Reply | None:
key = (address, channel)
running = self.devmap[key].running
return Reply(
Expand All @@ -152,12 +198,15 @@ def dev_status(self, address: str, channel: int, **kwargs):
)

@in_devmap
def task_start(self, address: str, channel: int, task: Task, **kwargs):
if task.technique_name not in self.tasks(address=address, channel=channel):
def task_start(
self, address: str, channel: int, task: Task, **kwargs
) -> Reply | None:
key = (address, channel)
if task.technique_name not in self.devmap[key].tasks(**kwargs):
return Reply(
success=False,
msg=f"unknown task {task.technique_name!r} requested",
data=self.tasks(),
data=self.tasks(address=address, channel=channel),
)

key = (address, channel)
Expand All @@ -179,8 +228,11 @@ def task_status(self, address: str, channel: int):
return Reply(success=True, msg="running")

@in_devmap
def task_stop(self, address: str, channel: int):
self.dev_set_attr(attr="started", val=False, address=address, channel=channel)
def task_stop(self, address: str, channel: int, **kwargs) -> Reply | None:
key = (address, channel)
ret = self.devmap[key].stop_task(**kwargs)
if ret is not None:
return Reply(success=False, msg="failed to stop task", data=ret)

ret = self.task_data(self, address, channel)
if ret.success:
Expand All @@ -189,27 +241,16 @@ def task_stop(self, address: str, channel: int):
return Reply(success=True, msg=f"task stopped, {ret.msg}")

@in_devmap
def task_data(self, address: str, channel: int, **kwargs):
def task_data(self, address: str, channel: int, **kwargs) -> Reply | None:
key = (address, channel)
data = self.devmap[key].data
self.devmap[key].data = []
data = self.devmap[key].get_data(**kwargs)

if len(data) == 0:
return Reply(success=False, msg="found no new datapoints")

data_vars = {}
for ii, item in enumerate(data):
for k, v in item.items():
if k not in data_vars:
data_vars[k] = [None] * ii
data_vars[k].append(v)
for k in data_vars:
if k not in item:
data_vars[k].append(None)

uts = {"uts": data_vars.pop("uts")}
data_vars = {k: ("uts", v) for k, v in data_vars.items()}
ds = Dataset(data_vars=data_vars, coords=uts)
uts = {"uts": data.pop("uts")}
data = {k: ("uts", v) for k, v in data.items()}
ds = Dataset(data_vars=data, coords=uts)
return Reply(success=True, msg=f"found {len(data)} new datapoints", data=ds)

def status(self):
Expand Down Expand Up @@ -237,19 +278,11 @@ def dev_get_data(self, address: str, channel: int, **kwargs):
)
return ret

@abstractmethod
def tasks(self, address: str, channel: int, **kwargs) -> dict:
"""
Function that returns all tasks that can be submitted to the Device. This
implements the driver specific language. Each task in tasks can only contain
elements present in :func:`self.attrs`.
Example:
::
return dict(
count = dict(time = dict(type=float), delay = dict(type=float),
)
"""
pass
key = (address, channel)
ret = self.devmap[key].tasks(**kwargs)
return Reply(
success=True,
msg=f"tasks supported by component {key} are: {ret}",
data=ret,
)
4 changes: 2 additions & 2 deletions src/tomato/drivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import logging

from typing import Union
from tomato.driverinterface_1_0 import ModelInterface
from tomato.driverinterface_1_0 import DriverInterface

logger = logging.getLogger(__name__)


def driver_to_interface(drivername: str) -> Union[None, ModelInterface]:
def driver_to_interface(drivername: str) -> Union[None, DriverInterface]:
modname = f"tomato_{drivername.replace('-', '_')}"

try:
Expand Down
4 changes: 2 additions & 2 deletions src/tomato/tomato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def reload(
logger.debug(f"{ret=}")
if ret.success is False:
return ret
params = dev.dict()
params = dev.model_dump()
ret = _updater(context, port, "device", params)
if ret.success is False:
return ret
Expand All @@ -408,7 +408,7 @@ def reload(
for pip in pips.values():
logger.debug(f"{pip=}")
if pip.name not in daemon.pips:
ret = _updater(context, port, "pipeline", pip.dict())
ret = _updater(context, port, "pipeline", pip.model_dump())
if ret.success is False:
return ret
else:
Expand Down

0 comments on commit af20439

Please sign in to comment.