Skip to content

Commit

Permalink
Collect subscribed entities in central (#1241)
Browse files Browse the repository at this point in the history
* Collect subscribed entities in central

* Change method signature for getting entities of central collections
  • Loading branch information
SukramJ committed Oct 7, 2023
1 parent 6abd840 commit 7bb8a3a
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 72 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Add more checks to get/set value from/tp values
- Use more tuple instead of list
- Cleanup code
- Collect subscribed entities in central

# Version 2023.10.4 (2023-10-03)

Expand Down
60 changes: 54 additions & 6 deletions hahomematic/central/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __init__(self, central_config: CentralConfig) -> None:
self._sysvar_entities: Final[dict[str, GenericSystemVariable]] = {}
# {sysvar_name, program_button}U
self._program_buttons: Final[dict[str, HmProgramButton]] = {}
# {unique_identifier}
self._subscribed_entity_unique_identifiers: Final[set[str]] = set()
# store last event received datetime by interface
self.last_events: Final[dict[str, datetime]] = {}
# Signature: (name, *args)
Expand Down Expand Up @@ -291,6 +293,20 @@ def remove_program_button(self, pid: str) -> None:
program_button.remove_entity()
del self._program_buttons[pid]

@property
def subscribed_entity_unique_identifiers(self) -> tuple[str, ...]:
"""Return the unique identifiers of subscribed entities."""
return tuple(self._subscribed_entity_unique_identifiers)

def add_subscribed_entity_unique_identifier(self, unique_identifier: str) -> None:
"""Add new program button."""
self._subscribed_entity_unique_identifiers.add(unique_identifier)

def remove_subscribed_entity_unique_identifier(self, unique_identifier: str) -> None:
"""Remove a program button."""
if unique_identifier in self._subscribed_entity_unique_identifiers:
self._subscribed_entity_unique_identifiers.remove(unique_identifier)

@property
def version(self) -> str | None:
"""Return the version of the backend."""
Expand Down Expand Up @@ -592,11 +608,12 @@ def get_device(self, address: str) -> HmDevice | None:
return self._devices.get(d_address)

def get_entities_by_platform(
self, platform: HmPlatform, existing_unique_ids: tuple[str, ...] | None = None
self, platform: HmPlatform, exclude_subscribed: bool | None = None
) -> tuple[BaseEntity, ...]:
"""Return all entities by platform."""
if not existing_unique_ids:
existing_unique_ids = ()
existing_unique_ids = (
self._subscribed_entity_unique_identifiers if exclude_subscribed else ()
)

return tuple(
be
Expand Down Expand Up @@ -634,18 +651,49 @@ def _get_primary_client(self) -> hmcl.Client | None:
return client

def get_hub_entities_by_platform(
self, platform: HmPlatform, existing_unique_ids: tuple[str, ...] | None = None
self, platform: HmPlatform, exclude_subscribed: bool | None = None
) -> tuple[GenericHubEntity, ...]:
"""Return the hub entities by platform."""
if not existing_unique_ids:
existing_unique_ids = ()
existing_unique_ids = (
self._subscribed_entity_unique_identifiers if exclude_subscribed else ()
)

return tuple(
he
for he in (self.program_buttons + self.sysvar_entities)
if (he.unique_identifier not in existing_unique_ids and he.platform == platform)
)

def get_update_entities(self, exclude_subscribed: bool | None = None) -> tuple[HmUpdate, ...]:
"""Return the update entities."""
existing_unique_ids = (
self._subscribed_entity_unique_identifiers if exclude_subscribed else ()
)

return tuple(
device.update_entity
for device in self.devices
if device.update_entity
and device.update_entity.unique_identifier not in existing_unique_ids
)

def get_channel_events_by_event_type(
self, event_type: EventType, exclude_subscribed: bool | None = None
) -> tuple[list[GenericEvent], ...]:
"""Return all channel event entities."""
existing_unique_ids = (
self._subscribed_entity_unique_identifiers if exclude_subscribed else ()
)

hm_channel_events: list[list[GenericEvent]] = []
for device in self.devices:
for channel_events in device.get_channel_events(event_type=event_type).values():
if channel_events[0].channel_unique_identifier not in existing_unique_ids:
hm_channel_events.append(channel_events)
continue

return tuple(hm_channel_events)

def get_virtual_remotes(self) -> tuple[HmDevice, ...]:
"""Get the virtual remote for the Client."""
return tuple(
Expand Down
7 changes: 7 additions & 0 deletions hahomematic/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class Backend(StrEnum):
PYDEVCCU = "PyDevCCU"


class CallBackSource(StrEnum):
"""Enum with sources for registered callbacks."""

HA: Final = "ha_callback"
INTERNAL: Final = "hm_initernal"


class CallSource(StrEnum):
"""Enum with sources for calls."""

Expand Down
28 changes: 15 additions & 13 deletions hahomematic/platforms/custom/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from typing import Any, Final, TypeVar, cast

from hahomematic.const import INIT_DATETIME, CallSource, EntityUsage
from hahomematic.const import INIT_DATETIME, CallBackSource, CallSource, EntityUsage
from hahomematic.platforms import device as hmd
from hahomematic.platforms.custom import definition as hmed
from hahomematic.platforms.custom.const import EntityDefinition
Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_entity_name(self) -> EntityNameData:
)
return get_custom_entity_name(
central=self._central,
device=self.device,
device=self._device,
channel_no=self.channel_no,
is_only_primary_channel=is_only_primary_channel,
usage=self._usage,
Expand Down Expand Up @@ -131,7 +131,7 @@ def _init_entities(self) -> None:
"""Init entity collection."""
# Add repeating fields
for field_name, parameter in self._device_desc.get(hmed.ED_REPEATABLE_FIELDS, {}).items():
entity = self.device.get_generic_entity(
entity = self._device.get_generic_entity(
channel_address=self._channel_address, parameter=parameter
)
self._add_entity(field_name=field_name, entity=entity)
Expand All @@ -140,7 +140,7 @@ def _init_entities(self) -> None:
for field_name, parameter in self._device_desc.get(
hmed.ED_VISIBLE_REPEATABLE_FIELDS, {}
).items():
entity = self.device.get_generic_entity(
entity = self._device.get_generic_entity(
channel_address=self._channel_address, parameter=parameter
)
self._add_entity(field_name=field_name, entity=entity, is_visible=True)
Expand All @@ -150,9 +150,9 @@ def _init_entities(self) -> None:
for channel_no, mapping in fixed_channels.items():
for field_name, parameter in mapping.items():
channel_address = get_channel_address(
device_address=self.device.device_address, channel_no=channel_no
device_address=self._device.device_address, channel_no=channel_no
)
entity = self.device.get_generic_entity(
entity = self._device.get_generic_entity(
channel_address=channel_address, parameter=parameter
)
self._add_entity(field_name=field_name, entity=entity)
Expand All @@ -178,7 +178,7 @@ def _init_entities(self) -> None:
# add custom un_ignore entities
self._mark_entity_by_custom_un_ignore_parameters(
un_ignore_params_by_paramset_key=self._central.parameter_visibility.get_un_ignore_parameters( # noqa: E501
device_type=self.device.device_type, channel_no=self.channel_no
device_type=self._device.device_type, channel_no=self.channel_no
)
)

Expand All @@ -188,9 +188,9 @@ def _add_entities(self, field_dict_name: str, is_visible: bool = False) -> None:
for channel_no, channel in fields.items():
for field_name, parameter in channel.items():
channel_address = get_channel_address(
device_address=self.device.device_address, channel_no=channel_no
device_address=self._device.device_address, channel_no=channel_no
)
if entity := self.device.get_generic_entity(
if entity := self._device.get_generic_entity(
channel_address=channel_address, parameter=parameter
):
if is_visible and entity.wrapped is False:
Expand All @@ -207,7 +207,9 @@ def _add_entity(
if is_visible:
entity.set_usage(EntityUsage.CE_VISIBLE)

entity.register_update_callback(self.update_entity)
entity.register_update_callback(
update_callback=self.update_entity, source=CallBackSource.INTERNAL
)
self._data_entities[field_name] = entity

def _mark_entities(self, entity_def: dict[int | tuple[int, ...], tuple[str, ...]]) -> None:
Expand All @@ -224,11 +226,11 @@ def _mark_entities(self, entity_def: dict[int | tuple[int, ...], tuple[str, ...]
def _mark_entity(self, channel_no: int | None, parameters: tuple[str, ...]) -> None:
"""Mark entity to be created in HA."""
channel_address = get_channel_address(
device_address=self.device.device_address, channel_no=channel_no
device_address=self._device.device_address, channel_no=channel_no
)

for parameter in parameters:
entity = self.device.get_generic_entity(
entity = self._device.get_generic_entity(
channel_address=channel_address, parameter=parameter
)
if entity:
Expand All @@ -241,7 +243,7 @@ def _mark_entity_by_custom_un_ignore_parameters(
if not un_ignore_params_by_paramset_key:
return # pragma: no cover
for paramset_key, un_ignore_params in un_ignore_params_by_paramset_key.items():
for entity in self.device.generic_entities:
for entity in self._device.generic_entities:
if entity.paramset_key == paramset_key and entity.parameter in un_ignore_params:
entity.set_usage(EntityUsage.ENTITY)

Expand Down
Loading

0 comments on commit 7bb8a3a

Please sign in to comment.