Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WS API to adjust incorrect energy statistics #65147

Merged
merged 6 commits into from Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions homeassistant/components/recorder/__init__.py
Expand Up @@ -462,6 +462,31 @@ def run(self, instance: Recorder) -> None:
instance.queue.put(ExternalStatisticsTask(self.metadata, self.statistics))


@dataclass
class AdjustStatisticsTask(RecorderTask):
"""An object to insert into the recorder queue to run an adjust statistics task."""

statistic_id: str
start_time: datetime
sum_adjustment: float

def run(self, instance: Recorder) -> None:
"""Run statistics task."""
if statistics.adjust_statistics(
instance,
self.statistic_id,
self.start_time,
self.sum_adjustment,
):
return
# Schedule a new adjust statistics task if this one didn't finish
instance.queue.put(
AdjustStatisticsTask(
self.statistic_id, self.start_time, self.sum_adjustment
)
)


@dataclass
class WaitTask(RecorderTask):
"""An object to insert into the recorder queue to tell it set the _queue_watch event."""
Expand Down Expand Up @@ -761,6 +786,11 @@ def async_periodic_statistics(self, now):
start = statistics.get_start_time()
self.queue.put(StatisticsTask(start))

@callback
def async_adjust_statistics(self, statistic_id, start_time, sum_adjustment):
"""Adjust statistics."""
self.queue.put(AdjustStatisticsTask(statistic_id, start_time, sum_adjustment))

@callback
def async_clear_statistics(self, statistic_ids):
"""Clear statistics for a list of statistic_ids."""
Expand Down
69 changes: 67 additions & 2 deletions homeassistant/components/recorder/statistics.py
Expand Up @@ -19,6 +19,7 @@
from sqlalchemy.ext import baked
from sqlalchemy.orm.scoping import scoped_session
from sqlalchemy.sql.expression import literal_column, true
import voluptuous as vol

