Skip to content

Commit

Permalink
Prioritize entity names over area names in Assist matching (#86982)
Browse files Browse the repository at this point in the history
* Refactor async_match_states

* Check entity name after state, before aliases

* Give entity name matches priority over area names

* Don't force result to have area

* Add area alias in tests

* Move name/area list creation back

* Clean up PR

* More clean up
  • Loading branch information
synesthesiam authored and balloob committed Jan 31, 2023
1 parent 32a7ae6 commit edf02b7
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 37 deletions.
39 changes: 31 additions & 8 deletions homeassistant/components/conversation/default_agent.py
Expand Up @@ -11,7 +11,7 @@
from typing import IO, Any

from hassil.intents import Intents, ResponseType, SlotList, TextSlotList
from hassil.recognize import recognize
from hassil.recognize import RecognizeResult, recognize_all
from hassil.util import merge_dict
from home_assistant_intents import get_intents
import yaml
Expand Down Expand Up @@ -128,7 +128,10 @@ async def async_process(self, user_input: ConversationInput) -> ConversationResu
}

result = await self.hass.async_add_executor_job(
recognize, user_input.text, lang_intents.intents, slot_lists
self._recognize,
user_input,
lang_intents,
slot_lists,
)
if result is None:
_LOGGER.debug("No intent was matched for '%s'", user_input.text)
Expand Down Expand Up @@ -197,6 +200,26 @@ async def async_process(self, user_input: ConversationInput) -> ConversationResu
response=intent_response, conversation_id=conversation_id
)

def _recognize(
self,
user_input: ConversationInput,
lang_intents: LanguageIntents,
slot_lists: dict[str, SlotList],
) -> RecognizeResult | None:
"""Search intents for a match to user input."""
# Prioritize matches with entity names above area names
maybe_result: RecognizeResult | None = None
for result in recognize_all(
user_input.text, lang_intents.intents, slot_lists=slot_lists
):
if "name" in result.entities:
return result

# Keep looking in case an entity has the same name
maybe_result = result

return maybe_result

async def async_reload(self, language: str | None = None):
"""Clear cached intents for a language."""
if language is None:
Expand Down Expand Up @@ -373,19 +396,19 @@ def _make_names_list(self) -> TextSlotList:
if self._names_list is not None:
return self._names_list
states = self.hass.states.async_all()
registry = entity_registry.async_get(self.hass)
entities = entity_registry.async_get(self.hass)
names = []
for state in states:
context = {"domain": state.domain}

entry = registry.async_get(state.entity_id)
if entry is not None:
if entry.entity_category:
entity = entities.async_get(state.entity_id)
if entity is not None:
if entity.entity_category:
# Skip configuration/diagnostic entities
continue

if entry.aliases:
for alias in entry.aliases:
if entity.aliases:
for alias in entity.aliases:
names.append((alias, state.entity_id, context))

# Default name
Expand Down
89 changes: 60 additions & 29 deletions homeassistant/helpers/intent.py
Expand Up @@ -138,15 +138,62 @@ def _has_name(
if name in (state.entity_id, state.name.casefold()):
return True

# Check aliases
if (entity is not None) and entity.aliases:
for alias in entity.aliases:
if name == alias.casefold():
return True
# Check name/aliases
if (entity is None) or (not entity.aliases):
return False

for alias in entity.aliases:
if name == alias.casefold():
return True

return False


def _find_area(
id_or_name: str, areas: area_registry.AreaRegistry
) -> area_registry.AreaEntry | None:
"""Find an area by id or name, checking aliases too."""
area = areas.async_get_area(id_or_name) or areas.async_get_area_by_name(id_or_name)
if area is not None:
return area

# Check area aliases
for maybe_area in areas.areas.values():
if not maybe_area.aliases:
continue

for area_alias in maybe_area.aliases:
if id_or_name == area_alias.casefold():
return maybe_area

return None


def _filter_by_area(
states_and_entities: list[tuple[State, entity_registry.RegistryEntry | None]],
area: area_registry.AreaEntry,
devices: device_registry.DeviceRegistry,
) -> Iterable[tuple[State, entity_registry.RegistryEntry | None]]:
"""Filter state/entity pairs by an area."""
entity_area_ids: dict[str, str | None] = {}
for _state, entity in states_and_entities:
if entity is None:
continue

if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id

for state, entity in states_and_entities:
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id):
yield (state, entity)


