Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for subscribing to updates by device MAC #327

Merged
merged 4 commits into from
Jan 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions aioshelly/block_device/coap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from __future__ import annotations

import asyncio
import json
import logging
import socket
import struct
from types import TracebackType
from typing import Callable, cast

from ..json import JSONDecodeError, json_loads

COAP_OPTION_DEVICE_ID = 3332

_LOGGER = logging.getLogger(__name__)


Expand All @@ -27,6 +30,7 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
"""Initialize a coap message."""
self.ip = sender_addr[0]
self.port = sender_addr[1]
self.options: dict[int, bytes] = {}

try:
self.vttkl, self.code, self.mid = struct.unpack("!BBH", payload[:4])
Expand All @@ -36,9 +40,34 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
if self.code not in (30, 69):
raise InvalidMessage(f"Wrong type, {self.code}")

raw_data = payload[4:]
option_number = 0
data = b""

# parse options
while raw_data:
if raw_data[0] == 0xFF: # end of options marker
data = raw_data[1:]
break

delta = (raw_data[0] & 0xF0) >> 4
length = raw_data[0] & 0x0F
(delta, raw_data) = self._read_extended_field_value(delta, raw_data[1:])
(length, raw_data) = self._read_extended_field_value(length, raw_data)
option_number += delta

if len(raw_data) < length:
raise InvalidMessage("Option announced but absent")

self.options[option_number] = raw_data[:length]
raw_data = raw_data[length:]

if not data:
raise InvalidMessage("Received message without data")

try:
self.payload = json.loads(payload.rsplit(b"\xff", 1)[1].decode())
except (json.decoder.JSONDecodeError, UnicodeDecodeError, IndexError) as err:
self.payload = json_loads(data.decode())
except (JSONDecodeError, UnicodeDecodeError) as err:
raise InvalidMessage(
f"Message type {self.code} is not a valid JSON format: {str(payload)}"
) from err
Expand All @@ -48,13 +77,30 @@ def __init__(self, sender_addr: tuple[str, int], payload: bytes) -> None:
else:
coap_type = "reply"
_LOGGER.debug(
"CoapMessage: ip=%s, type=%s(%s), payload=%s",
"CoapMessage: ip=%s, type=%s(%s), options=%s, payload=%s",
self.ip,
coap_type,
self.code,
self.options,
self.payload,
)

@staticmethod
def _read_extended_field_value(value: int, raw_data: bytes) -> tuple[int, bytes]:
"""Decode large values of option delta and option length."""
if 0 <= value < 13:
return (value, raw_data)
if value == 13:
if len(raw_data) < 1:
raise InvalidMessage("Option ended prematurely")
return (raw_data[0] + 13, raw_data[1:])
if value == 14:
if len(raw_data) < 2:
raise InvalidMessage("Option ended prematurely")
return (int.from_bytes(raw_data[:2], "big") + 269, raw_data[2:])

raise InvalidMessage("Option contained partial payload marker.")


def socket_init(socket_port: int) -> socket.socket:
"""Init UDP socket to send/receive data with Shelly devices."""
Expand Down Expand Up @@ -110,24 +156,38 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
try:
msg = CoapMessage(addr, data)
except InvalidMessage as err:
if host_ip in self.subscriptions:
_LOGGER.error("Invalid Message from known host %s: %s", host_ip, err)
else:
_LOGGER.debug("Invalid Message from unknown host %s: %s", host_ip, err)
_LOGGER.debug("Invalid Message from host %s: %s", host_ip, err)
return

if self._message_received:
self._message_received(msg)

if COAP_OPTION_DEVICE_ID not in msg.options:
_LOGGER.debug("Message from host %s missing device id option", host_ip)
return

try:
device_id = msg.options[COAP_OPTION_DEVICE_ID].decode().split("#")[1][-6:]
except (UnicodeDecodeError, IndexError) as err:
_LOGGER.debug("Invalid device id from host %s: %s", host_ip, err)
return

if device_id in self.subscriptions:
_LOGGER.debug("Calling CoAP message update for device id %s", device_id)
self.subscriptions[device_id](msg)
return

if msg.ip in self.subscriptions:
_LOGGER.debug("Calling CoAP message update for device %s", msg.ip)
_LOGGER.debug("Calling CoAP message update for host %s", msg.ip)
self.subscriptions[msg.ip](msg)

def subscribe_updates(self, ip: str, message_received: Callable) -> Callable:
def subscribe_updates(
self, ip_or_device_id: str, message_received: Callable
) -> Callable:
"""Subscribe to received updates."""
_LOGGER.debug("Adding device %s to CoAP message subscriptions", ip)
self.subscriptions[ip] = message_received
return lambda: self.subscriptions.pop(ip)
_LOGGER.debug("Adding device %s to CoAP message subscriptions", ip_or_device_id)
self.subscriptions[ip_or_device_id] = message_received
return lambda: self.subscriptions.pop(ip_or_device_id)

