Skip to content

Commit

Permalink
Add RFCOMM and SDP helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Feb 2, 2024
1 parent c6cfd10 commit 95d1a6d
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 75 deletions.
48 changes: 3 additions & 45 deletions apps/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
Client as SdpClient,
)
from bumble.transport import open_transport_or_link
import bumble.rfcomm
Expand Down Expand Up @@ -198,48 +196,6 @@ def make_sdp_records(channel):
}


async def find_rfcomm_channel_with_uuid(connection: Connection, uuid: str) -> int:
# Connect to the SDP Server
sdp_client = SdpClient(connection)
await sdp_client.connect()

# Search for services with an L2CAP service attribute
search_result = await sdp_client.search_attributes(
[BT_L2CAP_PROTOCOL_ID],
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_list in search_result:
service_uuid = None
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
for service_class_id in service_class_id_list.value:
service_uuid = service_class_id.value
if str(service_uuid) != uuid:
# This service doesn't have a UUID or isn't the right one.
continue

# Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
await sdp_client.disconnect()
return protocol_descriptor.value[1].value

await sdp_client.disconnect()
return 0


def log_stats(title, stats):
stats_min = min(stats)
stats_max = max(stats)
Expand Down Expand Up @@ -957,7 +913,9 @@ async def on_connection(self, connection):
logging.info(
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
)
channel = await find_rfcomm_channel_with_uuid(connection, self.uuid)
channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
connection, self.uuid
)
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
if channel == 0:
logging.info(color('!!! No RFComm service with this UUID found', 'red'))
Expand Down
136 changes: 107 additions & 29 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self

from pyee import EventEmitter

from . import core, l2cap
from bumble import core
from bumble import l2cap
from bumble import sdp
from .colors import color
from .core import (
UUID,
Expand All @@ -35,15 +38,6 @@
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)

if TYPE_CHECKING:
from bumble.device import Device, Connection
Expand Down Expand Up @@ -122,29 +116,33 @@ class MccType(enum.IntEnum):
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
) -> List[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
]
),
]
Expand All @@ -154,15 +152,81 @@ def make_service_sdp_records(

if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
)
)

return records


# -----------------------------------------------------------------------------
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results


# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
uuid: UUID of service record to search for.
Returns:
RFCOMM channel number if found, otherwise None.
"""
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[uuid if isinstance(uuid, UUID) else UUID(uuid)],
attribute_ids=[sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID],
)
for attribute_lists in search_result:
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
return protocol_descriptor_list[1].value[1].value
return None


# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
Expand Down Expand Up @@ -876,7 +940,15 @@ async def shutdown(self) -> None:
self.multiplexer = None

# Close the L2CAP channel
# TODO
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None

async def __aenter__(self) -> Multiplexer:
return await self.start()

async def __aexit__(self, *args) -> None:
await self.shutdown()


# -----------------------------------------------------------------------------
Expand All @@ -890,7 +962,7 @@ def __init__(self, device: Device) -> None:
self.acceptors = {}

# Register ourselves with the L2CAP channel manager
device.create_l2cap_server(
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
)

Expand Down Expand Up @@ -941,3 +1013,9 @@ def on_dlc(self, dlc: DLC) -> None:
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc)

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
self.l2cap_server.close()
8 changes: 8 additions & 0 deletions bumble/sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self

from . import core, l2cap
from .colors import color
Expand Down Expand Up @@ -920,6 +921,13 @@ async def get_attributes(

return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)

async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *args) -> None:
await self.disconnect()


# -----------------------------------------------------------------------------
class Server:
Expand Down
51 changes: 50 additions & 1 deletion tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@
import pytest

from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
from bumble import core
from bumble.rfcomm import (
RFCOMM_Frame,
Server,
Client,
DLC,
make_service_sdp_records,
find_rfcomm_channels,
find_rfcomm_channel_with_uuid,
RFCOMM_PSM,
)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +80,45 @@ async def test_basic_connection():
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
HANDLE = 2
CHANNEL = 1
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')

devices = test_utils.TwoDevices()
await devices.setup_connection()

devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
HANDLE, CHANNEL, SERVICE_UUID
)

assert SERVICE_UUID in (await find_rfcomm_channels(devices.connections[1]))[CHANNEL]
assert (
await find_rfcomm_channel_with_uuid(devices.connections[1], SERVICE_UUID)
== CHANNEL
)


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_context():
devices = test_utils.TwoDevices()
await devices.setup_connection()

server = Server(devices[0])
with server:
assert server.l2cap_server is not None

client = Client(devices.connections[1])
async with client:
assert client.l2cap_channel is not None

assert client.l2cap_channel is None
assert RFCOMM_PSM not in devices[0].l2cap_channel_manager.servers


# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()
15 changes: 15 additions & 0 deletions tests/sdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
def basic_check(x: DataElement) -> None:
serialized = bytes(x)
Expand Down Expand Up @@ -269,6 +270,20 @@ async def test_service_search_attribute():
assert expect.value == actual.value


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_client_async_context():
devices = TwoDevices()
await devices.setup_connection()

client = Client(devices.connections[1])

async with client:
assert client.channel is not None

assert client.channel is None


# -----------------------------------------------------------------------------
async def run():
test_data_elements()
Expand Down

0 comments on commit 95d1a6d

Please sign in to comment.