Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bayesian sensor to use negative observations #67631

Merged
merged 22 commits into from Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Expand Up @@ -129,6 +129,8 @@ build.json @home-assistant/supervisor
/tests/components/baf/ @bdraco @jfroy
/homeassistant/components/balboa/ @garbled1
/tests/components/balboa/ @garbled1
/homeassistant/components/bayesian/ @HarvsG
/tests/components/bayesian/ @HarvsG
/homeassistant/components/beewi_smartclim/ @alemuro
/homeassistant/components/binary_sensor/ @home-assistant/core
/tests/components/binary_sensor/ @home-assistant/core
Expand Down
105 changes: 69 additions & 36 deletions homeassistant/components/bayesian/binary_sensor.py
Expand Up @@ -16,6 +16,7 @@
CONF_PLATFORM,
CONF_STATE,
CONF_VALUE_TEMPLATE,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
from homeassistant.core import HomeAssistant, callback
Expand Down Expand Up @@ -60,7 +61,7 @@
vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_BELOW): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_F): vol.Coerce(float),
},
required=True,
)
Expand All @@ -71,7 +72,7 @@
vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Required(CONF_TO_STATE): cv.string,
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_F): vol.Coerce(float),
},
required=True,
)
Expand All @@ -81,7 +82,7 @@
CONF_PLATFORM: CONF_TEMPLATE,
vol.Required(CONF_VALUE_TEMPLATE): cv.template,
vol.Required(CONF_P_GIVEN_T): vol.Coerce(float),
vol.Optional(CONF_P_GIVEN_F): vol.Coerce(float),
vol.Required(CONF_P_GIVEN_F): vol.Coerce(float),
},
required=True,
)
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(self, name, prior, observations, probability_threshold, device_clas
self.observation_handlers = {
"numeric_state": self._process_numeric_state,
"state": self._process_state,
"multi_state": self._process_multi_state,
}

async def async_added_to_hass(self) -> None:
Expand All @@ -185,10 +187,6 @@ def async_threshold_sensor_state_listener(event):
When a state changes, we must update our list of current observations,
then calculate the new probability.
"""
new_state = event.data.get("new_state")

if new_state is None or new_state.state == STATE_UNKNOWN:
return
HarvsG marked this conversation as resolved.
Show resolved Hide resolved

entity = event.data.get("entity_id")

Expand All @@ -210,7 +208,6 @@ def _async_template_result_changed(event, updates):
template = track_template_result.template
result = track_template_result.result
entity = event and event.data.get("entity_id")

if isinstance(result, TemplateError):
_LOGGER.error(
"TemplateError('%s') "
Expand All @@ -221,15 +218,12 @@ def _async_template_result_changed(event, updates):
self.entity_id,
)

should_trigger = False
observation = None
HarvsG marked this conversation as resolved.
Show resolved Hide resolved
else:
should_trigger = result_as_boolean(result)
observation = result_as_boolean(result)

for obs in self.observations_by_template[template]:
if should_trigger:
obs_entry = {"entity_id": entity, **obs}
else:
obs_entry = None
obs_entry = {"entity_id": entity, "observation": observation, **obs}
self.current_observations[obs["id"]] = obs_entry

if event:
Expand Down Expand Up @@ -259,6 +253,7 @@ def _recalculate_and_write_state(self):

def _initialize_current_observations(self):
local_observations = OrderedDict({})

for entity in self.observations_by_entity:
local_observations.update(self._record_entity_observations(entity))
return local_observations
Expand All @@ -269,13 +264,13 @@ def _record_entity_observations(self, entity):
for entity_obs in self.observations_by_entity[entity]:
platform = entity_obs["platform"]

should_trigger = self.observation_handlers[platform](entity_obs)

if should_trigger:
obs_entry = {"entity_id": entity, **entity_obs}
else:
obs_entry = None
observation = self.observation_handlers[platform](entity_obs)

obs_entry = {
"entity_id": entity,
"observation": observation,
**entity_obs,
}
local_observations[entity_obs["id"]] = obs_entry

return local_observations
Expand All @@ -285,11 +280,28 @@ def _calculate_new_probability(self):

for obs in self.current_observations.values():
if obs is not None:
prior = update_probability(
prior,
obs["prob_given_true"],
obs.get("prob_given_false", 1 - obs["prob_given_true"]),
)
if obs["observation"] is True:
prior = update_probability(
prior,
obs["prob_given_true"],
obs["prob_given_false"],
)
elif obs["observation"] is False:
prior = update_probability(
prior,
1 - obs["prob_given_true"],
1 - obs["prob_given_false"],
)
elif obs["observation"] is None:
if obs["entity_id"] is not None:
_LOGGER.debug(
"Observation for entity '%s' returned None, it will not be used for Bayesian updating",
obs["entity_id"],
)
else:
_LOGGER.debug(
"Observation for template entity returned None rather than a valid boolean, it will not be used for Bayesian updating",
)

return prior

Expand All @@ -307,17 +319,21 @@ def _build_observations_by_entity(self):
for all relevant observations to be looked up via their `entity_id`.
"""

observations_by_entity = {}
for ind, obs in enumerate(self._observations):
obs["id"] = ind
observations_by_entity: dict[str, list[OrderedDict]] = {}
for i, obs in enumerate(self._observations):
obs["id"] = i

if "entity_id" not in obs:
continue
observations_by_entity.setdefault(obs["entity_id"], []).append(obs)

entity_ids = [obs["entity_id"]]

for e_id in entity_ids:
observations_by_entity.setdefault(e_id, []).append(obs)
for li_of_dicts in observations_by_entity.values():
if len(li_of_dicts) == 1:
continue
for ord_dict in li_of_dicts:
if ord_dict["platform"] != "state":
continue
HarvsG marked this conversation as resolved.
Show resolved Hide resolved
ord_dict["platform"] = "multi_state"

return observations_by_entity

Expand Down Expand Up @@ -348,10 +364,12 @@ def _build_observations_by_template(self):
return observations_by_template

def _process_numeric_state(self, entity_observation):
"""Return True if numeric condition is met."""
"""Return True if numeric condition is met, return False if not, return None otherwise."""
entity = entity_observation["entity_id"]

try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None
return condition.async_numeric_state(
self.hass,
entity,
Expand All @@ -361,18 +379,31 @@ def _process_numeric_state(self, entity_observation):
entity_observation,
)
except ConditionError:
return False
return None

def _process_state(self, entity_observation):
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]

try:
if condition.state(self.hass, entity, [STATE_UNKNOWN, STATE_UNAVAILABLE]):
return None

return condition.state(
self.hass, entity, entity_observation.get("to_state")
)
except ConditionError:
return False
return None

def _process_multi_state(self, entity_observation):
"""Return True if state conditions are met."""
entity = entity_observation["entity_id"]

try:
if condition.state(self.hass, entity, entity_observation.get("to_state")):
return True
except ConditionError:
return None

@property
def extra_state_attributes(self):
Expand All @@ -390,7 +421,9 @@ def extra_state_attributes(self):
{
obs.get("entity_id")
for obs in self.current_observations.values()
if obs is not None and obs.get("entity_id") is not None
if obs is not None
and obs.get("entity_id") is not None
and obs.get("observation") is not None
}
),
ATTR_PROBABILITY: round(self.probability, 2),
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/components/bayesian/manifest.json
Expand Up @@ -2,7 +2,7 @@
"domain": "bayesian",
"name": "Bayesian",
"documentation": "https://www.home-assistant.io/integrations/bayesian",
"codeowners": [],
"codeowners": ["@HarvsG"],
"quality_scale": "internal",
"iot_class": "local_polling"
}