diff --git a/aioshelly/block_device/coap.py b/aioshelly/block_device/coap.py index 18b421cb..52174af7 100644 --- a/aioshelly/block_device/coap.py +++ b/aioshelly/block_device/coap.py @@ -2,13 +2,16 @@ from __future__ import annotations import asyncio -import json import logging import socket import struct from types import TracebackType from typing import Callable, cast +from ..json import JSONDecodeError, json_loads + +COAP_OPTION_DEVICE_ID = 3332 + _LOGGER = logging.getLogger(__name__) @@ -27,6 +30,7 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None: """Initialize a coap message.""" self.ip = sender_addr[0] self.port = sender_addr[1] + self.options: dict[int, bytes] = {} try: self.vttkl, self.code, self.mid = struct.unpack("!BBH", payload[:4]) @@ -36,9 +40,34 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None: if self.code not in (30, 69): raise InvalidMessage(f"Wrong type, {self.code}") + raw_data = payload[4:] + option_number = 0 + data = b"" + + # parse options + while raw_data: + if raw_data[0] == 0xFF: # end of options marker + data = raw_data[1:] + break + + delta = (raw_data[0] & 0xF0) >> 4 + length = raw_data[0] & 0x0F + (delta, raw_data) = self._read_extended_field_value(delta, raw_data[1:]) + (length, raw_data) = self._read_extended_field_value(length, raw_data) + option_number += delta + + if len(raw_data) < length: + raise InvalidMessage("Option announced but absent") + + self.options[option_number] = raw_data[:length] + raw_data = raw_data[length:] + + if not data: + raise InvalidMessage("Received message without data") + try: - self.payload = json.loads(payload.rsplit(b"\xff", 1)[1].decode()) - except (json.decoder.JSONDecodeError, UnicodeDecodeError, IndexError) as err: + self.payload = json_loads(data.decode()) + except (JSONDecodeError, UnicodeDecodeError) as err: raise InvalidMessage( f"Message type {self.code} is not a valid JSON format: {str(payload)}" ) from err @@ -48,13 +77,30 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None: else: coap_type = "reply" _LOGGER.debug( - "CoapMessage: ip=%s, type=%s(%s), payload=%s", + "CoapMessage: ip=%s, type=%s(%s), options=%s, payload=%s", self.ip, coap_type, self.code, + self.options, self.payload, ) + @staticmethod + def _read_extended_field_value(value: int, raw_data: bytes) -> tuple[int, bytes]: + """Decode large values of option delta and option length.""" + if 0 <= value < 13: + return (value, raw_data) + if value == 13: + if len(raw_data) < 1: + raise InvalidMessage("Option ended prematurely") + return (raw_data[0] + 13, raw_data[1:]) + if value == 14: + if len(raw_data) < 2: + raise InvalidMessage("Option ended prematurely") + return (int.from_bytes(raw_data[:2], "big") + 269, raw_data[2:]) + + raise InvalidMessage("Option contained partial payload marker.") + def socket_init(socket_port: int) -> socket.socket: """Init UDP socket to send/receive data with Shelly devices.""" @@ -110,24 +156,38 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: try: msg = CoapMessage(addr, data) except InvalidMessage as err: - if host_ip in self.subscriptions: - _LOGGER.error("Invalid Message from known host %s: %s", host_ip, err) - else: - _LOGGER.debug("Invalid Message from unknown host %s: %s", host_ip, err) + _LOGGER.debug("Invalid Message from host %s: %s", host_ip, err) return if self._message_received: self._message_received(msg) + if COAP_OPTION_DEVICE_ID not in msg.options: + _LOGGER.debug("Message from host %s missing device id option", host_ip) + return + + try: + device_id = msg.options[COAP_OPTION_DEVICE_ID].decode().split("#")[1][-6:] + except (UnicodeDecodeError, IndexError) as err: + _LOGGER.debug("Invalid device id from host %s: %s", host_ip, err) + return + + if device_id in self.subscriptions: + _LOGGER.debug("Calling CoAP message update for device id %s", device_id) + self.subscriptions[device_id](msg) + return + if msg.ip in self.subscriptions: - _LOGGER.debug("Calling CoAP message update for device %s", msg.ip) + _LOGGER.debug("Calling CoAP message update for host %s", msg.ip) self.subscriptions[msg.ip](msg) - def subscribe_updates(self, ip: str, message_received: Callable) -> Callable: + def subscribe_updates( + self, ip_or_device_id: str, message_received: Callable + ) -> Callable: """Subscribe to received updates.""" - _LOGGER.debug("Adding device %s to CoAP message subscriptions", ip) - self.subscriptions[ip] = message_received - return lambda: self.subscriptions.pop(ip) + _LOGGER.debug("Adding device %s to CoAP message subscriptions", ip_or_device_id) + self.subscriptions[ip_or_device_id] = message_received + return lambda: self.subscriptions.pop(ip_or_device_id) async def __aenter__(self) -> "COAP": """Entering async context manager.""" diff --git a/aioshelly/block_device/device.py b/aioshelly/block_device/device.py index b4f04778..a6aee696 100644 --- a/aioshelly/block_device/device.py +++ b/aioshelly/block_device/device.py @@ -20,6 +20,7 @@ ShellyError, WrongShellyGen, ) +from ..json import json_loads from .coap import COAP, CoapMessage BLOCK_VALUE_UNIT = "U" @@ -62,8 +63,11 @@ def __init__( self._settings: dict[str, Any] | None = None self._shelly: dict[str, Any] | None = None self._status: dict[str, Any] | None = None + sub_id = options.ip_address + if options.device_mac: + sub_id = options.device_mac[-6:] self._unsub_coap: Callable | None = coap_context.subscribe_updates( - options.ip_address, self._coap_message_received + sub_id, self._coap_message_received ) self._update_listener: Callable | None = None self._coap_response_events: dict = {} @@ -275,7 +279,7 @@ async def http_request( self._last_error = DeviceConnectionError(err) raise DeviceConnectionError from err - resp_json = await resp.json() + resp_json = await resp.json(loads=json_loads) _LOGGER.debug("aiohttp response: %s", resp_json) return cast(dict, resp_json) diff --git a/aioshelly/common.py b/aioshelly/common.py index 30b8d016..59d43c92 100644 --- a/aioshelly/common.py +++ b/aioshelly/common.py @@ -33,6 +33,7 @@ class ConnectionOptions: password: str | None = None temperature_unit: str = "C" auth: aiohttp.BasicAuth | None = None + device_mac: str | None = None def __post_init__(self) -> None: """Call after initialization.""" diff --git a/aioshelly/json.py b/aioshelly/json.py index d2ea70c7..ee9a18a4 100644 --- a/aioshelly/json.py +++ b/aioshelly/json.py @@ -4,6 +4,7 @@ import orjson +JSONDecodeError = orjson.JSONDecodeError # pylint: disable=no-member json_loads = orjson.loads # pylint: disable=no-member diff --git a/aioshelly/rpc_device/device.py b/aioshelly/rpc_device/device.py index 4efdd0c4..58318b42 100644 --- a/aioshelly/rpc_device/device.py +++ b/aioshelly/rpc_device/device.py @@ -63,8 +63,11 @@ def __init__( self._event: dict[str, Any] | None = None self._config: dict[str, Any] | None = None self._wsrpc = WsRPC(options.ip_address, self._on_notification) + sub_id = options.ip_address + if options.device_mac: + sub_id = options.device_mac self._unsub_ws: Callable | None = ws_context.subscribe_updates( - options.ip_address, self._wsrpc.handle_frame + sub_id, self._wsrpc.handle_frame ) self._update_listener: Callable | None = None self.initialized: bool = False diff --git a/aioshelly/rpc_device/wsrpc.py b/aioshelly/rpc_device/wsrpc.py index 16b334ce..ee55c252 100644 --- a/aioshelly/rpc_device/wsrpc.py +++ b/aioshelly/rpc_device/wsrpc.py @@ -447,13 +447,23 @@ async def websocket_handler(self, request: BaseRequest) -> WebSocketResponse: except ConnectionClosed: await ws_res.close() except InvalidMessage as err: - if ip in self.subscriptions: - _LOGGER.error("Invalid Message from known host %s: %s", ip, err) - else: - _LOGGER.debug("Invalid Message from unknown host %s: %s", ip, err) + _LOGGER.debug("Invalid Message from host %s: %s", ip, err) else: + try: + device_id = frame["src"].split("-")[1].upper() + except (KeyError, IndexError) as err: + _LOGGER.debug("Invalid device id from host %s: %s", ip, err) + continue + + if device_id in self.subscriptions: + _LOGGER.debug( + "Calling WsRPC message update for device id %s", device_id + ) + self.subscriptions[device_id](frame) + continue + if ip in self.subscriptions: - _LOGGER.debug("Calling WsRPC message update for device %s", ip) + _LOGGER.debug("Calling WsRPC message update for host %s", ip) self.subscriptions[ip](frame) _LOGGER.debug("Websocket server connection from %s closed", ip) diff --git a/example.py b/example.py index cafd69bf..86f0cce6 100644 --- a/example.py +++ b/example.py @@ -240,6 +240,9 @@ def get_arguments() -> tuple[argparse.ArgumentParser, argparse.Namespace]: parser.add_argument( "--debug", "-deb", action="store_true", help="Enable debug level for logging" ) + parser.add_argument( + "--mac", "-m", type=str, help="Optional device MAC to subscribe for updates" + ) arguments = parser.parse_args() @@ -278,7 +281,9 @@ def handle_sigint(_exit_code: int, _frame: FrameType) -> None: elif args.ip_address: if args.username and args.password is None: parser.error("--username and --password must be used together") - options = ConnectionOptions(args.ip_address, args.username, args.password) + options = ConnectionOptions( + args.ip_address, args.username, args.password, device_mac=args.mac + ) await test_single(options, args.init, gen) else: parser.error("--ip_address or --devices must be specified")