from homeassistant.const import (
PRESSURE_PA,
Expand Down Expand Up @@ -163,6 +164,14 @@ def valid_statistic_id(statistic_id: str) -> bool:
return VALID_STATISTIC_ID.match(statistic_id) is not None


def validate_statistic_id(value: str) -> str:
"""Validate statistic ID."""
if valid_statistic_id(value):
return value

raise vol.Invalid(f"Statistics ID {value} is an invalid statistic ID")


@dataclasses.dataclass
class ValidationIssue:
"""Error or warning message."""
Expand Down Expand Up @@ -567,6 +576,30 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
return True


def _adjust_sum_statistics(
session: scoped_session,
table: type[Statistics | StatisticsShortTerm],
metadata_id: int,
start_time: datetime,
adj: float,
) -> None:
"""Adjust statistics in the database."""
try:
session.query(table).filter_by(metadata_id=metadata_id).filter(
table.start >= start_time
).update(
{
table.sum: table.sum + adj,
},
synchronize_session=False,
)
except SQLAlchemyError:
_LOGGER.exception(
"Unexpected exception when updating statistics %s",
id,
)


def _insert_statistics(
session: scoped_session,
table: type[Statistics | StatisticsShortTerm],
Expand Down Expand Up @@ -606,7 +639,7 @@ def _update_statistics(
except SQLAlchemyError:
_LOGGER.exception(
"Unexpected exception when updating statistics %s:%s ",
id,
stat_id,
statistic,
)

Expand Down Expand Up @@ -1249,7 +1282,7 @@ def add_external_statistics(
metadata: StatisticMetaData,
statistics: Iterable[StatisticData],
) -> bool:
"""Process an add_statistics job."""
"""Process an add_external_statistics job."""

with session_scope(
session=instance.get_session(), # type: ignore[misc]
Expand All @@ -1265,3 +1298,35 @@ def add_external_statistics(
_insert_statistics(session, Statistics, metadata_id, stat)

return True


@retryable_database_job("adjust_statistics")
def adjust_statistics(
instance: Recorder,
statistic_id: str,
start_time: datetime,
sum_adjustment: float,
) -> bool:
"""Process an add_statistics job."""

with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
metadata = get_metadata_with_session(
instance.hass, session, statistic_ids=(statistic_id,)
)
if statistic_id not in metadata:
return True

tables: tuple[type[Statistics | StatisticsShortTerm], ...] = (
Statistics,
StatisticsShortTerm,
)
for table in tables:
_adjust_sum_statistics(
session,
table,
metadata[statistic_id][0],
start_time,
sum_adjustment,
)

return True
30 changes: 30 additions & 0 deletions homeassistant/components/recorder/websocket_api.py
Expand Up @@ -8,6 +8,7 @@

from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util

from .const import DATA_INSTANCE, MAX_QUEUE_BACKLOG
from .statistics import list_statistic_ids, validate_statistics
Expand All @@ -29,6 +30,7 @@ def async_setup(hass: HomeAssistant) -> None:
websocket_api.async_register_command(hass, ws_info)
websocket_api.async_register_command(hass, ws_backup_start)
websocket_api.async_register_command(hass, ws_backup_end)
websocket_api.async_register_command(hass, ws_adjust_sum_statistics)


@websocket_api.websocket_command(
Expand Down Expand Up @@ -105,6 +107,34 @@ def ws_update_statistics_metadata(
connection.send_result(msg["id"])


@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "recorder/adjust_sum_statistics",
vol.Required("statistic_id"): str,
vol.Required("start_time"): str,
vol.Required("adjustment"): vol.Any(float, int),
}
)
@callback
def ws_adjust_sum_statistics(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
) -> None:
"""Adjust sum statistics."""
start_time_str = msg["start_time"]

if start_time := dt_util.parse_datetime(start_time_str):
start_time = dt_util.as_utc(start_time)
else:
connection.send_error(msg["id"], "invalid_start_time", "Invalid start time")
return

hass.data[DATA_INSTANCE].async_adjust_statistics(
msg["statistic_id"], start_time, msg["adjustment"]
)
connection.send_result(msg["id"])


@websocket_api.websocket_command(
{
vol.Required("type"): "recorder/info",
Expand Down
62 changes: 55 additions & 7 deletions tests/components/recorder/test_statistics.py
Expand Up @@ -34,7 +34,13 @@
from homeassistant.setup import setup_component
import homeassistant.util.dt as dt_util

from tests.common import get_test_home_assistant, mock_registry
from .common import async_wait_recording_done_without_instance

from tests.common import (
async_init_recorder_component,
get_test_home_assistant,
mock_registry,
)
from tests.components.recorder.common import wait_recording_done

ORIG_TZ = dt_util.DEFAULT_TIME_ZONE
Expand Down Expand Up @@ -327,10 +333,11 @@ def test_statistics_duplicated(hass_recorder, caplog):
caplog.clear()


def test_external_statistics(hass_recorder, caplog):
async def test_external_statistics(hass, hass_ws_client, caplog):
"""Test inserting external statistics."""
hass = hass_recorder()
wait_recording_done(hass)
client = await hass_ws_client()
await async_init_recorder_component(hass)

assert "Compiling statistics for" not in caplog.text
assert "Statistics already compiled" not in caplog.text

Expand Down Expand Up @@ -363,7 +370,7 @@ def test_external_statistics(hass_recorder, caplog):
async_add_external_statistics(
hass, external_metadata, (external_statistics1, external_statistics2)
)
wait_recording_done(hass)
await async_wait_recording_done_without_instance(hass)
stats = statistics_during_period(hass, zero, period="hour")
assert stats == {
"test:total_energy_import": [
Expand Down Expand Up @@ -439,7 +446,7 @@ def test_external_statistics(hass_recorder, caplog):
"sum": 6,
}
async_add_external_statistics(hass, external_metadata, (external_statistics,))
wait_recording_done(hass)
await async_wait_recording_done_without_instance(hass)
stats = statistics_during_period(hass, zero, period="hour")
assert stats == {
"test:total_energy_import": [
Expand Down Expand Up @@ -479,7 +486,7 @@ def test_external_statistics(hass_recorder, caplog):
"sum": 5,
}
async_add_external_statistics(hass, external_metadata, (external_statistics,))
wait_recording_done(hass)
await async_wait_recording_done_without_instance(hass)
stats = statistics_during_period(hass, zero, period="hour")
assert stats == {
"test:total_energy_import": [
Expand Down Expand Up @@ -508,6 +515,47 @@ def test_external_statistics(hass_recorder, caplog):
]
}

await client.send_json(
{
"id": 1,
"type": "recorder/adjust_sum_statistics",
"statistic_id": "test:total_energy_import",
"start_time": period2.isoformat(),
"adjustment": 1000.0,
}
)
response = await client.receive_json()
assert response["success"]

await async_wait_recording_done_without_instance(hass)
stats = statistics_during_period(hass, zero, period="hour")
assert stats == {
"test:total_energy_import": [
{
"statistic_id": "test:total_energy_import",
"start": period1.isoformat(),
"end": (period1 + timedelta(hours=1)).isoformat(),
"max": approx(1.0),
"mean": approx(2.0),
"min": approx(3.0),
"last_reset": None,
"state": approx(4.0),
"sum": approx(5.0),
},
{
"statistic_id": "test:total_energy_import",
"start": period2.isoformat(),
"end": (period2 + timedelta(hours=1)).isoformat(),
"max": None,
"mean": None,
"min": None,
"last_reset": None,
"state": approx(1.0),
"sum": approx(1003.0),
},
]
}


def test_external_statistics_errors(hass_recorder, caplog):
"""Test validation of external statistics."""
Expand Down