Skip to content

Commit

Permalink
Speed up reconnects by caching state serialize (#93050)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed May 16, 2023
1 parent 9c039a1 commit 99265a9
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 53 deletions.
122 changes: 75 additions & 47 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from collections.abc import Callable
from contextlib import suppress
import datetime as dt
from functools import lru_cache
import json
Expand Down Expand Up @@ -50,6 +49,17 @@
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND

_STATES_TEMPLATE = "__STATES__"
_STATES_JSON_TEMPLATE = '"__STATES__"'
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
messages.event_message(
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
)
)
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
)


@callback
def async_register_commands(
Expand Down Expand Up @@ -242,33 +252,43 @@ def handle_get_states(
"""Handle get states command."""
states = _async_get_allowed_states(hass, connection)

# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [state.as_dict_json() for state in states]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_get_states_response(connection, msg["id"], serialized_states)
return

# If we can't serialize, we'll filter out unserializable states
serialized = []
serialized_states = []
for state in states:
# Error is already logged above
with suppress(ValueError, TypeError):
serialized.append(JSON_DUMP(state))
try:
serialized_states.append(state.as_dict_json())
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)

_send_handle_get_states_response(connection, msg["id"], serialized_states)

# We now have partially serialized states. Craft some JSON.
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)

def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle get states response."""
connection.send_message(
_HANDLE_GET_STATES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"[" + ",".join(serialized_states) + "]",
1,
)
)


@callback
Expand Down Expand Up @@ -304,42 +324,50 @@ def forward_entity_changes(event: Event) -> None:
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
)
connection.send_result(msg["id"])
data: dict[str, dict[str, dict]] = {
messages.ENTITY_EVENT_ADD: {
state.entity_id: state.as_compressed_state()
for state in states
if not entity_ids or state.entity_id in entity_ids
}
}

# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [
state.as_compressed_state_json()
for state in states
if not entity_ids or state.entity_id in entity_ids
]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
return

add_entities = data[messages.ENTITY_EVENT_ADD]
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
serialized_states = []
for state in states:
try:
JSON_DUMP(state_dict)
serialized_states.append(state.as_compressed_state_json())
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)

for entity_id in cannot_serialize:
del add_entities[entity_id]
_send_handle_entities_init_response(connection, msg["id"], serialized_states)

connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))

def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle entities init response."""
connection.send_message(
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"{" + ",".join(serialized_states) + "}",
1,
)
)


@decorators.websocket_command({vol.Required("type"): "get_services"})
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/websocket_api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ENTITY_EVENT_CHANGE = "c"


def result_message(iden: int, result: Any = None) -> dict[str, Any]:
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}

Expand Down
24 changes: 24 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
Unauthorized,
)
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
from .helpers.json import json_dumps
from .util import dt as dt_util, location, ulid as ulid_util
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
from .util.read_only_dict import ReadOnlyDict
Expand Down Expand Up @@ -1224,6 +1225,8 @@ class State:
"object_id",
"_as_dict",
"_as_compressed_state",
"_as_dict_json",
"_as_compressed_state_json",
)

def __init__(
Expand Down Expand Up @@ -1260,6 +1263,8 @@ def __init__(
self.domain, self.object_id = split_entity_id(self.entity_id)
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
self._as_compressed_state: dict[str, Any] | None = None
self._as_dict_json: str | None = None
self._as_compressed_state_json: str | None = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -1294,6 +1299,12 @@ def as_dict(self) -> ReadOnlyDict[str, Collection[Any]]:
)
return self._as_dict

def as_dict_json(self) -> str:
"""Return a JSON string of the State."""
if not self._as_dict_json:
self._as_dict_json = json_dumps(self.as_dict())
return self._as_dict_json

def as_compressed_state(self) -> dict[str, Any]:
"""Build a compressed dict of a state for adds.
Expand Down Expand Up @@ -1321,6 +1332,19 @@ def as_compressed_state(self) -> dict[str, Any]:
self._as_compressed_state = compressed_state
return compressed_state

def as_compressed_state_json(self) -> str:
"""Build a compressed JSON key value pair of a state for adds.
The JSON string is a key value pair of the entity_id and the compressed state.
It is used for sending multiple states in a single message.
"""
if not self._as_compressed_state_json:
self._as_compressed_state_json = json_dumps(
{self.entity_id: self.as_compressed_state()}
)[1:-1]
return self._as_compressed_state_json

@classmethod
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
"""Initialize a state from a dict.
Expand Down
6 changes: 5 additions & 1 deletion homeassistant/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import orjson

from homeassistant.core import Event, State
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
JSON_DECODE_EXCEPTIONS,
Expand Down Expand Up @@ -189,6 +188,11 @@ def find_paths_unserializable_data(
This method is slow! Only use for error handling.
"""
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
Event,
State,
)

to_process = deque([(bad_data, "$")])
invalid = {}

Expand Down
7 changes: 3 additions & 4 deletions tests/components/websocket_api/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,9 @@ async def test_non_json_message(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == []
assert (
f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(<class 'object'>"
in caplog.text
)
assert "Unable to serialize to JSON. Bad data found" in caplog.text
assert "State: test_domain.entity" in caplog.text
assert "bad=<object" in caplog.text


async def test_prepare_fail(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,29 @@ def test_state_as_dict() -> None:
assert state.as_dict() is as_dict_1


def test_state_as_dict_json() -> None:
"""Test a State as JSON."""
last_time = datetime(1984, 12, 8, 12, 0, 0)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
)
expected = (
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
)
as_dict_json_1 = state.as_dict_json()
assert as_dict_json_1 == expected
# 2nd time to verify cache
assert state.as_dict_json() == expected
assert state.as_dict_json() is as_dict_json_1


def test_state_as_compressed_state() -> None:
"""Test a State as compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
Expand Down Expand Up @@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None:
assert state.as_compressed_state() is as_compressed_state


def test_state_as_compressed_state_json() -> None:
"""Test a State as a JSON compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
)
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
as_compressed_state = state.as_compressed_state_json()
# We are not too concerned about these being ReadOnlyDict
# since we don't expect them to be called by external callers
assert as_compressed_state == expected
# 2nd time to verify cache
assert state.as_compressed_state_json() == expected
assert state.as_compressed_state_json() is as_compressed_state


async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
"""Test remove_listener method."""
old_count = len(hass.bus.async_listeners())
Expand Down

0 comments on commit 99265a9

Please sign in to comment.