Skip to content

Commit

Permalink
Add support for subscribing to updates by device MAC (#327)
Browse files Browse the repository at this point in the history
* Add support for subscribing to updates by device MAC

* Fix _read_extended_field_value

* Switch to orjson and fix Gen1 MAC compare

* Fix tox failure
  • Loading branch information
thecode committed Jan 22, 2023
1 parent c75362c commit c5b6f78
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 22 deletions.
86 changes: 73 additions & 13 deletions aioshelly/block_device/coap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from __future__ import annotations

import asyncio
import json
import logging
import socket
import struct
from types import TracebackType
from typing import Callable, cast

from ..json import JSONDecodeError, json_loads

COAP_OPTION_DEVICE_ID = 3332

_LOGGER = logging.getLogger(__name__)


Expand All @@ -27,6 +30,7 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
"""Initialize a coap message."""
self.ip = sender_addr[0]
self.port = sender_addr[1]
self.options: dict[int, bytes] = {}

try:
self.vttkl, self.code, self.mid = struct.unpack("!BBH", payload[:4])
Expand All @@ -36,9 +40,34 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
if self.code not in (30, 69):
raise InvalidMessage(f"Wrong type, {self.code}")

raw_data = payload[4:]
option_number = 0
data = b""

# parse options
while raw_data:
if raw_data[0] == 0xFF: # end of options marker
data = raw_data[1:]
break

delta = (raw_data[0] & 0xF0) >> 4
length = raw_data[0] & 0x0F
(delta, raw_data) = self._read_extended_field_value(delta, raw_data[1:])
(length, raw_data) = self._read_extended_field_value(length, raw_data)
option_number += delta

if len(raw_data) < length:
raise InvalidMessage("Option announced but absent")

self.options[option_number] = raw_data[:length]
raw_data = raw_data[length:]

if not data:
raise InvalidMessage("Received message without data")

try:
self.payload = json.loads(payload.rsplit(b"\xff", 1)[1].decode())
except (json.decoder.JSONDecodeError, UnicodeDecodeError, IndexError) as err:
self.payload = json_loads(data.decode())
except (JSONDecodeError, UnicodeDecodeError) as err:
raise InvalidMessage(
f"Message type {self.code} is not a valid JSON format: {str(payload)}"
) from err
Expand All @@ -48,13 +77,30 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
else:
coap_type = "reply"
_LOGGER.debug(
"CoapMessage: ip=%s, type=%s(%s), payload=%s",
"CoapMessage: ip=%s, type=%s(%s), options=%s, payload=%s",
self.ip,
coap_type,
self.code,
self.options,
self.payload,
)

@staticmethod
def _read_extended_field_value(value: int, raw_data: bytes) -> tuple[int, bytes]:
"""Decode large values of option delta and option length."""
if 0 <= value < 13:
return (value, raw_data)
if value == 13:
if len(raw_data) < 1:
raise InvalidMessage("Option ended prematurely")
return (raw_data[0] + 13, raw_data[1:])
if value == 14:
if len(raw_data) < 2:
raise InvalidMessage("Option ended prematurely")
return (int.from_bytes(raw_data[:2], "big") + 269, raw_data[2:])

raise InvalidMessage("Option contained partial payload marker.")


def socket_init(socket_port: int) -> socket.socket:
"""Init UDP socket to send/receive data with Shelly devices."""
Expand Down Expand Up @@ -110,24 +156,38 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
try:
msg = CoapMessage(addr, data)
except InvalidMessage as err:
if host_ip in self.subscriptions:
_LOGGER.error("Invalid Message from known host %s: %s", host_ip, err)
else:
_LOGGER.debug("Invalid Message from unknown host %s: %s", host_ip, err)
_LOGGER.debug("Invalid Message from host %s: %s", host_ip, err)
return

if self._message_received:
self._message_received(msg)

if COAP_OPTION_DEVICE_ID not in msg.options:
_LOGGER.debug("Message from host %s missing device id option", host_ip)
return

try:
device_id = msg.options[COAP_OPTION_DEVICE_ID].decode().split("#")[1][-6:]
except (UnicodeDecodeError, IndexError) as err:
_LOGGER.debug("Invalid device id from host %s: %s", host_ip, err)
return

if device_id in self.subscriptions:
_LOGGER.debug("Calling CoAP message update for device id %s", device_id)
self.subscriptions[device_id](msg)
return

if msg.ip in self.subscriptions:
_LOGGER.debug("Calling CoAP message update for device %s", msg.ip)
_LOGGER.debug("Calling CoAP message update for host %s", msg.ip)
self.subscriptions[msg.ip](msg)

def subscribe_updates(self, ip: str, message_received: Callable) -> Callable:
def subscribe_updates(
self, ip_or_device_id: str, message_received: Callable
) -> Callable:
"""Subscribe to received updates."""
_LOGGER.debug("Adding device %s to CoAP message subscriptions", ip)
self.subscriptions[ip] = message_received
return lambda: self.subscriptions.pop(ip)
_LOGGER.debug("Adding device %s to CoAP message subscriptions", ip_or_device_id)
self.subscriptions[ip_or_device_id] = message_received
return lambda: self.subscriptions.pop(ip_or_device_id)

async def __aenter__(self) -> "COAP":
"""Entering async context manager."""
Expand Down
8 changes: 6 additions & 2 deletions aioshelly/block_device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ShellyError,
WrongShellyGen,
)
from ..json import json_loads
from .coap import COAP, CoapMessage

