Skip to content

Commit

Permalink
Manage lifecycle of CIS and SCO links in host
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed Jan 10, 2024
1 parent cdd7f37 commit 27d5492
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 30 deletions.
46 changes: 21 additions & 25 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3070,34 +3070,30 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink
cig_id=cig_id,
)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()

with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}

@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

return await asyncio.gather(*pending_cis_establishments.values())

# [LE only]
Expand Down Expand Up @@ -3719,7 +3715,7 @@ def on_sco_connection_failure(
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)

# [LE only]
Expand Down Expand Up @@ -3799,15 +3795,15 @@ def on_cis_establishment(self, cis_handle: int) -> None:
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)

# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)

@host_event_handler
Expand Down
37 changes: 35 additions & 2 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct

Expand Down Expand Up @@ -160,9 +161,25 @@ def on_acl_pdu(self, pdu: bytes) -> None:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, Connection]
sco_links: Dict[int, Connection]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
Expand All @@ -182,6 +199,8 @@ def __init__(
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.local_version = None
Expand Down Expand Up @@ -695,7 +714,13 @@ def on_hci_connection_complete_event(self, event):

def on_hci_disconnection_complete_event(self, event):
# Find the connection
if (connection := self.connections.get(event.connection_handle)) is None:
if (
connection := (
self.connections.pop(event.connection_handle, None)
or self.cis_links.pop(event.connection_handle, None)
or self.sco_links.pop(event.connection_handle, None)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return

Expand All @@ -705,7 +730,6 @@ def on_hci_disconnection_complete_event(self, event):
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]

# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
Expand Down Expand Up @@ -774,6 +798,10 @@ def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
Expand Down Expand Up @@ -840,6 +868,11 @@ def on_hci_synchronous_connection_complete_event(self, event):
f'{event.bd_addr}'
)

self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)

# Notify the client
self.emit(
'sco_connection',
Expand Down
5 changes: 2 additions & 3 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,9 +458,8 @@ def on_cis_request(
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2

# TODO: Fix Host CIS support.
# await cis_links[0].disconnect()
# await cis_links[1].disconnect()
await cis_links[0].disconnect()
await cis_links[1].disconnect()


# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 27d5492

Please sign in to comment.