diff --git a/bumble/device.py b/bumble/device.py index 21ca31df..dd0e2490 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -5618,8 +5618,8 @@ def add_default_services( async def notify_subscriber( self, connection: Connection, - attribute: Attribute, - value: Any | None = None, + attribute: Attribute[_T], + value: _T | None = None, force: bool = False, ) -> None: """ @@ -5638,7 +5638,7 @@ async def notify_subscriber( await self.gatt_server.notify_subscriber(connection, attribute, value, force) async def notify_subscribers( - self, attribute: Attribute, value: Any | None = None, force: bool = False + self, attribute: Attribute[_T], value: _T | None = None, force: bool = False ) -> None: """ Send a notification to all the subscribers of an attribute. @@ -5657,8 +5657,8 @@ async def notify_subscribers( async def indicate_subscriber( self, connection: Connection, - attribute: Attribute, - value: Any | None = None, + attribute: Attribute[_T], + value: _T | None = None, force: bool = False, ): """ @@ -5679,7 +5679,7 @@ async def indicate_subscriber( await self.gatt_server.indicate_subscriber(connection, attribute, value, force) async def indicate_subscribers( - self, attribute: Attribute, value: Any | None = None, force: bool = False + self, attribute: Attribute[_T], value: _T | None = None, force: bool = False ): """ Send an indication to all the subscribers of an attribute. diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 74be1626..31127fb0 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -67,6 +67,8 @@ # Helpers # ----------------------------------------------------------------------------- +_T = TypeVar('_T') + def _bearer_id(bearer: att.Bearer) -> str: if att.is_enhanced_bearer(bearer): @@ -369,8 +371,8 @@ def send_response(self, bearer: att.Bearer, response: att.ATT_PDU) -> None: async def notify_subscriber( self, bearer: att.Bearer, - attribute: att.Attribute, - value: bytes | None = None, + attribute: att.Attribute[_T], + value: _T | None = None, force: bool = False, ) -> None: if att.is_enhanced_bearer(bearer) or force: @@ -390,8 +392,8 @@ async def notify_subscriber( async def _notify_single_subscriber( self, bearer: att.Bearer, - attribute: att.Attribute, - value: bytes | None, + attribute: att.Attribute[_T], + value: _T | None, force: bool, ) -> None: # Check if there's a subscriber @@ -411,19 +413,19 @@ async def _notify_single_subscriber( return # Get or encode the value - value = ( + value_as_bytes = ( await attribute.read_value(bearer) if value is None else attribute.encode_value(value) ) # Truncate if needed - if len(value) > bearer.att_mtu - 3: - value = value[: bearer.att_mtu - 3] + if len(value_as_bytes) > bearer.att_mtu - 3: + value_as_bytes = value_as_bytes[: bearer.att_mtu - 3] # Notify notification = att.ATT_Handle_Value_Notification( - attribute_handle=attribute.handle, attribute_value=value + attribute_handle=attribute.handle, attribute_value=value_as_bytes ) logger.debug(f'GATT Notify from server: {_bearer_id(bearer)} {notification}') self.send_gatt_pdu(bearer, bytes(notification)) @@ -431,8 +433,8 @@ async def _notify_single_subscriber( async def indicate_subscriber( self, bearer: att.Bearer, - attribute: att.Attribute, - value: bytes | None = None, + attribute: att.Attribute[_T], + value: _T | None = None, force: bool = False, ) -> None: if att.is_enhanced_bearer(bearer) or force: @@ -452,8 +454,8 @@ async def indicate_subscriber( async def _indicate_single_bearer( self, bearer: att.Bearer, - attribute: att.Attribute, - value: bytes | None, + attribute: att.Attribute[_T], + value: _T | None, force: bool, ) -> None: # Check if there's a subscriber @@ -473,19 +475,19 @@ async def _indicate_single_bearer( return # Get or encode the value - value = ( + value_as_bytes = ( await attribute.read_value(bearer) if value is None else attribute.encode_value(value) ) # Truncate if needed - if len(value) > bearer.att_mtu - 3: - value = value[: bearer.att_mtu - 3] + if len(value_as_bytes) > bearer.att_mtu - 3: + value_as_bytes = value_as_bytes[: bearer.att_mtu - 3] # Indicate indication = att.ATT_Handle_Value_Indication( - attribute_handle=attribute.handle, attribute_value=value + attribute_handle=attribute.handle, attribute_value=value_as_bytes ) logger.debug(f'GATT Indicate from server: {_bearer_id(bearer)} {indication}') @@ -510,8 +512,8 @@ async def _indicate_single_bearer( async def _notify_or_indicate_subscribers( self, indicate: bool, - attribute: att.Attribute, - value: bytes | None = None, + attribute: att.Attribute[_T], + value: _T | None = None, force: bool = False, ) -> None: # Get all the bearers for which there's at least one subscription @@ -537,8 +539,8 @@ async def _notify_or_indicate_subscribers( async def notify_subscribers( self, - attribute: att.Attribute, - value: bytes | None = None, + attribute: att.Attribute[_T], + value: _T | None = None, force: bool = False, ): return await self._notify_or_indicate_subscribers( @@ -547,8 +549,8 @@ async def notify_subscribers( async def indicate_subscribers( self, - attribute: att.Attribute, - value: bytes | None = None, + attribute: att.Attribute[_T], + value: _T | None = None, force: bool = False, ): return await self._notify_or_indicate_subscribers(True, attribute, value, force)