diff --git a/custom_components/pyscript/__init__.py b/custom_components/pyscript/__init__.py index 79e5fc5..a30da90 100644 --- a/custom_components/pyscript/__init__.py +++ b/custom_components/pyscript/__init__.py @@ -44,6 +44,7 @@ UNSUB_LISTENERS, WATCHDOG_TASK, ) +from .decorator import DecoratorRegistry from .eval import AstEval from .event import Event from .function import Function @@ -270,6 +271,7 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b Webhook.init(hass) State.register_functions() GlobalContextMgr.init() + DecoratorRegistry.init(hass, config_entry) pyscript_folder = hass.config.path(FOLDER) if not await hass.async_add_executor_job(os.path.isdir, pyscript_folder): diff --git a/custom_components/pyscript/decorator.py b/custom_components/pyscript/decorator.py new file mode 100644 index 0000000..7273c63 --- /dev/null +++ b/custom_components/pyscript/decorator.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import ast +import asyncio +import logging +import os +import weakref +from typing import Type, ClassVar, Any, TypeVar + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, Context + +from .decorator_abc import ( + Decorator, + DecoratorManager, + DispatchData, + DecoratorManagerStatus, + TriggerHandlerDecorator, + CallHandlerDecorator, + TriggerDecorator, + CallResultHandlerDecorator, +) +from .eval import AstEval, EvalFunc, EvalFuncVar +from .function import Function +from .state import State + +_LOGGER = logging.getLogger(__name__) + + +class DecoratorRegistry: + """Decorator registry.""" + + _decorators: dict[str, Type[Decorator]] # decorator name to class + hass: ClassVar[HomeAssistant] + prefix: ClassVar[str] = "e" + + @classmethod + def init(cls, hass: HomeAssistant, config_entry: ConfigEntry = None) -> None: + """Initialize the decorator registry.""" + cls.hass = hass + cls._decorators = {} + enabled = False + if "PYTEST_CURRENT_TEST" in os.environ: + enabled = "NODM" not in os.environ + elif config_entry is not None and config_entry.data.get("dm", False): + enabled = True + + if enabled: + cls.prefix = "" + space = "\n" + " " * 35 + border = space + "=" * 35 + _LOGGER.warning(border + space + "DecoratorManager enabled by default" + border) + else: + cls.prefix = "e" + + DecoratorManager.hass = hass + + Function.register_ast({cls.prefix + "task.wait_until": DecoratorRegistry.wait_until_factory}) + + from .decorators import DECORATORS + + for dec_type in DECORATORS: + cls.register(dec_type) + + @classmethod + def register(cls, dec_type: Type[Decorator]): + """Register a decorator.""" + if not dec_type.name: + raise TypeError(f"Decorator name is required {dec_type}") + + name = cls.prefix + dec_type.name + _LOGGER.debug("Registering decorator @%s %s", name, dec_type) + if name in cls._decorators: + _LOGGER.warning("Overriding decorator: %s %s with %s", name, cls._decorators[name], dec_type) + cls._decorators[name] = dec_type + + @classmethod + async def get_decorator_by_expr(cls, ast_ctx: AstEval, dec_expr: ast.expr) -> Decorator | None: + """Return decorator instance from an AST decorator expression.""" + dec_name = None + has_args = False + + if isinstance(dec_expr, ast.Name): # decorator without () + dec_name = dec_expr.id + elif isinstance(dec_expr, ast.Call) and isinstance(dec_expr.func, ast.Name): + dec_name = dec_expr.func.id + has_args = True + + if know_decorator := cls._decorators.get(dec_name): + if has_args: + args = await ast_ctx.eval_elt_list(dec_expr.args) + kwargs = {keyw.arg: await ast_ctx.aeval(keyw.value) for keyw in dec_expr.keywords} + else: + args = [] + kwargs = {} + + decorator = know_decorator(args, kwargs, dec_expr.lineno, dec_expr.col_offset) + return decorator + + return None + + @classmethod + async def wait_until(cls, ast_ctx: AstEval, *arg, **kwargs): + """Build a temporary decorator manager that waits until one of trigger decorators fires.""" + func_args = set(kwargs.keys()) + if len(func_args) == 0: + return {"trigger_type": "none"} + + found_args = set() + dm = WaitUntilDecoratorManager(ast_ctx, **kwargs) + + found_args.add("timeout") + found_args.add("__test_handshake__") + + prefix_len = len(DecoratorRegistry.prefix) + for dec_name, dec_class in cls._decorators.items(): + if not issubclass(dec_class, TriggerDecorator): + continue + if prefix_len > 0: + dec_name = dec_name[prefix_len:] + if dec_name not in func_args: + continue + + dec_args = kwargs[dec_name] + if not isinstance(dec_args, list): + dec_args = [dec_args] + found_args.add(dec_name) + + dec_kwargs = {} + func_args.remove(dec_name) + kwargs_schema_keys = dec_class.kwargs_schema.schema.keys() + for key in kwargs_schema_keys: + if key in kwargs: + dec_kwargs[key] = kwargs[key] + found_args.add(key) + dec = dec_class(dec_args, dec_kwargs, ast_ctx.lineno, ast_ctx.col_offset) + dm.add(dec) + + unknown_args = set(kwargs.keys()).difference(found_args) + if unknown_args: + raise ValueError(f"Unknown arguments: {unknown_args}") + await dm.validate() + + # state_trigger sets __test_handshake__ after the initial checks. + # In some cases, it returns a value before __test_handshake__ is set. + if "state_trigger" not in kwargs: + if test_handshake := kwargs.get("__test_handshake__"): + # + # used for testing to avoid race conditions + # we use this as a handshake that we are about to + # listen to the queue + # + State.set(test_handshake[0], test_handshake[1]) + await dm.start() + + ret = await dm.wait_until() + + return ret + + @classmethod + def wait_until_factory(cls, ast_ctx): + """Return wrapper to call to astFunction with the ast context.""" + + async def wait_until_call(*arg, **kw): + return await cls.wait_until(ast_ctx, *arg, **kw) + + return wait_until_call + + +class WaitUntilDecoratorManager(DecoratorManager): + """Decorator manager for task.wait_until.""" + + def __init__(self, ast_ctx: AstEval, **kwargs: dict[str, Any]) -> None: + super().__init__(ast_ctx, ast_ctx.name) + self.kwargs = kwargs + self._future: asyncio.Future[DispatchData] = self.hass.loop.create_future() + self.timeout_decorator = None + if timeout := kwargs.get("timeout"): + to_dec = DecoratorRegistry._decorators.get(DecoratorRegistry.prefix + "time_trigger") + self.timeout_decorator = to_dec( + [f"once(now + {timeout}s)"], {}, ast_ctx.lineno, ast_ctx.col_offset + ) + self.add(self.timeout_decorator) + + async def dispatch(self, data: DispatchData) -> None: + """Resolve the waiting future on the first incoming dispatch.""" + _LOGGER.debug("task.wait_until dispatch: %s", data) + if self._future.done(): + _LOGGER.debug("task.wait_until future already completed: %s", self._future.exception()) + # ignore another calls + return + await self.stop() + self._future.set_result(data) + + async def wait_until(self) -> dict[str, Any]: + """Wait for dispatch and normalize the return payload.""" + data = await self._future + if data.exception is not None: + raise data.exception + if data.trigger == self.timeout_decorator: + ret = {"trigger_type": "timeout"} + else: + ret = data.func_args + _LOGGER.debug("task.wait_until finish: %s", ret) + return ret + + +DT = TypeVar("DT", bound=Decorator) + + +class FunctionDecoratorManager(DecoratorManager): + """Maintain and validate a set of decorators applied to a function.""" + + def __init__(self, ast_ctx: AstEval, eval_func_var: EvalFuncVar) -> None: + super().__init__(ast_ctx, f"{ast_ctx.get_global_ctx_name()}.{eval_func_var.get_name()}") + self.eval_func: EvalFunc = eval_func_var.func + + self.logger = self.eval_func.logger + + def on_func_var_deleted(): + if self.status is DecoratorManagerStatus.RUNNING: + self.hass.async_create_task(self.stop()) + + weakref.finalize(eval_func_var, on_func_var_deleted) + + async def _call(self, data: DispatchData) -> None: + handlers = self.get_decorators(CallHandlerDecorator) + result_handlers = self.get_decorators(CallResultHandlerDecorator) + + for handler_dec in handlers: + if await handler_dec.handle_call(data) is False: + self.logger.debug("Calling canceled by %s", handler_dec) + # notify handlers with "None" + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, None) + return + # Fire an event indicating that pyscript is running + # Note: the event must have an entity_id for logbook to work correctly. + ev_name = self.name.replace(".", "_") + ev_entity_id = f"pyscript.{ev_name}" + + event_data = {"name": ev_name, "entity_id": ev_entity_id, "func_args": data.func_args} + self.hass.bus.async_fire("pyscript_running", event_data, context=data.hass_context) + # Store HASS Context for this Task + Function.store_hass_context(data.hass_context) + + result = await data.call_ast_ctx.call_func(self.eval_func, None, **data.func_args) + for result_handler_dec in result_handlers: + await result_handler_dec.handle_call_result(data, result) + + if data.call_ast_ctx.get_exception_obj(): + data.call_ast_ctx.get_logger().error(data.call_ast_ctx.get_exception_long()) + + async def dispatch(self, data: DispatchData) -> None: + """Handle a trigger dispatch: run guards, create a context, and invoke the function.""" + _LOGGER.debug("Dispatching for %s: %s", self.name, data) + + if data.exception: + self.logger.error(data.exception_text) + return + + decorators = self.get_decorators(TriggerHandlerDecorator) + for dec in decorators: + if await dec.handle_dispatch(data) is False: + self.logger.debug("Trigger not active due to %s", dec) + return + + action_ast_ctx = AstEval( + f"{self.eval_func.global_ctx_name}.{self.eval_func.name}", self.eval_func.global_ctx + ) + Function.install_ast_funcs(action_ast_ctx) + data.call_ast_ctx = action_ast_ctx + + # Create new HASS Context with incoming as parent + if "context" in data.func_args and isinstance(data.func_args["context"], Context): + data.hass_context = Context(parent_id=data.func_args["context"].id) + else: + data.hass_context = Context() + + self.logger.debug( + "trigger %s got %s trigger, running action (kwargs = %s)", + self.name, + data.trigger, + data.func_args, + ) + + task = Function.create_task(self._call(data), ast_ctx=action_ast_ctx) + Function.task_done_callback_ctx(task, action_ast_ctx) diff --git a/custom_components/pyscript/decorator_abc.py b/custom_components/pyscript/decorator_abc.py new file mode 100644 index 0000000..381f753 --- /dev/null +++ b/custom_components/pyscript/decorator_abc.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from dataclasses import field, dataclass +from enum import StrEnum +from typing import ClassVar, Any, TypeVar, Type, final + +import voluptuous as vol +from homeassistant.core import Context, HomeAssistant + +from . import trigger +from .eval import AstEval + +_LOGGER = logging.getLogger(__name__) + + +def dt_now(): + """Return current time.""" + # For test compatibility. The tests patch this function + return trigger.dt_now() + + +class DecoratorManagerStatus(StrEnum): + """Status of a decorator manager.""" + + INIT = "init" # initial status when created + NO_DECORATORS = "no_decorators" # no decorators found + VALIDATED = "validated" + INVALID = "invalid" + RUNNING = "running" + STOPPED = "stopped" + + +@dataclass() +class DispatchData: + """Data for a dispatch event.""" + + func_args: dict[str, Any] + trigger: TriggerDecorator | None = field(default=None, kw_only=True) + trigger_context: dict[str, Any] = field(default_factory=dict, kw_only=True) + + call_ast_ctx: AstEval | None = field(default=None, kw_only=True) + hass_context: Context | None = field(default=None, kw_only=True) + + # Normally shouldn’t be used. + exception: Exception | None = field(default=None, kw_only=True) + exception_text: str | None = field(default=None, kw_only=True) + + +class Decorator(ABC): + """Generic decorator abstraction.""" + + # Subclasses should override. + name: ClassVar[str] = "" + # without args by default + args_schema: ClassVar[vol.Schema] = vol.Schema([], extra=vol.PREVENT_EXTRA) + # without kwargs by default + kwargs_schema: ClassVar[vol.Schema] = vol.Schema({}, extra=vol.PREVENT_EXTRA) + + # instance attributes + dm: DecoratorManager + raw_args: list[Any] + raw_kwargs: dict[str, Any] + + args: list[Any] + kwargs: dict[str, Any] + + @final + def __init__( + self, raw_args: list[Any], raw_kwargs: dict[str, Any], lineno: int, col_offset: int + ) -> None: + """Initialize the decorator definition.""" + + self.raw_args = raw_args + self.raw_kwargs = raw_kwargs + self.lineno = lineno + self.col_offset = col_offset + + async def validate(self) -> None: + """Validate the arguments.""" + + _LOGGER.debug("Validating %s", self.name) + + try: + self.args = self.args_schema(self.raw_args) + self.kwargs = self.kwargs_schema(self.raw_kwargs) + + except vol.Invalid as err: + # FIXME For test compatibility. Update the message in the future. + if len(err.path) == 1: + if "extra keys not allowed" in err.msg: + message = f"invalid keyword argument '{err.path[0]}'" + else: + message = f"keyword '{err.path[0]}' {err}" + else: + message = str(err) + + type_error = TypeError( + f"function '{self.dm.func_name}' defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"decorator @{self.name} {message}" + ) + raise type_error from err + + async def start(self): + """Start the decorator.""" + + async def stop(self): + """Stop the decorator.""" + + def __repr__(self): + parts = [] + if self.raw_args is not None: + parts.append(",".join(map(str, self.raw_args))) + if self.raw_kwargs is not None: + parts += [f"{k}={v!r}" for k, v in self.raw_kwargs.items()] + return f"@{self.name}({', '.join(parts)})" + + +DecoratorType = TypeVar("DecoratorType", bound=Decorator) + + +class DecoratorManager(ABC): + """Maintain and validate a set of decorators""" + + hass: ClassVar[HomeAssistant] + + def __init__(self, ast_ctx: AstEval, name: str) -> None: + self.ast_ctx = ast_ctx + self.name = name + self.func_name = name.split(".")[-1] + self.logger = ast_ctx.get_logger() + + self.lineno = ast_ctx.lineno + self.col_offset = ast_ctx.col_offset + + self.status: DecoratorManagerStatus = DecoratorManagerStatus.INIT + self.startup_time = None + self._decorators: list[Decorator] = [] + + def update_status(self, new_status: DecoratorManagerStatus) -> None: + """Update the manager status.""" + if self.status is new_status: + return + _LOGGER.debug("DM %s status: %s -> %s", self.name, self.status.value, new_status.value) + self.status = new_status + + if new_status in (DecoratorManagerStatus.STOPPED, DecoratorManagerStatus.INVALID): + del self._decorators[:] + + def add(self, decorator: Decorator) -> None: + """Add a decorator to the manager.""" + _LOGGER.debug("Add %s to %s", decorator, self) + self._decorators.append(decorator) + decorator.dm = self + + def get_decorators(self, decorator_type: Type[DecoratorType] | None = None) -> list[DecoratorType]: + """Get decorators of a specific type.""" + if decorator_type is None: + return self._decorators.copy() + return [dec for dec in self._decorators if isinstance(dec, decorator_type)] + + async def validate(self) -> None: + """Validate all decorators.""" + lineno, col_offset = self.ast_ctx.lineno, self.ast_ctx.col_offset + try: + for decorator in self._decorators: + self.ast_ctx.lineno, self.ast_ctx.col_offset = decorator.lineno, decorator.col_offset + _LOGGER.debug("Validating decorator: %s", decorator) + self.lineno, self.col_offset = decorator.lineno, decorator.col_offset + await decorator.validate() + + if decorator.name == "service": + # FIXME For test compatibility. In the legacy implementation, the service was registered immediately. + await decorator.start() + except Exception: + self.update_status(DecoratorManagerStatus.INVALID) + raise + + self.ast_ctx.lineno, self.ast_ctx.col_offset = lineno, col_offset + + if len(self._decorators) == 0: + self.update_status(DecoratorManagerStatus.NO_DECORATORS) + else: + self.update_status(DecoratorManagerStatus.VALIDATED) + + async def start(self): + """Start all decorators.""" + if self.status is not DecoratorManagerStatus.VALIDATED: + raise RuntimeError(f"Starting not valid {self}") + + self.startup_time = dt_now() + self.update_status(DecoratorManagerStatus.RUNNING) + started = [] + for decorator in self._decorators: + # FIXME For test compatibility. + if decorator.name == "service": + continue + _LOGGER.debug("Starting decorator: %s", decorator) + try: + await decorator.start() + started.append(decorator) + except Exception as err: + self.logger.exception("%s start failed: %s", self, err) + for started_dec in started: + await self._stop_decorator(started_dec) + self.startup_time = None + self.update_status(DecoratorManagerStatus.INVALID) + raise + + async def _stop_decorator(self, decorator: Decorator) -> None: + try: + await decorator.stop() + except Exception as err: + _LOGGER.exception("%s stop failed: %s", self, err) + + async def stop(self): + """Stop all decorators.""" + if self.status is not DecoratorManagerStatus.RUNNING: + _LOGGER.warning("Stopping before starting for %s (status=%s)", self.name, self.status.value) + for dec in self.get_decorators(): + # FIXME For test compatibility. + if dec.name == "service": + await self._stop_decorator(dec) + return + + _LOGGER.debug("Stopping all decorators %s", self) + for decorator in self._decorators: + await self._stop_decorator(decorator) + + self.update_status(DecoratorManagerStatus.STOPPED) + + @abstractmethod + async def dispatch(self, data: DispatchData) -> None: + pass + + def __repr__(self): + return f"{self.__class__.__name__}({self.status}) {self._decorators} for {self.name}()>" + + +class TriggerDecorator(Decorator, ABC): + """Base class for trigger-based decorators.""" + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + # kwargs for all triggers + if "kwargs" not in cls.kwargs_schema.schema.keys(): + cls.kwargs_schema = cls.kwargs_schema.extend( + {vol.Optional("kwargs"): vol.Coerce(dict[str, Any], msg="should be type dict")} + ) + + async def dispatch(self, data: DispatchData): + """Dispatch a trigger call to the function.""" + if not data.trigger: + data.trigger = self + + data.func_args.update(self.kwargs.get("kwargs", {})) + + await self.dm.dispatch(data) + + +class TriggerHandlerDecorator(Decorator, ABC): + """Base class for trigger handler decorators.""" + + async def validate(self) -> None: + """Validate the decorated function.""" + await super().validate() + decorators = self.dm.get_decorators(TriggerDecorator) + if len(decorators) == 0: + # FIXME For test compatibility. Update the message in the future. + trig_decorators_reqd = { + "event_trigger", + "mqtt_trigger", + "state_trigger", + "time_trigger", + "webhook_trigger", + } + raise ValueError( + f"{self.dm.func_name} defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"needs at least one trigger decorator (ie: {', '.join(sorted(trig_decorators_reqd))})" + ) + + @abstractmethod + async def handle_dispatch(self, data: DispatchData) -> bool | None: + """Handle a trigger dispatch call. Return False for stop dispatching.""" + + +class CallHandlerDecorator(Decorator, ABC): + """Base class for call-based handlers.""" + + @abstractmethod + async def handle_call(self, data: DispatchData) -> bool | None: + """Handle an action call. Return False for stop calling.""" + pass + + +class CallResultHandlerDecorator(Decorator, ABC): + """Base class for call-based result handlers.""" + + @abstractmethod + async def handle_call_result(self, data: DispatchData, result: Any) -> None: + """Handle an action call result.""" + pass diff --git a/custom_components/pyscript/decorators/__init__.py b/custom_components/pyscript/decorators/__init__.py new file mode 100644 index 0000000..85cf580 --- /dev/null +++ b/custom_components/pyscript/decorators/__init__.py @@ -0,0 +1,19 @@ +from .event import EventTriggerDecorator +from .mqtt import MQTTTriggerDecorator +from .service import ServiceDecorator +from .state import StateTriggerDecorator, StateActiveDecorator +from .task import TaskUniqueDecorator +from .timing import TimeTriggerDecorator, TimeActiveDecorator +from .webhook import WebhookTriggerDecorator + +DECORATORS = [ + StateTriggerDecorator, + StateActiveDecorator, + TimeTriggerDecorator, + TimeActiveDecorator, + TaskUniqueDecorator, + EventTriggerDecorator, + MQTTTriggerDecorator, + WebhookTriggerDecorator, + ServiceDecorator, +] diff --git a/custom_components/pyscript/decorators/base.py b/custom_components/pyscript/decorators/base.py new file mode 100644 index 0000000..ef32c44 --- /dev/null +++ b/custom_components/pyscript/decorators/base.py @@ -0,0 +1,63 @@ +"""Base mixins for pyscript decorators.""" + +import logging +from abc import ABC +from typing import Any + +import voluptuous as vol + +from ..decorator import FunctionDecoratorManager +from ..decorator_abc import Decorator, DispatchData +from ..eval import AstEval, Function + +_LOGGER = logging.getLogger(__name__) + + +class AutoKwargsDecorator(Decorator, ABC): + """Mixin that copies validated kwargs into instance attributes based on annotations.""" + + async def validate(self) -> None: + """Run base validation and materialize annotated kwargs as attributes.""" + await super().validate() + for k in self.__class__.kwargs_schema.schema: + if isinstance(k, vol.Marker): + k = k.schema + if k in self.__class__.__annotations__: + setattr(self, k, self.kwargs.get(k, None)) + + +class ExpressionDecorator(Decorator, ABC): + """Base for AstEval-based decorators.""" + + _ast_expression: AstEval = None + + def create_expression(self, expression: str) -> None: + """Create AstEval expression.""" + _LOGGER.debug("Create expression: %s, %s", expression, self) + dec_name = self.name + if isinstance(self.dm, FunctionDecoratorManager): + dec_name = "@" + dec_name + "()" + + self._ast_expression = AstEval( + self.dm.name + " " + dec_name, self.dm.ast_ctx.global_ctx, self.dm.name + ) + Function.install_ast_funcs(self._ast_expression) + self._ast_expression.parse(expression, mode="eval") + exc = self._ast_expression.get_exception_obj() + if exc is not None: + raise exc + + def has_expression(self) -> bool: + """Return True if expression was created.""" + return self._ast_expression is not None + + async def check_expression_vars(self, state_vars: dict[str, Any]) -> bool: + """Evaluate expression and dispatch an exception event via manager on failure.""" + if not self.has_expression(): + raise AttributeError(f"{self} has no expression defined") + ret = await self._ast_expression.eval(state_vars) + if exception := self._ast_expression.get_exception_obj(): + exception_text = self._ast_expression.get_exception_long() + await self.dm.dispatch(DispatchData({}, exception=exception, exception_text=exception_text)) + return False + return ret diff --git a/custom_components/pyscript/decorators/event.py b/custom_components/pyscript/decorators/event.py new file mode 100644 index 0000000..46ac898 --- /dev/null +++ b/custom_components/pyscript/decorators/event.py @@ -0,0 +1,57 @@ +import logging + +import voluptuous as vol +from homeassistant.core import Event, CALLBACK_TYPE + +from .base import ExpressionDecorator +from ..decorator_abc import DispatchData, TriggerDecorator + +_LOGGER = logging.getLogger(__name__) + + +class EventTriggerDecorator(TriggerDecorator, ExpressionDecorator): + """Implementation for @event_trigger.""" + + name = "event_trigger" + args_schema = vol.Schema( + vol.All( + [vol.Coerce(str)], + vol.Length(min=1, max=2, msg="needs at least one argument"), + ) + ) + + remove_listener_callback: CALLBACK_TYPE | None = None + + async def validate(self) -> None: + """Validate the event trigger.""" + await super().validate() + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _event_callback(self, event: Event) -> None: + """Callback for the event trigger.""" + _LOGGER.debug("Event trigger received: %s %s", type(event), event) + func_args = { + "trigger_type": "event", + "event_type": event.event_type, + "context": event.context, + } + func_args.update(event.data) + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + + await self.dispatch(DispatchData(func_args)) + + async def start(self) -> None: + """Start the event trigger.""" + await super().start() + self.remove_listener_callback = self.dm.hass.bus.async_listen(self.args[0], self._event_callback) + _LOGGER.debug("Event trigger started for event: %s", self.args[0]) + _LOGGER.debug("Remove listener: %s", self.remove_listener_callback) + + async def stop(self) -> None: + """Stop the event trigger.""" + await super().stop() + if self.remove_listener_callback: + self.remove_listener_callback() diff --git a/custom_components/pyscript/decorators/mqtt.py b/custom_components/pyscript/decorators/mqtt.py new file mode 100644 index 0000000..fa2dc7b --- /dev/null +++ b/custom_components/pyscript/decorators/mqtt.py @@ -0,0 +1,68 @@ +"""Trigger decorator implementations.""" + +from __future__ import annotations + +import json +import logging + +import voluptuous as vol +from homeassistant.components import mqtt +from homeassistant.core import CALLBACK_TYPE + +from .base import ExpressionDecorator, AutoKwargsDecorator +from ..decorator_abc import DispatchData, TriggerDecorator + +_LOGGER = logging.getLogger(__name__) + + +class MQTTTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @mqtt_trigger.""" + + name = "mqtt_trigger" + args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=1, max=2))) + kwargs_schema = vol.Schema({vol.Optional("encoding", default="utf-8"): str}) + + encoding: str + + remove_listener_callback: CALLBACK_TYPE | None = None + + async def validate(self) -> None: + """Validate the MQTT trigger.""" + await super().validate() + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _mqtt_message_handler(self, mqttmsg: mqtt.ReceiveMessage) -> None: + func_args = { + "trigger_type": "mqtt", + "topic": mqttmsg.topic, + "payload": mqttmsg.payload, + "qos": mqttmsg.qos, + "retain": mqttmsg.retain, + } + try: + func_args["payload_obj"] = json.loads(mqttmsg.payload) + except ValueError: + pass + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + await self.dispatch(DispatchData(func_args)) + + async def start(self) -> None: + """Start the MQTT trigger.""" + await super().start() + topic = self.args[0] + self.remove_listener_callback = await mqtt.async_subscribe( + self.dm.hass, + topic, + self._mqtt_message_handler, + encoding=self.encoding, + qos=0, + ) + + async def stop(self) -> None: + """Stop the MQTT trigger.""" + await super().stop() + if self.remove_listener_callback: + self.remove_listener_callback() diff --git a/custom_components/pyscript/decorators/service.py b/custom_components/pyscript/decorators/service.py new file mode 100644 index 0000000..668dd26 --- /dev/null +++ b/custom_components/pyscript/decorators/service.py @@ -0,0 +1,142 @@ +"""Service decorator implementation.""" + +from __future__ import annotations + +import ast +import io +import logging +import typing +from collections import OrderedDict + +import voluptuous as vol +import yaml +from homeassistant.const import SERVICE_RELOAD +from homeassistant.core import ServiceCall, SupportsResponse +from homeassistant.helpers.service import async_set_service_schema +from ..decorator import FunctionDecoratorManager + +from ..decorator_abc import Decorator +from .. import DOMAIN, SERVICE_JUPYTER_KERNEL_START, AstEval, Function, State + +_LOGGER = logging.getLogger(__name__) + + +def service_validator(args: list[str]) -> list[str]: + if len(args) == 0: + return [] + s = str(args[0]).strip() + + if not isinstance(s, str): + raise vol.Invalid("must be string") + s = s.strip() + if s.count(".") != 1: + raise vol.Invalid("argument 1 should be a string with one period") + domain, name = s.split(".", 1) + return [domain, name] + + +class ServiceDecorator(Decorator): + """Implementation for @service.""" + + name = "service" + args_schema = vol.Schema(vol.All(vol.Length(max=1), service_validator)) + kwargs_schema = vol.Schema( + {vol.Optional("supports_response", default=SupportsResponse.NONE): vol.Coerce(SupportsResponse)} + ) + + description: dict + + async def validate(self) -> None: + await super().validate() + + if len(self.args) != 2: + self.args = [DOMAIN, self.dm.func_name] + # FIXME This condition doesn’t verify the domain - it may not be Pyscript. + # The error is kept for backward compatibility. + if self.args[1] in (SERVICE_RELOAD, SERVICE_JUPYTER_KERNEL_START): + # FIXME For test compatibility. Update the message in the future. + raise SyntaxError( + f"function '{self.dm.func_name}' defined in {self.dm.ast_ctx.get_global_ctx_name()}: " + f"@service conflicts with builtin service" + ) + + ast_funcdef = typing.cast(FunctionDecoratorManager, self.dm).eval_func.func_def + desc = ast.get_docstring(ast_funcdef) + if desc is None or desc == "": + desc = f"pyscript function {ast_funcdef.name}()" + desc = desc.lstrip(" \n\r") + if desc.startswith("yaml"): + try: + desc = desc[4:].lstrip(" \n\r") + file_desc = io.StringIO(desc) + self.description = yaml.load(file_desc, Loader=yaml.BaseLoader) or OrderedDict() + file_desc.close() + except Exception as exc: + self.dm.logger.error( + "Unable to decode yaml doc_string for %s(): %s", + ast_funcdef.name, + str(exc), + ) + raise exc + else: + fields = OrderedDict() + for arg in ast_funcdef.args.posonlyargs + ast_funcdef.args.args: + # _LOGGER.warning(f"------ {type(arg.arg)} {arg.arg}") + fields[arg.arg] = OrderedDict(description=f"argument {arg.arg}") + self.description = {"description": desc, "fields": fields} + + async def _service_callback(self, call: ServiceCall) -> None: + _LOGGER.info("Service callback: %s", call.service) + + # use a new AstEval context so it can run fully independently + # of other instances (except for global_ctx which is common) + global_ctx = self.dm.eval_func.global_ctx + ast_ctx = AstEval(self.dm.name, global_ctx) + Function.install_ast_funcs(ast_ctx) + func_args = { + "trigger_type": "service", + "context": call.context, + } + func_args.update(call.data) + + # + async def do_service_call(func, ast_ctx, data): + try: + _LOGGER.debug("Service call start: %s", func.name) + retval = await func.call(ast_ctx, **data) + _LOGGER.debug("Service call done: %s", ast_ctx.get_exception_long()) + if ast_ctx.get_exception_obj(): + ast_ctx.get_logger().error(ast_ctx.get_exception_long()) + return retval + except Exception as exc: + _LOGGER.exception(exc) + return None + + # + task = Function.create_task(do_service_call(self.dm.eval_func, ast_ctx, func_args)) + await task + return task.result() + + async def start(self) -> None: + """Register the service.""" + domain = self.args[0] + name = self.args[1] + _LOGGER.debug("Registering service: %s.%s", domain, name) + Function.service_register( + self.dm.ast_ctx.name, + domain, + name, + self._service_callback, + self.kwargs.get("supports_response"), + ) + async_set_service_schema(Function.hass, domain, name, self.description) + + # update service params. In the legacy implementation, Pyscript services were registered + # right after the function definition, then decorators were executed, and finally the + # service cache was updated. + await State.get_service_params() + + async def stop(self) -> None: + """Unregister the service.""" + _LOGGER.debug("Unregistering service: %s.%s", self.args[0], self.args[1]) + Function.service_remove(self.dm.ast_ctx.global_ctx.get_name(), self.args[0], self.args[1]) diff --git a/custom_components/pyscript/decorators/state.py b/custom_components/pyscript/decorators/state.py new file mode 100644 index 0000000..ddd3a4c --- /dev/null +++ b/custom_components/pyscript/decorators/state.py @@ -0,0 +1,311 @@ +import asyncio +import re + +from homeassistant.helpers import config_validation as cv + +from .base import ExpressionDecorator, AutoKwargsDecorator +from ..decorator import WaitUntilDecoratorManager +from ..decorator_abc import * +from ..state import State +from ..trigger import ident_any_values_changed, ident_values_changed + +STATE_RE = re.compile(r"\w+\.\w+(\.((\w+)|\*))?$") + +_LOGGER = logging.getLogger(__name__) + + +class StateActiveDecorator(TriggerHandlerDecorator, ExpressionDecorator): + """Implementation for @state_active.""" + + name = "state_active" + args_schema = vol.Schema( + vol.All( + vol.Length( + min=1, max=1, msg="got 2 arguments, expected 1" + ), # FIXME For test compatibility. Update the message in the future. + vol.All([str]), + ) + ) + + var_names: set[str] + + async def validate(self) -> None: + """Validate the decorator arguments.""" + await super().validate() + self.create_expression(self.args[0]) + self.var_names = await self._ast_expression.get_names() + + async def handle_dispatch(self, data: DispatchData) -> bool: + new_vars = data.trigger_context.get("new_vars", {}) + active_vars = State.notify_var_get(self.var_names, new_vars) + return await self.check_expression_vars(active_vars) + + +def _validate_state_trigger_args(args: list[Any]) -> list[str]: + """Validate and normalize @state_trigger positional arguments.""" + if not isinstance(args, list): + raise vol.Invalid("arguments must be a list") + if len(args) == 0: + raise vol.Invalid("needs at least one argument") + + normalized: list[str] = [] + for idx, arg in enumerate(args, start=1): + if isinstance(arg, str): + normalized.append(arg) + continue + if isinstance(arg, (list, set)): + if not all(isinstance(expr, str) for expr in arg): + raise vol.Invalid(f"argument {idx} should be a string, or list, or set") + normalized.extend(list(arg)) + continue + raise vol.Invalid(f"argument {idx} should be a string, or list, or set") + return normalized + + +class StateTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @state_trigger.""" + + name = "state_trigger" + args_schema = vol.Schema(vol.All(_validate_state_trigger_args)) + kwargs_schema = vol.Schema( + { + vol.Optional("state_hold"): vol.Any(None, cv.positive_float), + vol.Optional("state_hold_false"): vol.Any(None, cv.positive_float), + vol.Optional("state_check_now"): cv.boolean, + vol.Optional("watch"): vol.Coerce(set[str], msg="should be type list or set"), + vol.Optional("__test_handshake__"): vol.Coerce(list), + } + ) + # kwargs + state_hold: float | None + state_hold_false: float | None + state_check_now: bool | None + __test_handshake__: list[str] | None + + notify_q: asyncio.Queue + in_wait_until_function: bool + cycle_task: asyncio.Task = None + + state_trig_ident: set[str] + state_trig_ident_any: set[str] + true_entered_at: float | None + false_entered_at: float | None + + async def validate(self) -> None: + await super().validate() + self.state_trig_ident = set() + self.state_trig_ident_any = set() + + self.in_wait_until_function = isinstance(self.dm, WaitUntilDecoratorManager) + + if self.state_check_now is None and self.in_wait_until_function: + # check by default for task.wait_until + self.state_check_now = True + + state_trig = [] + + for trig in self.args: + if STATE_RE.match(trig): + self.state_trig_ident_any.add(trig) + else: + state_trig.append(trig) + + if len(state_trig) > 0: + if len(state_trig) == 1: + state_trig_expr = state_trig[0] + else: + state_trig_expr = f"any([{', '.join(state_trig)}])" + + self.create_expression(state_trig_expr) + + if self.kwargs.get("watch") is not None: + self.state_trig_ident = set(self.kwargs.get("watch", [])) + else: + if self.has_expression(): + self.state_trig_ident = await self._ast_expression.get_names() + self.state_trig_ident.update(self.state_trig_ident_any) + + _LOGGER.debug("trigger %s: watching vars %s", self.name, self.state_trig_ident) + _LOGGER.debug("trigger %s: any %s", self.name, self.state_trig_ident_any) + if len(self.state_trig_ident) == 0: + self.dm.logger.error( + "trigger %s: @state_trigger is not watching any variables; will never trigger", + self.dm.name, + ) + + def _diff(self, dt: float, now: float) -> str: + if dt is None: + return "None" + else: + return f"{(now-dt):g} ago" + + async def _check_new_state(self, trig_ok: bool): + now = asyncio.get_running_loop().time() + if _LOGGER.isEnabledFor(logging.DEBUG): + msg = f"check_new_state: {self}" + msg += f"\ntrig_ok: {trig_ok} now {now} func_args: {self.last_func_args} new_vars: {self.last_new_vars}" + if self.true_entered_at: + msg += f"\ntrue_entered_at: {self.true_entered_at}({(now-self.true_entered_at):g} ago)\n" + if self.false_entered_at: + msg += f"\nfalse_entered_at: {self.false_entered_at}({(now-self.false_entered_at):g} ago)\n" + _LOGGER.debug(msg) + + state_hold_false_passed = False + state_hold_true_passed = False + if trig_ok: + if self.state_hold_false is None or not self.has_expression(): + state_hold_false_passed = True + else: + if self.false_entered_at: + false_duration = now - self.false_entered_at + if false_duration >= self.state_hold_false: + state_hold_false_passed = True + _LOGGER.debug( + "state_hold_false passed (%g), reset false_entered_at, %s", false_duration, self + ) + self.false_entered_at = None + + if state_hold_false_passed: + if self.state_hold is None: + state_hold_true_passed = True + else: + if self.true_entered_at: + true_duration = now - self.true_entered_at + if true_duration >= self.state_hold: + state_hold_true_passed = True + self.true_entered_at = None + _LOGGER.debug( + "state_hold passed (%g), reset true_entered_at, %s", true_duration, self + ) + else: + _LOGGER.debug("state_hold started, %s", self) + self.true_entered_at = now + + if state_hold_true_passed: + self.true_entered_at = None + await self.dispatch( + DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars}) + ) + self.__test_handshake__ = None + else: + self.true_entered_at = None + if self.state_hold_false is not None: + if not self.false_entered_at: + _LOGGER.debug("state_hold_false started, %s", self) + self.false_entered_at = now + + async def _check_state_hold(self): + if self.true_entered_at is None: + raise RuntimeError(f"state_hold not started for {self}") + + now = asyncio.get_running_loop().time() + true_duration = now - self.true_entered_at + if true_duration >= self.state_hold: + self.true_entered_at = None + await self.dispatch( + DispatchData(self.last_func_args, trigger_context={"new_vars": self.last_new_vars}) + ) + + async def _cycle(self): + """Run the trigger cycle with state_hold and state_hold_false logic.""" + loop = asyncio.get_running_loop() + + self.true_entered_at = None + self.false_entered_at = None + + self.last_func_args = {"trigger_type": "state"} + self.last_new_vars = {} + + check_state_expr_on_start = self.state_check_now or self.state_hold_false is not None + + if check_state_expr_on_start: + self.last_new_vars = State.notify_var_get(self.state_trig_ident, {}) + trig_ok = await self._is_trig_ok() + + if self.in_wait_until_function and trig_ok and self.state_check_now is True: + self.state_hold_false = None + + if self.state_check_now and self.has_expression(): + await self._check_new_state(trig_ok) + else: + if not trig_ok and self.state_hold_false is not None: + self.false_entered_at = loop.time() + + if self.__test_handshake__ is not None: + # + # used for testing to avoid race conditions + # we use this as a handshake that we are about to + # listen to the queue + # + _LOGGER.debug("__test_handshake__ handshake: %s", self.__test_handshake__) + State.set(self.__test_handshake__[0], self.__test_handshake__[1]) + self.__test_handshake__ = None + + while self.dm.status is DecoratorManagerStatus.RUNNING: + if self.true_entered_at is None: + effective_timeout = None + else: + effective_timeout = self.state_hold + if self.true_entered_at is not None: + effective_timeout -= loop.time() - self.true_entered_at + + if effective_timeout <= 1e-6: + # ignore deltas smaller than 1us. + await self._check_state_hold() + continue + + try: + if effective_timeout is None: + notify_type, notify_info = await self.notify_q.get() + else: + notify_type, notify_info = await asyncio.wait_for(self.notify_q.get(), effective_timeout) + if notify_type != "state": + raise RuntimeError(f"Invalid notify_type {notify_type}, {self}") + self.last_new_vars = notify_info[0] + self.last_func_args = notify_info[1] + + if ident_any_values_changed(self.last_func_args, self.state_trig_ident_any): + trig_ok = True + elif ident_values_changed(self.last_func_args, self.state_trig_ident): + trig_ok = await self._is_trig_ok() + else: + trig_ok = False + await self._check_new_state(trig_ok) + except asyncio.TimeoutError: + await self._check_state_hold() + + async def _is_trig_ok(self) -> bool: + if self.has_expression(): + return await self.check_expression_vars(self.last_new_vars) + else: + return True + + def _on_task_done(self, task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + self.dm.logger.exception(f"{self} failed", exc_info=exc) + + async def start(self) -> None: + """Start the trigger.""" + await super().start() + self.notify_q = asyncio.Queue(0) + if not await State.notify_add(self.state_trig_ident, self.notify_q): + # FIXME raise exception? + self.dm.logger.error( + "trigger %s: @state_trigger is not watching any variables; will never trigger", + self.dm.name, + ) + return + _LOGGER.debug("trigger %s: starting", self.name) + + self.cycle_task = self.dm.hass.async_create_task(self._cycle()) + self.cycle_task.add_done_callback(self._on_task_done) + + async def stop(self): + """Stop the trigger.""" + await super().stop() + if hasattr(self, "cycle_task"): + self.cycle_task.cancel() + State.notify_del(self.state_trig_ident, self.notify_q) diff --git a/custom_components/pyscript/decorators/task.py b/custom_components/pyscript/decorators/task.py new file mode 100644 index 0000000..0848f89 --- /dev/null +++ b/custom_components/pyscript/decorators/task.py @@ -0,0 +1,38 @@ +"""Task decorators.""" + +from __future__ import annotations + +import logging + +import voluptuous as vol +from homeassistant.helpers import config_validation as cv + +from .base import AutoKwargsDecorator +from ..decorator_abc import DispatchData, CallHandlerDecorator +from ..function import Function + +_LOGGER = logging.getLogger(__name__) + + +class TaskUniqueDecorator(CallHandlerDecorator, AutoKwargsDecorator): + """Implementation for @task_unique.""" + + name = "task_unique" + args_schema = vol.Schema(vol.All([str], vol.Length(min=1, max=1))) + kwargs_schema = vol.Schema({vol.Optional("kill_me", default=False): cv.boolean}) + + kill_me: bool + + async def handle_call(self, data: DispatchData) -> bool: + if self.kill_me: + if Function.unique_name_used(data.call_ast_ctx, self.args[0]): + _LOGGER.debug( + "trigger %s got %s trigger, @task_unique kill_me=True prevented new action", + "notify_type", + self.name, + ) + return False + + task_unique_func = Function.task_unique_factory(data.call_ast_ctx) + await task_unique_func(self.args[0]) + return True diff --git a/custom_components/pyscript/decorators/timing.py b/custom_components/pyscript/decorators/timing.py new file mode 100644 index 0000000..f5b53af --- /dev/null +++ b/custom_components/pyscript/decorators/timing.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import datetime as dt +import logging +import time + +import voluptuous as vol +from homeassistant.helpers import config_validation as cv + +from .base import AutoKwargsDecorator +from .. import trigger +from ..decorator import WaitUntilDecoratorManager +from ..decorator_abc import DispatchData, TriggerHandlerDecorator, TriggerDecorator, DecoratorManagerStatus + +_LOGGER = logging.getLogger(__name__) + + +def dt_now(): + """Return current time.""" + # FIXME For test compatibility. The tests patch this function + return trigger.dt_now() + + +class TimeActiveDecorator(TriggerHandlerDecorator, AutoKwargsDecorator): + """Implementation for @time_active.""" + + name = "time_active" + args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=0))) + kwargs_schema = vol.Schema({vol.Optional("hold_off", default=0.0): cv.positive_float}) + + hold_off: float + + last_trig_time: float = 0.0 + + async def handle_dispatch(self, data: DispatchData) -> bool: + if self.last_trig_time > 0.0 and self.hold_off > 0.0: + if time.monotonic() - self.last_trig_time < self.hold_off: + return False + + if len(self.args) > 0: + if "trigger_time" in data.func_args and isinstance(data.func_args["trigger_time"], dt.datetime): + now = data.func_args["trigger_time"] + else: + now = dt_now() + + for time_spec in self.args: + _LOGGER.debug("time_spec %s, %s", time_spec, self) + _LOGGER.debug("time_active now %s, %s", now, self) + if await trigger.TrigTime.timer_active_check(time_spec, now, self.dm.startup_time): + self.last_trig_time = time.monotonic() + return True + return False + + self.last_trig_time = time.monotonic() + return True + + +class TimeTriggerDecorator(TriggerDecorator): + """Implementation for @time_trigger.""" + + name = "time_trigger" + # args_schema = vol.Schema(vol.All([vol.Coerce(str)], vol.Length(min=0))) + args_schema = vol.Schema( + vol.All( + vol.Length(min=0), + vol.All( + [str], msg="argument 2 should be a string" + ), # FIXME For test compatibility. Update the message in the future. + ) + ) + + run_on_startup: bool = False + run_on_shutdown: bool = False + timespec: list[str] + _cycle_task: asyncio.Task + + async def validate(self) -> None: + """Validate the decorator arguments.""" + await super().validate() + self.timespec = self.args + + if len(self.timespec) == 0: + self.run_on_startup = True + return + + while "startup" in self.timespec: + self.run_on_startup = True + self.timespec.remove("startup") + while "shutdown" in self.timespec: + self.run_on_shutdown = True + self.timespec.remove("shutdown") + + async def _cycle(self): + if self.run_on_startup: + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": "startup"})) + + first_run = True + try: + while self.dm.status is DecoratorManagerStatus.RUNNING: + if first_run: + now = self.dm.startup_time + first_run = False + else: + now = dt_now() + + _LOGGER.debug("time_trigger now %s", now) + time_next, time_next_adj = await trigger.TrigTime.timer_trigger_next( + self.timespec, now, self.dm.startup_time + ) + _LOGGER.debug( + "trigger %s time_next = %s, time_next_adj = %s, now = %s", + self.dm.name, + time_next, + time_next_adj, + now, + ) + if time_next is None: + _LOGGER.debug("trigger %s finished", self.name) + if isinstance(self.dm, WaitUntilDecoratorManager): + await self.dispatch(DispatchData({"trigger_type": "none"})) + break + + # replace with homeassistant.helpers.event.async_track_point_in_utc_time? + timeout = (time_next_adj - now).total_seconds() + _LOGGER.debug("%s sleeping for %s seconds", self, timeout) + await asyncio.sleep(timeout) + _LOGGER.debug("%s finish sleeping for %s seconds", self, timeout) + while True: + now = dt_now() + timeout = (time_next_adj - now).total_seconds() + if timeout <= 1e-6: + break + _LOGGER.debug("%s additional sleep for %s seconds", self, timeout) + await asyncio.sleep(timeout) + + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": time_next})) + except asyncio.CancelledError: + raise + + async def stop(self): + """Stop the trigger.""" + if hasattr(self, "_cycle_task"): + self._cycle_task.cancel() + if self.run_on_shutdown: + await self.dispatch(DispatchData({"trigger_type": "time", "trigger_time": "shutdown"})) + + def _on_task_done(self, task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + self.dm.logger.exception(f"{self} failed", exc_info=exc) + + async def start(self) -> None: + """Start the decorator.""" + await super().start() + self._cycle_task = self.dm.hass.async_create_task(self._cycle()) + self._cycle_task.add_done_callback(self._on_task_done) diff --git a/custom_components/pyscript/decorators/webhook.py b/custom_components/pyscript/decorators/webhook.py new file mode 100644 index 0000000..80e6103 --- /dev/null +++ b/custom_components/pyscript/decorators/webhook.py @@ -0,0 +1,82 @@ +import logging + +import voluptuous as vol +from aiohttp import hdrs +from homeassistant.components import webhook +from homeassistant.components.webhook import SUPPORTED_METHODS +from homeassistant.helpers import config_validation as cv + +from .base import ExpressionDecorator, AutoKwargsDecorator +from ..decorator_abc import DispatchData, TriggerDecorator + +_LOGGER = logging.getLogger(__name__) + + +class WebhookTriggerDecorator(TriggerDecorator, ExpressionDecorator, AutoKwargsDecorator): + """Implementation for @webhook_trigger.""" + + name = "webhook_trigger" + args_schema = vol.Schema( + vol.All( + [vol.Coerce(str)], + vol.Length(min=1, max=2, msg="needs at least one argument"), + ) + ) + kwargs_schema = vol.Schema( + { + vol.Optional("local_only", default=True): cv.boolean, + vol.Optional("methods"): vol.All(list[str], [vol.In(SUPPORTED_METHODS)]), + } + ) + + webhook_id: str + local_only: bool + methods: set[str] + + async def validate(self): + """Validate the webhook trigger configuration.""" + await super().validate() + self.webhook_id = self.args[0] + + if len(self.args) == 2: + self.create_expression(self.args[1]) + + async def _handler(self, hass, webhook_id, request): + func_args = { + "trigger_type": "webhook", + "webhook_id": webhook_id, + } + + if "json" in request.headers.get(hdrs.CONTENT_TYPE, ""): + func_args["payload"] = await request.json() + else: + # Could potentially return multiples of a key - only take the first + payload_multidict = await request.post() + func_args["payload"] = {k: payload_multidict.getone(k) for k in payload_multidict.keys()} + + if self.has_expression(): + if not await self.check_expression_vars(func_args): + return + + await self.dispatch(DispatchData(func_args)) + + async def start(self): + """Start the webhook trigger.""" + await super().start() + webhook.async_register( + self.dm.hass, + "pyscript", # DOMAIN + "pyscript", # NAME + self.webhook_id, + self._handler, + local_only=self.local_only, + allowed_methods=self.methods, + ) + + _LOGGER.debug("webhook trigger %s listening on id %s", self.dm.name, self.webhook_id) + + async def stop(self): + """Stop the webhook trigger.""" + await super().stop() + # FIXME uncomment + webhook.async_unregister(self.dm.hass, self.webhook_id) diff --git a/custom_components/pyscript/eval.py b/custom_components/pyscript/eval.py index 780a1f8..e8c5177 100644 --- a/custom_components/pyscript/eval.py +++ b/custom_components/pyscript/eval.py @@ -325,6 +325,7 @@ def __init__(self, func_def, code_list, code_str, global_ctx, async_func=False): self.defaults = [] self.kw_defaults = [] self.decorators = [] + self.dm_decorators = [] self.global_names = set() self.nonlocal_names = set() self.local_names = None @@ -613,8 +614,11 @@ async def eval_decorators(self, ast_ctx): dec_other = [] dec_trig = [] + dec_dm = [] for dec in self.func_def.decorator_list: - if ( + if known_dec := await ast_ctx.global_ctx.get_decorator_by_expr(ast_ctx, dec): + dec_dm.append(known_dec) + elif ( isinstance(dec, ast.Call) and isinstance(dec.func, ast.Name) and dec.func.id in TRIG_SERV_DECORATORS @@ -628,7 +632,7 @@ async def eval_decorators(self, ast_ctx): dec_other.append(await ast_ctx.aeval(dec)) ast_ctx.code_str, ast_ctx.code_list = code_str, code_list - return dec_trig, reversed(dec_other) + return dec_trig, reversed(dec_other), dec_dm async def resolve_nonlocals(self, ast_ctx): """Tag local variables and resolve nonlocals.""" @@ -1197,7 +1201,7 @@ async def executor_wrap(*args, **kwargs): await func.eval_defaults(self) await func.resolve_nonlocals(self) name = func.get_name() - dec_trig, dec_other = await func.eval_decorators(self) + dec_trig, dec_other, dec_dm = await func.eval_decorators(self) self.dec_eval_depth += 1 for dec_func in dec_other: func = await self.call_func(dec_func, None, func) @@ -1206,16 +1210,21 @@ async def executor_wrap(*args, **kwargs): func.set_name(name) func = func.remove_func() dec_trig += func.decorators + dec_dm += func.dm_decorators elif isinstance(func, EvalFunc): func.set_name(name) self.dec_eval_depth -= 1 if isinstance(func, EvalFunc): func.decorators = dec_trig + func.dm_decorators = dec_dm if self.dec_eval_depth == 0: func.trigger_stop() await func.trigger_init(self.global_ctx, name) func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) + + if len(dec_dm) > 0: + await self.get_global_ctx().create_decorator_manager(dec_dm, self, func_var) else: func_var = EvalFuncVar(func) func_var.set_ast_ctx(self) diff --git a/custom_components/pyscript/global_ctx.py b/custom_components/pyscript/global_ctx.py index 3d382ed..e48e1c4 100644 --- a/custom_components/pyscript/global_ctx.py +++ b/custom_components/pyscript/global_ctx.py @@ -1,5 +1,6 @@ """Global context handling.""" +import ast import logging import os from types import ModuleType @@ -8,7 +9,9 @@ from homeassistant.config_entries import ConfigEntry from .const import CONF_HASS_IS_GLOBAL, CONFIG_ENTRY, DOMAIN, FOLDER, LOGGER_PATH -from .eval import AstEval, EvalFunc +from .decorator import DecoratorRegistry, FunctionDecoratorManager +from .decorator_abc import Decorator, DecoratorManagerStatus +from .eval import AstEval, EvalFunc, EvalFuncVar from .function import Function from .trigger import TrigInfo @@ -33,6 +36,8 @@ def __init__( self.global_sym_table: Dict[str, Any] = global_sym_table if global_sym_table else {} self.triggers: Set[EvalFunc] = set() self.triggers_delay_start: Set[EvalFunc] = set() + self.dms: Set[FunctionDecoratorManager] = set() + self.dms_delay_start: Set[FunctionDecoratorManager] = set() self.logger: logging.Logger = logging.getLogger(LOGGER_PATH + "." + name) self.manager: GlobalContextMgr = manager self.auto_start: bool = False @@ -60,6 +65,30 @@ def trigger_register(self, func: EvalFunc) -> bool: self.triggers_delay_start.add(func) return False + async def get_decorator_by_expr(self, ast_ctx: AstEval, dec: ast.expr) -> Decorator | None: + """Return decorator instance from an AST decorator expression.""" + return await DecoratorRegistry.get_decorator_by_expr(ast_ctx, dec) + + async def create_decorator_manager( + self, decs: list[Decorator], ast_ctx: AstEval, func_var: EvalFuncVar + ) -> None: + """Create decorator manager from an AST decorator expression.""" + dm = FunctionDecoratorManager(ast_ctx, func_var) + for dec in decs: + dm.add(dec) + + try: + await dm.validate() + if dm.status is DecoratorManagerStatus.VALIDATED: + self.dms.add(dm) + + if self.auto_start: + await dm.start() + else: + self.dms_delay_start.add(dm) + except Exception as exc: + _LOGGER.error(ast_ctx.format_exc(exc, dm.lineno, dm.col_offset)) + def trigger_unregister(self, func: EvalFunc) -> None: """Unregister a trigger function.""" self.triggers.discard(func) @@ -75,12 +104,20 @@ def start(self) -> None: func.trigger_start() self.triggers_delay_start = set() + for dm in self.dms_delay_start: + Function.hass.async_create_task(dm.start()) + self.dms_delay_start = set() + def stop(self) -> None: """Stop all triggers and auto_start.""" for func in self.triggers: func.trigger_stop() self.triggers = set() self.triggers_delay_start = set() + for dm in self.dms: + Function.hass.async_create_task(dm.stop()) + self.dms = set() + self.dms_delay_start = set() self.set_auto_start(False) def get_name(self) -> str: diff --git a/tests/test_decorator_errors.py b/tests/test_decorator_errors.py index d0f5051..6c75113 100644 --- a/tests/test_decorator_errors.py +++ b/tests/test_decorator_errors.py @@ -476,7 +476,4 @@ def func8(): pass """, ) - assert ( - "TypeError: function 'func8' defined in file.hello: {'bad'} aren't valid webhook_trigger methods" - in caplog.text - ) + assert "TypeError: function 'func8' defined in file.hello:" in caplog.text diff --git a/tests/test_unit_eval.py b/tests/test_unit_eval.py index 527e479..21afa20 100644 --- a/tests/test_unit_eval.py +++ b/tests/test_unit_eval.py @@ -3,6 +3,7 @@ import pytest from pytest_homeassistant_custom_component.common import MockConfigEntry +from custom_components.pyscript import DecoratorRegistry from custom_components.pyscript.const import CONF_ALLOW_ALL_IMPORTS, CONFIG_ENTRY, DOMAIN from custom_components.pyscript.eval import AstEval from custom_components.pyscript.function import Function @@ -1666,6 +1667,7 @@ async def test_eval(hass): State.init(hass) State.register_functions() TrigTime.init(hass) + DecoratorRegistry.init(hass) for test_data in evalTests: await run_one_test(test_data)