Skip to content

Commit

Permalink
Avoid generating matchers that will never be used in MQTT (#118068)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored May 25, 2024
1 parent fa1ef8b commit 65a7027
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions homeassistant/components/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,13 @@ def remove() -> None:
return remove


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class Subscription:
"""Class to hold data about an active subscription."""

topic: str
matcher: Any
is_simple_match: bool
complex_matcher: Callable[[str], bool] | None
job: HassJob[[ReceiveMessage], Coroutine[Any, Any, None] | None]
qos: int = 0
encoding: str | None = "utf-8"
Expand Down Expand Up @@ -312,11 +313,6 @@ def client(self) -> mqtt.Client:
return self._client


def _is_simple_match(topic: str) -> bool:
"""Return if a topic is a simple match."""
return not ("+" in topic or "#" in topic)


class EnsureJobAfterCooldown:
"""Ensure a cool down period before executing a job.
Expand Down Expand Up @@ -788,7 +784,7 @@ def _async_track_subscription(self, subscription: Subscription) -> None:
The caller is responsible clearing the cache of _matching_subscriptions.
"""
if _is_simple_match(subscription.topic):
if subscription.is_simple_match:
self._simple_subscriptions.setdefault(subscription.topic, []).append(
subscription
)
Expand All @@ -805,7 +801,7 @@ def _async_untrack_subscription(self, subscription: Subscription) -> None:
"""
topic = subscription.topic
try:
if _is_simple_match(topic):
if subscription.is_simple_match:
simple_subscriptions = self._simple_subscriptions
simple_subscriptions[topic].remove(subscription)
if not simple_subscriptions[topic]:
Expand Down Expand Up @@ -846,8 +842,11 @@ async def async_subscribe(
if not isinstance(topic, str):
raise HomeAssistantError("Topic needs to be a string!")

is_simple_match = not ("+" in topic or "#" in topic)
matcher = None if is_simple_match else _matcher_for_topic(topic)

subscription = Subscription(
topic, _matcher_for_topic(topic), HassJob(msg_callback), qos, encoding
topic, is_simple_match, matcher, HassJob(msg_callback), qos, encoding
)
self._async_track_subscription(subscription)
self._matching_subscriptions.cache_clear()
Expand Down Expand Up @@ -1053,7 +1052,9 @@ def _matching_subscriptions(self, topic: str) -> list[Subscription]:
subscriptions.extend(
subscription
for subscription in self._wildcard_subscriptions
if subscription.matcher(topic)
# mypy doesn't know that complex_matcher is always set when
# is_simple_match is False
if subscription.complex_matcher(topic) # type: ignore[misc]
)
return subscriptions

Expand Down Expand Up @@ -1241,7 +1242,7 @@ def _raise_on_error(result_code: int) -> None:
raise HomeAssistantError(f"Error talking to MQTT: {message}")


def _matcher_for_topic(subscription: str) -> Any:
def _matcher_for_topic(subscription: str) -> Callable[[str], bool]:
# pylint: disable-next=import-outside-toplevel
from paho.mqtt.matcher import MQTTMatcher

Expand Down

0 comments on commit 65a7027

Please sign in to comment.