Skip to content

Commit

Permalink
Manage lifecycle of CIS and SCO links in host
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Dec 25, 2023
1 parent a286700 commit 10c63a9
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 27 deletions.
46 changes: 21 additions & 25 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3071,34 +3071,30 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink
cig_id=cig_id,
)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()

with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}

@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

return await asyncio.gather(*pending_cis_establishments.values())

# [LE only]
Expand Down Expand Up @@ -3707,7 +3703,7 @@ def on_sco_connection_failure(
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)

# [LE only]
Expand Down Expand Up @@ -3787,15 +3783,15 @@ def on_cis_establishment(self, cis_handle: int) -> None:
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)

# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)

@host_event_handler
Expand Down
37 changes: 35 additions & 2 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct

Expand Down Expand Up @@ -120,9 +121,25 @@ def on_acl_pdu(self, pdu: bytes) -> None:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, Connection]
sco_links: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
Expand All @@ -141,6 +158,8 @@ def __init__(
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
Expand Down Expand Up @@ -666,7 +685,13 @@ def on_hci_connection_complete_event(self, event):

def on_hci_disconnection_complete_event(self, event):
# Find the connection
if (connection := self.connections.get(event.connection_handle)) is None:
if (
connection := (
self.connections.pop(event.connection_handle, None)
or self.cis_links.pop(event.connection_handle, None)
or self.sco_links.pop(event.connection_handle, None)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return

Expand All @@ -676,7 +701,6 @@ def on_hci_disconnection_complete_event(self, event):
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]

# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
Expand Down Expand Up @@ -745,6 +769,10 @@ def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
Expand Down Expand Up @@ -811,6 +839,11 @@ def on_hci_synchronous_connection_complete_event(self, event):
f'{event.bd_addr}'
)

self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)

# Notify the client
self.emit(
'sco_connection',
Expand Down
216 changes: 216 additions & 0 deletions tests/host_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import pytest
from unittest import mock


from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.controller import Controller
from bumble.link import LocalLink
from bumble.host import Host
from bumble.transport import AsyncPipeSink
from bumble.hci import (
Address,
HCI_Connection_Request_Event,
HCI_Connection_Complete_Event,
HCI_Disconnection_Complete_Event,
HCI_Synchronous_Connection_Complete_Event,
HCI_LE_CIS_Request_Event,
HCI_LE_CIS_Established_Event,
HCI_SUCCESS,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)

# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
def make_host():
link = LocalLink()
controller = Controller('C1', link=link)
return Host(controller, AsyncPipeSink(controller))


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_acl_connection_disconnection():
host = make_host()

m = mock.Mock()
host.on('connection_request', m)
host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=Address('00:11:22:33:44:55'),
class_of_device=0,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
)
m.assert_called_with(
Address('00:11:22:33:44:55'), 0, HCI_Connection_Complete_Event.ACL_LINK_TYPE
)

host.on('connection', m)
host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
bd_addr=Address('00:11:22:33:44:55'),
connection_handle=0x0001,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=0,
)
)
assert host.connections[0x0001].handle == 0x0001
assert host.connections[0x0001].peer_address == Address('00:11:22:33:44:55')
m.assert_called_with(
0x0001, BT_BR_EDR_TRANSPORT, Address('00:11:22:33:44:55'), None, None
)

host.on('disconnection', m)
host.on_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x0001,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
m.assert_called_with(0x0001, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
assert 0x0001 not in host.connections


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sco_connection_disconnection():
host = make_host()

m = mock.Mock()
host.on('connection_request', m)
host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=Address('00:11:22:33:44:55'),
class_of_device=0,
link_type=HCI_Connection_Complete_Event.SCO_LINK_TYPE,
)
)
m.assert_called_with(
Address('00:11:22:33:44:55'), 0, HCI_Connection_Complete_Event.SCO_LINK_TYPE
)

host.on('sco_connection', m)
host.on_hci_packet(
HCI_Synchronous_Connection_Complete_Event(
status=HCI_SUCCESS,
bd_addr=Address('00:11:22:33:44:55'),
link_type=HCI_Connection_Complete_Event.SCO_LINK_TYPE,
connection_handle=0x0001,
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
)
assert host.sco_links[0x0001].handle == 0x0001
assert host.sco_links[0x0001].peer_address == Address('00:11:22:33:44:55')
m.assert_called_with(
Address('00:11:22:33:44:55'),
0x0001,
HCI_Connection_Complete_Event.SCO_LINK_TYPE,
)

host.on('disconnection', m)
host.on_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x0001,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
assert 0x0001 not in host.sco_links
m.assert_called_with(0x0001, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cis_connection_disconnection():
host = make_host()

m = mock.Mock()
host.on('cis_request', m)
host.on_hci_packet(
HCI_LE_CIS_Request_Event(
acl_connection_handle=0x0001,
cis_connection_handle=0x0002,
cig_id=0x0001,
cis_id=0x0001,
)
)
m.assert_called_with(0x0001, 0x0002, 0x0001, 0x0001)

host.on('cis_establishment', m)
host.on_hci_packet(
HCI_LE_CIS_Established_Event(
status=HCI_SUCCESS,
connection_handle=0x0002,
cig_sync_delay=0,
cis_sync_delay=0,
transport_latency_c_to_p=0,
transport_latency_p_to_c=0,
phy_c_to_p=0,
phy_p_to_c=0,
nse=0,
bn_c_to_p=0,
bn_p_to_c=0,
ft_c_to_p=0,
ft_p_to_c=0,
max_pdu_c_to_p=0,
max_pdu_p_to_c=0,
iso_interval=0,
)
)
assert host.cis_links[0x0002].handle == 0x0002
m.assert_called_with(0x0002)

host.on('disconnection', m)
host.on_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x0002,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
assert 0x0002 not in host.cis_links
m.assert_called_with(0x0002, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)


# -----------------------------------------------------------------------------
async def run_test_host():
await test_acl_connection_disconnection()
await test_sco_connection_disconnection()
await test_cis_connection_disconnection()


# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_host())

0 comments on commit 10c63a9

Please sign in to comment.