From 96d437aeef55aee77db48e406217d810917f8185 Mon Sep 17 00:00:00 2001 From: Erik Date: Tue, 17 May 2022 10:57:31 +0200 Subject: [PATCH 1/2] Move MQTT config schemas and client to separate modules --- homeassistant/components/mqtt/__init__.py | 962 +----------------- .../components/mqtt/alarm_control_panel.py | 14 +- .../components/mqtt/binary_sensor.py | 7 +- homeassistant/components/mqtt/button.py | 11 +- homeassistant/components/mqtt/camera.py | 7 +- homeassistant/components/mqtt/client.py | 659 ++++++++++++ homeassistant/components/mqtt/climate.py | 58 +- homeassistant/components/mqtt/config.py | 148 +++ homeassistant/components/mqtt/config_flow.py | 2 +- homeassistant/components/mqtt/const.py | 26 +- homeassistant/components/mqtt/cover.py | 20 +- .../components/mqtt/device_automation.py | 4 +- .../mqtt/device_tracker/schema_discovery.py | 7 +- .../mqtt/device_tracker/schema_yaml.py | 10 +- .../components/mqtt/device_trigger.py | 6 +- homeassistant/components/mqtt/fan.py | 24 +- homeassistant/components/mqtt/humidifier.py | 18 +- .../components/mqtt/light/schema_basic.py | 48 +- .../components/mqtt/light/schema_json.py | 11 +- .../components/mqtt/light/schema_template.py | 7 +- homeassistant/components/mqtt/lock.py | 7 +- homeassistant/components/mqtt/mixins.py | 14 +- homeassistant/components/mqtt/models.py | 126 ++- homeassistant/components/mqtt/number.py | 7 +- homeassistant/components/mqtt/scene.py | 10 +- homeassistant/components/mqtt/select.py | 7 +- homeassistant/components/mqtt/sensor.py | 10 +- homeassistant/components/mqtt/siren.py | 7 +- homeassistant/components/mqtt/switch.py | 7 +- homeassistant/components/mqtt/tag.py | 8 +- .../components/mqtt/vacuum/schema_legacy.py | 28 +- .../components/mqtt/vacuum/schema_state.py | 15 +- tests/components/mqtt/test_cover.py | 2 +- tests/components/mqtt/test_init.py | 18 +- tests/components/mqtt/test_legacy_vacuum.py | 2 +- tests/components/mqtt/test_state_vacuum.py | 2 +- 36 files changed, 1210 insertions(+), 1109 deletions(-) create mode 100644 homeassistant/components/mqtt/client.py create mode 100644 homeassistant/components/mqtt/config.py diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 46eb7052f4fe..5cb2223a9ac0 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -1,23 +1,13 @@ """Support for MQTT message handling.""" from __future__ import annotations -from ast import literal_eval import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Callable from dataclasses import dataclass import datetime as dt -from functools import lru_cache, partial, wraps -import inspect -from itertools import groupby import logging -from operator import attrgetter -import ssl -import time -from typing import TYPE_CHECKING, Any, Union, cast -import uuid - -import attr -import certifi +from typing import Any, cast + import jinja2 import voluptuous as vol @@ -25,43 +15,32 @@ from homeassistant.components import websocket_api from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( - ATTR_ENTITY_ID, - ATTR_NAME, - CONF_CLIENT_ID, CONF_DISCOVERY, CONF_PASSWORD, CONF_PAYLOAD, CONF_PORT, - CONF_PROTOCOL, CONF_USERNAME, - CONF_VALUE_TEMPLATE, - EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, SERVICE_RELOAD, - Platform, -) -from homeassistant.core import ( - CoreState, - Event, - HassJob, - HomeAssistant, - ServiceCall, - callback, ) +from homeassistant.core import Event, HassJob, HomeAssistant, ServiceCall, callback from homeassistant.data_entry_flow import BaseServiceInfo -from homeassistant.exceptions import HomeAssistantError, TemplateError, Unauthorized +from homeassistant.exceptions import TemplateError, Unauthorized from homeassistant.helpers import config_validation as cv, event, template from homeassistant.helpers.device_registry import DeviceEntry -from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send -from homeassistant.helpers.entity import Entity -from homeassistant.helpers.typing import ConfigType, TemplateVarsType -from homeassistant.loader import bind_hass -from homeassistant.util import dt as dt_util -from homeassistant.util.async_ import run_callback_threadsafe -from homeassistant.util.logging import catch_log_exception +from homeassistant.helpers.dispatcher import async_dispatcher_connect +from homeassistant.helpers.typing import ConfigType # Loading the config flow file will register the flow from . import debug_info, discovery +from .client import ( # noqa: F401 + MQTT, + async_publish, + async_subscribe, + publish, + subscribe, +) +from .config import CONFIG_SCHEMA_BASE, DEFAULT_VALUES, DEPRECATED_CONFIG_KEYS from .const import ( ATTR_PAYLOAD, ATTR_QOS, @@ -69,76 +48,36 @@ ATTR_TOPIC, CONF_BIRTH_MESSAGE, CONF_BROKER, - CONF_CERTIFICATE, - CONF_CLIENT_CERT, - CONF_CLIENT_KEY, - CONF_COMMAND_TOPIC, - CONF_ENCODING, - CONF_QOS, - CONF_RETAIN, - CONF_STATE_TOPIC, - CONF_TLS_INSECURE, + CONF_DISCOVERY_PREFIX, CONF_TLS_VERSION, CONF_TOPIC, CONF_WILL_MESSAGE, CONFIG_ENTRY_IS_SETUP, DATA_CONFIG_ENTRY_LOCK, + DATA_MQTT, DATA_MQTT_CONFIG, DATA_MQTT_RELOAD_NEEDED, - DEFAULT_BIRTH, - DEFAULT_DISCOVERY, DEFAULT_ENCODING, - DEFAULT_PREFIX, DEFAULT_QOS, DEFAULT_RETAIN, - DEFAULT_WILL, DOMAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, - PROTOCOL_31, - PROTOCOL_311, + PLATFORMS, ) -from .discovery import LAST_DISCOVERY -from .models import ( - AsyncMessageCallbackType, - MessageCallbackType, - PublishMessage, - PublishPayloadType, +from .models import ( # noqa: F401 + MqttCommandTemplate, + MqttValueTemplate, ReceiveMessage, ReceivePayloadType, ) from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic -if TYPE_CHECKING: - # Only import for paho-mqtt type checking here, imports are done locally - # because integrations should be able to optionally rely on MQTT. - import paho.mqtt.client as mqtt - _LOGGER = logging.getLogger(__name__) -_SENTINEL = object() - -DATA_MQTT = "mqtt" - SERVICE_PUBLISH = "publish" SERVICE_DUMP = "dump" -CONF_DISCOVERY_PREFIX = "discovery_prefix" -CONF_KEEPALIVE = "keepalive" - -DEFAULT_PORT = 1883 -DEFAULT_KEEPALIVE = 60 -DEFAULT_PROTOCOL = PROTOCOL_311 -DEFAULT_TLS_PROTOCOL = "auto" - -DEFAULT_VALUES = { - CONF_BIRTH_MESSAGE: DEFAULT_BIRTH, - CONF_DISCOVERY: DEFAULT_DISCOVERY, - CONF_PORT: DEFAULT_PORT, - CONF_TLS_VERSION: DEFAULT_TLS_PROTOCOL, - CONF_WILL_MESSAGE: DEFAULT_WILL, -} - MANDATORY_DEFAULT_VALUES = (CONF_PORT,) ATTR_TOPIC_TEMPLATE = "topic_template" @@ -150,93 +89,6 @@ CONNECTION_FAILED = "connection_failed" CONNECTION_FAILED_RECOVERABLE = "connection_failed_recoverable" -DISCOVERY_COOLDOWN = 2 -TIMEOUT_ACK = 10 - -PLATFORMS = [ - Platform.ALARM_CONTROL_PANEL, - Platform.BINARY_SENSOR, - Platform.BUTTON, - Platform.CAMERA, - Platform.CLIMATE, - Platform.DEVICE_TRACKER, - Platform.COVER, - Platform.FAN, - Platform.HUMIDIFIER, - Platform.LIGHT, - Platform.LOCK, - Platform.NUMBER, - Platform.SELECT, - Platform.SCENE, - Platform.SENSOR, - Platform.SIREN, - Platform.SWITCH, - Platform.VACUUM, -] - -CLIENT_KEY_AUTH_MSG = ( - "client_key and client_cert must both be present in " - "the MQTT broker configuration" -) - -MQTT_WILL_BIRTH_SCHEMA = vol.Schema( - { - vol.Inclusive(ATTR_TOPIC, "topic_payload"): valid_publish_topic, - vol.Inclusive(ATTR_PAYLOAD, "topic_payload"): cv.string, - vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA, - vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean, - }, - required=True, -) - -PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema( - {vol.Optional(platform.value): cv.ensure_list for platform in PLATFORMS} -) - -CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend( - { - vol.Optional(CONF_CLIENT_ID): cv.string, - vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All( - vol.Coerce(int), vol.Range(min=15) - ), - vol.Optional(CONF_BROKER): cv.string, - vol.Optional(CONF_PORT): cv.port, - vol.Optional(CONF_USERNAME): cv.string, - vol.Optional(CONF_PASSWORD): cv.string, - vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile), - vol.Inclusive( - CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG - ): cv.isfile, - vol.Inclusive( - CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG - ): cv.isfile, - vol.Optional(CONF_TLS_INSECURE): cv.boolean, - vol.Optional(CONF_TLS_VERSION): vol.Any("auto", "1.0", "1.1", "1.2"), - vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All( - cv.string, vol.In([PROTOCOL_31, PROTOCOL_311]) - ), - vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, - vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, - vol.Optional(CONF_DISCOVERY): cv.boolean, - # discovery_prefix must be a valid publish topic because if no - # state topic is specified, it will be created with the given prefix. - vol.Optional( - CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX - ): valid_publish_topic, - } -) - -DEPRECATED_CONFIG_KEYS = [ - CONF_BIRTH_MESSAGE, - CONF_BROKER, - CONF_DISCOVERY, - CONF_PASSWORD, - CONF_PORT, - CONF_TLS_VERSION, - CONF_USERNAME, - CONF_WILL_MESSAGE, -] - CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.All( @@ -254,29 +106,6 @@ extra=vol.ALLOW_EXTRA, ) -SCHEMA_BASE = { - vol.Optional(CONF_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA, - vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, -} - -MQTT_BASE_SCHEMA = vol.Schema(SCHEMA_BASE) - -# Sensor type platforms subscribe to MQTT events -MQTT_RO_SCHEMA = MQTT_BASE_SCHEMA.extend( - { - vol.Required(CONF_STATE_TOPIC): valid_subscribe_topic, - vol.Optional(CONF_VALUE_TEMPLATE): cv.template, - } -) - -# Switch type platforms publish to MQTT and may subscribe -MQTT_RW_SCHEMA = MQTT_BASE_SCHEMA.extend( - { - vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic, - vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, - vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic, - } -) # Service call validation schema MQTT_PUBLISH_SCHEMA = vol.All( @@ -295,124 +124,6 @@ ) -SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None - - -class MqttCommandTemplate: - """Class for rendering MQTT payload with command templates.""" - - def __init__( - self, - command_template: template.Template | None, - *, - hass: HomeAssistant | None = None, - entity: Entity | None = None, - ) -> None: - """Instantiate a command template.""" - self._attr_command_template = command_template - if command_template is None: - return - - self._entity = entity - - command_template.hass = hass - - if entity: - command_template.hass = entity.hass - - @callback - def async_render( - self, - value: PublishPayloadType = None, - variables: TemplateVarsType = None, - ) -> PublishPayloadType: - """Render or convert the command template with given value or variables.""" - - def _convert_outgoing_payload( - payload: PublishPayloadType, - ) -> PublishPayloadType: - """Ensure correct raw MQTT payload is passed as bytes for publishing.""" - if isinstance(payload, str): - try: - native_object = literal_eval(payload) - if isinstance(native_object, bytes): - return native_object - - except (ValueError, TypeError, SyntaxError, MemoryError): - pass - - return payload - - if self._attr_command_template is None: - return value - - values = {"value": value} - if self._entity: - values[ATTR_ENTITY_ID] = self._entity.entity_id - values[ATTR_NAME] = self._entity.name - if variables is not None: - values.update(variables) - return _convert_outgoing_payload( - self._attr_command_template.async_render(values, parse_result=False) - ) - - -class MqttValueTemplate: - """Class for rendering MQTT value template with possible json values.""" - - def __init__( - self, - value_template: template.Template | None, - *, - hass: HomeAssistant | None = None, - entity: Entity | None = None, - config_attributes: TemplateVarsType = None, - ) -> None: - """Instantiate a value template.""" - self._value_template = value_template - self._config_attributes = config_attributes - if value_template is None: - return - - value_template.hass = hass - self._entity = entity - - if entity: - value_template.hass = entity.hass - - @callback - def async_render_with_possible_json_value( - self, - payload: ReceivePayloadType, - default: ReceivePayloadType | object = _SENTINEL, - variables: TemplateVarsType = None, - ) -> ReceivePayloadType: - """Render with possible json value or pass-though a received MQTT value.""" - if self._value_template is None: - return payload - - values: dict[str, Any] = {} - - if variables is not None: - values.update(variables) - - if self._config_attributes is not None: - values.update(self._config_attributes) - - if self._entity: - values[ATTR_ENTITY_ID] = self._entity.entity_id - values[ATTR_NAME] = self._entity.name - - if default == _SENTINEL: - return self._value_template.async_render_with_possible_json_value( - payload, variables=values - ) - - return self._value_template.async_render_with_possible_json_value( - payload, default, variables=values - ) - - @dataclass class MqttServiceInfo(BaseServiceInfo): """Prepared info from mqtt entries.""" @@ -425,163 +136,6 @@ class MqttServiceInfo(BaseServiceInfo): timestamp: dt.datetime -def publish( - hass: HomeAssistant, - topic: str, - payload: PublishPayloadType, - qos: int | None = 0, - retain: bool | None = False, - encoding: str | None = DEFAULT_ENCODING, -) -> None: - """Publish message to a MQTT topic.""" - hass.add_job(async_publish, hass, topic, payload, qos, retain, encoding) - - -async def async_publish( - hass: HomeAssistant, - topic: str, - payload: PublishPayloadType, - qos: int | None = 0, - retain: bool | None = False, - encoding: str | None = DEFAULT_ENCODING, -) -> None: - """Publish message to a MQTT topic.""" - - outgoing_payload = payload - if not isinstance(payload, bytes): - if not encoding: - _LOGGER.error( - "Can't pass-through payload for publishing %s on %s with no encoding set, need 'bytes' got %s", - payload, - topic, - type(payload), - ) - return - outgoing_payload = str(payload) - if encoding != DEFAULT_ENCODING: - # a string is encoded as utf-8 by default, other encoding requires bytes as payload - try: - outgoing_payload = outgoing_payload.encode(encoding) - except (AttributeError, LookupError, UnicodeEncodeError): - _LOGGER.error( - "Can't encode payload for publishing %s on %s with encoding %s", - payload, - topic, - encoding, - ) - return - - await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain) - - -AsyncDeprecatedMessageCallbackType = Callable[ - [str, ReceivePayloadType, int], Awaitable[None] -] -DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] - - -def wrap_msg_callback( - msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType, -) -> AsyncMessageCallbackType | MessageCallbackType: - """Wrap an MQTT message callback to support deprecated signature.""" - # Check for partials to properly determine if coroutine function - check_func = msg_callback - while isinstance(check_func, partial): - check_func = check_func.func - - wrapper_func: AsyncMessageCallbackType | MessageCallbackType - if asyncio.iscoroutinefunction(check_func): - - @wraps(msg_callback) - async def async_wrapper(msg: ReceiveMessage) -> None: - """Call with deprecated signature.""" - await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( - msg.topic, msg.payload, msg.qos - ) - - wrapper_func = async_wrapper - else: - - @wraps(msg_callback) - def wrapper(msg: ReceiveMessage) -> None: - """Call with deprecated signature.""" - msg_callback(msg.topic, msg.payload, msg.qos) - - wrapper_func = wrapper - return wrapper_func - - -@bind_hass -async def async_subscribe( - hass: HomeAssistant, - topic: str, - msg_callback: AsyncMessageCallbackType - | MessageCallbackType - | DeprecatedMessageCallbackType - | AsyncDeprecatedMessageCallbackType, - qos: int = DEFAULT_QOS, - encoding: str | None = "utf-8", -): - """Subscribe to an MQTT topic. - - Call the return value to unsubscribe. - """ - # Count callback parameters which don't have a default value - non_default = 0 - if msg_callback: - non_default = sum( - p.default == inspect.Parameter.empty - for _, p in inspect.signature(msg_callback).parameters.items() - ) - - wrapped_msg_callback = msg_callback - # If we have 3 parameters with no default value, wrap the callback - if non_default == 3: - module = inspect.getmodule(msg_callback) - _LOGGER.warning( - "Signature of MQTT msg_callback '%s.%s' is deprecated", - module.__name__ if module else "", - msg_callback.__name__, - ) - wrapped_msg_callback = wrap_msg_callback( - cast(DeprecatedMessageCallbackType, msg_callback) - ) - - async_remove = await hass.data[DATA_MQTT].async_subscribe( - topic, - catch_log_exception( - wrapped_msg_callback, - lambda msg: ( - f"Exception in {msg_callback.__name__} when handling msg on " - f"'{msg.topic}': '{msg.payload}'" - ), - ), - qos, - encoding, - ) - return async_remove - - -@bind_hass -def subscribe( - hass: HomeAssistant, - topic: str, - msg_callback: MessageCallbackType, - qos: int = DEFAULT_QOS, - encoding: str = "utf-8", -) -> Callable[[], None]: - """Subscribe to an MQTT topic.""" - async_remove = asyncio.run_coroutine_threadsafe( - async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop - ).result() - - def remove(): - """Remove listener convert.""" - run_callback_threadsafe(hass.loop, async_remove).result() - - return remove - - async def _async_setup_discovery( hass: HomeAssistant, conf: ConfigType, config_entry ) -> None: @@ -649,6 +203,26 @@ def _merge_extended_config(entry, conf): return {**conf, **entry.data} +async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) -> None: + """Handle signals of config entry being updated. + + Causes for this is config entry options changing. + """ + mqtt_client = hass.data[DATA_MQTT] + + if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: + conf = CONFIG_SCHEMA_BASE(dict(entry.data)) + + mqtt_client.conf = _merge_extended_config(entry, conf) + await mqtt_client.async_disconnect() + mqtt_client.init_client() + await mqtt_client.async_connect() + + await discovery.async_stop(hass) + if mqtt_client.conf.get(CONF_DISCOVERY): + await _async_setup_discovery(hass, mqtt_client.conf, entry) + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Load a config entry.""" # Merge basic configuration, and add missing defaults for basic options @@ -685,6 +259,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: entry, conf, ) + entry.add_update_listener(_async_config_entry_updated) await hass.data[DATA_MQTT].async_connect() @@ -813,459 +388,6 @@ async def async_forward_entry_setup(): return True -@attr.s(slots=True, frozen=True) -class Subscription: - """Class to hold data about an active subscription.""" - - topic: str = attr.ib() - matcher: Any = attr.ib() - job: HassJob = attr.ib() - qos: int = attr.ib(default=0) - encoding: str | None = attr.ib(default="utf-8") - - -class MqttClientSetup: - """Helper class to setup the paho mqtt client from config.""" - - def __init__(self, config: ConfigType) -> None: - """Initialize the MQTT client setup helper.""" - - # We don't import on the top because some integrations - # should be able to optionally rely on MQTT. - import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - - if config[CONF_PROTOCOL] == PROTOCOL_31: - proto = mqtt.MQTTv31 - else: - proto = mqtt.MQTTv311 - - if (client_id := config.get(CONF_CLIENT_ID)) is None: - # PAHO MQTT relies on the MQTT server to generate random client IDs. - # However, that feature is not mandatory so we generate our own. - client_id = mqtt.base62(uuid.uuid4().int, padding=22) - self._client = mqtt.Client(client_id, protocol=proto) - - # Enable logging - self._client.enable_logger() - - username = config.get(CONF_USERNAME) - password = config.get(CONF_PASSWORD) - if username is not None: - self._client.username_pw_set(username, password) - - if (certificate := config.get(CONF_CERTIFICATE)) == "auto": - certificate = certifi.where() - - client_key = config.get(CONF_CLIENT_KEY) - client_cert = config.get(CONF_CLIENT_CERT) - tls_insecure = config.get(CONF_TLS_INSECURE) - if certificate is not None: - self._client.tls_set( - certificate, - certfile=client_cert, - keyfile=client_key, - tls_version=ssl.PROTOCOL_TLS, - ) - - if tls_insecure is not None: - self._client.tls_insecure_set(tls_insecure) - - @property - def client(self) -> mqtt.Client: - """Return the paho MQTT client.""" - return self._client - - -class MQTT: - """Home Assistant MQTT client.""" - - def __init__( - self, - hass: HomeAssistant, - config_entry, - conf, - ) -> None: - """Initialize Home Assistant MQTT client.""" - # We don't import on the top because some integrations - # should be able to optionally rely on MQTT. - import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - - self.hass = hass - self.config_entry = config_entry - self.conf = conf - self.subscriptions: list[Subscription] = [] - self.connected = False - self._ha_started = asyncio.Event() - self._last_subscribe = time.time() - self._mqttc: mqtt.Client = None - self._paho_lock = asyncio.Lock() - - self._pending_operations: dict[str, asyncio.Event] = {} - - if self.hass.state == CoreState.running: - self._ha_started.set() - else: - - @callback - def ha_started(_): - self._ha_started.set() - - self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started) - - self.init_client() - self.config_entry.add_update_listener(self.async_config_entry_updated) - - @staticmethod - async def async_config_entry_updated( - hass: HomeAssistant, entry: ConfigEntry - ) -> None: - """Handle signals of config entry being updated. - - This is a static method because a class method (bound method), can not be used with weak references. - Causes for this is config entry options changing. - """ - self = hass.data[DATA_MQTT] - - if (conf := hass.data.get(DATA_MQTT_CONFIG)) is None: - conf = CONFIG_SCHEMA_BASE(dict(entry.data)) - - self.conf = _merge_extended_config(entry, conf) - await self.async_disconnect() - self.init_client() - await self.async_connect() - - await discovery.async_stop(hass) - if self.conf.get(CONF_DISCOVERY): - await _async_setup_discovery(hass, self.conf, entry) - - def init_client(self): - """Initialize paho client.""" - self._mqttc = MqttClientSetup(self.conf).client - self._mqttc.on_connect = self._mqtt_on_connect - self._mqttc.on_disconnect = self._mqtt_on_disconnect - self._mqttc.on_message = self._mqtt_on_message - self._mqttc.on_publish = self._mqtt_on_callback - self._mqttc.on_subscribe = self._mqtt_on_callback - self._mqttc.on_unsubscribe = self._mqtt_on_callback - - if ( - CONF_WILL_MESSAGE in self.conf - and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE] - ): - will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE]) - else: - will_message = None - - if will_message is not None: - self._mqttc.will_set( - topic=will_message.topic, - payload=will_message.payload, - qos=will_message.qos, - retain=will_message.retain, - ) - - async def async_publish( - self, topic: str, payload: PublishPayloadType, qos: int, retain: bool - ) -> None: - """Publish a MQTT message.""" - async with self._paho_lock: - msg_info = await self.hass.async_add_executor_job( - self._mqttc.publish, topic, payload, qos, retain - ) - _LOGGER.debug( - "Transmitting message on %s: '%s', mid: %s", - topic, - payload, - msg_info.mid, - ) - _raise_on_error(msg_info.rc) - await self._wait_for_mid(msg_info.mid) - - async def async_connect(self) -> None: - """Connect to the host. Does not process messages yet.""" - # pylint: disable-next=import-outside-toplevel - import paho.mqtt.client as mqtt - - result: int | None = None - try: - result = await self.hass.async_add_executor_job( - self._mqttc.connect, - self.conf[CONF_BROKER], - self.conf[CONF_PORT], - self.conf[CONF_KEEPALIVE], - ) - except OSError as err: - _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) - - if result is not None and result != 0: - _LOGGER.error( - "Failed to connect to MQTT server: %s", mqtt.error_string(result) - ) - - self._mqttc.loop_start() - - async def async_disconnect(self): - """Stop the MQTT client.""" - - def stop(): - """Stop the MQTT client.""" - # Do not disconnect, we want the broker to always publish will - self._mqttc.loop_stop() - - await self.hass.async_add_executor_job(stop) - - async def async_subscribe( - self, - topic: str, - msg_callback: MessageCallbackType, - qos: int, - encoding: str | None = None, - ) -> Callable[[], None]: - """Set up a subscription to a topic with the provided qos. - - This method is a coroutine. - """ - if not isinstance(topic, str): - raise HomeAssistantError("Topic needs to be a string!") - - subscription = Subscription( - topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding - ) - self.subscriptions.append(subscription) - self._matching_subscriptions.cache_clear() - - # Only subscribe if currently connected. - if self.connected: - self._last_subscribe = time.time() - await self._async_perform_subscription(topic, qos) - - @callback - def async_remove() -> None: - """Remove subscription.""" - if subscription not in self.subscriptions: - raise HomeAssistantError("Can't remove subscription twice") - self.subscriptions.remove(subscription) - self._matching_subscriptions.cache_clear() - - # Only unsubscribe if currently connected. - if self.connected: - self.hass.async_create_task(self._async_unsubscribe(topic)) - - return async_remove - - async def _async_unsubscribe(self, topic: str) -> None: - """Unsubscribe from a topic. - - This method is a coroutine. - """ - if any(other.topic == topic for other in self.subscriptions): - # Other subscriptions on topic remaining - don't unsubscribe. - return - - async with self._paho_lock: - result: int | None = None - result, mid = await self.hass.async_add_executor_job( - self._mqttc.unsubscribe, topic - ) - _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) - _raise_on_error(result) - await self._wait_for_mid(mid) - - async def _async_perform_subscription(self, topic: str, qos: int) -> None: - """Perform a paho-mqtt subscription.""" - async with self._paho_lock: - result: int | None = None - result, mid = await self.hass.async_add_executor_job( - self._mqttc.subscribe, topic, qos - ) - _LOGGER.debug("Subscribing to %s, mid: %s", topic, mid) - _raise_on_error(result) - await self._wait_for_mid(mid) - - def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: - """On connect callback. - - Resubscribe to all topics we were subscribed to and publish birth - message. - """ - # pylint: disable-next=import-outside-toplevel - import paho.mqtt.client as mqtt - - if result_code != mqtt.CONNACK_ACCEPTED: - _LOGGER.error( - "Unable to connect to the MQTT broker: %s", - mqtt.connack_string(result_code), - ) - return - - self.connected = True - dispatcher_send(self.hass, MQTT_CONNECTED) - _LOGGER.info( - "Connected to MQTT server %s:%s (%s)", - self.conf[CONF_BROKER], - self.conf[CONF_PORT], - result_code, - ) - - # Group subscriptions to only re-subscribe once for each topic. - keyfunc = attrgetter("topic") - for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc): - # Re-subscribe with the highest requested qos - max_qos = max(subscription.qos for subscription in subs) - self.hass.add_job(self._async_perform_subscription, topic, max_qos) - - if ( - CONF_BIRTH_MESSAGE in self.conf - and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE] - ): - - async def publish_birth_message(birth_message): - await self._ha_started.wait() # Wait for Home Assistant to start - await self._discovery_cooldown() # Wait for MQTT discovery to cool down - await self.async_publish( - topic=birth_message.topic, - payload=birth_message.payload, - qos=birth_message.qos, - retain=birth_message.retain, - ) - - birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE]) - asyncio.run_coroutine_threadsafe( - publish_birth_message(birth_message), self.hass.loop - ) - - def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: - """Message received callback.""" - self.hass.add_job(self._mqtt_handle_message, msg) - - @lru_cache(2048) - def _matching_subscriptions(self, topic): - subscriptions = [] - for subscription in self.subscriptions: - if subscription.matcher(topic): - subscriptions.append(subscription) - return subscriptions - - @callback - def _mqtt_handle_message(self, msg) -> None: - _LOGGER.debug( - "Received message on %s%s: %s", - msg.topic, - " (retained)" if msg.retain else "", - msg.payload[0:8192], - ) - timestamp = dt_util.utcnow() - - subscriptions = self._matching_subscriptions(msg.topic) - - for subscription in subscriptions: - - payload: SubscribePayloadType = msg.payload - if subscription.encoding is not None: - try: - payload = msg.payload.decode(subscription.encoding) - except (AttributeError, UnicodeDecodeError): - _LOGGER.warning( - "Can't decode payload %s on %s with encoding %s (for %s)", - msg.payload[0:8192], - msg.topic, - subscription.encoding, - subscription.job, - ) - continue - - self.hass.async_run_hass_job( - subscription.job, - ReceiveMessage( - msg.topic, - payload, - msg.qos, - msg.retain, - subscription.topic, - timestamp, - ), - ) - - def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None: - """Publish / Subscribe / Unsubscribe callback.""" - self.hass.add_job(self._mqtt_handle_mid, mid) - - @callback - def _mqtt_handle_mid(self, mid) -> None: - # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid - # may be executed first. - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() - self._pending_operations[mid].set() - - def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: - """Disconnected callback.""" - self.connected = False - dispatcher_send(self.hass, MQTT_DISCONNECTED) - _LOGGER.warning( - "Disconnected from MQTT server %s:%s (%s)", - self.conf[CONF_BROKER], - self.conf[CONF_PORT], - result_code, - ) - - async def _wait_for_mid(self, mid): - """Wait for ACK from broker.""" - # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid - # may be executed first. - if mid not in self._pending_operations: - self._pending_operations[mid] = asyncio.Event() - try: - await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK) - except asyncio.TimeoutError: - _LOGGER.warning( - "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid - ) - finally: - del self._pending_operations[mid] - - async def _discovery_cooldown(self): - now = time.time() - # Reset discovery and subscribe cooldowns - self.hass.data[LAST_DISCOVERY] = now - self._last_subscribe = now - - last_discovery = self.hass.data[LAST_DISCOVERY] - last_subscribe = self._last_subscribe - wait_until = max( - last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN - ) - while now < wait_until: - await asyncio.sleep(wait_until - now) - now = time.time() - last_discovery = self.hass.data[LAST_DISCOVERY] - last_subscribe = self._last_subscribe - wait_until = max( - last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN - ) - - -def _raise_on_error(result_code: int | None) -> None: - """Raise error if error result.""" - # pylint: disable-next=import-outside-toplevel - import paho.mqtt.client as mqtt - - if result_code is not None and result_code != 0: - raise HomeAssistantError( - f"Error talking to MQTT: {mqtt.error_string(result_code)}" - ) - - -def _matcher_for_topic(subscription: str) -> Any: - # pylint: disable-next=import-outside-toplevel - from paho.mqtt.matcher import MQTTMatcher - - matcher = MQTTMatcher() - matcher[subscription] = True - - return lambda topic: next(matcher.iter_match(topic), False) - - @websocket_api.websocket_command( {vol.Required("type"): "mqtt/device/debug_info", vol.Required("device_id"): str} ) diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index 06c013ec7442..c20fbb7c657d 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -31,8 +31,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import DEFAULT_RETAIN, MQTT_BASE_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -50,6 +50,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate +from .util import valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -85,7 +87,7 @@ REMOTE_CODE = "REMOTE_CODE" REMOTE_CODE_TEXT = "REMOTE_CODE_TEXT" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_BASE_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_BASE_SCHEMA.extend( { vol.Optional(CONF_CODE): cv.string, vol.Optional(CONF_CODE_ARM_REQUIRED, default=True): cv.boolean, @@ -94,7 +96,7 @@ vol.Optional( CONF_COMMAND_TEMPLATE, default=DEFAULT_COMMAND_TEMPLATE ): cv.template, - vol.Required(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_PAYLOAD_ARM_AWAY, default=DEFAULT_ARM_AWAY): cv.string, vol.Optional(CONF_PAYLOAD_ARM_HOME, default=DEFAULT_ARM_HOME): cv.string, @@ -107,8 +109,8 @@ ): cv.string, vol.Optional(CONF_PAYLOAD_DISARM, default=DEFAULT_DISARM): cv.string, vol.Optional(CONF_PAYLOAD_TRIGGER, default=DEFAULT_TRIGGER): cv.string, - vol.Optional(CONF_RETAIN, default=mqtt.DEFAULT_RETAIN): cv.boolean, - vol.Required(CONF_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, + vol.Required(CONF_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, } ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) diff --git a/homeassistant/components/mqtt/binary_sensor.py b/homeassistant/components/mqtt/binary_sensor.py index b9ab190cc9bf..1cb90d6c903d 100644 --- a/homeassistant/components/mqtt/binary_sensor.py +++ b/homeassistant/components/mqtt/binary_sensor.py @@ -34,8 +34,8 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util import dt as dt_util -from . import MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RO_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC, PAYLOAD_NONE from .debug_info import log_messages from .mixins import ( @@ -47,6 +47,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttValueTemplate _LOGGER = logging.getLogger(__name__) @@ -57,7 +58,7 @@ DEFAULT_FORCE_UPDATE = False CONF_EXPIRE_AFTER = "expire_after" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RO_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend( { vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_EXPIRE_AFTER): cv.positive_int, diff --git a/homeassistant/components/mqtt/button.py b/homeassistant/components/mqtt/button.py index 47e96ff3e1a1..b50856d20c16 100644 --- a/homeassistant/components/mqtt/button.py +++ b/homeassistant/components/mqtt/button.py @@ -15,8 +15,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate -from .. import mqtt +from .config import DEFAULT_RETAIN, MQTT_BASE_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -32,19 +31,21 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate +from .util import valid_publish_topic CONF_PAYLOAD_PRESS = "payload_press" DEFAULT_NAME = "MQTT Button" DEFAULT_PAYLOAD_PRESS = "PRESS" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_BASE_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_BASE_SCHEMA.extend( { vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, - vol.Required(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_DEVICE_CLASS): button.DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_PAYLOAD_PRESS, default=DEFAULT_PAYLOAD_PRESS): cv.string, - vol.Optional(CONF_RETAIN, default=mqtt.DEFAULT_RETAIN): cv.boolean, + vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, } ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) diff --git a/homeassistant/components/mqtt/camera.py b/homeassistant/components/mqtt/camera.py index 2e5d95ebda40..ae38e07d17ac 100644 --- a/homeassistant/components/mqtt/camera.py +++ b/homeassistant/components/mqtt/camera.py @@ -17,7 +17,7 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import subscription -from .. import mqtt +from .config import MQTT_BASE_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_TOPIC from .debug_info import log_messages from .mixins import ( @@ -28,6 +28,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .util import valid_subscribe_topic DEFAULT_NAME = "MQTT Camera" @@ -40,10 +41,10 @@ } ) -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_BASE_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_BASE_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Required(CONF_TOPIC): mqtt.valid_subscribe_topic, + vol.Required(CONF_TOPIC): valid_subscribe_topic, } ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema) diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py new file mode 100644 index 000000000000..666993725165 --- /dev/null +++ b/homeassistant/components/mqtt/client.py @@ -0,0 +1,659 @@ +"""Support for MQTT message handling.""" +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from functools import lru_cache, partial, wraps +import inspect +from itertools import groupby +import logging +from operator import attrgetter +import ssl +import time +from typing import TYPE_CHECKING, Any, Union, cast +import uuid + +import attr +import certifi + +from homeassistant.const import ( + CONF_CLIENT_ID, + CONF_PASSWORD, + CONF_PORT, + CONF_PROTOCOL, + CONF_USERNAME, + EVENT_HOMEASSISTANT_STARTED, +) +from homeassistant.core import CoreState, HassJob, HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.dispatcher import dispatcher_send +from homeassistant.helpers.typing import ConfigType +from homeassistant.loader import bind_hass +from homeassistant.util import dt as dt_util +from homeassistant.util.async_ import run_callback_threadsafe +from homeassistant.util.logging import catch_log_exception + +from .const import ( + ATTR_TOPIC, + CONF_BIRTH_MESSAGE, + CONF_BROKER, + CONF_CERTIFICATE, + CONF_CLIENT_CERT, + CONF_CLIENT_KEY, + CONF_KEEPALIVE, + CONF_TLS_INSECURE, + CONF_WILL_MESSAGE, + DATA_MQTT, + DEFAULT_ENCODING, + DEFAULT_QOS, + MQTT_CONNECTED, + MQTT_DISCONNECTED, + PROTOCOL_31, +) +from .discovery import LAST_DISCOVERY +from .models import ( + AsyncMessageCallbackType, + MessageCallbackType, + PublishMessage, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) + +if TYPE_CHECKING: + # Only import for paho-mqtt type checking here, imports are done locally + # because integrations should be able to optionally rely on MQTT. + import paho.mqtt.client as mqtt + +_LOGGER = logging.getLogger(__name__) + +DISCOVERY_COOLDOWN = 2 +TIMEOUT_ACK = 10 + +SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None + + +def publish( + hass: HomeAssistant, + topic: str, + payload: PublishPayloadType, + qos: int | None = 0, + retain: bool | None = False, + encoding: str | None = DEFAULT_ENCODING, +) -> None: + """Publish message to a MQTT topic.""" + hass.add_job(async_publish, hass, topic, payload, qos, retain, encoding) + + +async def async_publish( + hass: HomeAssistant, + topic: str, + payload: PublishPayloadType, + qos: int | None = 0, + retain: bool | None = False, + encoding: str | None = DEFAULT_ENCODING, +) -> None: + """Publish message to a MQTT topic.""" + + outgoing_payload = payload + if not isinstance(payload, bytes): + if not encoding: + _LOGGER.error( + "Can't pass-through payload for publishing %s on %s with no encoding set, need 'bytes' got %s", + payload, + topic, + type(payload), + ) + return + outgoing_payload = str(payload) + if encoding != DEFAULT_ENCODING: + # a string is encoded as utf-8 by default, other encoding requires bytes as payload + try: + outgoing_payload = outgoing_payload.encode(encoding) + except (AttributeError, LookupError, UnicodeEncodeError): + _LOGGER.error( + "Can't encode payload for publishing %s on %s with encoding %s", + payload, + topic, + encoding, + ) + return + + await hass.data[DATA_MQTT].async_publish(topic, outgoing_payload, qos, retain) + + +AsyncDeprecatedMessageCallbackType = Callable[ + [str, ReceivePayloadType, int], Awaitable[None] +] +DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] + + +def wrap_msg_callback( + msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType, +) -> AsyncMessageCallbackType | MessageCallbackType: + """Wrap an MQTT message callback to support deprecated signature.""" + # Check for partials to properly determine if coroutine function + check_func = msg_callback + while isinstance(check_func, partial): + check_func = check_func.func + + wrapper_func: AsyncMessageCallbackType | MessageCallbackType + if asyncio.iscoroutinefunction(check_func): + + @wraps(msg_callback) + async def async_wrapper(msg: ReceiveMessage) -> None: + """Call with deprecated signature.""" + await cast(AsyncDeprecatedMessageCallbackType, msg_callback)( + msg.topic, msg.payload, msg.qos + ) + + wrapper_func = async_wrapper + else: + + @wraps(msg_callback) + def wrapper(msg: ReceiveMessage) -> None: + """Call with deprecated signature.""" + msg_callback(msg.topic, msg.payload, msg.qos) + + wrapper_func = wrapper + return wrapper_func + + +@bind_hass +async def async_subscribe( + hass: HomeAssistant, + topic: str, + msg_callback: AsyncMessageCallbackType + | MessageCallbackType + | DeprecatedMessageCallbackType + | AsyncDeprecatedMessageCallbackType, + qos: int = DEFAULT_QOS, + encoding: str | None = "utf-8", +): + """Subscribe to an MQTT topic. + + Call the return value to unsubscribe. + """ + # Count callback parameters which don't have a default value + non_default = 0 + if msg_callback: + non_default = sum( + p.default == inspect.Parameter.empty + for _, p in inspect.signature(msg_callback).parameters.items() + ) + + wrapped_msg_callback = msg_callback + # If we have 3 parameters with no default value, wrap the callback + if non_default == 3: + module = inspect.getmodule(msg_callback) + _LOGGER.warning( + "Signature of MQTT msg_callback '%s.%s' is deprecated", + module.__name__ if module else "", + msg_callback.__name__, + ) + wrapped_msg_callback = wrap_msg_callback( + cast(DeprecatedMessageCallbackType, msg_callback) + ) + + async_remove = await hass.data[DATA_MQTT].async_subscribe( + topic, + catch_log_exception( + wrapped_msg_callback, + lambda msg: ( + f"Exception in {msg_callback.__name__} when handling msg on " + f"'{msg.topic}': '{msg.payload}'" + ), + ), + qos, + encoding, + ) + return async_remove + + +@bind_hass +def subscribe( + hass: HomeAssistant, + topic: str, + msg_callback: MessageCallbackType, + qos: int = DEFAULT_QOS, + encoding: str = "utf-8", +) -> Callable[[], None]: + """Subscribe to an MQTT topic.""" + async_remove = asyncio.run_coroutine_threadsafe( + async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop + ).result() + + def remove(): + """Remove listener convert.""" + run_callback_threadsafe(hass.loop, async_remove).result() + + return remove + + +@attr.s(slots=True, frozen=True) +class Subscription: + """Class to hold data about an active subscription.""" + + topic: str = attr.ib() + matcher: Any = attr.ib() + job: HassJob = attr.ib() + qos: int = attr.ib(default=0) + encoding: str | None = attr.ib(default="utf-8") + + +class MqttClientSetup: + """Helper class to setup the paho mqtt client from config.""" + + def __init__(self, config: ConfigType) -> None: + """Initialize the MQTT client setup helper.""" + + # We don't import on the top because some integrations + # should be able to optionally rely on MQTT. + import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + + if config[CONF_PROTOCOL] == PROTOCOL_31: + proto = mqtt.MQTTv31 + else: + proto = mqtt.MQTTv311 + + if (client_id := config.get(CONF_CLIENT_ID)) is None: + # PAHO MQTT relies on the MQTT server to generate random client IDs. + # However, that feature is not mandatory so we generate our own. + client_id = mqtt.base62(uuid.uuid4().int, padding=22) + self._client = mqtt.Client(client_id, protocol=proto) + + # Enable logging + self._client.enable_logger() + + username = config.get(CONF_USERNAME) + password = config.get(CONF_PASSWORD) + if username is not None: + self._client.username_pw_set(username, password) + + if (certificate := config.get(CONF_CERTIFICATE)) == "auto": + certificate = certifi.where() + + client_key = config.get(CONF_CLIENT_KEY) + client_cert = config.get(CONF_CLIENT_CERT) + tls_insecure = config.get(CONF_TLS_INSECURE) + if certificate is not None: + self._client.tls_set( + certificate, + certfile=client_cert, + keyfile=client_key, + tls_version=ssl.PROTOCOL_TLS, + ) + + if tls_insecure is not None: + self._client.tls_insecure_set(tls_insecure) + + @property + def client(self) -> mqtt.Client: + """Return the paho MQTT client.""" + return self._client + + +class MQTT: + """Home Assistant MQTT client.""" + + def __init__( + self, + hass: HomeAssistant, + config_entry, + conf, + ) -> None: + """Initialize Home Assistant MQTT client.""" + # We don't import on the top because some integrations + # should be able to optionally rely on MQTT. + import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel + + self.hass = hass + self.config_entry = config_entry + self.conf = conf + self.subscriptions: list[Subscription] = [] + self.connected = False + self._ha_started = asyncio.Event() + self._last_subscribe = time.time() + self._mqttc: mqtt.Client = None + self._paho_lock = asyncio.Lock() + + self._pending_operations: dict[str, asyncio.Event] = {} + + if self.hass.state == CoreState.running: + self._ha_started.set() + else: + + @callback + def ha_started(_): + self._ha_started.set() + + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, ha_started) + + self.init_client() + + def init_client(self): + """Initialize paho client.""" + self._mqttc = MqttClientSetup(self.conf).client + self._mqttc.on_connect = self._mqtt_on_connect + self._mqttc.on_disconnect = self._mqtt_on_disconnect + self._mqttc.on_message = self._mqtt_on_message + self._mqttc.on_publish = self._mqtt_on_callback + self._mqttc.on_subscribe = self._mqtt_on_callback + self._mqttc.on_unsubscribe = self._mqtt_on_callback + + if ( + CONF_WILL_MESSAGE in self.conf + and ATTR_TOPIC in self.conf[CONF_WILL_MESSAGE] + ): + will_message = PublishMessage(**self.conf[CONF_WILL_MESSAGE]) + else: + will_message = None + + if will_message is not None: + self._mqttc.will_set( + topic=will_message.topic, + payload=will_message.payload, + qos=will_message.qos, + retain=will_message.retain, + ) + + async def async_publish( + self, topic: str, payload: PublishPayloadType, qos: int, retain: bool + ) -> None: + """Publish a MQTT message.""" + async with self._paho_lock: + msg_info = await self.hass.async_add_executor_job( + self._mqttc.publish, topic, payload, qos, retain + ) + _LOGGER.debug( + "Transmitting message on %s: '%s', mid: %s", + topic, + payload, + msg_info.mid, + ) + _raise_on_error(msg_info.rc) + await self._wait_for_mid(msg_info.mid) + + async def async_connect(self) -> None: + """Connect to the host. Does not process messages yet.""" + # pylint: disable-next=import-outside-toplevel + import paho.mqtt.client as mqtt + + result: int | None = None + try: + result = await self.hass.async_add_executor_job( + self._mqttc.connect, + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + self.conf[CONF_KEEPALIVE], + ) + except OSError as err: + _LOGGER.error("Failed to connect to MQTT server due to exception: %s", err) + + if result is not None and result != 0: + _LOGGER.error( + "Failed to connect to MQTT server: %s", mqtt.error_string(result) + ) + + self._mqttc.loop_start() + + async def async_disconnect(self): + """Stop the MQTT client.""" + + def stop(): + """Stop the MQTT client.""" + # Do not disconnect, we want the broker to always publish will + self._mqttc.loop_stop() + + await self.hass.async_add_executor_job(stop) + + async def async_subscribe( + self, + topic: str, + msg_callback: MessageCallbackType, + qos: int, + encoding: str | None = None, + ) -> Callable[[], None]: + """Set up a subscription to a topic with the provided qos. + + This method is a coroutine. + """ + if not isinstance(topic, str): + raise HomeAssistantError("Topic needs to be a string!") + + subscription = Subscription( + topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding + ) + self.subscriptions.append(subscription) + self._matching_subscriptions.cache_clear() + + # Only subscribe if currently connected. + if self.connected: + self._last_subscribe = time.time() + await self._async_perform_subscription(topic, qos) + + @callback + def async_remove() -> None: + """Remove subscription.""" + if subscription not in self.subscriptions: + raise HomeAssistantError("Can't remove subscription twice") + self.subscriptions.remove(subscription) + self._matching_subscriptions.cache_clear() + + # Only unsubscribe if currently connected. + if self.connected: + self.hass.async_create_task(self._async_unsubscribe(topic)) + + return async_remove + + async def _async_unsubscribe(self, topic: str) -> None: + """Unsubscribe from a topic. + + This method is a coroutine. + """ + if any(other.topic == topic for other in self.subscriptions): + # Other subscriptions on topic remaining - don't unsubscribe. + return + + async with self._paho_lock: + result: int | None = None + result, mid = await self.hass.async_add_executor_job( + self._mqttc.unsubscribe, topic + ) + _LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid) + _raise_on_error(result) + await self._wait_for_mid(mid) + + async def _async_perform_subscription(self, topic: str, qos: int) -> None: + """Perform a paho-mqtt subscription.""" + async with self._paho_lock: + result: int | None = None + result, mid = await self.hass.async_add_executor_job( + self._mqttc.subscribe, topic, qos + ) + _LOGGER.debug("Subscribing to %s, mid: %s", topic, mid) + _raise_on_error(result) + await self._wait_for_mid(mid) + + def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None: + """On connect callback. + + Resubscribe to all topics we were subscribed to and publish birth + message. + """ + # pylint: disable-next=import-outside-toplevel + import paho.mqtt.client as mqtt + + if result_code != mqtt.CONNACK_ACCEPTED: + _LOGGER.error( + "Unable to connect to the MQTT broker: %s", + mqtt.connack_string(result_code), + ) + return + + self.connected = True + dispatcher_send(self.hass, MQTT_CONNECTED) + _LOGGER.info( + "Connected to MQTT server %s:%s (%s)", + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + result_code, + ) + + # Group subscriptions to only re-subscribe once for each topic. + keyfunc = attrgetter("topic") + for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc): + # Re-subscribe with the highest requested qos + max_qos = max(subscription.qos for subscription in subs) + self.hass.add_job(self._async_perform_subscription, topic, max_qos) + + if ( + CONF_BIRTH_MESSAGE in self.conf + and ATTR_TOPIC in self.conf[CONF_BIRTH_MESSAGE] + ): + + async def publish_birth_message(birth_message): + await self._ha_started.wait() # Wait for Home Assistant to start + await self._discovery_cooldown() # Wait for MQTT discovery to cool down + await self.async_publish( + topic=birth_message.topic, + payload=birth_message.payload, + qos=birth_message.qos, + retain=birth_message.retain, + ) + + birth_message = PublishMessage(**self.conf[CONF_BIRTH_MESSAGE]) + asyncio.run_coroutine_threadsafe( + publish_birth_message(birth_message), self.hass.loop + ) + + def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: + """Message received callback.""" + self.hass.add_job(self._mqtt_handle_message, msg) + + @lru_cache(2048) + def _matching_subscriptions(self, topic): + subscriptions = [] + for subscription in self.subscriptions: + if subscription.matcher(topic): + subscriptions.append(subscription) + return subscriptions + + @callback + def _mqtt_handle_message(self, msg) -> None: + _LOGGER.debug( + "Received message on %s%s: %s", + msg.topic, + " (retained)" if msg.retain else "", + msg.payload[0:8192], + ) + timestamp = dt_util.utcnow() + + subscriptions = self._matching_subscriptions(msg.topic) + + for subscription in subscriptions: + + payload: SubscribePayloadType = msg.payload + if subscription.encoding is not None: + try: + payload = msg.payload.decode(subscription.encoding) + except (AttributeError, UnicodeDecodeError): + _LOGGER.warning( + "Can't decode payload %s on %s with encoding %s (for %s)", + msg.payload[0:8192], + msg.topic, + subscription.encoding, + subscription.job, + ) + continue + + self.hass.async_run_hass_job( + subscription.job, + ReceiveMessage( + msg.topic, + payload, + msg.qos, + msg.retain, + subscription.topic, + timestamp, + ), + ) + + def _mqtt_on_callback(self, _mqttc, _userdata, mid, _granted_qos=None) -> None: + """Publish / Subscribe / Unsubscribe callback.""" + self.hass.add_job(self._mqtt_handle_mid, mid) + + @callback + def _mqtt_handle_mid(self, mid) -> None: + # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid + # may be executed first. + if mid not in self._pending_operations: + self._pending_operations[mid] = asyncio.Event() + self._pending_operations[mid].set() + + def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: + """Disconnected callback.""" + self.connected = False + dispatcher_send(self.hass, MQTT_DISCONNECTED) + _LOGGER.warning( + "Disconnected from MQTT server %s:%s (%s)", + self.conf[CONF_BROKER], + self.conf[CONF_PORT], + result_code, + ) + + async def _wait_for_mid(self, mid): + """Wait for ACK from broker.""" + # Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid + # may be executed first. + if mid not in self._pending_operations: + self._pending_operations[mid] = asyncio.Event() + try: + await asyncio.wait_for(self._pending_operations[mid].wait(), TIMEOUT_ACK) + except asyncio.TimeoutError: + _LOGGER.warning( + "No ACK from MQTT server in %s seconds (mid: %s)", TIMEOUT_ACK, mid + ) + finally: + del self._pending_operations[mid] + + async def _discovery_cooldown(self): + now = time.time() + # Reset discovery and subscribe cooldowns + self.hass.data[LAST_DISCOVERY] = now + self._last_subscribe = now + + last_discovery = self.hass.data[LAST_DISCOVERY] + last_subscribe = self._last_subscribe + wait_until = max( + last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN + ) + while now < wait_until: + await asyncio.sleep(wait_until - now) + now = time.time() + last_discovery = self.hass.data[LAST_DISCOVERY] + last_subscribe = self._last_subscribe + wait_until = max( + last_discovery + DISCOVERY_COOLDOWN, last_subscribe + DISCOVERY_COOLDOWN + ) + + +def _raise_on_error(result_code: int | None) -> None: + """Raise error if error result.""" + # pylint: disable-next=import-outside-toplevel + import paho.mqtt.client as mqtt + + if result_code is not None and result_code != 0: + raise HomeAssistantError( + f"Error talking to MQTT: {mqtt.error_string(result_code)}" + ) + + +def _matcher_for_topic(subscription: str) -> Any: + # pylint: disable-next=import-outside-toplevel + from paho.mqtt.matcher import MQTTMatcher + + matcher = MQTTMatcher() + matcher[subscription] = True + + return lambda topic: next(matcher.iter_match(topic), False) diff --git a/homeassistant/components/mqtt/climate.py b/homeassistant/components/mqtt/climate.py index 52465bbba242..64b462359be7 100644 --- a/homeassistant/components/mqtt/climate.py +++ b/homeassistant/components/mqtt/climate.py @@ -44,8 +44,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import DEFAULT_RETAIN, MQTT_BASE_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_RETAIN, PAYLOAD_NONE from .debug_info import log_messages from .mixins import ( @@ -56,6 +56,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate +from .util import valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -232,33 +234,33 @@ def valid_preset_mode_configuration(config): return config -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_BASE_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend( { - vol.Optional(CONF_AUX_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_AUX_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_AUX_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_AUX_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_AUX_STATE_TOPIC): valid_subscribe_topic, # AWAY and HOLD mode topics and templates are deprecated, support will be removed with release 2022.9 - vol.Optional(CONF_AWAY_MODE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_AWAY_MODE_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_AWAY_MODE_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_AWAY_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_AWAY_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_CURRENT_TEMP_TEMPLATE): cv.template, - vol.Optional(CONF_CURRENT_TEMP_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_CURRENT_TEMP_TOPIC): valid_subscribe_topic, vol.Optional(CONF_FAN_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_FAN_MODE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_FAN_MODE_COMMAND_TOPIC): valid_publish_topic, vol.Optional( CONF_FAN_MODE_LIST, default=[FAN_AUTO, FAN_LOW, FAN_MEDIUM, FAN_HIGH], ): cv.ensure_list, vol.Optional(CONF_FAN_MODE_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_FAN_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_FAN_MODE_STATE_TOPIC): valid_subscribe_topic, # AWAY and HOLD mode topics and templates are deprecated, support will be removed with release 2022.9 vol.Optional(CONF_HOLD_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_HOLD_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_HOLD_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_HOLD_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_HOLD_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_HOLD_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_HOLD_LIST): cv.ensure_list, vol.Optional(CONF_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_MODE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_MODE_COMMAND_TOPIC): valid_publish_topic, vol.Optional( CONF_MODE_LIST, default=[ @@ -271,54 +273,54 @@ def valid_preset_mode_configuration(config): ], ): cv.ensure_list, vol.Optional(CONF_MODE_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_PAYLOAD_ON, default="ON"): cv.string, vol.Optional(CONF_PAYLOAD_OFF, default="OFF"): cv.string, - vol.Optional(CONF_POWER_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_POWER_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_POWER_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_POWER_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_POWER_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_PRECISION): vol.In( [PRECISION_TENTHS, PRECISION_HALVES, PRECISION_WHOLE] ), - vol.Optional(CONF_RETAIN, default=mqtt.DEFAULT_RETAIN): cv.boolean, + vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, # CONF_SEND_IF_OFF is deprecated, support will be removed with release 2022.9 vol.Optional(CONF_SEND_IF_OFF): cv.boolean, vol.Optional(CONF_ACTION_TEMPLATE): cv.template, - vol.Optional(CONF_ACTION_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_ACTION_TOPIC): valid_subscribe_topic, # CONF_PRESET_MODE_COMMAND_TOPIC and CONF_PRESET_MODES_LIST must be used together vol.Inclusive( CONF_PRESET_MODE_COMMAND_TOPIC, "preset_modes" - ): mqtt.valid_publish_topic, + ): valid_publish_topic, vol.Inclusive( CONF_PRESET_MODES_LIST, "preset_modes", default=[] ): cv.ensure_list, vol.Optional(CONF_PRESET_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_PRESET_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_PRESET_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_PRESET_MODE_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_SWING_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_SWING_MODE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_SWING_MODE_COMMAND_TOPIC): valid_publish_topic, vol.Optional( CONF_SWING_MODE_LIST, default=[SWING_ON, SWING_OFF] ): cv.ensure_list, vol.Optional(CONF_SWING_MODE_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_SWING_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_SWING_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_TEMP_INITIAL, default=21): cv.positive_int, vol.Optional(CONF_TEMP_MIN, default=DEFAULT_MIN_TEMP): vol.Coerce(float), vol.Optional(CONF_TEMP_MAX, default=DEFAULT_MAX_TEMP): vol.Coerce(float), vol.Optional(CONF_TEMP_STEP, default=1.0): vol.Coerce(float), vol.Optional(CONF_TEMP_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_TEMP_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_TEMP_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_TEMP_HIGH_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_TEMP_HIGH_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_TEMP_HIGH_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_TEMP_HIGH_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_TEMP_HIGH_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_TEMP_HIGH_STATE_TEMPLATE): cv.template, vol.Optional(CONF_TEMP_LOW_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_TEMP_LOW_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_TEMP_LOW_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_TEMP_LOW_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_TEMP_LOW_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_TEMP_LOW_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_TEMP_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_TEMP_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_TEMP_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_TEMPERATURE_UNIT): cv.temperature_unit, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, } diff --git a/homeassistant/components/mqtt/config.py b/homeassistant/components/mqtt/config.py new file mode 100644 index 000000000000..4f84d9114188 --- /dev/null +++ b/homeassistant/components/mqtt/config.py @@ -0,0 +1,148 @@ +"""Support for MQTT message handling.""" +from __future__ import annotations + +import voluptuous as vol + +from homeassistant.const import ( + CONF_CLIENT_ID, + CONF_DISCOVERY, + CONF_PASSWORD, + CONF_PORT, + CONF_PROTOCOL, + CONF_USERNAME, + CONF_VALUE_TEMPLATE, +) +from homeassistant.helpers import config_validation as cv + +from .const import ( + ATTR_PAYLOAD, + ATTR_QOS, + ATTR_RETAIN, + ATTR_TOPIC, + CONF_BIRTH_MESSAGE, + CONF_BROKER, + CONF_CERTIFICATE, + CONF_CLIENT_CERT, + CONF_CLIENT_KEY, + CONF_COMMAND_TOPIC, + CONF_DISCOVERY_PREFIX, + CONF_ENCODING, + CONF_KEEPALIVE, + CONF_QOS, + CONF_RETAIN, + CONF_STATE_TOPIC, + CONF_TLS_INSECURE, + CONF_TLS_VERSION, + CONF_WILL_MESSAGE, + DEFAULT_BIRTH, + DEFAULT_DISCOVERY, + DEFAULT_ENCODING, + DEFAULT_PREFIX, + DEFAULT_QOS, + DEFAULT_RETAIN, + DEFAULT_WILL, + PLATFORMS, + PROTOCOL_31, + PROTOCOL_311, +) +from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic + +DEFAULT_PORT = 1883 +DEFAULT_KEEPALIVE = 60 +DEFAULT_PROTOCOL = PROTOCOL_311 +DEFAULT_TLS_PROTOCOL = "auto" + +DEFAULT_VALUES = { + CONF_BIRTH_MESSAGE: DEFAULT_BIRTH, + CONF_DISCOVERY: DEFAULT_DISCOVERY, + CONF_PORT: DEFAULT_PORT, + CONF_TLS_VERSION: DEFAULT_TLS_PROTOCOL, + CONF_WILL_MESSAGE: DEFAULT_WILL, +} + +CLIENT_KEY_AUTH_MSG = ( + "client_key and client_cert must both be present in " + "the MQTT broker configuration" +) + +MQTT_WILL_BIRTH_SCHEMA = vol.Schema( + { + vol.Inclusive(ATTR_TOPIC, "topic_payload"): valid_publish_topic, + vol.Inclusive(ATTR_PAYLOAD, "topic_payload"): cv.string, + vol.Optional(ATTR_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA, + vol.Optional(ATTR_RETAIN, default=DEFAULT_RETAIN): cv.boolean, + }, + required=True, +) + +PLATFORM_CONFIG_SCHEMA_BASE = vol.Schema( + {vol.Optional(platform.value): cv.ensure_list for platform in PLATFORMS} +) + +CONFIG_SCHEMA_BASE = PLATFORM_CONFIG_SCHEMA_BASE.extend( + { + vol.Optional(CONF_CLIENT_ID): cv.string, + vol.Optional(CONF_KEEPALIVE, default=DEFAULT_KEEPALIVE): vol.All( + vol.Coerce(int), vol.Range(min=15) + ), + vol.Optional(CONF_BROKER): cv.string, + vol.Optional(CONF_PORT): cv.port, + vol.Optional(CONF_USERNAME): cv.string, + vol.Optional(CONF_PASSWORD): cv.string, + vol.Optional(CONF_CERTIFICATE): vol.Any("auto", cv.isfile), + vol.Inclusive( + CONF_CLIENT_KEY, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG + ): cv.isfile, + vol.Inclusive( + CONF_CLIENT_CERT, "client_key_auth", msg=CLIENT_KEY_AUTH_MSG + ): cv.isfile, + vol.Optional(CONF_TLS_INSECURE): cv.boolean, + vol.Optional(CONF_TLS_VERSION): vol.Any("auto", "1.0", "1.1", "1.2"), + vol.Optional(CONF_PROTOCOL, default=DEFAULT_PROTOCOL): vol.All( + cv.string, vol.In([PROTOCOL_31, PROTOCOL_311]) + ), + vol.Optional(CONF_WILL_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, + vol.Optional(CONF_BIRTH_MESSAGE): MQTT_WILL_BIRTH_SCHEMA, + vol.Optional(CONF_DISCOVERY): cv.boolean, + # discovery_prefix must be a valid publish topic because if no + # state topic is specified, it will be created with the given prefix. + vol.Optional( + CONF_DISCOVERY_PREFIX, default=DEFAULT_PREFIX + ): valid_publish_topic, + } +) + +DEPRECATED_CONFIG_KEYS = [ + CONF_BIRTH_MESSAGE, + CONF_BROKER, + CONF_DISCOVERY, + CONF_PASSWORD, + CONF_PORT, + CONF_TLS_VERSION, + CONF_USERNAME, + CONF_WILL_MESSAGE, +] + +SCHEMA_BASE = { + vol.Optional(CONF_QOS, default=DEFAULT_QOS): _VALID_QOS_SCHEMA, + vol.Optional(CONF_ENCODING, default=DEFAULT_ENCODING): cv.string, +} + +MQTT_BASE_SCHEMA = vol.Schema(SCHEMA_BASE) + +# Sensor type platforms subscribe to MQTT events +MQTT_RO_SCHEMA = MQTT_BASE_SCHEMA.extend( + { + vol.Required(CONF_STATE_TOPIC): valid_subscribe_topic, + vol.Optional(CONF_VALUE_TEMPLATE): cv.template, + } +) + +# Switch type platforms publish to MQTT and may subscribe +MQTT_RW_SCHEMA = MQTT_BASE_SCHEMA.extend( + { + vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, + vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic, + } +) diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 0a763e850e54..822ae7125730 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -17,7 +17,7 @@ ) from homeassistant.data_entry_flow import FlowResult -from . import MqttClientSetup +from .client import MqttClientSetup from .const import ( ATTR_PAYLOAD, ATTR_QOS, diff --git a/homeassistant/components/mqtt/const.py b/homeassistant/components/mqtt/const.py index 106d03101587..2f7e27e7252b 100644 --- a/homeassistant/components/mqtt/const.py +++ b/homeassistant/components/mqtt/const.py @@ -1,5 +1,5 @@ """Constants used by multiple MQTT modules.""" -from homeassistant.const import CONF_PAYLOAD +from homeassistant.const import CONF_PAYLOAD, Platform ATTR_DISCOVERY_HASH = "discovery_hash" ATTR_DISCOVERY_PAYLOAD = "discovery_payload" @@ -14,7 +14,9 @@ CONF_BIRTH_MESSAGE = "birth_message" CONF_COMMAND_TEMPLATE = "command_template" CONF_COMMAND_TOPIC = "command_topic" +CONF_DISCOVERY_PREFIX = "discovery_prefix" CONF_ENCODING = "encoding" +CONF_KEEPALIVE = "keepalive" CONF_QOS = ATTR_QOS CONF_RETAIN = ATTR_RETAIN CONF_STATE_TOPIC = "state_topic" @@ -30,6 +32,7 @@ CONFIG_ENTRY_IS_SETUP = "mqtt_config_entry_is_setup" DATA_CONFIG_ENTRY_LOCK = "mqtt_config_entry_lock" +DATA_MQTT = "mqtt" DATA_MQTT_CONFIG = "mqtt_config" DATA_MQTT_RELOAD_NEEDED = "mqtt_reload_needed" @@ -66,3 +69,24 @@ PROTOCOL_31 = "3.1" PROTOCOL_311 = "3.1.1" + +PLATFORMS = [ + Platform.ALARM_CONTROL_PANEL, + Platform.BINARY_SENSOR, + Platform.BUTTON, + Platform.CAMERA, + Platform.CLIMATE, + Platform.DEVICE_TRACKER, + Platform.COVER, + Platform.FAN, + Platform.HUMIDIFIER, + Platform.LIGHT, + Platform.LOCK, + Platform.NUMBER, + Platform.SELECT, + Platform.SCENE, + Platform.SENSOR, + Platform.SIREN, + Platform.SWITCH, + Platform.VACUUM, +] diff --git a/homeassistant/components/mqtt/cover.py b/homeassistant/components/mqtt/cover.py index 8e36329946a0..5814f3e43f79 100644 --- a/homeassistant/components/mqtt/cover.py +++ b/homeassistant/components/mqtt/cover.py @@ -33,8 +33,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_BASE_SCHEMA from .const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -51,6 +51,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate +from .util import valid_publish_topic, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -152,11 +154,11 @@ def validate_options(value): return value -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_BASE_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_BASE_SCHEMA.extend( { - vol.Optional(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, - vol.Optional(CONF_GET_POSITION_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_GET_POSITION_TOPIC): valid_subscribe_topic, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, vol.Optional(CONF_PAYLOAD_CLOSE, default=DEFAULT_PAYLOAD_CLOSE): vol.Any( @@ -172,24 +174,24 @@ def validate_options(value): vol.Optional(CONF_POSITION_OPEN, default=DEFAULT_POSITION_OPEN): int, vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, vol.Optional(CONF_SET_POSITION_TEMPLATE): cv.template, - vol.Optional(CONF_SET_POSITION_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_SET_POSITION_TOPIC): valid_publish_topic, vol.Optional(CONF_STATE_CLOSED, default=STATE_CLOSED): cv.string, vol.Optional(CONF_STATE_CLOSING, default=STATE_CLOSING): cv.string, vol.Optional(CONF_STATE_OPEN, default=STATE_OPEN): cv.string, vol.Optional(CONF_STATE_OPENING, default=STATE_OPENING): cv.string, vol.Optional(CONF_STATE_STOPPED, default=DEFAULT_STATE_STOPPED): cv.string, - vol.Optional(CONF_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic, vol.Optional( CONF_TILT_CLOSED_POSITION, default=DEFAULT_TILT_CLOSED_POSITION ): int, - vol.Optional(CONF_TILT_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_TILT_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_TILT_MAX, default=DEFAULT_TILT_MAX): int, vol.Optional(CONF_TILT_MIN, default=DEFAULT_TILT_MIN): int, vol.Optional(CONF_TILT_OPEN_POSITION, default=DEFAULT_TILT_OPEN_POSITION): int, vol.Optional( CONF_TILT_STATE_OPTIMISTIC, default=DEFAULT_TILT_OPTIMISTIC ): cv.boolean, - vol.Optional(CONF_TILT_STATUS_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_TILT_STATUS_TOPIC): valid_subscribe_topic, vol.Optional(CONF_TILT_STATUS_TEMPLATE): cv.template, vol.Optional(CONF_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_GET_POSITION_TEMPLATE): cv.template, diff --git a/homeassistant/components/mqtt/device_automation.py b/homeassistant/components/mqtt/device_automation.py index 002ae6e3991d..0646a5bda0c9 100644 --- a/homeassistant/components/mqtt/device_automation.py +++ b/homeassistant/components/mqtt/device_automation.py @@ -6,7 +6,7 @@ import homeassistant.helpers.config_validation as cv from . import device_trigger -from .. import mqtt +from .config import MQTT_BASE_SCHEMA from .mixins import async_setup_entry_helper AUTOMATION_TYPE_TRIGGER = "trigger" @@ -17,7 +17,7 @@ PLATFORM_SCHEMA = cv.PLATFORM_SCHEMA.extend( {vol.Required(CONF_AUTOMATION_TYPE): AUTOMATION_TYPES_SCHEMA}, extra=vol.ALLOW_EXTRA, -).extend(mqtt.MQTT_BASE_SCHEMA.schema) +).extend(MQTT_BASE_SCHEMA.schema) async def async_setup_entry(hass, config_entry): diff --git a/homeassistant/components/mqtt/device_tracker/schema_discovery.py b/homeassistant/components/mqtt/device_tracker/schema_discovery.py index aa7506bd5e3b..1b48e15b80e1 100644 --- a/homeassistant/components/mqtt/device_tracker/schema_discovery.py +++ b/homeassistant/components/mqtt/device_tracker/schema_discovery.py @@ -19,8 +19,8 @@ from homeassistant.core import callback import homeassistant.helpers.config_validation as cv -from .. import MqttValueTemplate, subscription -from ... import mqtt +from .. import subscription +from ..config import MQTT_RO_SCHEMA from ..const import CONF_QOS, CONF_STATE_TOPIC from ..debug_info import log_messages from ..mixins import ( @@ -29,12 +29,13 @@ async_get_platform_config_from_yaml, async_setup_entry_helper, ) +from ..models import MqttValueTemplate CONF_PAYLOAD_HOME = "payload_home" CONF_PAYLOAD_NOT_HOME = "payload_not_home" CONF_SOURCE_TYPE = "source_type" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RO_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RO_SCHEMA.extend( { vol.Optional(CONF_NAME): cv.string, vol.Optional(CONF_PAYLOAD_HOME, default=STATE_HOME): cv.string, diff --git a/homeassistant/components/mqtt/device_tracker/schema_yaml.py b/homeassistant/components/mqtt/device_tracker/schema_yaml.py index f871ac89c2d6..2dfa5b7134c3 100644 --- a/homeassistant/components/mqtt/device_tracker/schema_yaml.py +++ b/homeassistant/components/mqtt/device_tracker/schema_yaml.py @@ -7,16 +7,18 @@ from homeassistant.core import callback import homeassistant.helpers.config_validation as cv -from ... import mqtt +from ..client import async_subscribe +from ..config import SCHEMA_BASE from ..const import CONF_QOS +from ..util import valid_subscribe_topic CONF_PAYLOAD_HOME = "payload_home" CONF_PAYLOAD_NOT_HOME = "payload_not_home" CONF_SOURCE_TYPE = "source_type" -PLATFORM_SCHEMA_YAML = PLATFORM_SCHEMA.extend(mqtt.SCHEMA_BASE).extend( +PLATFORM_SCHEMA_YAML = PLATFORM_SCHEMA.extend(SCHEMA_BASE).extend( { - vol.Required(CONF_DEVICES): {cv.string: mqtt.valid_subscribe_topic}, + vol.Required(CONF_DEVICES): {cv.string: valid_subscribe_topic}, vol.Optional(CONF_PAYLOAD_HOME, default=STATE_HOME): cv.string, vol.Optional(CONF_PAYLOAD_NOT_HOME, default=STATE_NOT_HOME): cv.string, vol.Optional(CONF_SOURCE_TYPE): vol.In(SOURCE_TYPES), @@ -50,6 +52,6 @@ def async_message_received(msg, dev_id=dev_id): hass.async_create_task(async_see(**see_args)) - await mqtt.async_subscribe(hass, topic, async_message_received, qos) + await async_subscribe(hass, topic, async_message_received, qos) return True diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 2c6c6ecc3bab..0b4bcbfcbc26 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -29,7 +29,7 @@ from homeassistant.helpers.typing import ConfigType from . import debug_info, trigger as mqtt_trigger -from .. import mqtt +from .config import MQTT_BASE_SCHEMA from .const import ( ATTR_DISCOVERY_HASH, CONF_ENCODING, @@ -71,7 +71,7 @@ } ) -TRIGGER_DISCOVERY_SCHEMA = mqtt.MQTT_BASE_SCHEMA.extend( +TRIGGER_DISCOVERY_SCHEMA = MQTT_BASE_SCHEMA.extend( { vol.Required(CONF_AUTOMATION_TYPE): str, vol.Required(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, @@ -101,7 +101,7 @@ class TriggerInstance: async def async_attach_trigger(self) -> None: """Attach MQTT trigger.""" mqtt_config = { - CONF_PLATFORM: mqtt.DOMAIN, + CONF_PLATFORM: DOMAIN, CONF_TOPIC: self.trigger.topic, CONF_ENCODING: DEFAULT_ENCODING, CONF_QOS: self.trigger.qos, diff --git a/homeassistant/components/mqtt/fan.py b/homeassistant/components/mqtt/fan.py index f2b738cd2bb2..f72b0bdf689c 100644 --- a/homeassistant/components/mqtt/fan.py +++ b/homeassistant/components/mqtt/fan.py @@ -34,8 +34,8 @@ ranged_value_to_percentage, ) -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -55,6 +55,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate +from .util import valid_publish_topic, valid_subscribe_topic CONF_PERCENTAGE_STATE_TOPIC = "percentage_state_topic" CONF_PERCENTAGE_COMMAND_TOPIC = "percentage_command_topic" @@ -125,28 +127,28 @@ def valid_preset_mode_configuration(config): return config -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_RW_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_OSCILLATION_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_OSCILLATION_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_OSCILLATION_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_OSCILLATION_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_OSCILLATION_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_OSCILLATION_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_PERCENTAGE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_PERCENTAGE_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_PERCENTAGE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_PERCENTAGE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_PERCENTAGE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_PERCENTAGE_VALUE_TEMPLATE): cv.template, # CONF_PRESET_MODE_COMMAND_TOPIC and CONF_PRESET_MODES_LIST must be used together vol.Inclusive( CONF_PRESET_MODE_COMMAND_TOPIC, "preset_modes" - ): mqtt.valid_publish_topic, + ): valid_publish_topic, vol.Inclusive( CONF_PRESET_MODES_LIST, "preset_modes", default=[] ): cv.ensure_list, vol.Optional(CONF_PRESET_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_PRESET_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_PRESET_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_PRESET_MODE_VALUE_TEMPLATE): cv.template, vol.Optional( CONF_SPEED_RANGE_MIN, default=DEFAULT_SPEED_RANGE_MIN @@ -168,8 +170,8 @@ def valid_preset_mode_configuration(config): vol.Optional( CONF_PAYLOAD_OSCILLATION_ON, default=OSCILLATE_ON_PAYLOAD ): cv.string, - vol.Optional(CONF_SPEED_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_SPEED_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_SPEED_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_SPEED_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_SPEED_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_STATE_VALUE_TEMPLATE): cv.template, } diff --git a/homeassistant/components/mqtt/humidifier.py b/homeassistant/components/mqtt/humidifier.py index f6d4aa01dab2..000a9b9700e1 100644 --- a/homeassistant/components/mqtt/humidifier.py +++ b/homeassistant/components/mqtt/humidifier.py @@ -30,8 +30,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -51,6 +51,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate +from .util import valid_publish_topic, valid_subscribe_topic CONF_AVAILABLE_MODES_LIST = "modes" CONF_DEVICE_CLASS = "device_class" @@ -103,15 +105,13 @@ def valid_humidity_range_configuration(config): return config -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_RW_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend( { # CONF_AVAIALABLE_MODES_LIST and CONF_MODE_COMMAND_TOPIC must be used together vol.Inclusive( CONF_AVAILABLE_MODES_LIST, "available_modes", default=[] ): cv.ensure_list, - vol.Inclusive( - CONF_MODE_COMMAND_TOPIC, "available_modes" - ): mqtt.valid_publish_topic, + vol.Inclusive(CONF_MODE_COMMAND_TOPIC, "available_modes"): valid_publish_topic, vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, vol.Optional( CONF_DEVICE_CLASS, default=HumidifierDeviceClass.HUMIDIFIER @@ -119,14 +119,14 @@ def valid_humidity_range_configuration(config): [HumidifierDeviceClass.HUMIDIFIER, HumidifierDeviceClass.DEHUMIDIFIER] ), vol.Optional(CONF_MODE_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_MODE_STATE_TEMPLATE): cv.template, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, vol.Optional(CONF_PAYLOAD_OFF, default=DEFAULT_PAYLOAD_OFF): cv.string, vol.Optional(CONF_PAYLOAD_ON, default=DEFAULT_PAYLOAD_ON): cv.string, vol.Optional(CONF_STATE_VALUE_TEMPLATE): cv.template, - vol.Required(CONF_TARGET_HUMIDITY_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Required(CONF_TARGET_HUMIDITY_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_TARGET_HUMIDITY_COMMAND_TEMPLATE): cv.template, vol.Optional( CONF_TARGET_HUMIDITY_MAX, default=DEFAULT_MAX_HUMIDITY @@ -135,7 +135,7 @@ def valid_humidity_range_configuration(config): CONF_TARGET_HUMIDITY_MIN, default=DEFAULT_MIN_HUMIDITY ): cv.positive_int, vol.Optional(CONF_TARGET_HUMIDITY_STATE_TEMPLATE): cv.template, - vol.Optional(CONF_TARGET_HUMIDITY_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_TARGET_HUMIDITY_STATE_TOPIC): valid_subscribe_topic, vol.Optional( CONF_PAYLOAD_RESET_HUMIDITY, default=DEFAULT_PAYLOAD_RESET ): cv.string, diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index eb4ec264981c..1c94fa82f73b 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -42,8 +42,8 @@ from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.util.color as color_util -from .. import MqttCommandTemplate, MqttValueTemplate, subscription -from ... import mqtt +from .. import subscription +from ..config import MQTT_RW_SCHEMA from ..const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -55,6 +55,8 @@ ) from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity +from ..models import MqttCommandTemplate, MqttValueTemplate +from ..util import valid_publish_topic, valid_subscribe_topic from .schema import MQTT_LIGHT_SCHEMA_SCHEMA _LOGGER = logging.getLogger(__name__) @@ -156,28 +158,28 @@ ] _PLATFORM_SCHEMA_BASE = ( - mqtt.MQTT_RW_SCHEMA.extend( + MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_BRIGHTNESS_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_BRIGHTNESS_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_BRIGHTNESS_COMMAND_TOPIC): valid_publish_topic, vol.Optional( CONF_BRIGHTNESS_SCALE, default=DEFAULT_BRIGHTNESS_SCALE ): vol.All(vol.Coerce(int), vol.Range(min=1)), - vol.Optional(CONF_BRIGHTNESS_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_BRIGHTNESS_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_BRIGHTNESS_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_COLOR_MODE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_COLOR_MODE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_COLOR_MODE_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_COLOR_TEMP_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_COLOR_TEMP_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_COLOR_TEMP_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_COLOR_TEMP_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_COLOR_TEMP_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_COLOR_TEMP_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_EFFECT_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_EFFECT_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_EFFECT_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_EFFECT_LIST): vol.All(cv.ensure_list, [cv.string]), - vol.Optional(CONF_EFFECT_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_EFFECT_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_EFFECT_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_HS_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_HS_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_HS_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_HS_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_HS_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_MAX_MIREDS): cv.positive_int, vol.Optional(CONF_MIN_MIREDS): cv.positive_int, @@ -189,30 +191,30 @@ vol.Optional(CONF_PAYLOAD_OFF, default=DEFAULT_PAYLOAD_OFF): cv.string, vol.Optional(CONF_PAYLOAD_ON, default=DEFAULT_PAYLOAD_ON): cv.string, vol.Optional(CONF_RGB_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_RGB_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_RGB_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_RGB_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_RGB_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_RGB_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_RGBW_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_RGBW_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_RGBW_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_RGBW_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_RGBW_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_RGBW_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_RGBWW_COMMAND_TEMPLATE): cv.template, - vol.Optional(CONF_RGBWW_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_RGBWW_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_RGBWW_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_RGBWW_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_RGBWW_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_STATE_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_WHITE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_WHITE_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_WHITE_SCALE, default=DEFAULT_WHITE_SCALE): vol.All( vol.Coerce(int), vol.Range(min=1) ), - vol.Optional(CONF_WHITE_VALUE_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_WHITE_VALUE_COMMAND_TOPIC): valid_publish_topic, vol.Optional( CONF_WHITE_VALUE_SCALE, default=DEFAULT_WHITE_VALUE_SCALE ): vol.All(vol.Coerce(int), vol.Range(min=1)), - vol.Optional(CONF_WHITE_VALUE_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_WHITE_VALUE_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_WHITE_VALUE_TEMPLATE): cv.template, - vol.Optional(CONF_XY_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_XY_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_XY_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_XY_STATE_TOPIC): valid_subscribe_topic, vol.Optional(CONF_XY_VALUE_TEMPLATE): cv.template, }, ) diff --git a/homeassistant/components/mqtt/light/schema_json.py b/homeassistant/components/mqtt/light/schema_json.py index 2049818ab319..be49f1ad2e3a 100644 --- a/homeassistant/components/mqtt/light/schema_json.py +++ b/homeassistant/components/mqtt/light/schema_json.py @@ -51,7 +51,7 @@ import homeassistant.util.color as color_util from .. import subscription -from ... import mqtt +from ..config import DEFAULT_QOS, DEFAULT_RETAIN, MQTT_RW_SCHEMA from ..const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -61,6 +61,7 @@ ) from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity +from ..util import valid_subscribe_topic from .schema import MQTT_LIGHT_SCHEMA_SCHEMA from .schema_basic import CONF_BRIGHTNESS_SCALE, MQTT_LIGHT_ATTRIBUTES_BLOCKED @@ -103,7 +104,7 @@ def valid_color_configuration(config): _PLATFORM_SCHEMA_BASE = ( - mqtt.MQTT_RW_SCHEMA.extend( + MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_BRIGHTNESS, default=DEFAULT_BRIGHTNESS): cv.boolean, vol.Optional( @@ -126,12 +127,12 @@ def valid_color_configuration(config): vol.Optional(CONF_MIN_MIREDS): cv.positive_int, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, - vol.Optional(CONF_QOS, default=mqtt.DEFAULT_QOS): vol.All( + vol.Optional(CONF_QOS, default=DEFAULT_QOS): vol.All( vol.Coerce(int), vol.In([0, 1, 2]) ), - vol.Optional(CONF_RETAIN, default=mqtt.DEFAULT_RETAIN): cv.boolean, + vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, vol.Optional(CONF_RGB, default=DEFAULT_RGB): cv.boolean, - vol.Optional(CONF_STATE_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_STATE_TOPIC): valid_subscribe_topic, vol.Inclusive(CONF_SUPPORTED_COLOR_MODES, "color_mode"): vol.All( cv.ensure_list, [vol.In(VALID_COLOR_MODES)], diff --git a/homeassistant/components/mqtt/light/schema_template.py b/homeassistant/components/mqtt/light/schema_template.py index 0165bfc8efa3..779f2f17e247 100644 --- a/homeassistant/components/mqtt/light/schema_template.py +++ b/homeassistant/components/mqtt/light/schema_template.py @@ -31,8 +31,8 @@ from homeassistant.helpers.restore_state import RestoreEntity import homeassistant.util.color as color_util -from .. import MqttValueTemplate, subscription -from ... import mqtt +from .. import subscription +from ..config import MQTT_RW_SCHEMA from ..const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -43,6 +43,7 @@ ) from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity +from ..models import MqttValueTemplate from .schema import MQTT_LIGHT_SCHEMA_SCHEMA from .schema_basic import MQTT_LIGHT_ATTRIBUTES_BLOCKED @@ -67,7 +68,7 @@ CONF_WHITE_VALUE_TEMPLATE = "white_value_template" _PLATFORM_SCHEMA_BASE = ( - mqtt.MQTT_RW_SCHEMA.extend( + MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_BLUE_TEMPLATE): cv.template, vol.Optional(CONF_BRIGHTNESS_TEMPLATE): cv.template, diff --git a/homeassistant/components/mqtt/lock.py b/homeassistant/components/mqtt/lock.py index 5dc0a974d26a..0cfd1d2b70ff 100644 --- a/homeassistant/components/mqtt/lock.py +++ b/homeassistant/components/mqtt/lock.py @@ -15,8 +15,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -33,6 +33,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttValueTemplate CONF_PAYLOAD_LOCK = "payload_lock" CONF_PAYLOAD_UNLOCK = "payload_unlock" @@ -56,7 +57,7 @@ } ) -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RW_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index a46debeae548..694fae0b3c02 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -49,14 +49,8 @@ from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import ( - DATA_MQTT, - PLATFORMS, - MqttValueTemplate, - async_publish, - debug_info, - subscription, -) +from . import debug_info, subscription +from .client import async_publish from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, @@ -65,6 +59,7 @@ CONF_ENCODING, CONF_QOS, CONF_TOPIC, + DATA_MQTT, DATA_MQTT_CONFIG, DATA_MQTT_RELOAD_NEEDED, DEFAULT_ENCODING, @@ -73,6 +68,7 @@ DOMAIN, MQTT_CONNECTED, MQTT_DISCONNECTED, + PLATFORMS, ) from .debug_info import log_message, log_messages from .discovery import ( @@ -82,7 +78,7 @@ clear_discovery_hash, set_discovery_hash, ) -from .models import PublishPayloadType, ReceiveMessage +from .models import MqttValueTemplate, PublishPayloadType, ReceiveMessage from .subscription import ( async_prepare_subscribe_topics, async_subscribe_topics, diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 9cec65d72541..9bce6baab8bf 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -1,12 +1,21 @@ """Models used by multiple MQTT modules.""" from __future__ import annotations +from ast import literal_eval from collections.abc import Awaitable, Callable import datetime as dt -from typing import Union +from typing import Any, Union import attr +from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import template +from homeassistant.helpers.entity import Entity +from homeassistant.helpers.typing import TemplateVarsType + +_SENTINEL = object() + PublishPayloadType = Union[str, bytes, int, float, None] ReceivePayloadType = Union[str, bytes] @@ -35,3 +44,118 @@ class ReceiveMessage: AsyncMessageCallbackType = Callable[[ReceiveMessage], Awaitable[None]] MessageCallbackType = Callable[[ReceiveMessage], None] + + +class MqttCommandTemplate: + """Class for rendering MQTT payload with command templates.""" + + def __init__( + self, + command_template: template.Template | None, + *, + hass: HomeAssistant | None = None, + entity: Entity | None = None, + ) -> None: + """Instantiate a command template.""" + self._attr_command_template = command_template + if command_template is None: + return + + self._entity = entity + + command_template.hass = hass + + if entity: + command_template.hass = entity.hass + + @callback + def async_render( + self, + value: PublishPayloadType = None, + variables: TemplateVarsType = None, + ) -> PublishPayloadType: + """Render or convert the command template with given value or variables.""" + + def _convert_outgoing_payload( + payload: PublishPayloadType, + ) -> PublishPayloadType: + """Ensure correct raw MQTT payload is passed as bytes for publishing.""" + if isinstance(payload, str): + try: + native_object = literal_eval(payload) + if isinstance(native_object, bytes): + return native_object + + except (ValueError, TypeError, SyntaxError, MemoryError): + pass + + return payload + + if self._attr_command_template is None: + return value + + values = {"value": value} + if self._entity: + values[ATTR_ENTITY_ID] = self._entity.entity_id + values[ATTR_NAME] = self._entity.name + if variables is not None: + values.update(variables) + return _convert_outgoing_payload( + self._attr_command_template.async_render(values, parse_result=False) + ) + + +class MqttValueTemplate: + """Class for rendering MQTT value template with possible json values.""" + + def __init__( + self, + value_template: template.Template | None, + *, + hass: HomeAssistant | None = None, + entity: Entity | None = None, + config_attributes: TemplateVarsType = None, + ) -> None: + """Instantiate a value template.""" + self._value_template = value_template + self._config_attributes = config_attributes + if value_template is None: + return + + value_template.hass = hass + self._entity = entity + + if entity: + value_template.hass = entity.hass + + @callback + def async_render_with_possible_json_value( + self, + payload: ReceivePayloadType, + default: ReceivePayloadType | object = _SENTINEL, + variables: TemplateVarsType = None, + ) -> ReceivePayloadType: + """Render with possible json value or pass-though a received MQTT value.""" + if self._value_template is None: + return payload + + values: dict[str, Any] = {} + + if variables is not None: + values.update(variables) + + if self._config_attributes is not None: + values.update(self._config_attributes) + + if self._entity: + values[ATTR_ENTITY_ID] = self._entity.entity_id + values[ATTR_NAME] = self._entity.name + + if default == _SENTINEL: + return self._value_template.async_render_with_possible_json_value( + payload, variables=values + ) + + return self._value_template.async_render_with_possible_json_value( + payload, default, variables=values + ) diff --git a/homeassistant/components/mqtt/number.py b/homeassistant/components/mqtt/number.py index 001f9f4f668a..6ea1f0959f65 100644 --- a/homeassistant/components/mqtt/number.py +++ b/homeassistant/components/mqtt/number.py @@ -27,8 +27,8 @@ from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -46,6 +46,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate _LOGGER = logging.getLogger(__name__) @@ -75,7 +76,7 @@ def validate_config(config): return config -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_RW_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, vol.Optional(CONF_MAX, default=DEFAULT_MAX_VALUE): vol.Coerce(float), diff --git a/homeassistant/components/mqtt/scene.py b/homeassistant/components/mqtt/scene.py index 98c692ceaff1..ce8f0b0a3e81 100644 --- a/homeassistant/components/mqtt/scene.py +++ b/homeassistant/components/mqtt/scene.py @@ -15,7 +15,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from .. import mqtt +from .client import async_publish +from .config import MQTT_BASE_SCHEMA from .const import CONF_COMMAND_TOPIC, CONF_ENCODING, CONF_QOS, CONF_RETAIN from .mixins import ( CONF_ENABLED_BY_DEFAULT, @@ -27,13 +28,14 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .util import valid_publish_topic DEFAULT_NAME = "MQTT Scene" DEFAULT_RETAIN = False -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_BASE_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_BASE_SCHEMA.extend( { - vol.Required(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Required(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_PAYLOAD_ON): cv.string, @@ -128,7 +130,7 @@ async def async_activate(self, **kwargs): This method is a coroutine. """ - await mqtt.async_publish( + await async_publish( self.hass, self._config[CONF_COMMAND_TOPIC], self._config[CONF_PAYLOAD_ON], diff --git a/homeassistant/components/mqtt/select.py b/homeassistant/components/mqtt/select.py index 0765eb7f1762..75e1b4e8efd1 100644 --- a/homeassistant/components/mqtt/select.py +++ b/homeassistant/components/mqtt/select.py @@ -17,8 +17,8 @@ from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -36,6 +36,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate _LOGGER = logging.getLogger(__name__) @@ -51,7 +52,7 @@ ) -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RW_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, diff --git a/homeassistant/components/mqtt/sensor.py b/homeassistant/components/mqtt/sensor.py index d865d90c4ee8..4dd1ad4d95ff 100644 --- a/homeassistant/components/mqtt/sensor.py +++ b/homeassistant/components/mqtt/sensor.py @@ -34,8 +34,8 @@ from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.util import dt as dt_util -from . import MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RO_SCHEMA from .const import CONF_ENCODING, CONF_QOS, CONF_STATE_TOPIC from .debug_info import log_messages from .mixins import ( @@ -47,6 +47,8 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttValueTemplate +from .util import valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -89,12 +91,12 @@ def validate_options(conf): return conf -_PLATFORM_SCHEMA_BASE = mqtt.MQTT_RO_SCHEMA.extend( +_PLATFORM_SCHEMA_BASE = MQTT_RO_SCHEMA.extend( { vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_EXPIRE_AFTER): cv.positive_int, vol.Optional(CONF_FORCE_UPDATE, default=DEFAULT_FORCE_UPDATE): cv.boolean, - vol.Optional(CONF_LAST_RESET_TOPIC): mqtt.valid_subscribe_topic, + vol.Optional(CONF_LAST_RESET_TOPIC): valid_subscribe_topic, vol.Optional(CONF_LAST_RESET_VALUE_TEMPLATE): cv.template, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_STATE_CLASS): STATE_CLASSES_SCHEMA, diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index c3a41c3618e8..1ecf2c37dbf4 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -35,8 +35,8 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttCommandTemplate, MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TEMPLATE, CONF_COMMAND_TOPIC, @@ -57,6 +57,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttCommandTemplate, MqttValueTemplate DEFAULT_NAME = "MQTT Siren" DEFAULT_PAYLOAD_ON = "ON" @@ -74,7 +75,7 @@ STATE = "state" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RW_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_AVAILABLE_TONES): cv.ensure_list, vol.Optional(CONF_COMMAND_TEMPLATE): cv.template, diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index f5f8363eb33f..c20ddfe51519 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -24,8 +24,8 @@ from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -from . import MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_RW_SCHEMA from .const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -43,6 +43,7 @@ async_setup_platform_helper, warn_for_legacy_schema, ) +from .models import MqttValueTemplate DEFAULT_NAME = "MQTT Switch" DEFAULT_PAYLOAD_ON = "ON" @@ -51,7 +52,7 @@ CONF_STATE_ON = "state_on" CONF_STATE_OFF = "state_off" -PLATFORM_SCHEMA_MODERN = mqtt.MQTT_RW_SCHEMA.extend( +PLATFORM_SCHEMA_MODERN = MQTT_RW_SCHEMA.extend( { vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, diff --git a/homeassistant/components/mqtt/tag.py b/homeassistant/components/mqtt/tag.py index 25e49524b8f7..9452d5fc259a 100644 --- a/homeassistant/components/mqtt/tag.py +++ b/homeassistant/components/mqtt/tag.py @@ -11,8 +11,8 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.typing import ConfigType -from . import MqttValueTemplate, subscription -from .. import mqtt +from . import subscription +from .config import MQTT_BASE_SCHEMA from .const import ATTR_DISCOVERY_HASH, CONF_QOS, CONF_TOPIC from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, @@ -21,7 +21,7 @@ send_discovery_done, update_device, ) -from .models import ReceiveMessage +from .models import MqttValueTemplate, ReceiveMessage from .subscription import EntitySubscription from .util import valid_subscribe_topic @@ -30,7 +30,7 @@ TAG = "tag" TAGS = "mqtt_tags" -PLATFORM_SCHEMA = mqtt.MQTT_BASE_SCHEMA.extend( +PLATFORM_SCHEMA = MQTT_BASE_SCHEMA.extend( { vol.Optional(CONF_DEVICE): MQTT_ENTITY_DEVICE_INFO_SCHEMA, vol.Optional(CONF_PLATFORM): "mqtt", diff --git a/homeassistant/components/mqtt/vacuum/schema_legacy.py b/homeassistant/components/mqtt/vacuum/schema_legacy.py index eb5e01b62510..f25131c43b7a 100644 --- a/homeassistant/components/mqtt/vacuum/schema_legacy.py +++ b/homeassistant/components/mqtt/vacuum/schema_legacy.py @@ -15,11 +15,13 @@ import homeassistant.helpers.config_validation as cv from homeassistant.helpers.icon import icon_for_battery_level -from .. import MqttValueTemplate, subscription -from ... import mqtt +from .. import subscription +from ..config import MQTT_BASE_SCHEMA from ..const import CONF_COMMAND_TOPIC, CONF_ENCODING, CONF_QOS, CONF_RETAIN from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, warn_for_legacy_schema +from ..models import MqttValueTemplate +from ..util import valid_publish_topic from .const import MQTT_VACUUM_ATTRIBUTES_BLOCKED from .schema import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services @@ -96,25 +98,23 @@ ) PLATFORM_SCHEMA_LEGACY_MODERN = ( - mqtt.MQTT_BASE_SCHEMA.extend( + MQTT_BASE_SCHEMA.extend( { vol.Inclusive(CONF_BATTERY_LEVEL_TEMPLATE, "battery"): cv.template, - vol.Inclusive( - CONF_BATTERY_LEVEL_TOPIC, "battery" - ): mqtt.valid_publish_topic, + vol.Inclusive(CONF_BATTERY_LEVEL_TOPIC, "battery"): valid_publish_topic, vol.Inclusive(CONF_CHARGING_TEMPLATE, "charging"): cv.template, - vol.Inclusive(CONF_CHARGING_TOPIC, "charging"): mqtt.valid_publish_topic, + vol.Inclusive(CONF_CHARGING_TOPIC, "charging"): valid_publish_topic, vol.Inclusive(CONF_CLEANING_TEMPLATE, "cleaning"): cv.template, - vol.Inclusive(CONF_CLEANING_TOPIC, "cleaning"): mqtt.valid_publish_topic, + vol.Inclusive(CONF_CLEANING_TOPIC, "cleaning"): valid_publish_topic, vol.Inclusive(CONF_DOCKED_TEMPLATE, "docked"): cv.template, - vol.Inclusive(CONF_DOCKED_TOPIC, "docked"): mqtt.valid_publish_topic, + vol.Inclusive(CONF_DOCKED_TOPIC, "docked"): valid_publish_topic, vol.Inclusive(CONF_ERROR_TEMPLATE, "error"): cv.template, - vol.Inclusive(CONF_ERROR_TOPIC, "error"): mqtt.valid_publish_topic, + vol.Inclusive(CONF_ERROR_TOPIC, "error"): valid_publish_topic, vol.Optional(CONF_FAN_SPEED_LIST, default=[]): vol.All( cv.ensure_list, [cv.string] ), vol.Inclusive(CONF_FAN_SPEED_TEMPLATE, "fan_speed"): cv.template, - vol.Inclusive(CONF_FAN_SPEED_TOPIC, "fan_speed"): mqtt.valid_publish_topic, + vol.Inclusive(CONF_FAN_SPEED_TOPIC, "fan_speed"): valid_publish_topic, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional( CONF_PAYLOAD_CLEAN_SPOT, default=DEFAULT_PAYLOAD_CLEAN_SPOT @@ -135,12 +135,12 @@ vol.Optional( CONF_PAYLOAD_TURN_ON, default=DEFAULT_PAYLOAD_TURN_ON ): cv.string, - vol.Optional(CONF_SEND_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_SET_FAN_SPEED_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_SEND_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_SET_FAN_SPEED_TOPIC): valid_publish_topic, vol.Optional( CONF_SUPPORTED_FEATURES, default=DEFAULT_SERVICE_STRINGS ): vol.All(cv.ensure_list, [vol.In(STRING_TO_SERVICE.keys())]), - vol.Optional(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, } ) diff --git a/homeassistant/components/mqtt/vacuum/schema_state.py b/homeassistant/components/mqtt/vacuum/schema_state.py index 7aa7be077974..3d670780994d 100644 --- a/homeassistant/components/mqtt/vacuum/schema_state.py +++ b/homeassistant/components/mqtt/vacuum/schema_state.py @@ -23,7 +23,7 @@ import homeassistant.helpers.config_validation as cv from .. import subscription -from ... import mqtt +from ..config import MQTT_BASE_SCHEMA from ..const import ( CONF_COMMAND_TOPIC, CONF_ENCODING, @@ -33,6 +33,7 @@ ) from ..debug_info import log_messages from ..mixins import MQTT_ENTITY_COMMON_SCHEMA, MqttEntity, warn_for_legacy_schema +from ..util import valid_publish_topic from .const import MQTT_VACUUM_ATTRIBUTES_BLOCKED from .schema import MQTT_VACUUM_SCHEMA, services_to_strings, strings_to_services @@ -105,7 +106,7 @@ DEFAULT_PAYLOAD_PAUSE = "pause" PLATFORM_SCHEMA_STATE_MODERN = ( - mqtt.MQTT_BASE_SCHEMA.extend( + MQTT_BASE_SCHEMA.extend( { vol.Optional(CONF_FAN_SPEED_LIST, default=[]): vol.All( cv.ensure_list, [cv.string] @@ -123,13 +124,13 @@ vol.Optional(CONF_PAYLOAD_START, default=DEFAULT_PAYLOAD_START): cv.string, vol.Optional(CONF_PAYLOAD_PAUSE, default=DEFAULT_PAYLOAD_PAUSE): cv.string, vol.Optional(CONF_PAYLOAD_STOP, default=DEFAULT_PAYLOAD_STOP): cv.string, - vol.Optional(CONF_SEND_COMMAND_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_SET_FAN_SPEED_TOPIC): mqtt.valid_publish_topic, - vol.Optional(CONF_STATE_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_SEND_COMMAND_TOPIC): valid_publish_topic, + vol.Optional(CONF_SET_FAN_SPEED_TOPIC): valid_publish_topic, + vol.Optional(CONF_STATE_TOPIC): valid_publish_topic, vol.Optional( CONF_SUPPORTED_FEATURES, default=DEFAULT_SERVICE_STRINGS ): vol.All(cv.ensure_list, [vol.In(STRING_TO_SERVICE.keys())]), - vol.Optional(CONF_COMMAND_TOPIC): mqtt.valid_publish_topic, + vol.Optional(CONF_COMMAND_TOPIC): valid_publish_topic, vol.Optional(CONF_RETAIN, default=DEFAULT_RETAIN): cv.boolean, } ) @@ -178,7 +179,7 @@ def _setup_from_config(self, config): supported_feature_strings, STRING_TO_SERVICE ) self._fan_speed_list = config[CONF_FAN_SPEED_LIST] - self._command_topic = config.get(mqtt.CONF_COMMAND_TOPIC) + self._command_topic = config.get(CONF_COMMAND_TOPIC) self._set_fan_speed_topic = config.get(CONF_SET_FAN_SPEED_TOPIC) self._send_command_topic = config.get(CONF_SEND_COMMAND_TOPIC) diff --git a/tests/components/mqtt/test_cover.py b/tests/components/mqtt/test_cover.py index 285af765ab46..e130b820c1b6 100644 --- a/tests/components/mqtt/test_cover.py +++ b/tests/components/mqtt/test_cover.py @@ -12,7 +12,7 @@ ATTR_POSITION, ATTR_TILT_POSITION, ) -from homeassistant.components.mqtt import CONF_STATE_TOPIC +from homeassistant.components.mqtt.const import CONF_STATE_TOPIC from homeassistant.components.mqtt.cover import ( CONF_GET_POSITION_TEMPLATE, CONF_GET_POSITION_TOPIC, diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 07c39d70df0f..aa0bfb82608b 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -1123,7 +1123,7 @@ async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_m assert mqtt_client_mock.subscribe.call_count == 1 mqtt_client_mock.on_disconnect(None, None, 0) - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() assert mqtt_client_mock.subscribe.call_count == 2 @@ -1157,7 +1157,7 @@ async def test_restore_all_active_subscriptions_on_reconnect( assert mqtt_client_mock.unsubscribe.call_count == 0 mqtt_client_mock.on_disconnect(None, None, 0) - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0): mqtt_client_mock.on_connect(None, None, None, 0) await hass.async_block_till_done() @@ -1188,7 +1188,7 @@ async def test_logs_error_if_no_connect_broker( ) -@patch("homeassistant.components.mqtt.TIMEOUT_ACK", 0.3) +@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.3) async def test_handle_mqtt_on_callback(hass, caplog, mqtt_mock, mqtt_client_mock): """Test receiving an ACK callback before waiting for it.""" # Simulate an ACK for mid == 1, this will call mqtt_mock._mqtt_handle_mid(mid) @@ -1331,7 +1331,7 @@ async def test_setup_mqtt_client_protocol(hass): """Test MQTT client protocol setup.""" entry = MockConfigEntry( domain=mqtt.DOMAIN, - data={mqtt.CONF_BROKER: "test-broker", mqtt.CONF_PROTOCOL: "3.1"}, + data={mqtt.CONF_BROKER: "test-broker", mqtt.config.CONF_PROTOCOL: "3.1"}, ) with patch("paho.mqtt.client.Client") as mock_client: mock_client.on_connect(return_value=0) @@ -1341,7 +1341,7 @@ async def test_setup_mqtt_client_protocol(hass): assert mock_client.call_args[1]["protocol"] == 3 -@patch("homeassistant.components.mqtt.TIMEOUT_ACK", 0.2) +@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2) async def test_handle_mqtt_timeout_on_callback(hass, caplog): """Test publish without receiving an ACK callback.""" mid = 0 @@ -1486,7 +1486,7 @@ async def wait_birth(topic, payload, qos): """Handle birth message.""" birth.set() - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "birth", wait_birth) mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() @@ -1516,7 +1516,7 @@ async def wait_birth(topic, payload, qos): """Handle birth message.""" birth.set() - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() @@ -1532,7 +1532,7 @@ async def wait_birth(topic, payload, qos): ) async def test_no_birth_message(hass, mqtt_client_mock, mqtt_mock): """Test disabling birth message.""" - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.1): mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() await asyncio.sleep(0.2) @@ -1580,7 +1580,7 @@ async def wait_birth(topic, payload, qos): """Handle birth message.""" birth.set() - with patch("homeassistant.components.mqtt.DISCOVERY_COOLDOWN", 0.1): + with patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.1): await mqtt.async_subscribe(hass, "homeassistant/status", wait_birth) mqtt_client_mock.on_connect(None, None, 0, 0) await hass.async_block_till_done() diff --git a/tests/components/mqtt/test_legacy_vacuum.py b/tests/components/mqtt/test_legacy_vacuum.py index f451079e0f04..b7a3b5f21187 100644 --- a/tests/components/mqtt/test_legacy_vacuum.py +++ b/tests/components/mqtt/test_legacy_vacuum.py @@ -6,7 +6,7 @@ import pytest from homeassistant.components import vacuum -from homeassistant.components.mqtt import CONF_COMMAND_TOPIC +from homeassistant.components.mqtt.const import CONF_COMMAND_TOPIC from homeassistant.components.mqtt.vacuum import schema_legacy as mqttvacuum from homeassistant.components.mqtt.vacuum.schema import services_to_strings from homeassistant.components.mqtt.vacuum.schema_legacy import ( diff --git a/tests/components/mqtt/test_state_vacuum.py b/tests/components/mqtt/test_state_vacuum.py index 3f752f1b5283..c1017446effd 100644 --- a/tests/components/mqtt/test_state_vacuum.py +++ b/tests/components/mqtt/test_state_vacuum.py @@ -6,7 +6,7 @@ import pytest from homeassistant.components import vacuum -from homeassistant.components.mqtt import CONF_COMMAND_TOPIC, CONF_STATE_TOPIC +from homeassistant.components.mqtt.const import CONF_COMMAND_TOPIC, CONF_STATE_TOPIC from homeassistant.components.mqtt.vacuum import CONF_SCHEMA, schema_state as mqttvacuum from homeassistant.components.mqtt.vacuum.const import MQTT_VACUUM_ATTRIBUTES_BLOCKED from homeassistant.components.mqtt.vacuum.schema import services_to_strings From 397c9e71d0e0698bbd06304c158a6530f6473cc2 Mon Sep 17 00:00:00 2001 From: Erik Date: Mon, 30 May 2022 15:43:09 +0200 Subject: [PATCH 2/2] Update integrations depending on MQTT --- homeassistant/components/manual_mqtt/alarm_control_panel.py | 2 +- homeassistant/components/mqtt/__init__.py | 6 +++++- homeassistant/components/mqtt_json/device_tracker.py | 2 +- homeassistant/components/mqtt_room/sensor.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/manual_mqtt/alarm_control_panel.py b/homeassistant/components/manual_mqtt/alarm_control_panel.py index 5b74af49a914..67675a44e22f 100644 --- a/homeassistant/components/manual_mqtt/alarm_control_panel.py +++ b/homeassistant/components/manual_mqtt/alarm_control_panel.py @@ -110,7 +110,7 @@ def _state_schema(state): PLATFORM_SCHEMA = vol.Schema( vol.All( - mqtt.MQTT_BASE_SCHEMA.extend( + mqtt.config.MQTT_BASE_SCHEMA.extend( { vol.Required(CONF_PLATFORM): "manual_mqtt", vol.Optional(CONF_NAME, default=DEFAULT_ALARM_NAME): cv.string, diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 5cb2223a9ac0..e21885d2585a 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -41,14 +41,17 @@ subscribe, ) from .config import CONFIG_SCHEMA_BASE, DEFAULT_VALUES, DEPRECATED_CONFIG_KEYS -from .const import ( +from .const import ( # noqa: F401 ATTR_PAYLOAD, ATTR_QOS, ATTR_RETAIN, ATTR_TOPIC, CONF_BIRTH_MESSAGE, CONF_BROKER, + CONF_COMMAND_TOPIC, CONF_DISCOVERY_PREFIX, + CONF_QOS, + CONF_STATE_TOPIC, CONF_TLS_VERSION, CONF_TOPIC, CONF_WILL_MESSAGE, @@ -68,6 +71,7 @@ from .models import ( # noqa: F401 MqttCommandTemplate, MqttValueTemplate, + PublishPayloadType, ReceiveMessage, ReceivePayloadType, ) diff --git a/homeassistant/components/mqtt_json/device_tracker.py b/homeassistant/components/mqtt_json/device_tracker.py index 505fa3bd8092..1d99e6d7b6f8 100644 --- a/homeassistant/components/mqtt_json/device_tracker.py +++ b/homeassistant/components/mqtt_json/device_tracker.py @@ -35,7 +35,7 @@ extra=vol.ALLOW_EXTRA, ) -PLATFORM_SCHEMA = PARENT_PLATFORM_SCHEMA.extend(mqtt.SCHEMA_BASE).extend( +PLATFORM_SCHEMA = PARENT_PLATFORM_SCHEMA.extend(mqtt.config.SCHEMA_BASE).extend( {vol.Required(CONF_DEVICES): {cv.string: mqtt.valid_subscribe_topic}} ) diff --git a/homeassistant/components/mqtt_room/sensor.py b/homeassistant/components/mqtt_room/sensor.py index 54de561c11e7..276695d8edde 100644 --- a/homeassistant/components/mqtt_room/sensor.py +++ b/homeassistant/components/mqtt_room/sensor.py @@ -43,7 +43,7 @@ vol.Optional(CONF_AWAY_TIMEOUT, default=DEFAULT_AWAY_TIMEOUT): cv.positive_int, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, } -).extend(mqtt.MQTT_RO_SCHEMA.schema) +).extend(mqtt.config.MQTT_RO_SCHEMA.schema) MQTT_PAYLOAD = vol.Schema( vol.All(