Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage lifecycle of CIS and SCO links in host #376

Merged
merged 1 commit into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 21 additions & 25 deletions bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -3078,34 +3078,30 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink
cig_id=cig_id,
)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()

with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}

@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
zxzxwu marked this conversation as resolved.
Show resolved Hide resolved
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)

result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)

return await asyncio.gather(*pending_cis_establishments.values())

# [LE only]
Expand Down Expand Up @@ -3753,7 +3749,7 @@ def on_sco_connection_failure(
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)

# [LE only]
Expand Down Expand Up @@ -3833,15 +3829,15 @@ def on_cis_establishment(self, cis_handle: int) -> None:
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)

# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)

@host_event_handler
Expand Down
49 changes: 44 additions & 5 deletions bumble/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct

Expand Down Expand Up @@ -161,9 +162,25 @@ def on_acl_pdu(self, pdu: bytes) -> None:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int


# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
Expand All @@ -183,6 +200,8 @@ def __init__(
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.local_version = None
Expand Down Expand Up @@ -696,25 +715,36 @@ def on_hci_connection_complete_event(self, event):

def on_hci_disconnection_complete_event(self, event):
# Find the connection
zxzxwu marked this conversation as resolved.
Show resolved Hide resolved
if (connection := self.connections.get(event.connection_handle)) is None:
handle = event.connection_handle
if (
connection := (
self.connections.get(handle)
or self.cis_links.get(handle)
or self.sco_links.get(handle)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return

if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
barbibulle marked this conversation as resolved.
Show resolved Hide resolved

# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
self.emit('disconnection', handle, event.reason)
(
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0)
)
else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}')

# Notify the listeners
self.emit('disconnection_failure', event.connection_handle, event.status)
self.emit('disconnection_failure', handle, event.status)

def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
Expand Down Expand Up @@ -775,6 +805,10 @@ def on_hci_le_cis_request_event(self, event):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
Expand Down Expand Up @@ -841,6 +875,11 @@ def on_hci_synchronous_connection_complete_event(self, event):
f'{event.bd_addr}'
)

self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)

# Notify the client
self.emit(
'sco_connection',
Expand Down
5 changes: 2 additions & 3 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,8 @@ def on_cis_request(
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2

# TODO: Fix Host CIS support.
# await cis_links[0].disconnect()
# await cis_links[1].disconnect()
await cis_links[0].disconnect()
await cis_links[1].disconnect()


# -----------------------------------------------------------------------------
Expand Down
25 changes: 15 additions & 10 deletions tests/hfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@

from .test_utils import TwoDevices
from bumble import core
from bumble import device
from bumble import hfp
from bumble import rfcomm
from bumble import hci


# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -109,7 +109,7 @@ async def test_sco_setup():
devices[1].accept(devices[0].public_address),
)

def on_sco_request(_connection: device.Connection, _link_type: int):
def on_sco_request(_connection, _link_type: int):
connections[1].abort_on(
'disconnection',
devices[1].send_command(
Expand All @@ -124,26 +124,31 @@ def on_sco_request(_connection: device.Connection, _link_type: int):

devices[1].on('sco_request', on_sco_request)

sco_connections = [
sco_connection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]

devices[0].on(
'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link)
)
devices[1].on(
'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link)
)
for device, future in zip(devices, sco_connection_futures):
device.on('sco_connection', future.set_result)

await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
connection_handle=connections[0].handle,
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
)
)
sco_connections = await asyncio.gather(*sco_connection_futures)

sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)

await asyncio.gather(*sco_connections)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)


# -----------------------------------------------------------------------------
Expand Down