diff --git a/bumble/avrcp.py b/bumble/avrcp.py index 5fe52d25..98ded8f1 100644 --- a/bumble/avrcp.py +++ b/bumble/avrcp.py @@ -21,11 +21,12 @@ import enum import logging import struct -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ( AsyncIterator, Awaitable, Callable, + ClassVar, Iterable, List, Optional, @@ -36,7 +37,7 @@ cast, ) -from bumble import avc, avctp, core, l2cap, utils +from bumble import avc, avctp, core, hci, l2cap, utils from bumble.colors import color from bumble.device import Connection, Device from bumble.sdp import ( @@ -64,6 +65,96 @@ AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958 +class PduId(utils.OpenIntEnum): + GET_CAPABILITIES = 0x10 + LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11 + LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12 + GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13 + SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14 + GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15 + GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16 + INFORM_DISPLAYABLE_CHARACTER_SET = 0x17 + INFORM_BATTERY_STATUS_OF_CT = 0x18 + GET_ELEMENT_ATTRIBUTES = 0x20 + GET_PLAY_STATUS = 0x30 + REGISTER_NOTIFICATION = 0x31 + REQUEST_CONTINUING_RESPONSE = 0x40 + ABORT_CONTINUING_RESPONSE = 0x41 + SET_ABSOLUTE_VOLUME = 0x50 + SET_ADDRESSED_PLAYER = 0x60 + SET_BROWSED_PLAYER = 0x70 + GET_FOLDER_ITEMS = 0x71 + GET_TOTAL_NUMBER_OF_ITEMS = 0x75 + + +class CharacterSetId(hci.SpecableEnum): + UTF_8 = 0x06 + + +class MediaAttributeId(hci.SpecableEnum): + TITLE = 0x01 + ARTIST_NAME = 0x02 + ALBUM_NAME = 0x03 + TRACK_NUMBER = 0x04 + TOTAL_NUMBER_OF_TRACKS = 0x05 + GENRE = 0x06 + PLAYING_TIME = 0x07 + DEFAULT_COVER_ART = 0x08 + + +class PlayStatus(hci.SpecableEnum): + STOPPED = 0x00 + PLAYING = 0x01 + PAUSED = 0x02 + FWD_SEEK = 0x03 + REV_SEEK = 0x04 + ERROR = 0xFF + + +class EventId(hci.SpecableEnum): + PLAYBACK_STATUS_CHANGED = 0x01 + TRACK_CHANGED = 0x02 + TRACK_REACHED_END = 0x03 + TRACK_REACHED_START = 0x04 + PLAYBACK_POS_CHANGED = 0x05 + BATT_STATUS_CHANGED = 0x06 + SYSTEM_STATUS_CHANGED = 0x07 + PLAYER_APPLICATION_SETTING_CHANGED = 0x08 + NOW_PLAYING_CONTENT_CHANGED = 0x09 + AVAILABLE_PLAYERS_CHANGED = 0x0A + ADDRESSED_PLAYER_CHANGED = 0x0B + UIDS_CHANGED = 0x0C + VOLUME_CHANGED = 0x0D + + def __bytes__(self) -> bytes: + return bytes([int(self)]) + + +class StatusCode(hci.SpecableEnum): + INVALID_COMMAND = 0x00 + INVALID_PARAMETER = 0x01 + PARAMETER_CONTENT_ERROR = 0x02 + INTERNAL_ERROR = 0x03 + OPERATION_COMPLETED = 0x04 + UID_CHANGED = 0x05 + INVALID_DIRECTION = 0x07 + NOT_A_DIRECTORY = 0x08 + DOES_NOT_EXIST = 0x09 + INVALID_SCOPE = 0x0A + RANGE_OUT_OF_BOUNDS = 0x0B + FOLDER_ITEM_IS_NOT_PLAYABLE = 0x0C + MEDIA_IN_USE = 0x0D + NOW_PLAYING_LIST_FULL = 0x0E + SEARCH_NOT_SUPPORTED = 0x0F + SEARCH_IN_PROGRESS = 0x10 + INVALID_PLAYER_ID = 0x11 + PLAYER_NOT_BROWSABLE = 0x12 + PLAYER_NOT_ADDRESSED = 0x13 + NO_VALID_SEARCH_RESULTS = 0x14 + NO_AVAILABLE_PLAYERS = 0x15 + ADDRESSED_PLAYER_CHANGED = 0x16 + + # ----------------------------------------------------------------------------- def make_controller_service_sdp_records( service_record_handle: int, @@ -200,14 +291,52 @@ def make_target_service_sdp_records( # ----------------------------------------------------------------------------- -def _decode_attribute_value(value: bytes, character_set: CharacterSetId) -> str: - try: - if character_set == CharacterSetId.UTF_8: - return value.decode("utf-8") - return value.decode("ascii") - except UnicodeDecodeError: - logger.warning(f"cannot decode string with bytes: {value.hex()}") - return "" +@dataclass +class MediaAttribute: + attribute_id: MediaAttributeId + attribute_value: str + character_set_id: CharacterSetId = CharacterSetId.UTF_8 + + @classmethod + def _decode_attribute_value( + cls, value: bytes, character_set: CharacterSetId + ) -> str: + try: + if character_set == CharacterSetId.UTF_8: + return value.decode("utf-8") + return value.decode("ascii") + except UnicodeDecodeError: + logger.warning(f"cannot decode string with bytes: {value.hex()}") + return value.hex() + + @classmethod + def parse_from_bytes(cls, pdu: bytes, offset: int) -> tuple[int, MediaAttribute]: + ( + attribute_id_int, + character_set_id_int, + attribute_value_length, + ) = struct.unpack_from(">IHH", pdu, offset) + attribute_value_bytes = pdu[offset + 8 : offset + 8 + attribute_value_length] + character_set_id = CharacterSetId(character_set_id_int) + return offset + 8 + attribute_value_length, cls( + attribute_id=MediaAttributeId(attribute_id_int), + character_set_id=character_set_id, + attribute_value=cls._decode_attribute_value( + attribute_value_bytes, character_set_id + ), + ) + + def __bytes__(self) -> bytes: + attribute_value_bytes = self.attribute_value.encode("utf-8") + return ( + struct.pack( + ">IHH", + int(self.attribute_id), + int(self.character_set_id), + len(attribute_value_bytes), + ) + + attribute_value_bytes + ) # ----------------------------------------------------------------------------- @@ -218,10 +347,10 @@ class PduAssembler: 6.3.1 AVRCP specific AV//C commands """ - pdu_id: Optional[Protocol.PduId] + pdu_id: Optional[PduId] payload: bytes - def __init__(self, callback: Callable[[Protocol.PduId, bytes], None]) -> None: + def __init__(self, callback: Callable[[PduId, bytes], None]) -> None: self.callback = callback self.reset() @@ -230,7 +359,7 @@ def reset(self) -> None: self.parameter = b'' def on_pdu(self, pdu: bytes) -> None: - pdu_id = Protocol.PduId(pdu[0]) + pdu_id = PduId(pdu[0]) packet_type = Protocol.PacketType(pdu[1] & 3) parameter_length = struct.unpack_from('>H', pdu, 2)[0] parameter = pdu[4 : 4 + parameter_length] @@ -269,178 +398,151 @@ def on_pdu_complete(self) -> None: # ----------------------------------------------------------------------------- -@dataclass class Command: - pdu_id: Protocol.PduId - parameter: bytes + pdu_id: ClassVar[PduId] + _payload: Optional[bytes] = None - def to_string(self, properties: dict[str, str]) -> str: - properties_str = ",".join( - [f"{name}={value}" for name, value in properties.items()] - ) - return f"Command[{self.pdu_id.name}]({properties_str})" + _Command = TypeVar('_Command', bound='Command') + subclasses: ClassVar[dict[int, type[Command]]] = {} + fields: ClassVar[hci.Fields] = () - def __str__(self) -> str: - return self.to_string({"parameters": self.parameter.hex()}) + @classmethod + def command(cls, subclass: type[_Command]) -> type[_Command]: + cls.subclasses[subclass.pdu_id] = subclass + subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass) + return subclass - def __repr__(self) -> str: - return str(self) + @classmethod + def from_bytes(cls, pdu_id: int, pdu: bytes) -> Command: + if not (subclass := cls.subclasses.get(pdu_id)): + raise core.InvalidPacketError(f"Unimplemented PDU {pdu_id}") + instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 0, subclass.fields)) + instance._payload = pdu[0:] + return instance + + def __bytes__(self) -> bytes: + if self._payload is None: + self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) + return self._payload # ----------------------------------------------------------------------------- +@Command.command +@dataclass class GetCapabilitiesCommand(Command): - class CapabilityId(utils.OpenIntEnum): + pdu_id = PduId.GET_CAPABILITIES + + class CapabilityId(hci.SpecableEnum): COMPANY_ID = 0x02 EVENTS_SUPPORTED = 0x03 - capability_id: CapabilityId - - @classmethod - def from_bytes(cls, pdu: bytes) -> GetCapabilitiesCommand: - return cls(cls.CapabilityId(pdu[0])) - - def __init__(self, capability_id: CapabilityId) -> None: - super().__init__(Protocol.PduId.GET_CAPABILITIES, bytes([capability_id])) - self.capability_id = capability_id - - def __str__(self) -> str: - return self.to_string({"capability_id": self.capability_id.name}) + capability_id: CapabilityId = field(metadata=CapabilityId.type_metadata(1)) # ----------------------------------------------------------------------------- +@Command.command +@dataclass class GetPlayStatusCommand(Command): - @classmethod - def from_bytes(cls, _: bytes) -> GetPlayStatusCommand: - return cls() - - def __init__(self) -> None: - super().__init__(Protocol.PduId.GET_PLAY_STATUS, b'') + pdu_id = PduId.GET_PLAY_STATUS # ----------------------------------------------------------------------------- +@Command.command +@dataclass class GetElementAttributesCommand(Command): - identifier: int - attribute_ids: list[MediaAttributeId] + pdu_id = PduId.GET_ELEMENT_ATTRIBUTES - @classmethod - def from_bytes(cls, pdu: bytes) -> GetElementAttributesCommand: - identifier = struct.unpack_from(">Q", pdu)[0] - num_attributes = pdu[8] - attribute_ids = [MediaAttributeId(pdu[9 + i]) for i in range(num_attributes)] - return cls(identifier, attribute_ids) - - def __init__( - self, identifier: int, attribute_ids: Sequence[MediaAttributeId] - ) -> None: - parameter = struct.pack(">QB", identifier, len(attribute_ids)) + b''.join( - [struct.pack(">I", int(attribute_id)) for attribute_id in attribute_ids] + identifier: int = field( + metadata=hci.metadata( + { + 'parser': lambda data, offset: ( + offset + 8, + int.from_bytes(data[offset : offset + 8], byteorder='big'), + ), + 'serializer': lambda x: x.to_bytes(8, byteorder='big'), + } ) - super().__init__(Protocol.PduId.GET_ELEMENT_ATTRIBUTES, parameter) - self.identifier = identifier - self.attribute_ids = list(attribute_ids) + ) + attribute_ids: Sequence[MediaAttributeId] = field( + metadata=MediaAttributeId.type_metadata(1, list_begin=True, list_end=True) + ) # ----------------------------------------------------------------------------- +@Command.command +@dataclass class SetAbsoluteVolumeCommand(Command): + pdu_id = PduId.SET_ABSOLUTE_VOLUME MAXIMUM_VOLUME = 0x7F - volume: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeCommand: - return cls(pdu[0]) - - def __init__(self, volume: int) -> None: - super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) - self.volume = volume - - def __str__(self) -> str: - return self.to_string({"volume": str(self.volume)}) + volume: int = field(metadata=hci.metadata(1)) # ----------------------------------------------------------------------------- +@Command.command +@dataclass class RegisterNotificationCommand(Command): - event_id: EventId - playback_interval: int + pdu_id = PduId.REGISTER_NOTIFICATION - @classmethod - def from_bytes(cls, pdu: bytes) -> RegisterNotificationCommand: - event_id = EventId(pdu[0]) - playback_interval = struct.unpack_from(">I", pdu, 1)[0] - return cls(event_id, playback_interval) - - def __init__(self, event_id: EventId, playback_interval: int) -> None: - super().__init__( - Protocol.PduId.REGISTER_NOTIFICATION, - struct.pack(">BI", int(event_id), playback_interval), - ) - self.event_id = event_id - self.playback_interval = playback_interval - - def __str__(self) -> str: - return self.to_string( - { - "event_id": self.event_id.name, - "playback_interval": str(self.playback_interval), - } - ) + event_id: EventId = field(metadata=EventId.type_metadata(1)) + playback_interval: int = field(metadata=hci.metadata('>4')) # ----------------------------------------------------------------------------- -@dataclass class Response: - pdu_id: Protocol.PduId - parameter: bytes - - def to_string(self, properties: dict[str, str]) -> str: - properties_str = ",".join( - [f"{name}={value}" for name, value in properties.items()] - ) - return f"Response[{self.pdu_id.name}]({properties_str})" + pdu_id: PduId + _payload: Optional[bytes] = None - def __str__(self) -> str: - return self.to_string({"parameter": self.parameter.hex()}) + fields: ClassVar[hci.Fields] = () - def __repr__(self) -> str: - return str(self) + _Response = TypeVar('_Response', bound='Response') + @classmethod + def register(cls, subclass: type[_Response]) -> type[_Response]: + subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass) + return subclass -# ----------------------------------------------------------------------------- -class RejectedResponse(Response): - status_code: Protocol.StatusCode + def __bytes__(self) -> bytes: + if self._payload is None: + self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields) + return self._payload @classmethod - def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> RejectedResponse: - return cls(pdu_id, Protocol.StatusCode(pdu[0])) + def from_bytes(cls, pdu: bytes, pdu_id: Optional[PduId] = None) -> Response: + kwargs = hci.HCI_Object.dict_from_bytes(pdu, 0, cls.fields) + if pdu_id is not None: + kwargs['pdu_id'] = pdu_id + instance = cls(**kwargs) + instance._payload = pdu + return instance - def __init__( - self, pdu_id: Protocol.PduId, status_code: Protocol.StatusCode - ) -> None: - super().__init__(pdu_id, bytes([int(status_code)])) - self.status_code = status_code - def __str__(self) -> str: - return self.to_string( - { - "status_code": self.status_code.name, - } - ) +# ----------------------------------------------------------------------------- +@Response.register +@dataclass +class RejectedResponse(Response): + pdu_id: PduId + status_code: StatusCode = field(metadata=StatusCode.type_metadata(1)) # ----------------------------------------------------------------------------- +@Response.register +@dataclass class NotImplementedResponse(Response): - @classmethod - def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> NotImplementedResponse: - return cls(pdu_id, pdu[1:]) + pdu_id: PduId + parameters: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- +@dataclass class GetCapabilitiesResponse(Response): + pdu_id = PduId.GET_CAPABILITIES capability_id: GetCapabilitiesCommand.CapabilityId - capabilities: list[Union[SupportsBytes, bytes]] + capabilities: Sequence[Union[SupportsBytes, bytes]] @classmethod - def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse: + def from_bytes(cls, pdu: bytes, pdu_id: Optional[PduId] = None) -> Response: + del pdu_id # Unused. if len(pdu) < 2: # Possibly a reject response. return cls(GetCapabilitiesCommand.CapabilityId(0), []) @@ -462,215 +564,52 @@ def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse: return cls(capability_id, capabilities) - def __init__( - self, - capability_id: GetCapabilitiesCommand.CapabilityId, - capabilities: Sequence[Union[SupportsBytes, bytes]], - ) -> None: - super().__init__( - Protocol.PduId.GET_CAPABILITIES, - bytes([capability_id, len(capabilities)]) - + b''.join(bytes(capability) for capability in capabilities), - ) - self.capability_id = capability_id - self.capabilities = list(capabilities) - - def __str__(self) -> str: - return self.to_string( - { - "capability_id": self.capability_id.name, - "capabilities": str(self.capabilities), - } + def __post_init__(self) -> None: + self._payload = bytes([self.capability_id, len(self.capabilities)]) + b''.join( + bytes(capability) for capability in self.capabilities ) # ----------------------------------------------------------------------------- +@Response.register +@dataclass class GetPlayStatusResponse(Response): - song_length: int - song_position: int - play_status: PlayStatus - - @classmethod - def from_bytes(cls, pdu: bytes) -> GetPlayStatusResponse: - (song_length, song_position) = struct.unpack_from(">II", pdu, 0) - play_status = PlayStatus(pdu[8]) - - return cls(song_length, song_position, play_status) - - def __init__( - self, - song_length: int, - song_position: int, - play_status: PlayStatus, - ) -> None: - super().__init__( - Protocol.PduId.GET_PLAY_STATUS, - struct.pack(">IIB", song_length, song_position, int(play_status)), - ) - self.song_length = song_length - self.song_position = song_position - self.play_status = play_status - - def __str__(self) -> str: - return self.to_string( - { - "song_length": str(self.song_length), - "song_position": str(self.song_position), - "play_status": self.play_status.name, - } - ) + pdu_id = PduId.GET_PLAY_STATUS + song_length: int = field(metadata=hci.metadata(">4")) + song_position: int = field(metadata=hci.metadata(">4")) + play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) # ----------------------------------------------------------------------------- +@Response.register +@dataclass class GetElementAttributesResponse(Response): - attributes: list[MediaAttribute] - - @classmethod - def from_bytes(cls, pdu: bytes) -> GetElementAttributesResponse: - num_attributes = pdu[0] - offset = 1 - attributes: list[MediaAttribute] = [] - for _ in range(num_attributes): - ( - attribute_id_int, - character_set_id_int, - attribute_value_length, - ) = struct.unpack_from(">IHH", pdu, offset) - attribute_value_bytes = pdu[ - offset + 8 : offset + 8 + attribute_value_length - ] - attribute_id = MediaAttributeId(attribute_id_int) - character_set_id = CharacterSetId(character_set_id_int) - attribute_value = _decode_attribute_value( - attribute_value_bytes, character_set_id - ) - attributes.append( - MediaAttribute(attribute_id, character_set_id, attribute_value) - ) - offset += 8 + attribute_value_length - - return cls(attributes) - - def __init__(self, attributes: Sequence[MediaAttribute]) -> None: - parameter = bytes([len(attributes)]) - for attribute in attributes: - attribute_value_bytes = attribute.attribute_value.encode("utf-8") - parameter += ( - struct.pack( - ">IHH", - int(attribute.attribute_id), - int(CharacterSetId.UTF_8), - len(attribute_value_bytes), - ) - + attribute_value_bytes - ) - super().__init__( - Protocol.PduId.GET_ELEMENT_ATTRIBUTES, - parameter, - ) - self.attributes = list(attributes) - - def __str__(self) -> str: - attribute_strs = [str(attribute) for attribute in self.attributes] - return self.to_string( - { - "attributes": f"[{', '.join(attribute_strs)}]", - } + pdu_id = PduId.GET_ELEMENT_ATTRIBUTES + attributes: Sequence[MediaAttribute] = field( + metadata=hci.metadata( + MediaAttribute.parse_from_bytes, list_begin=True, list_end=True ) + ) # ----------------------------------------------------------------------------- +@Response.register +@dataclass class SetAbsoluteVolumeResponse(Response): - volume: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeResponse: - return cls(pdu[0]) - - def __init__(self, volume: int) -> None: - super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) - self.volume = volume - - def __str__(self) -> str: - return self.to_string({"volume": str(self.volume)}) + pdu_id = PduId.SET_ABSOLUTE_VOLUME + volume: int = field(metadata=hci.metadata(1)) # ----------------------------------------------------------------------------- +@Response.register +@dataclass class RegisterNotificationResponse(Response): - event: Event - - @classmethod - def from_bytes(cls, pdu: bytes) -> RegisterNotificationResponse: - return cls(Event.from_bytes(pdu)) - - def __init__(self, event: Event) -> None: - super().__init__( - Protocol.PduId.REGISTER_NOTIFICATION, - bytes(event), + pdu_id = PduId.REGISTER_NOTIFICATION + event: Event = field( + metadata=hci.metadata( + lambda data, offset: (len(data), Event.from_bytes(data[offset:])) ) - self.event = event - - def __str__(self) -> str: - return self.to_string( - { - "event": str(self.event), - } - ) - - -# ----------------------------------------------------------------------------- -class EventId(utils.OpenIntEnum): - PLAYBACK_STATUS_CHANGED = 0x01 - TRACK_CHANGED = 0x02 - TRACK_REACHED_END = 0x03 - TRACK_REACHED_START = 0x04 - PLAYBACK_POS_CHANGED = 0x05 - BATT_STATUS_CHANGED = 0x06 - SYSTEM_STATUS_CHANGED = 0x07 - PLAYER_APPLICATION_SETTING_CHANGED = 0x08 - NOW_PLAYING_CONTENT_CHANGED = 0x09 - AVAILABLE_PLAYERS_CHANGED = 0x0A - ADDRESSED_PLAYER_CHANGED = 0x0B - UIDS_CHANGED = 0x0C - VOLUME_CHANGED = 0x0D - - def __bytes__(self) -> bytes: - return bytes([int(self)]) - - -# ----------------------------------------------------------------------------- -class CharacterSetId(utils.OpenIntEnum): - UTF_8 = 0x06 - - -# ----------------------------------------------------------------------------- -class MediaAttributeId(utils.OpenIntEnum): - TITLE = 0x01 - ARTIST_NAME = 0x02 - ALBUM_NAME = 0x03 - TRACK_NUMBER = 0x04 - TOTAL_NUMBER_OF_TRACKS = 0x05 - GENRE = 0x06 - PLAYING_TIME = 0x07 - DEFAULT_COVER_ART = 0x08 - - -# ----------------------------------------------------------------------------- -@dataclass -class MediaAttribute: - attribute_id: MediaAttributeId - character_set_id: CharacterSetId - attribute_value: str - - -# ----------------------------------------------------------------------------- -class PlayStatus(utils.OpenIntEnum): - STOPPED = 0x00 - PLAYING = 0x01 - PAUSED = 0x02 - FWD_SEEK = 0x03 - REV_SEEK = 0x04 - ERROR = 0xFF + ) # ----------------------------------------------------------------------------- @@ -683,256 +622,180 @@ class SongAndPlayStatus: # ----------------------------------------------------------------------------- class ApplicationSetting: - class AttributeId(utils.OpenIntEnum): + class AttributeId(hci.SpecableEnum): EQUALIZER_ON_OFF = 0x01 REPEAT_MODE = 0x02 SHUFFLE_ON_OFF = 0x03 SCAN_ON_OFF = 0x04 - class EqualizerOnOffStatus(utils.OpenIntEnum): + class EqualizerOnOffStatus(hci.SpecableEnum): OFF = 0x01 ON = 0x02 - class RepeatModeStatus(utils.OpenIntEnum): + class RepeatModeStatus(hci.SpecableEnum): OFF = 0x01 SINGLE_TRACK_REPEAT = 0x02 ALL_TRACK_REPEAT = 0x03 GROUP_REPEAT = 0x04 - class ShuffleOnOffStatus(utils.OpenIntEnum): + class ShuffleOnOffStatus(hci.SpecableEnum): OFF = 0x01 ALL_TRACKS_SHUFFLE = 0x02 GROUP_SHUFFLE = 0x03 - class ScanOnOffStatus(utils.OpenIntEnum): + class ScanOnOffStatus(hci.SpecableEnum): OFF = 0x01 ALL_TRACKS_SCAN = 0x02 GROUP_SCAN = 0x03 - class GenericValue(utils.OpenIntEnum): + class GenericValue(hci.SpecableEnum): pass # ----------------------------------------------------------------------------- -@dataclass class Event: event_id: EventId + _pdu: Optional[bytes] = None + + _Event = TypeVar('_Event', bound='Event') + subclasses: ClassVar[dict[int, type[Event]]] = {} + fields: ClassVar[hci.Fields] = () + + @classmethod + def event(cls, subclass: type[_Event]) -> type[_Event]: + cls.subclasses[subclass.event_id] = subclass + subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass) + return subclass @classmethod def from_bytes(cls, pdu: bytes) -> Event: - event_id = EventId(pdu[0]) - subclass = EVENT_SUBCLASSES.get(event_id, GenericEvent) - return subclass.from_bytes(pdu) + if not (subclass := cls.subclasses.get(pdu[0])): + raise core.InvalidPacketError(f"Unimplemented PDU {pdu[0]}") + instance = subclass(**hci.HCI_Object.dict_from_bytes(pdu, 1, subclass.fields)) + instance._pdu = pdu + return instance def __bytes__(self) -> bytes: - return bytes([self.event_id]) + if self._pdu is None: + self._pdu = bytes([self.event_id]) + hci.HCI_Object.dict_to_bytes( + self.__dict__, self.fields + ) + return self._pdu # ----------------------------------------------------------------------------- @dataclass class GenericEvent(Event): - data: bytes + event_id: EventId = field(metadata=EventId.type_metadata(1)) + data: bytes = field(metadata=hci.metadata('*')) - @classmethod - def from_bytes(cls, pdu: bytes) -> GenericEvent: - return cls(event_id=EventId(pdu[0]), data=pdu[1:]) - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + self.data +GenericEvent.fields = hci.HCI_Object.fields_from_dataclass(GenericEvent) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlaybackStatusChangedEvent(Event): - play_status: PlayStatus - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent: - return cls(play_status=PlayStatus(pdu[1])) - - def __init__(self, play_status: PlayStatus) -> None: - super().__init__(EventId.PLAYBACK_STATUS_CHANGED) - self.play_status = play_status - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + bytes([self.play_status]) + event_id = EventId.PLAYBACK_STATUS_CHANGED + play_status: PlayStatus = field(metadata=PlayStatus.type_metadata(1)) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlaybackPositionChangedEvent(Event): - playback_position: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlaybackPositionChangedEvent: - return cls(playback_position=struct.unpack_from(">I", pdu, 1)[0]) - - def __init__(self, playback_position: int) -> None: - super().__init__(EventId.PLAYBACK_POS_CHANGED) - self.playback_position = playback_position - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack(">I", self.playback_position) + event_id = EventId.PLAYBACK_POS_CHANGED + playback_position: int = field(metadata=hci.metadata('>4')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class TrackChangedEvent(Event): - identifier: bytes - - @classmethod - def from_bytes(cls, pdu: bytes) -> TrackChangedEvent: - return cls(identifier=pdu[1:]) - - def __init__(self, identifier: bytes) -> None: - super().__init__(EventId.TRACK_CHANGED) - self.identifier = identifier - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + self.identifier + event_id = EventId.TRACK_CHANGED + identifier: bytes = field(metadata=hci.metadata('*')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class PlayerApplicationSettingChangedEvent(Event): - @dataclass - class Setting: - attribute_id: ApplicationSetting.AttributeId - value_id: utils.OpenIntEnum + event_id = EventId.PLAYER_APPLICATION_SETTING_CHANGED - player_application_settings: list[Setting] - - @classmethod - def from_bytes(cls, pdu: bytes) -> PlayerApplicationSettingChangedEvent: - def setting(attribute_id_int: int, value_id_int: int): - attribute_id = ApplicationSetting.AttributeId(attribute_id_int) - value_id: utils.OpenIntEnum - if attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: - value_id = ApplicationSetting.EqualizerOnOffStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: - value_id = ApplicationSetting.RepeatModeStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: - value_id = ApplicationSetting.ShuffleOnOffStatus(value_id_int) - elif attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF: - value_id = ApplicationSetting.ScanOnOffStatus(value_id_int) + @dataclass + class Setting(hci.HCI_Dataclass_Object): + attribute_id: ApplicationSetting.AttributeId = field( + metadata=ApplicationSetting.AttributeId.type_metadata(1) + ) + value_id: Union[ + ApplicationSetting.EqualizerOnOffStatus, + ApplicationSetting.RepeatModeStatus, + ApplicationSetting.ShuffleOnOffStatus, + ApplicationSetting.ScanOnOffStatus, + ApplicationSetting.GenericValue, + ] = field(metadata=hci.metadata(1)) + + def __post_init__(self) -> None: + super().__post_init__() + if self.attribute_id == ApplicationSetting.AttributeId.EQUALIZER_ON_OFF: + self.value_id = ApplicationSetting.EqualizerOnOffStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.REPEAT_MODE: + self.value_id = ApplicationSetting.RepeatModeStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.SHUFFLE_ON_OFF: + self.value_id = ApplicationSetting.ShuffleOnOffStatus(self.value_id) + elif self.attribute_id == ApplicationSetting.AttributeId.SCAN_ON_OFF: + self.value_id = ApplicationSetting.ScanOnOffStatus(self.value_id) else: - value_id = ApplicationSetting.GenericValue(value_id_int) + self.value_id = ApplicationSetting.GenericValue(self.value_id) - return cls.Setting(attribute_id, value_id) - - settings = [ - setting(pdu[2 + (i * 2)], pdu[2 + (i * 2) + 1]) for i in range(pdu[1]) - ] - return cls(player_application_settings=settings) - - def __init__(self, player_application_settings: Sequence[Setting]) -> None: - super().__init__(EventId.PLAYER_APPLICATION_SETTING_CHANGED) - self.player_application_settings = list(player_application_settings) - - def __bytes__(self) -> bytes: - return ( - bytes([self.event_id]) - + bytes([len(self.player_application_settings)]) - + b''.join( - [ - bytes([setting.attribute_id, setting.value_id]) - for setting in self.player_application_settings - ] - ) - ) + player_application_settings: Sequence[Setting] = field( + metadata=hci.metadata(Setting.parse_from_bytes, list_begin=True, list_end=True) + ) # ----------------------------------------------------------------------------- +@Event.event @dataclass class NowPlayingContentChangedEvent(Event): - @classmethod - def from_bytes(cls, pdu: bytes) -> NowPlayingContentChangedEvent: - return cls() - - def __init__(self) -> None: - super().__init__(EventId.NOW_PLAYING_CONTENT_CHANGED) + event_id = EventId.NOW_PLAYING_CONTENT_CHANGED # ----------------------------------------------------------------------------- +@Event.event @dataclass class AvailablePlayersChangedEvent(Event): - @classmethod - def from_bytes(cls, pdu: bytes) -> AvailablePlayersChangedEvent: - return cls() - - def __init__(self) -> None: - super().__init__(EventId.AVAILABLE_PLAYERS_CHANGED) + event_id = EventId.AVAILABLE_PLAYERS_CHANGED # ----------------------------------------------------------------------------- +@Event.event @dataclass class AddressedPlayerChangedEvent(Event): - @dataclass - class Player: - player_id: int - uid_counter: int + event_id = EventId.ADDRESSED_PLAYER_CHANGED - @classmethod - def from_bytes(cls, pdu: bytes) -> AddressedPlayerChangedEvent: - player_id, uid_counter = struct.unpack_from(" None: - super().__init__(EventId.ADDRESSED_PLAYER_CHANGED) - self.player = player + @dataclass + class Player(hci.HCI_Dataclass_Object): + player_id: int = field(metadata=hci.metadata('>2')) + uid_counter: int = field(metadata=hci.metadata('>2')) - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack( - ">HH", self.player.player_id, self.player.uid_counter - ) + player: Player = field(metadata=hci.metadata(Player.parse_from_bytes)) # ----------------------------------------------------------------------------- +@Event.event @dataclass class UidsChangedEvent(Event): - uid_counter: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> UidsChangedEvent: - return cls(uid_counter=struct.unpack_from(">H", pdu, 1)[0]) - - def __init__(self, uid_counter: int) -> None: - super().__init__(EventId.UIDS_CHANGED) - self.uid_counter = uid_counter - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + struct.pack(">H", self.uid_counter) + event_id = EventId.UIDS_CHANGED + uid_counter: int = field(metadata=hci.metadata('>2')) # ----------------------------------------------------------------------------- +@Event.event @dataclass class VolumeChangedEvent(Event): - volume: int - - @classmethod - def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent: - return cls(volume=pdu[1]) - - def __init__(self, volume: int) -> None: - super().__init__(EventId.VOLUME_CHANGED) - self.volume = volume - - def __bytes__(self) -> bytes: - return bytes([self.event_id]) + bytes([self.volume]) - - -# ----------------------------------------------------------------------------- -EVENT_SUBCLASSES: dict[EventId, type[Event]] = { - EventId.PLAYBACK_STATUS_CHANGED: PlaybackStatusChangedEvent, - EventId.PLAYBACK_POS_CHANGED: PlaybackPositionChangedEvent, - EventId.TRACK_CHANGED: TrackChangedEvent, - EventId.PLAYER_APPLICATION_SETTING_CHANGED: PlayerApplicationSettingChangedEvent, - EventId.NOW_PLAYING_CONTENT_CHANGED: NowPlayingContentChangedEvent, - EventId.AVAILABLE_PLAYERS_CHANGED: AvailablePlayersChangedEvent, - EventId.ADDRESSED_PLAYER_CHANGED: AddressedPlayerChangedEvent, - EventId.UIDS_CHANGED: UidsChangedEvent, - EventId.VOLUME_CHANGED: VolumeChangedEvent, -} + event_id = EventId.VOLUME_CHANGED + volume: int = field(metadata=hci.metadata(1)) # ----------------------------------------------------------------------------- @@ -947,7 +810,7 @@ class Delegate: class Error(Exception): """The delegate method failed, with a specified status code.""" - def __init__(self, status_code: Protocol.StatusCode) -> None: + def __init__(self, status_code: StatusCode) -> None: self.status_code = status_code supported_events: list[EventId] @@ -989,51 +852,6 @@ class PacketType(enum.IntEnum): CONTINUE = 0b10 END = 0b11 - class PduId(utils.OpenIntEnum): - GET_CAPABILITIES = 0x10 - LIST_PLAYER_APPLICATION_SETTING_ATTRIBUTES = 0x11 - LIST_PLAYER_APPLICATION_SETTING_VALUES = 0x12 - GET_CURRENT_PLAYER_APPLICATION_SETTING_VALUE = 0x13 - SET_PLAYER_APPLICATION_SETTING_VALUE = 0x14 - GET_PLAYER_APPLICATION_SETTING_ATTRIBUTE_TEXT = 0x15 - GET_PLAYER_APPLICATION_SETTING_VALUE_TEXT = 0x16 - INFORM_DISPLAYABLE_CHARACTER_SET = 0x17 - INFORM_BATTERY_STATUS_OF_CT = 0x18 - GET_ELEMENT_ATTRIBUTES = 0x20 - GET_PLAY_STATUS = 0x30 - REGISTER_NOTIFICATION = 0x31 - REQUEST_CONTINUING_RESPONSE = 0x40 - ABORT_CONTINUING_RESPONSE = 0x41 - SET_ABSOLUTE_VOLUME = 0x50 - SET_ADDRESSED_PLAYER = 0x60 - SET_BROWSED_PLAYER = 0x70 - GET_FOLDER_ITEMS = 0x71 - GET_TOTAL_NUMBER_OF_ITEMS = 0x75 - - class StatusCode(utils.OpenIntEnum): - INVALID_COMMAND = 0x00 - INVALID_PARAMETER = 0x01 - PARAMETER_CONTENT_ERROR = 0x02 - INTERNAL_ERROR = 0x03 - OPERATION_COMPLETED = 0x04 - UID_CHANGED = 0x05 - INVALID_DIRECTION = 0x07 - NOT_A_DIRECTORY = 0x08 - DOES_NOT_EXIST = 0x09 - INVALID_SCOPE = 0x0A - RANGE_OUT_OF_BOUNDS = 0x0B - FOLDER_ITEM_IS_NOT_PLAYABLE = 0x0C - MEDIA_IN_USE = 0x0D - NOW_PLAYING_LIST_FULL = 0x0E - SEARCH_NOT_SUPPORTED = 0x0F - SEARCH_IN_PROGRESS = 0x10 - INVALID_PLAYER_ID = 0x11 - PLAYER_NOT_BROWSABLE = 0x12 - PLAYER_NOT_ADDRESSED = 0x13 - NO_VALID_SEARCH_RESULTS = 0x14 - NO_AVAILABLE_PLAYERS = 0x15 - ADDRESSED_PLAYER_CHANGED = 0x16 - class InvalidPidError(Exception): """A response frame with ipid==1 was received.""" @@ -1208,7 +1026,7 @@ async def call() -> None: self.send_rejected_avrcp_response( transaction_label, command.pdu_id, - Protocol.StatusCode.INTERNAL_ERROR, + StatusCode.INTERNAL_ERROR, ) utils.AsyncRunner.spawn(call()) @@ -1243,7 +1061,7 @@ async def get_element_attributes( GetElementAttributesCommand(element_identifier, attribute_ids), ) response = self._check_response(response_context, GetElementAttributesResponse) - return response.attributes + return list(response.attributes) async def monitor_events( self, event_id: EventId, playback_interval: int = 0 @@ -1326,7 +1144,7 @@ async def monitor_player_application_settings( if not isinstance(event, PlayerApplicationSettingChangedEvent): logger.warning("unexpected event class") continue - yield event.player_application_settings + yield list(event.player_application_settings) async def monitor_now_playing_content(self) -> AsyncIterator[None]: """Monitor Now Playing changes from the connected peer.""" @@ -1596,29 +1414,24 @@ def _on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None: avc.CommandFrame.CommandType.NOTIFY, ): # TODO: catch exceptions from delegates - if pdu_id == self.PduId.GET_CAPABILITIES: - self._on_get_capabilities_command( - transaction_label, GetCapabilitiesCommand.from_bytes(pdu) - ) - elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: - self._on_set_absolute_volume_command( - transaction_label, SetAbsoluteVolumeCommand.from_bytes(pdu) - ) - elif pdu_id == self.PduId.REGISTER_NOTIFICATION: - self._on_register_notification_command( - transaction_label, RegisterNotificationCommand.from_bytes(pdu) - ) + command = Command.from_bytes(pdu_id, pdu) + if isinstance(command, GetCapabilitiesCommand): + self._on_get_capabilities_command(transaction_label, command) + elif isinstance(command, SetAbsoluteVolumeCommand): + self._on_set_absolute_volume_command(transaction_label, command) + elif isinstance(command, RegisterNotificationCommand): + self._on_register_notification_command(transaction_label, command) else: # Not supported. # TODO: check that this is the right way to respond in this case. logger.debug("unsupported PDU ID") self.send_rejected_avrcp_response( - transaction_label, pdu_id, self.StatusCode.INVALID_PARAMETER + transaction_label, pdu_id, StatusCode.INVALID_PARAMETER ) else: logger.debug("unsupported command type") self.send_rejected_avrcp_response( - transaction_label, pdu_id, self.StatusCode.INVALID_COMMAND + transaction_label, pdu_id, StatusCode.INVALID_COMMAND ) self.receive_command_state = None @@ -1643,25 +1456,25 @@ def _on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None: # more appropriate. response: Optional[Response] = None if response_code == avc.ResponseFrame.ResponseCode.REJECTED: - response = RejectedResponse.from_bytes(pdu_id, pdu) + response = RejectedResponse.from_bytes(pdu_id=pdu_id, pdu=pdu) elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED: - response = NotImplementedResponse.from_bytes(pdu_id, pdu) + response = NotImplementedResponse.from_bytes(pdu_id=pdu_id, pdu=pdu) elif response_code in ( avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, avc.ResponseFrame.ResponseCode.INTERIM, avc.ResponseFrame.ResponseCode.CHANGED, avc.ResponseFrame.ResponseCode.ACCEPTED, ): - if pdu_id == self.PduId.GET_CAPABILITIES: - response = GetCapabilitiesResponse.from_bytes(pdu) - elif pdu_id == self.PduId.GET_PLAY_STATUS: - response = GetPlayStatusResponse.from_bytes(pdu) - elif pdu_id == self.PduId.GET_ELEMENT_ATTRIBUTES: - response = GetElementAttributesResponse.from_bytes(pdu) - elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: - response = SetAbsoluteVolumeResponse.from_bytes(pdu) - elif pdu_id == self.PduId.REGISTER_NOTIFICATION: - response = RegisterNotificationResponse.from_bytes(pdu) + if pdu_id == PduId.GET_CAPABILITIES: + response = GetCapabilitiesResponse.from_bytes(pdu=pdu) + elif pdu_id == PduId.GET_PLAY_STATUS: + response = GetPlayStatusResponse.from_bytes(pdu=pdu) + elif pdu_id == PduId.GET_ELEMENT_ATTRIBUTES: + response = GetElementAttributesResponse.from_bytes(pdu=pdu) + elif pdu_id == PduId.SET_ABSOLUTE_VOLUME: + response = SetAbsoluteVolumeResponse.from_bytes(pdu=pdu) + elif pdu_id == PduId.REGISTER_NOTIFICATION: + response = RegisterNotificationResponse.from_bytes(pdu=pdu) else: logger.debug("unexpected PDU ID") pending_command.response.set_exception( @@ -1757,10 +1570,8 @@ async def send_avrcp_command( # TODO: fragmentation # Send the command. logger.debug(f">>> AVRCP command PDU: {command}") - pdu = ( - struct.pack(">BBH", command.pdu_id, 0, len(command.parameter)) - + command.parameter - ) + payload = bytes(command) + pdu = struct.pack(">BBH", command.pdu_id, 0, len(payload)) + payload command_frame = avc.VendorDependentCommandFrame( command_type, avc.Frame.SubunitType.PANEL, @@ -1804,10 +1615,8 @@ def send_avrcp_response( ) -> None: # TODO: fragmentation logger.debug(f">>> AVRCP response PDU: {response}") - pdu = ( - struct.pack(">BBH", response.pdu_id, 0, len(response.parameter)) - + response.parameter - ) + payload = bytes(response) + pdu = struct.pack(">BBH", response.pdu_id, 0, len(payload)) + payload response_frame = avc.VendorDependentResponseFrame( response_code, avc.Frame.SubunitType.PANEL, @@ -1830,7 +1639,7 @@ def send_not_implemented_response( self.send_response(transaction_label, response) def send_rejected_avrcp_response( - self, transaction_label: int, pdu_id: Protocol.PduId, status_code: StatusCode + self, transaction_label: int, pdu_id: PduId, status_code: StatusCode ) -> None: self.send_avrcp_response( transaction_label, @@ -1839,7 +1648,7 @@ def send_rejected_avrcp_response( ) def send_not_implemented_avrcp_response( - self, transaction_label: int, pdu_id: Protocol.PduId + self, transaction_label: int, pdu_id: PduId ) -> None: self.send_avrcp_response( transaction_label, @@ -1895,7 +1704,7 @@ async def register_notification() -> None: if command.event_id not in supported_events: logger.debug("event not supported") self.send_not_implemented_avrcp_response( - transaction_label, self.PduId.REGISTER_NOTIFICATION + transaction_label, PduId.REGISTER_NOTIFICATION ) return diff --git a/bumble/hci.py b/bumble/hci.py index 06d6cd31..07f01f03 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -112,7 +112,14 @@ class SpecableEnum(utils.OpenIntEnum): @classmethod def type_spec(cls, size: int): - return {'size': size, 'mapper': lambda x: cls(x).name} + return { + 'serializer': lambda x: x.to_bytes(size, 'little'), + 'parser': lambda data, offset: ( + offset + size, + cls(int.from_bytes(data[offset : offset + size], 'little')), + ), + 'mapper': lambda x: cls(x).name, + } @classmethod def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False): @@ -123,7 +130,14 @@ class SpecableFlag(enum.IntFlag): @classmethod def type_spec(cls, size: int): - return {'size': size, 'mapper': lambda x: cls(x).name} + return { + 'serializer': lambda x: x.to_bytes(size, 'little'), + 'parser': lambda data, offset: ( + offset + size, + cls(int.from_bytes(data[offset : offset + size], 'little')), + ), + 'mapper': lambda x: cls(x).name, + } @classmethod def type_metadata(cls, size: int, list_begin: bool = False, list_end: bool = False): diff --git a/tests/avrcp_test.py b/tests/avrcp_test.py index dfc0bb31..5769eb17 100644 --- a/tests/avrcp_test.py +++ b/tests/avrcp_test.py @@ -15,67 +15,210 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- -import asyncio +from __future__ import annotations + import struct +from collections.abc import Sequence import pytest -from bumble import avc, avctp, avrcp, controller, core, device, host, link -from bumble.transport import common +from bumble import avc, avctp, avrcp + +from . import test_utils # ----------------------------------------------------------------------------- -class TwoDevices: - def __init__(self): - self.connections = [None, None] +class TwoDevices(test_utils.TwoDevices): + protocols: Sequence[avrcp.Protocol] = () - addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] - self.link = link.LocalLink() - self.controllers = [ - controller.Controller('C1', link=self.link, public_address=addresses[0]), - controller.Controller('C2', link=self.link, public_address=addresses[1]), - ] - self.devices = [ - device.Device( - address=addresses[0], - host=host.Host( - self.controllers[0], common.AsyncPipeSink(self.controllers[0]) - ), - ), - device.Device( - address=addresses[1], - host=host.Host( - self.controllers[1], common.AsyncPipeSink(self.controllers[1]) - ), - ), + async def setup_avdtp_connections(self): + self.protocols = [avrcp.Protocol(), avrcp.Protocol()] + self.protocols[0].listen(self.devices[1]) + await self.protocols[1].connect(self.connections[0]) + + @classmethod + async def create_with_avdtp(cls) -> TwoDevices: + devices = await cls.create_with_connection() + await devices.setup_avdtp_connections() + return devices + + +# ----------------------------------------------------------------------------- +def test_GetPlayStatusCommand(): + command = avrcp.GetPlayStatusCommand() + assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command + + +# ----------------------------------------------------------------------------- +def test_GetCapabilitiesCommand(): + command = avrcp.GetCapabilitiesCommand( + capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID + ) + assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command + + +# ----------------------------------------------------------------------------- +def test_SetAbsoluteVolumeCommand(): + command = avrcp.SetAbsoluteVolumeCommand(volume=5) + assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command + + +# ----------------------------------------------------------------------------- +def test_GetElementAttributesCommand(): + command = avrcp.GetElementAttributesCommand( + identifier=999, + attribute_ids=[ + avrcp.MediaAttributeId.ALBUM_NAME, + avrcp.MediaAttributeId.ARTIST_NAME, + ], + ) + assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command + + +# ----------------------------------------------------------------------------- +def test_RegisterNotificationCommand(): + command = avrcp.RegisterNotificationCommand( + event_id=avrcp.EventId.ADDRESSED_PLAYER_CHANGED, playback_interval=123 + ) + assert avrcp.Command.from_bytes(command.pdu_id, bytes(command)) == command + + +# ----------------------------------------------------------------------------- +def test_UidsChangedEvent(): + event = avrcp.UidsChangedEvent(uid_counter=7) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_TrackChangedEvent(): + event = avrcp.TrackChangedEvent(identifier=b'12356') + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_VolumeChangedEvent(): + event = avrcp.VolumeChangedEvent(volume=9) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlaybackStatusChangedEvent(): + event = avrcp.PlaybackStatusChangedEvent(play_status=avrcp.PlayStatus.PLAYING) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_AddressedPlayerChangedEvent(): + event = avrcp.AddressedPlayerChangedEvent( + player=avrcp.AddressedPlayerChangedEvent.Player(player_id=9, uid_counter=10) + ) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_AvailablePlayersChangedEvent(): + event = avrcp.AvailablePlayersChangedEvent() + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlaybackPositionChangedEvent(): + event = avrcp.PlaybackPositionChangedEvent(playback_position=1314) + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_NowPlayingContentChangedEvent(): + event = avrcp.NowPlayingContentChangedEvent() + assert avrcp.Event.from_bytes(bytes(event)) == event + + +# ----------------------------------------------------------------------------- +def test_PlayerApplicationSettingChangedEvent(): + event = avrcp.PlayerApplicationSettingChangedEvent( + player_application_settings=[ + avrcp.PlayerApplicationSettingChangedEvent.Setting( + avrcp.ApplicationSetting.AttributeId.REPEAT_MODE, + avrcp.ApplicationSetting.RepeatModeStatus.ALL_TRACK_REPEAT, + ) ] - self.devices[0].classic_enabled = True - self.devices[1].classic_enabled = True - self.connections = [None, None] - self.protocols = [None, None] + ) + assert avrcp.Event.from_bytes(bytes(event)) == event - def on_connection(self, which, connection): - self.connections[which] = connection - async def setup_connections(self): - await self.devices[0].power_on() - await self.devices[1].power_on() +# ----------------------------------------------------------------------------- +def test_RejectedResponse(): + pdu_id = avrcp.PduId.GET_ELEMENT_ATTRIBUTES + response = avrcp.RejectedResponse( + pdu_id=pdu_id, + status_code=avrcp.StatusCode.DOES_NOT_EXIST, + ) + assert ( + avrcp.RejectedResponse.from_bytes(pdu=bytes(response), pdu_id=pdu_id) + == response + ) - self.connections = await asyncio.gather( - self.devices[0].connect( - self.devices[1].public_address, core.PhysicalTransport.BR_EDR - ), - self.devices[1].accept(self.devices[0].public_address), - ) - self.protocols = [avrcp.Protocol(), avrcp.Protocol()] - self.protocols[0].listen(self.devices[1]) - await self.protocols[1].connect(self.connections[0]) +# ----------------------------------------------------------------------------- +def test_GetPlayStatusResponse(): + response = avrcp.GetPlayStatusResponse( + song_length=1010, song_position=13, play_status=avrcp.PlayStatus.PAUSED + ) + assert avrcp.GetPlayStatusResponse.from_bytes(bytes(response)) == response + + +# ----------------------------------------------------------------------------- +def test_NotImplementedResponse(): + pdu_id = avrcp.PduId.GET_ELEMENT_ATTRIBUTES + response = avrcp.NotImplementedResponse(pdu_id=pdu_id, parameters=b'koasd') + assert ( + avrcp.NotImplementedResponse.from_bytes(bytes(response), pdu_id=pdu_id) + == response + ) + + +# ----------------------------------------------------------------------------- +def test_GetCapabilitiesResponse(): + response = avrcp.GetCapabilitiesResponse( + capability_id=avrcp.GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED, + capabilities=[ + avrcp.EventId.ADDRESSED_PLAYER_CHANGED, + avrcp.EventId.BATT_STATUS_CHANGED, + ], + ) + assert avrcp.GetCapabilitiesResponse.from_bytes(bytes(response)) == response + + +# ----------------------------------------------------------------------------- +def test_RegisterNotificationResponse(): + response = avrcp.RegisterNotificationResponse( + event=avrcp.PlaybackPositionChangedEvent(playback_position=38) + ) + assert avrcp.RegisterNotificationResponse.from_bytes(bytes(response)) == response + + +# ----------------------------------------------------------------------------- +def test_SetAbsoluteVolumeResponse(): + response = avrcp.SetAbsoluteVolumeResponse(volume=99) + assert avrcp.SetAbsoluteVolumeResponse.from_bytes(bytes(response)) == response + + +# ----------------------------------------------------------------------------- +def test_GetElementAttributesResponse(): + response = avrcp.GetElementAttributesResponse( + attributes=[ + avrcp.MediaAttribute( + attribute_id=avrcp.MediaAttributeId.ALBUM_NAME, + attribute_value="White Album", + ) + ] + ) + assert avrcp.GetElementAttributesResponse.from_bytes(bytes(response)) == response # ----------------------------------------------------------------------------- def test_frame_parser(): - with pytest.raises(ValueError) as error: + with pytest.raises(ValueError): avc.Frame.from_bytes(bytes.fromhex("11480000")) x = bytes.fromhex("014D0208") @@ -217,8 +360,7 @@ def test_passthrough_commands(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_supported_events(): - two_devices = TwoDevices() - await two_devices.setup_connections() + two_devices = await TwoDevices.create_with_avdtp() supported_events = await two_devices.protocols[0].get_supported_events() assert supported_events == []