async def __aenter__(self) -> "COAP":
"""Entering async context manager."""
Expand Down
8 changes: 6 additions & 2 deletions aioshelly/block_device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ShellyError,
WrongShellyGen,
)
from ..json import json_loads
from .coap import COAP, CoapMessage

BLOCK_VALUE_UNIT = "U"
Expand Down Expand Up @@ -62,8 +63,11 @@ def __init__(
self._settings: dict[str, Any] | None = None
self.shelly: dict[str, Any] | None = None
self._status: dict[str, Any] | None = None
sub_id = options.ip_address
if options.device_mac:
sub_id = options.device_mac[-6:]
self._unsub_coap: Callable | None = coap_context.subscribe_updates(
options.ip_address, self._coap_message_received
sub_id, self._coap_message_received
)
self._update_listener: Callable | None = None
self._coap_response_events: dict = {}
Expand Down Expand Up @@ -275,7 +279,7 @@ async def http_request(
self._last_error = DeviceConnectionError(err)
raise DeviceConnectionError from err

resp_json = await resp.json()
resp_json = await resp.json(loads=json_loads)
_LOGGER.debug("aiohttp response: %s", resp_json)
return cast(dict, resp_json)

Expand Down
1 change: 1 addition & 0 deletions aioshelly/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ConnectionOptions:
password: str | None = None
temperature_unit: str = "C"
auth: aiohttp.BasicAuth | None = None
device_mac: str | None = None

def __post_init__(self) -> None:
"""Call after initialization."""
Expand Down
1 change: 1 addition & 0 deletions aioshelly/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import orjson

JSONDecodeError = orjson.JSONDecodeError # pylint: disable=no-member
json_loads = orjson.loads # pylint: disable=no-member


Expand Down
5 changes: 4 additions & 1 deletion aioshelly/rpc_device/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ def __init__(
self._event: dict[str, Any] | None = None
self._config: dict[str, Any] | None = None
self._wsrpc = WsRPC(options.ip_address, self._on_notification)
sub_id = options.ip_address
if options.device_mac:
sub_id = options.device_mac
self._unsub_ws: Callable | None = ws_context.subscribe_updates(
options.ip_address, self._wsrpc.handle_frame
sub_id, self._wsrpc.handle_frame
)
self._update_listener: Callable | None = None
self.initialized: bool = False
Expand Down
20 changes: 15 additions & 5 deletions aioshelly/rpc_device/wsrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,23 @@ async def websocket_handler(self, request: BaseRequest) -> WebSocketResponse:
except ConnectionClosed:
await ws_res.close()
except InvalidMessage as err:
if ip in self.subscriptions:
_LOGGER.error("Invalid Message from known host %s: %s", ip, err)
else:
_LOGGER.debug("Invalid Message from unknown host %s: %s", ip, err)
_LOGGER.debug("Invalid Message from host %s: %s", ip, err)
else:
try:
device_id = frame["src"].split("-")[1].upper()
except (KeyError, IndexError) as err:
_LOGGER.debug("Invalid device id from host %s: %s", ip, err)
continue

if device_id in self.subscriptions:
_LOGGER.debug(
"Calling WsRPC message update for device id %s", device_id
)
self.subscriptions[device_id](frame)
continue

if ip in self.subscriptions:
_LOGGER.debug("Calling WsRPC message update for device %s", ip)
_LOGGER.debug("Calling WsRPC message update for host %s", ip)
self.subscriptions[ip](frame)

_LOGGER.debug("Websocket server connection from %s closed", ip)
Expand Down
7 changes: 6 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def get_arguments() -> tuple[argparse.ArgumentParser, argparse.Namespace]:
parser.add_argument(
"--debug", "-deb", action="store_true", help="Enable debug level for logging"
)
parser.add_argument(
"--mac", "-m", type=str, help="Optional device MAC to subscribe for updates"
)

arguments = parser.parse_args()

Expand Down Expand Up @@ -278,7 +281,9 @@ def handle_sigint(_exit_code: int, _frame: FrameType) -> None:
elif args.ip_address:
if args.username and args.password is None:
parser.error("--username and --password must be used together")
options = ConnectionOptions(args.ip_address, args.username, args.password)
options = ConnectionOptions(
args.ip_address, args.username, args.password, device_mac=args.mac
)
await test_single(options, args.init, gen)
else:
parser.error("--ip_address or --devices must be specified")
Expand Down