BLOCK_VALUE_UNIT = "U"
Expand Down Expand Up @@ -62,8 +63,11 @@ def __init__(
self._settings: dict[str, Any] | None = None
self._shelly: dict[str, Any] | None = None
self._status: dict[str, Any] | None = None
sub_id = options.ip_address
if options.device_mac:
sub_id = options.device_mac[-6:]
self._unsub_coap: Callable | None = coap_context.subscribe_updates(
options.ip_address, self._coap_message_received
sub_id, self._coap_message_received
)
self._update_listener: Callable | None = None
self._coap_response_events: dict = {}
Expand Down Expand Up @@ -275,7 +279,7 @@ async def http_request(
self._last_error = DeviceConnectionError(err)
raise DeviceConnectionError from err

resp_json = await resp.json()
resp_json = await resp.json(loads=json_loads)
_LOGGER.debug("aiohttp response: %s", resp_json)
return cast(dict, resp_json)

Expand Down
1 change: 1 addition & 0 deletions aioshelly/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ConnectionOptions:
password: str | None = None
temperature_unit: str = "C"
auth: aiohttp.BasicAuth | None = None
device_mac: str | None = None

def __post_init__(self) -> None:
"""Call after initialization."""
Expand Down
1 change: 1 addition & 0 deletions aioshelly/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import orjson

JSONDecodeError = orjson.JSONDecodeError # pylint: disable=no-member
json_loads = orjson.loads # pylint: disable=no-member


Expand Down
5 changes: 4 additions & 1 deletion aioshelly/rpc_device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def __init__(
self._event: dict[str, Any] | None = None
self._config: dict[str, Any] | None = None
self._wsrpc = WsRPC(options.ip_address, self._on_notification)
sub_id = options.ip_address
if options.device_mac:
sub_id = options.device_mac
self._unsub_ws: Callable | None = ws_context.subscribe_updates(
options.ip_address, self._wsrpc.handle_frame
sub_id, self._wsrpc.handle_frame
)
self._update_listener: Callable | None = None
self.initialized: bool = False
Expand Down
20 changes: 15 additions & 5 deletions aioshelly/rpc_device/wsrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,23 @@ async def websocket_handler(self, request: BaseRequest) -> WebSocketResponse:
except ConnectionClosed:
await ws_res.close()
except InvalidMessage as err:
if ip in self.subscriptions:
_LOGGER.error("Invalid Message from known host %s: %s", ip, err)
else:
_LOGGER.debug("Invalid Message from unknown host %s: %s", ip, err)
_LOGGER.debug("Invalid Message from host %s: %s", ip, err)
else:
try:
device_id = frame["src"].split("-")[1].upper()
except (KeyError, IndexError) as err:
_LOGGER.debug("Invalid device id from host %s: %s", ip, err)
continue

if device_id in self.subscriptions:
_LOGGER.debug(
"Calling WsRPC message update for device id %s", device_id
)
self.subscriptions[device_id](frame)
continue

if ip in self.subscriptions:
_LOGGER.debug("Calling WsRPC message update for device %s", ip)
_LOGGER.debug("Calling WsRPC message update for host %s", ip)
self.subscriptions[ip](frame)

_LOGGER.debug("Websocket server connection from %s closed", ip)
Expand Down
7 changes: 6 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def get_arguments() -> tuple[argparse.ArgumentParser, argparse.Namespace]:
parser.add_argument(
"--debug", "-deb", action="store_true", help="Enable debug level for logging"
)
parser.add_argument(
"--mac", "-m", type=str, help="Optional device MAC to subscribe for updates"
)

arguments = parser.parse_args()

Expand Down Expand Up @@ -278,7 +281,9 @@ def handle_sigint(_exit_code: int, _frame: FrameType) -> None:
elif args.ip_address:
if args.username and args.password is None:
parser.error("--username and --password must be used together")
options = ConnectionOptions(args.ip_address, args.username, args.password)
options = ConnectionOptions(
args.ip_address, args.username, args.password, device_mac=args.mac
)
await test_single(options, args.init, gen)
else:
parser.error("--ip_address or --devices must be specified")
Expand Down

0 comments on commit c5b6f78

Please sign in to comment.