Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up reconnects by caching state serialize #93050

Merged
merged 2 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 75 additions & 47 deletions homeassistant/components/websocket_api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

from collections.abc import Callable
from contextlib import suppress
import datetime as dt
from functools import lru_cache
import json
Expand Down Expand Up @@ -50,6 +49,17 @@
from .connection import ActiveConnection
from .const import ERR_NOT_FOUND

_STATES_TEMPLATE = "__STATES__"
_STATES_JSON_TEMPLATE = '"__STATES__"'
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE = JSON_DUMP(
messages.event_message(
messages.IDEN_TEMPLATE, {messages.ENTITY_EVENT_ADD: _STATES_TEMPLATE}
)
)
_HANDLE_GET_STATES_TEMPLATE = JSON_DUMP(
messages.result_message(messages.IDEN_TEMPLATE, _STATES_TEMPLATE)
)


@callback
def async_register_commands(
Expand Down Expand Up @@ -242,33 +252,43 @@ def handle_get_states(
"""Handle get states command."""
states = _async_get_allowed_states(hass, connection)

# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.result_message(msg["id"], states)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [state.as_dict_json() for state in states]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_get_states_response(connection, msg["id"], serialized_states)
return

# If we can't serialize, we'll filter out unserializable states
serialized = []
serialized_states = []
for state in states:
# Error is already logged above
with suppress(ValueError, TypeError):
serialized.append(JSON_DUMP(state))
try:
serialized_states.append(state.as_dict_json())
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)

_send_handle_get_states_response(connection, msg["id"], serialized_states)

# We now have partially serialized states. Craft some JSON.
response2 = JSON_DUMP(messages.result_message(msg["id"], ["TO_REPLACE"]))
response2 = response2.replace('"TO_REPLACE"', ", ".join(serialized))
connection.send_message(response2)

def _send_handle_get_states_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
bdraco marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Send handle get states response."""
connection.send_message(
_HANDLE_GET_STATES_TEMPLATE.replace(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future PR I think we can make this a bit more efficient by making a function and can construct event and result messages without replace. I'm going to explore that and if successful will do a PR that does that for every place we do this. I am not 100% sure it will work, just an idea at this point

messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"[" + ",".join(serialized_states) + "]",
1,
)
)


@callback
Expand Down Expand Up @@ -304,42 +324,50 @@ def forward_entity_changes(event: Event) -> None:
EVENT_STATE_CHANGED, forward_entity_changes, run_immediately=True
)
connection.send_result(msg["id"])
data: dict[str, dict[str, dict]] = {
messages.ENTITY_EVENT_ADD: {
state.entity_id: state.as_compressed_state()
for state in states
if not entity_ids or state.entity_id in entity_ids
}
}

# JSON serialize here so we can recover if it blows up due to the
# state machine containing unserializable data. This command is required
# to succeed for the UI to show.
response = messages.event_message(msg["id"], data)
try:
connection.send_message(JSON_DUMP(response))
return
serialized_states = [
state.as_compressed_state_json()
for state in states
if not entity_ids or state.entity_id in entity_ids
]
except (ValueError, TypeError):
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(response, dump=JSON_DUMP)
),
)
del response
pass
else:
_send_handle_entities_init_response(connection, msg["id"], serialized_states)
return

add_entities = data[messages.ENTITY_EVENT_ADD]
cannot_serialize: list[str] = []
for entity_id, state_dict in add_entities.items():
serialized_states = []
for state in states:
try:
JSON_DUMP(state_dict)
serialized_states.append(state.as_compressed_state_json())
except (ValueError, TypeError):
cannot_serialize.append(entity_id)
connection.logger.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(state, dump=JSON_DUMP)
),
)

for entity_id in cannot_serialize:
del add_entities[entity_id]
_send_handle_entities_init_response(connection, msg["id"], serialized_states)

connection.send_message(JSON_DUMP(messages.event_message(msg["id"], data)))

def _send_handle_entities_init_response(
connection: ActiveConnection, msg_id: int, serialized_states: list[str]
) -> None:
"""Send handle entities init response."""
connection.send_message(
_HANDLE_SUBSCRIBE_ENTITIES_TEMPLATE.replace(
messages.IDEN_JSON_TEMPLATE, str(msg_id), 1
).replace(
_STATES_JSON_TEMPLATE,
"{" + ",".join(serialized_states) + "}",
1,
)
)


@decorators.websocket_command({vol.Required("type"): "get_services"})
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/websocket_api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ENTITY_EVENT_CHANGE = "c"


def result_message(iden: int, result: Any = None) -> dict[str, Any]:
def result_message(iden: JSON_TYPE | int, result: Any = None) -> dict[str, Any]:
"""Return a success result message."""
return {"id": iden, "type": const.TYPE_RESULT, "success": True, "result": result}

Expand Down
24 changes: 24 additions & 0 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
Unauthorized,
)
from .helpers.aiohttp_compat import restore_original_aiohttp_cancel_behavior
from .helpers.json import json_dumps
from .util import dt as dt_util, location, ulid as ulid_util
from .util.async_ import run_callback_threadsafe, shutdown_run_callback_threadsafe
from .util.read_only_dict import ReadOnlyDict
Expand Down Expand Up @@ -1224,6 +1225,8 @@ class State:
"object_id",
"_as_dict",
"_as_compressed_state",
"_as_dict_json",
"_as_compressed_state_json",
)

def __init__(
Expand Down Expand Up @@ -1260,6 +1263,8 @@ def __init__(
self.domain, self.object_id = split_entity_id(self.entity_id)
self._as_dict: ReadOnlyDict[str, Collection[Any]] | None = None
self._as_compressed_state: dict[str, Any] | None = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can drop the unserialized cache in a future PR to save a small bit of RAM but need to verify it's always covered by the serialized cache.

self._as_dict_json: str | None = None
self._as_compressed_state_json: str | None = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -1294,6 +1299,12 @@ def as_dict(self) -> ReadOnlyDict[str, Collection[Any]]:
)
return self._as_dict

def as_dict_json(self) -> str:
"""Return a JSON string of the State."""
if not self._as_dict_json:
self._as_dict_json = json_dumps(self.as_dict())
return self._as_dict_json

def as_compressed_state(self) -> dict[str, Any]:
"""Build a compressed dict of a state for adds.

Expand Down Expand Up @@ -1321,6 +1332,19 @@ def as_compressed_state(self) -> dict[str, Any]:
self._as_compressed_state = compressed_state
return compressed_state

def as_compressed_state_json(self) -> str:
"""Build a compressed JSON key value pair of a state for adds.

The JSON string is a key value pair of the entity_id and the compressed state.

It is used for sending multiple states in a single message.
"""
if not self._as_compressed_state_json:
self._as_compressed_state_json = json_dumps(
{self.entity_id: self.as_compressed_state()}
)[1:-1]
return self._as_compressed_state_json

@classmethod
def from_dict(cls, json_dict: dict[str, Any]) -> Self | None:
"""Initialize a state from a dict.
Expand Down
6 changes: 5 additions & 1 deletion homeassistant/helpers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import orjson

from homeassistant.core import Event, State
from homeassistant.util.file import write_utf8_file, write_utf8_file_atomic
from homeassistant.util.json import ( # pylint: disable=unused-import # noqa: F401
JSON_DECODE_EXCEPTIONS,
Expand Down Expand Up @@ -189,6 +188,11 @@ def find_paths_unserializable_data(

This method is slow! Only use for error handling.
"""
from homeassistant.core import ( # pylint: disable=import-outside-toplevel
Event,
State,
)

to_process = deque([(bad_data, "$")])
invalid = {}

Expand Down
7 changes: 3 additions & 4 deletions tests/components/websocket_api/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,9 @@ async def test_non_json_message(
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == []
assert (
f"Unable to serialize to JSON. Bad data found at $.result[0](State: test_domain.entity).attributes.bad={bad_data}(<class 'object'>"
in caplog.text
)
assert "Unable to serialize to JSON. Bad data found" in caplog.text
assert "State: test_domain.entity" in caplog.text
assert "bad=<object" in caplog.text


async def test_prepare_fail(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,29 @@ def test_state_as_dict() -> None:
assert state.as_dict() is as_dict_1


def test_state_as_dict_json() -> None:
"""Test a State as JSON."""
last_time = datetime(1984, 12, 8, 12, 0, 0)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6K3RFJAYAV2093ZW30PCW"),
)
expected = (
'{"entity_id":"happy.happy","state":"on","attributes":{"pig":"dog"},'
'"last_changed":"1984-12-08T12:00:00","last_updated":"1984-12-08T12:00:00",'
'"context":{"id":"01H0D6K3RFJAYAV2093ZW30PCW","parent_id":null,"user_id":null}}'
)
as_dict_json_1 = state.as_dict_json()
assert as_dict_json_1 == expected
# 2nd time to verify cache
assert state.as_dict_json() == expected
assert state.as_dict_json() is as_dict_json_1


def test_state_as_compressed_state() -> None:
"""Test a State as compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
Expand Down Expand Up @@ -518,6 +541,27 @@ def test_state_as_compressed_state_unique_last_updated() -> None:
assert state.as_compressed_state() is as_compressed_state


def test_state_as_compressed_state_json() -> None:
"""Test a State as a JSON compressed state."""
last_time = datetime(1984, 12, 8, 12, 0, 0, tzinfo=dt_util.UTC)
state = ha.State(
"happy.happy",
"on",
{"pig": "dog"},
last_updated=last_time,
last_changed=last_time,
context=ha.Context(id="01H0D6H5K3SZJ3XGDHED1TJ79N"),
)
expected = '"happy.happy":{"s":"on","a":{"pig":"dog"},"c":"01H0D6H5K3SZJ3XGDHED1TJ79N","lc":471355200.0}'
as_compressed_state = state.as_compressed_state_json()
# We are not too concerned about these being ReadOnlyDict
# since we don't expect them to be called by external callers
assert as_compressed_state == expected
# 2nd time to verify cache
assert state.as_compressed_state_json() == expected
assert state.as_compressed_state_json() is as_compressed_state


async def test_eventbus_add_remove_listener(hass: HomeAssistant) -> None:
"""Test remove_listener method."""
old_count = len(hass.bus.async_listeners())
Expand Down