Skip to content

Commit

Permalink
Add RFCOMM and SDP helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Feb 2, 2024
1 parent c6cfd10 commit 23ad515
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 31 deletions.
109 changes: 80 additions & 29 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self

from pyee import EventEmitter

from . import core, l2cap
from bumble import core
from bumble import l2cap
from bumble import sdp
from .colors import color
from .core import (
UUID,
Expand All @@ -35,15 +38,6 @@
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)

if TYPE_CHECKING:
from bumble.device import Device, Connection
Expand Down Expand Up @@ -122,29 +116,33 @@ class MccType(enum.IntEnum):
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
) -> List[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
]
),
]
Expand All @@ -154,15 +152,54 @@ def make_service_sdp_records(

if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
)
)

return records


# -----------------------------------------------------------------------------
async def search_channels_from_sdp(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results


# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
Expand Down Expand Up @@ -876,7 +913,15 @@ async def shutdown(self) -> None:
self.multiplexer = None

# Close the L2CAP channel
# TODO
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None

async def __aenter__(self) -> Multiplexer:
return await self.start()

async def __aexit__(self, *args) -> None:
await self.shutdown()


# -----------------------------------------------------------------------------
Expand All @@ -890,7 +935,7 @@ def __init__(self, device: Device) -> None:
self.acceptors = {}

# Register ourselves with the L2CAP channel manager
device.create_l2cap_server(
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
)

Expand Down Expand Up @@ -941,3 +986,9 @@ def on_dlc(self, dlc: DLC) -> None:
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc)

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
self.l2cap_server.close()
8 changes: 8 additions & 0 deletions bumble/sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self

from . import core, l2cap
from .colors import color
Expand Down Expand Up @@ -920,6 +921,13 @@ async def get_attributes(

return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)

async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *args) -> None:
await self.disconnect()


# -----------------------------------------------------------------------------
class Server:
Expand Down
28 changes: 27 additions & 1 deletion examples/run_hfp_handsfree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import functools
from typing import Optional

from bumble import core
from bumble import rfcomm
from bumble import hci
from bumble.device import Device, Connection
Expand Down Expand Up @@ -101,7 +102,9 @@ def on_sco_request(connection: Connection, link_type: int, protocol: HfProtocol)
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_classic_hfp.py <device-config> <transport-spec>')
print(
'Usage: run_classic_hfp.py <device-config> <transport-spec> [remote-address]'
)
print('example: run_classic_hfp.py classic2.json usb:04b4:f901')
return

Expand Down Expand Up @@ -157,6 +160,29 @@ async def main():
ui_server = UiServer()
await ui_server.start()

if len(sys.argv) >= 4:
peer_address = sys.argv[3]
connection = await device.connect(
peer_address, core.BT_BR_EDR_TRANSPORT, timeout=5.0
)
await connection.authenticate()
await connection.encrypt()
channels = await rfcomm.search_channels_from_sdp(connection)
channel = next(
(
channel
for channel, service_classes in channels.items()
if core.BT_HANDSFREE_AUDIO_GATEWAY_SERVICE in service_classes
),
None,
)
if channel is not None:
async with rfcomm.Client(connection) as rfcomm_client:
dlc = await rfcomm_client.open_dlc(channel)
connection.abort_on(
'disconnection', hfp.HfProtocol(dlc, configuration).run()
)

await hci_source.wait_for_termination()


Expand Down
49 changes: 48 additions & 1 deletion tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
import pytest

from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
from bumble import core
from bumble.rfcomm import (
RFCOMM_Frame,
Server,
Client,
DLC,
make_service_sdp_records,
search_channels_from_sdp,
RFCOMM_PSM,
)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +79,44 @@ async def test_basic_connection():
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
HANDLE = 2
CHANNEL = 1
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')

devices = test_utils.TwoDevices()
await devices.setup_connection()

devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
HANDLE, CHANNEL, SERVICE_UUID
)

assert (
SERVICE_UUID
in (await search_channels_from_sdp(devices.connections[1]))[CHANNEL]
)


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_context():
devices = test_utils.TwoDevices()
await devices.setup_connection()

server = Server(devices[0])
with server:
assert server.l2cap_server is not None

client = Client(devices.connections[1])
async with client:
assert client.l2cap_channel is not None

assert client.l2cap_channel is None
assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers


# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()
15 changes: 15 additions & 0 deletions tests/sdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
def basic_check(x: DataElement) -> None:
serialized = bytes(x)
Expand Down Expand Up @@ -269,6 +270,20 @@ async def test_service_search_attribute():
assert expect.value == actual.value


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_client_async_context():
devices = TwoDevices()
await devices.setup_connection()

client = Client(devices.connections[1])

async with client:
assert client.channel is not None

assert client.channel is None


# -----------------------------------------------------------------------------
async def run():
test_data_elements()
Expand Down

0 comments on commit 23ad515

Please sign in to comment.