From 3ad4ad646d8e9d7709435bd806460c8f37567a22 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 21 Jun 2023 14:32:48 +0200 Subject: [PATCH 1/3] Fix async_scanner_devices_by_address unexpectedly combining scanners Since _get_scanners_by_type returned the scanners without making a copy, async_scanner_devices_by_address would extend the list of scanners and unexpectedly add the non-connectable ones to the connectable list. Refactor to remove the _get_X_by_type functions to avoid this pattern --- homeassistant/components/bluetooth/manager.py | 74 +++++++++---------- 1 file changed, 34 insertions(+), 40 deletions(-) diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index 3210822e795274..e5b407442090ba 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -100,7 +100,11 @@ def _dispatch_bleak_callback( class BluetoothManager: - """Manage Bluetooth.""" + """Manage Bluetooth. + + This class is expected to be a singleton and should + never be instantiated more than once. + """ def __init__( self, @@ -246,9 +250,9 @@ def async_scanner_devices_by_address( self, address: str, connectable: bool ) -> list[BluetoothScannerDevice]: """Get BluetoothScannerDevice by address.""" - scanners = self._get_scanners_by_type(True) + scanners = [*self._connectable_scanners] if not connectable: - scanners.extend(self._get_scanners_by_type(False)) + scanners.extend(self._non_connectable_scanners) return [ BluetoothScannerDevice(scanner, *device_adv) for scanner in scanners @@ -267,21 +271,19 @@ def _async_all_discovered_addresses(self, connectable: bool) -> Iterable[str]: """ yield from itertools.chain.from_iterable( scanner.discovered_devices_and_advertisement_data - for scanner in self._get_scanners_by_type(True) + for scanner in self._connectable_scanners ) if not connectable: yield from itertools.chain.from_iterable( scanner.discovered_devices_and_advertisement_data - for scanner in self._get_scanners_by_type(False) + for scanner in self._non_connectable_scanners ) @hass_callback def async_discovered_devices(self, connectable: bool) -> list[BLEDevice]: """Return all of combined best path to discovered from all the scanners.""" - return [ - history.device - for history in self._get_history_by_type(connectable).values() - ] + histories = self._connectable_history if connectable else self._all_history + return [history.device for history in histories.values()] @hass_callback def async_setup_unavailable_tracking(self) -> None: @@ -303,7 +305,10 @@ def _async_check_unavailable(self, now: datetime) -> None: intervals = tracker.intervals for connectable in (True, False): - unavailable_callbacks = self._get_unavailable_callbacks_by_type(connectable) + if connectable: + unavailable_callbacks = self._connectable_unavailable_callbacks + else: + unavailable_callbacks = self._unavailable_callbacks history = connectable_history if connectable else all_history disappeared = set(history).difference( self._async_all_discovered_addresses(connectable) @@ -583,7 +588,10 @@ def async_track_unavailable( connectable: bool, ) -> Callable[[], None]: """Register a callback.""" - unavailable_callbacks = self._get_unavailable_callbacks_by_type(connectable) + if connectable: + unavailable_callbacks = self._connectable_unavailable_callbacks + else: + unavailable_callbacks = self._unavailable_callbacks unavailable_callbacks.setdefault(address, []).append(callback) @hass_callback @@ -620,13 +628,13 @@ def _async_remove_callback() -> None: # If we have history for the subscriber, we can trigger the callback # immediately with the last packet so the subscriber can see the # device. - all_history = self._get_history_by_type(connectable) + history = self._connectable_history if connectable else self._all_history service_infos: Iterable[BluetoothServiceInfoBleak] = [] if address := callback_matcher.get(ADDRESS): - if service_info := all_history.get(address): + if service_info := history.get(address): service_infos = [service_info] else: - service_infos = all_history.values() + service_infos = history.values() for service_info in service_infos: if ble_device_matches(callback_matcher, service_info): @@ -642,29 +650,32 @@ def async_ble_device_from_address( self, address: str, connectable: bool ) -> BLEDevice | None: """Return the BLEDevice if present.""" - all_history = self._get_history_by_type(connectable) - if history := all_history.get(address): + histories = self._connectable_history if connectable else self._all_history + if history := histories.get(address): return history.device return None @hass_callback def async_address_present(self, address: str, connectable: bool) -> bool: """Return if the address is present.""" - return address in self._get_history_by_type(connectable) + histories = self._connectable_history if connectable else self._all_history + return address in histories @hass_callback def async_discovered_service_info( self, connectable: bool ) -> Iterable[BluetoothServiceInfoBleak]: """Return all the discovered services info.""" - return self._get_history_by_type(connectable).values() + histories = self._connectable_history if connectable else self._all_history + return histories.values() @hass_callback def async_last_service_info( self, address: str, connectable: bool ) -> BluetoothServiceInfoBleak | None: """Return the last service info for an address.""" - return self._get_history_by_type(connectable).get(address) + histories = self._connectable_history if connectable else self._all_history + return histories.get(address) def _async_trigger_matching_discovery( self, service_info: BluetoothServiceInfoBleak @@ -688,26 +699,6 @@ def async_rediscover_address(self, address: str) -> None: if service_info := self._all_history.get(address): self._async_trigger_matching_discovery(service_info) - def _get_scanners_by_type(self, connectable: bool) -> list[BaseHaScanner]: - """Return the scanners by type.""" - if connectable: - return self._connectable_scanners - return self._non_connectable_scanners - - def _get_unavailable_callbacks_by_type( - self, connectable: bool - ) -> dict[str, list[Callable[[BluetoothServiceInfoBleak], None]]]: - """Return the unavailable callbacks by type.""" - if connectable: - return self._connectable_unavailable_callbacks - return self._unavailable_callbacks - - def _get_history_by_type( - self, connectable: bool - ) -> dict[str, BluetoothServiceInfoBleak]: - """Return the history by type.""" - return self._connectable_history if connectable else self._all_history - def async_register_scanner( self, scanner: BaseHaScanner, @@ -716,7 +707,10 @@ def async_register_scanner( ) -> CALLBACK_TYPE: """Register a new scanner.""" _LOGGER.debug("Registering scanner %s", scanner.name) - scanners = self._get_scanners_by_type(connectable) + if connectable: + scanners = self._connectable_scanners + else: + scanners = self._non_connectable_scanners def _unregister_scanner() -> None: _LOGGER.debug("Unregistering scanner %s", scanner.name) From 930e5c01fa9759cfb30a5b6be32dabcbde6198fe Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 21 Jun 2023 14:33:34 +0200 Subject: [PATCH 2/3] Update homeassistant/components/bluetooth/manager.py --- homeassistant/components/bluetooth/manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index e5b407442090ba..4b3568a7c253aa 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -100,11 +100,7 @@ def _dispatch_bleak_callback( class BluetoothManager: - """Manage Bluetooth. - - This class is expected to be a singleton and should - never be instantiated more than once. - """ + """Manage Bluetooth.""" def __init__( self, From 8a4a936f4d24ddff81307f5c1ebaa8b3fce39d80 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 21 Jun 2023 15:02:32 +0200 Subject: [PATCH 3/3] use chain instead --- homeassistant/components/bluetooth/manager.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/homeassistant/components/bluetooth/manager.py b/homeassistant/components/bluetooth/manager.py index 4b3568a7c253aa..f1221290c74105 100644 --- a/homeassistant/components/bluetooth/manager.py +++ b/homeassistant/components/bluetooth/manager.py @@ -246,9 +246,12 @@ def async_scanner_devices_by_address( self, address: str, connectable: bool ) -> list[BluetoothScannerDevice]: """Get BluetoothScannerDevice by address.""" - scanners = [*self._connectable_scanners] if not connectable: - scanners.extend(self._non_connectable_scanners) + scanners: Iterable[BaseHaScanner] = itertools.chain( + self._connectable_scanners, self._non_connectable_scanners + ) + else: + scanners = self._connectable_scanners return [ BluetoothScannerDevice(scanner, *device_adv) for scanner in scanners