Skip to content

Commit

Permalink
Make unit converter use a factory to avoid looking up the ratios each…
Browse files Browse the repository at this point in the history
… conversion (#93706)
  • Loading branch information
bdraco committed May 29, 2023
1 parent 7f3f2ee commit 2f1f32f
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 59 deletions.
61 changes: 26 additions & 35 deletions homeassistant/components/recorder/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,50 +219,34 @@ def _get_statistic_to_display_unit_converter(
if display_unit == statistic_unit:
return None

convert = converter.convert

def _from_normalized_unit(val: float | None) -> float | None:
"""Return val."""
if val is None:
return val
return convert(val, statistic_unit, display_unit)

return _from_normalized_unit
return converter.converter_factory_allow_none(
from_unit=statistic_unit, to_unit=display_unit
)


def _get_display_to_statistic_unit_converter(
display_unit: str | None,
statistic_unit: str | None,
) -> Callable[[float], float]:
) -> Callable[[float], float] | None:
"""Prepare a converter from the display unit to the statistics unit."""

def no_conversion(val: float) -> float:
"""Return val."""
return val

if (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None:
return no_conversion

return partial(converter.convert, from_unit=display_unit, to_unit=statistic_unit)
if (
display_unit == statistic_unit
or (converter := STATISTIC_UNIT_TO_UNIT_CONVERTER.get(statistic_unit)) is None
):
return None
return converter.converter_factory(from_unit=display_unit, to_unit=statistic_unit)


def _get_unit_converter(
from_unit: str, to_unit: str
) -> Callable[[float | None], float | None]:
) -> Callable[[float | None], float | None] | None:
"""Prepare a converter from a unit to another unit."""

def convert_units(
val: float | None, conv: type[BaseUnitConverter], from_unit: str, to_unit: str
) -> float | None:
"""Return converted val."""
if val is None:
return val
return conv.convert(val, from_unit=from_unit, to_unit=to_unit)

for conv in STATISTIC_UNIT_TO_UNIT_CONVERTER.values():
if from_unit in conv.VALID_UNITS and to_unit in conv.VALID_UNITS:
return partial(
convert_units, conv=conv, from_unit=from_unit, to_unit=to_unit
if from_unit == to_unit:
return None
return conv.converter_factory_allow_none(
from_unit=from_unit, to_unit=to_unit
)
raise HomeAssistantError

Expand Down Expand Up @@ -2290,10 +2274,10 @@ def adjust_statistics(
return True

statistic_unit = metadata[statistic_id][1]["unit_of_measurement"]
convert = _get_display_to_statistic_unit_converter(
if convert := _get_display_to_statistic_unit_converter(
adjustment_unit, statistic_unit
)
sum_adjustment = convert(sum_adjustment)
):
sum_adjustment = convert(sum_adjustment)

_adjust_sum_statistics(
session,
Expand Down Expand Up @@ -2360,7 +2344,14 @@ def change_statistics_unit(

metadata_id = metadata[0]

convert = _get_unit_converter(old_unit, new_unit)
if not (convert := _get_unit_converter(old_unit, new_unit)):
_LOGGER.warning(
"Statistics unit of measurement for %s is already %s",
statistic_id,
new_unit,
)
return

tables: tuple[type[StatisticsBase], ...] = (
Statistics,
StatisticsShortTerm,
Expand Down
107 changes: 83 additions & 24 deletions homeassistant/util/unit_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Typing Helpers for Home Assistant."""
from __future__ import annotations

from collections.abc import Callable
from functools import lru_cache

from homeassistant.const import (
CONCENTRATION_PARTS_PER_BILLION,
CONCENTRATION_PARTS_PER_MILLION,
Expand Down Expand Up @@ -67,30 +70,49 @@ class BaseUnitConverter:
@classmethod
def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float:
"""Convert one unit of measurement to another."""
if from_unit == to_unit:
return value
return cls.converter_factory(from_unit, to_unit)(value)

try:
from_ratio = cls._UNIT_CONVERSION[from_unit]
except KeyError as err:
raise HomeAssistantError(
UNIT_NOT_RECOGNIZED_TEMPLATE.format(from_unit, cls.UNIT_CLASS)
) from err
@classmethod
@lru_cache
def converter_factory(
cls, from_unit: str | None, to_unit: str | None
) -> Callable[[float], float]:
"""Return a function to convert one unit of measurement to another."""
if from_unit == to_unit:
return lambda value: value
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
return lambda val: (val / from_ratio) * to_ratio

@classmethod
def _get_from_to_ratio(
cls, from_unit: str | None, to_unit: str | None
) -> tuple[float, float]:
"""Get unit ratio between units of measurement."""
unit_conversion = cls._UNIT_CONVERSION
try:
to_ratio = cls._UNIT_CONVERSION[to_unit]
return unit_conversion[from_unit], unit_conversion[to_unit]
except KeyError as err:
raise HomeAssistantError(
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
UNIT_NOT_RECOGNIZED_TEMPLATE.format(err.args[0], cls.UNIT_CLASS)
) from err

new_value = value / from_ratio
return new_value * to_ratio
@classmethod
@lru_cache
def converter_factory_allow_none(
cls, from_unit: str | None, to_unit: str | None
) -> Callable[[float | None], float | None]:
"""Return a function to convert one unit of measurement to another which allows None."""
if from_unit == to_unit:
return lambda value: value
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
return lambda val: None if val is None else (val / from_ratio) * to_ratio

@classmethod
@lru_cache
def get_unit_ratio(cls, from_unit: str | None, to_unit: str | None) -> float:
"""Get unit ratio between units of measurement."""
return cls._UNIT_CONVERSION[from_unit] / cls._UNIT_CONVERSION[to_unit]
from_ratio, to_ratio = cls._get_from_to_ratio(from_unit, to_unit)
return from_ratio / to_ratio


class DataRateConverter(BaseUnitConverter):
Expand Down Expand Up @@ -339,7 +361,37 @@ class TemperatureConverter(BaseUnitConverter):
}

@classmethod
def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> float:
@lru_cache(maxsize=8)
def converter_factory(
cls, from_unit: str | None, to_unit: str | None
) -> Callable[[float], float]:
"""Return a function to convert a temperature from one unit to another."""
if from_unit == to_unit:
# Return a function that does nothing. This is not
# in _converter_factory because we do not want to wrap
# it with the None check in converter_factory_allow_none.
return lambda value: value

return cls._converter_factory(from_unit, to_unit)

@classmethod
@lru_cache(maxsize=8)
def converter_factory_allow_none(
cls, from_unit: str | None, to_unit: str | None
) -> Callable[[float | None], float | None]:
"""Return a function to convert a temperature from one unit to another which allows None."""
if from_unit == to_unit:
# Return a function that does nothing. This is not
# in _converter_factory because we do not want to wrap
# it with the None check in this case.
return lambda value: value
convert = cls._converter_factory(from_unit, to_unit)
return lambda value: None if value is None else convert(value)

@classmethod
def _converter_factory(
cls, from_unit: str | None, to_unit: str | None
) -> Callable[[float], float]:
"""Convert a temperature from one unit to another.
eg. 10°C will return 50°F
Expand All @@ -349,32 +401,29 @@ def convert(cls, value: float, from_unit: str | None, to_unit: str | None) -> fl
"""
# We cannot use the implementation from BaseUnitConverter here because the
# temperature units do not use the same floor: 0°C, 0°F and 0K do not align
if from_unit == to_unit:
return value

if from_unit == UnitOfTemperature.CELSIUS:
if to_unit == UnitOfTemperature.FAHRENHEIT:
return cls._celsius_to_fahrenheit(value)
return cls._celsius_to_fahrenheit
if to_unit == UnitOfTemperature.KELVIN:
return cls._celsius_to_kelvin(value)
return cls._celsius_to_kelvin
raise HomeAssistantError(
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
)

if from_unit == UnitOfTemperature.FAHRENHEIT:
if to_unit == UnitOfTemperature.CELSIUS:
return cls._fahrenheit_to_celsius(value)
return cls._fahrenheit_to_celsius
if to_unit == UnitOfTemperature.KELVIN:
return cls._celsius_to_kelvin(cls._fahrenheit_to_celsius(value))
return cls._fahrenheit_to_kelvin
raise HomeAssistantError(
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
)

if from_unit == UnitOfTemperature.KELVIN:
if to_unit == UnitOfTemperature.CELSIUS:
return cls._kelvin_to_celsius(value)
return cls._kelvin_to_celsius
if to_unit == UnitOfTemperature.FAHRENHEIT:
return cls._celsius_to_fahrenheit(cls._kelvin_to_celsius(value))
return cls._kelvin_to_fahrenheit
raise HomeAssistantError(
UNIT_NOT_RECOGNIZED_TEMPLATE.format(to_unit, cls.UNIT_CLASS)
)
Expand All @@ -393,7 +442,17 @@ def convert_interval(cls, interval: float, from_unit: str, to_unit: str) -> floa
"""
# We use BaseUnitConverter implementation here because we are only interested
# in the ratio between the units.
return super().convert(interval, from_unit, to_unit)
return super().converter_factory(from_unit, to_unit)(interval)

@classmethod
def _kelvin_to_fahrenheit(cls, kelvin: float) -> float:
"""Convert a temperature in Kelvin to Fahrenheit."""
return (kelvin - 273.15) * 1.8 + 32.0

@classmethod
def _fahrenheit_to_kelvin(cls, fahrenheit: float) -> float:
"""Convert a temperature in Fahrenheit to Kelvin."""
return 273.15 + ((fahrenheit - 32.0) / 1.8)

@classmethod
def _fahrenheit_to_celsius(cls, fahrenheit: float) -> float:
Expand Down
30 changes: 30 additions & 0 deletions tests/components/recorder/test_websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,6 +1973,36 @@ async def test_change_statistics_unit(
],
}

# Changing to the same unit is allowed but does nothing
await client.send_json(
{
"id": 6,
"type": "recorder/change_statistics_unit",
"statistic_id": "sensor.test",
"new_unit_of_measurement": "W",
"old_unit_of_measurement": "W",
}
)
response = await client.receive_json()
assert response["success"]
await async_recorder_block_till_done(hass)

await client.send_json({"id": 7, "type": "recorder/list_statistic_ids"})
response = await client.receive_json()
assert response["success"]
assert response["result"] == [
{
"statistic_id": "sensor.test",
"display_unit_of_measurement": "kW",
"has_mean": True,
"has_sum": False,
"name": None,
"source": "recorder",
"statistics_unit_of_measurement": "W",
"unit_class": "power",
}
]


async def test_change_statistics_unit_errors(
recorder_mock: Recorder,
Expand Down
81 changes: 81 additions & 0 deletions tests/util/test_unit_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import inspect
from itertools import chain

import pytest

Expand Down Expand Up @@ -534,6 +535,86 @@ def test_unit_conversion(
assert converter.convert(value, from_unit, to_unit) == pytest.approx(expected)


@pytest.mark.parametrize(
("converter", "value", "from_unit", "expected", "to_unit"),
[
# Process all items in _CONVERTED_VALUE
(converter, value, from_unit, expected, to_unit)
for converter, item in _CONVERTED_VALUE.items()
for value, from_unit, expected, to_unit in item
],
)
def test_unit_conversion_factory(
converter: type[BaseUnitConverter],
value: float,
from_unit: str,
expected: float,
to_unit: str,
) -> None:
"""Test conversion to other units."""
assert converter.converter_factory(from_unit, to_unit)(value) == pytest.approx(
expected
)


def test_unit_conversion_factory_allow_none_with_none() -> None:
"""Test test_unit_conversion_factory_allow_none with None."""
assert (
SpeedConverter.converter_factory_allow_none(
UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND
)(1)
== 1
)
assert (
SpeedConverter.converter_factory_allow_none(
UnitOfSpeed.FEET_PER_SECOND, UnitOfSpeed.FEET_PER_SECOND
)(None)
is None
)
assert (
TemperatureConverter.converter_factory_allow_none(
UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS
)(1)
== 1
)
assert (
TemperatureConverter.converter_factory_allow_none(
UnitOfTemperature.CELSIUS, UnitOfTemperature.CELSIUS
)(None)
is None
)


@pytest.mark.parametrize(
("converter", "value", "from_unit", "expected", "to_unit"),
chain(
[
# Process all items in _CONVERTED_VALUE
(converter, value, from_unit, expected, to_unit)
for converter, item in _CONVERTED_VALUE.items()
for value, from_unit, expected, to_unit in item
],
[
# Process all items in _CONVERTED_VALUE and replace the value with None
(converter, None, from_unit, None, to_unit)
for converter, item in _CONVERTED_VALUE.items()
for value, from_unit, expected, to_unit in item
],
),
)
def test_unit_conversion_factory_allow_none(
converter: type[BaseUnitConverter],
value: float,
from_unit: str,
expected: float,
to_unit: str,
) -> None:
"""Test conversion to other units."""
assert converter.converter_factory_allow_none(from_unit, to_unit)(
value
) == pytest.approx(expected)


@pytest.mark.parametrize(
("value", "from_unit", "expected", "to_unit"),
[
Expand Down

0 comments on commit 2f1f32f

Please sign in to comment.