Skip to content

Commit

Permalink
Add RFCOMM and SDP helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Feb 2, 2024
1 parent c6cfd10 commit 15542dd
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 30 deletions.
99 changes: 70 additions & 29 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self

from pyee import EventEmitter

from . import core, l2cap
from bumble import core
from bumble import l2cap
from bumble import sdp
from .colors import color
from .core import (
UUID,
Expand All @@ -35,15 +38,6 @@
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)

if TYPE_CHECKING:
from bumble.device import Device, Connection
Expand Down Expand Up @@ -122,29 +116,33 @@ class MccType(enum.IntEnum):
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
) -> List[sdp.ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
]
),
]
Expand All @@ -154,15 +152,43 @@ def make_service_sdp_records(

if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
)
)

return records


# -----------------------------------------------------------------------------
async def search_channel_from_sdp(sdp_client: sdp.Client, uuid: UUID) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
sdp_client: A connected SDP client to search from.
uuid: UUID to associate with.
Returns:
Channel number if found, otherwise None.
"""
attributes = await sdp_client.search_attributes(
uuids=[uuid],
attribute_ids=[sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID],
)
for service in attributes:
for attribute in service:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
protocol_descriptor_list = attribute.value.value
if (
len(protocol_descriptor_list) >= 2
and protocol_descriptor_list[1].value[0].value
== core.BT_RFCOMM_PROTOCOL_ID
):
return protocol_descriptor_list[1].value[1].value
return None


# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
Expand Down Expand Up @@ -876,7 +902,15 @@ async def shutdown(self) -> None:
self.multiplexer = None

# Close the L2CAP channel
# TODO
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None

async def __aenter__(self) -> Multiplexer:
return await self.start()

async def __aexit__(self, *args) -> None:
await self.shutdown()


# -----------------------------------------------------------------------------
Expand All @@ -890,7 +924,7 @@ def __init__(self, device: Device) -> None:
self.acceptors = {}

# Register ourselves with the L2CAP channel manager
device.create_l2cap_server(
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
)

Expand Down Expand Up @@ -941,3 +975,10 @@ def on_dlc(self, dlc: DLC) -> None:
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc)

def __enter__(self) -> Self:
return self

def __exit__(self, *args) -> None:
self.l2cap_server.close()
self.l2cap_server = None
8 changes: 8 additions & 0 deletions bumble/sdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self

from . import core, l2cap
from .colors import color
Expand Down Expand Up @@ -920,6 +921,13 @@ async def get_attributes(

return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)

async def __aenter__(self) -> Self:
await self.connect()
return self

async def __aexit__(self, *args) -> None:
await self.disconnect()


# -----------------------------------------------------------------------------
class Server:
Expand Down
51 changes: 50 additions & 1 deletion tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
import pytest

from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
from bumble import core
from bumble import sdp
from bumble.rfcomm import (
RFCOMM_Frame,
Server,
Client,
DLC,
make_service_sdp_records,
search_channel_from_sdp,
)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -70,6 +79,46 @@ async def test_basic_connection():
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
HANDLE = 2
CHANNEL = 1
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')

devices = test_utils.TwoDevices()
await devices.setup_connection()

devices[0].sdp_service_records[HANDLE] = make_service_sdp_records(
HANDLE, CHANNEL, SERVICE_UUID
)

async with sdp.Client(devices.connections[1]) as sdp_client:
assert await search_channel_from_sdp(sdp_client, SERVICE_UUID) == CHANNEL


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_context():
HANDLE = 2
CHANNEL = 1
SERVICE_UUID = core.UUID('00000000-0000-0000-0000-000000000001')

devices = test_utils.TwoDevices()
await devices.setup_connection()

server = Server(devices[0])
with server:
assert server.l2cap_server is not None

client = Client(devices.connections[1])
async with client:
assert client.l2cap_channel is not None

assert client.l2cap_channel is None
assert server.l2cap_server is None


# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()
15 changes: 15 additions & 0 deletions tests/sdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------


# -----------------------------------------------------------------------------
def basic_check(x: DataElement) -> None:
serialized = bytes(x)
Expand Down Expand Up @@ -269,6 +270,20 @@ async def test_service_search_attribute():
assert expect.value == actual.value


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_client_async_context():
devices = TwoDevices()
await devices.setup_connection()

client = Client(devices.connections[1])

async with client:
assert client.channel is not None

assert client.channel is None


# -----------------------------------------------------------------------------
async def run():
test_data_elements()
Expand Down

0 comments on commit 15542dd

Please sign in to comment.