@callback
@bind_hass
def async_match_states(
Expand Down Expand Up @@ -200,45 +247,29 @@ def async_match_states(
if areas is None:
areas = area_registry.async_get(hass)

# id or name
area = areas.async_get_area(area_name) or areas.async_get_area_by_name(
area_name
)
area = _find_area(area_name, areas)
assert area is not None, f"No area named {area_name}"

if area is not None:
# Filter by states/entities by area
if devices is None:
devices = device_registry.async_get(hass)

entity_area_ids: dict[str, str | None] = {}
for _state, entity in states_and_entities:
if entity is None:
continue

if entity.area_id:
# Use entity's area id first
entity_area_ids[entity.id] = entity.area_id
elif entity.device_id:
# Fall back to device area if not set on entity
device = devices.async_get(entity.device_id)
if device is not None:
entity_area_ids[entity.id] = device.area_id

# Filter by area
states_and_entities = [
(state, entity)
for state, entity in states_and_entities
if (entity is not None) and (entity_area_ids.get(entity.id) == area.id)
]
states_and_entities = list(_filter_by_area(states_and_entities, area, devices))

if name is not None:
if devices is None:
devices = device_registry.async_get(hass)

# Filter by name
name = name.casefold()

# Check states
for state, entity in states_and_entities:
if _has_name(state, entity, name):
yield state
break

else:
# Not filtered by name
for state, _entity in states_and_entities:
Expand Down
49 changes: 49 additions & 0 deletions tests/components/conversation/test_init.py
Expand Up @@ -6,6 +6,7 @@

from homeassistant.components import conversation
from homeassistant.components.cover import SERVICE_OPEN_COVER
from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import DOMAIN as HASS_DOMAIN, Context
from homeassistant.helpers import (
area_registry,
Expand Down Expand Up @@ -777,3 +778,51 @@ async def test_turn_on_area(hass, init_components):
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": "light.stove"}


async def test_light_area_same_name(hass, init_components):
"""Test turning on a light with the same name as an area."""
entities = entity_registry.async_get(hass)
devices = device_registry.async_get(hass)
areas = area_registry.async_get(hass)
entry = MockConfigEntry(domain="test")

device = devices.async_get_or_create(
config_entry_id=entry.entry_id,
connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")},
)

kitchen_area = areas.async_create("kitchen")
devices.async_update_device(device.id, area_id=kitchen_area.id)

kitchen_light = entities.async_get_or_create(
"light", "demo", "1234", original_name="kitchen light"
)
entities.async_update_entity(kitchen_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
kitchen_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "kitchen light"}
)

ceiling_light = entities.async_get_or_create(
"light", "demo", "5678", original_name="ceiling light"
)
entities.async_update_entity(ceiling_light.entity_id, area_id=kitchen_area.id)
hass.states.async_set(
ceiling_light.entity_id, "off", attributes={ATTR_FRIENDLY_NAME: "ceiling light"}
)

calls = async_mock_service(hass, HASS_DOMAIN, "turn_on")

await hass.services.async_call(
"conversation",
"process",
{conversation.ATTR_TEXT: "turn on kitchen light"},
)
await hass.async_block_till_done()

# Should only turn on one light instead of all lights in the kitchen
assert len(calls) == 1
call = calls[0]
assert call.domain == HASS_DOMAIN
assert call.service == "turn_on"
assert call.data == {"entity_id": kitchen_light.entity_id}
8 changes: 8 additions & 0 deletions tests/helpers/test_intent.py
Expand Up @@ -27,6 +27,7 @@ async def test_async_match_states(hass):
"""Test async_match_state helper."""
areas = area_registry.async_get(hass)
area_kitchen = areas.async_get_or_create("kitchen")
areas.async_update(area_kitchen.id, aliases={"food room"})
area_bedroom = areas.async_get_or_create("bedroom")

state1 = State(
Expand Down Expand Up @@ -68,6 +69,13 @@ async def test_async_match_states(hass):
)
)

# Test area alias
assert [state1] == list(
intent.async_match_states(
hass, name="kitchen light", area_name="food room", states=[state1, state2]
)
)

# Wrong area
assert not list(
intent.async_match_states(
Expand Down

0 comments on commit edf02b7

Please sign in to comment.