Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions bumble/avdtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,8 +1477,23 @@ def on_message(self, transaction_label: int, message: Message) -> None:
handler = getattr(self, handler_name, None)
if handler:
try:
response = handler(message)
self.send_message(transaction_label, response)
result = handler(message)
if asyncio.iscoroutine(result):

async def wait_and_send() -> None:
try:
response = await result
if response:
self.send_message(transaction_label, response)
except Exception:
logger.exception(
color("!!! Exception in handler:", "red")
)

utils.cancel_on_event(self, self.EVENT_CLOSE, wait_and_send())
else:
if result:
self.send_message(transaction_label, result)
except Exception:
logger.exception(color("!!! Exception in handler:", "red"))
else:
Expand Down Expand Up @@ -1559,7 +1574,7 @@ def send_message(self, transaction_label: int, message: Message) -> None:
async def send_command(self, command: Message):
# TODO: support timeouts
# Send the command
(transaction_label, transaction_result) = await self.start_transaction()
transaction_label, transaction_result = await self.start_transaction()
self.send_message(transaction_label, command)

# Wait for the response
Expand Down Expand Up @@ -1624,14 +1639,14 @@ async def close(self, seid: int) -> Close_Response:
async def abort(self, seid: int) -> Abort_Response:
return await self.send_command(Abort_Command(seid))

def on_discover_command(self, command: Discover_Command) -> Message | None:
async def on_discover_command(self, command: Discover_Command) -> Message | None:
endpoint_infos = [
EndPointInfo(endpoint.seid, 0, endpoint.media_type, endpoint.tsep)
for endpoint in self.local_endpoints
]
return Discover_Response(endpoint_infos)

def on_get_capabilities_command(
async def on_get_capabilities_command(
self, command: Get_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
Expand All @@ -1640,7 +1655,7 @@ def on_get_capabilities_command(

return Get_Capabilities_Response(endpoint.capabilities)

def on_get_all_capabilities_command(
async def on_get_all_capabilities_command(
self, command: Get_All_Capabilities_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
Expand All @@ -1649,7 +1664,7 @@ def on_get_all_capabilities_command(

return Get_All_Capabilities_Response(endpoint.capabilities)

def on_set_configuration_command(
async def on_set_configuration_command(
self, command: Set_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
Expand All @@ -1664,10 +1679,10 @@ def on_set_configuration_command(
stream = Stream(self, endpoint, StreamEndPointProxy(self, command.int_seid))
self.streams[command.acp_seid] = stream

result = stream.on_set_configuration_command(command.capabilities)
result = await stream.on_set_configuration_command(command.capabilities)
return result or Set_Configuration_Response()

def on_get_configuration_command(
async def on_get_configuration_command(
self, command: Get_Configuration_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
Expand All @@ -1676,29 +1691,31 @@ def on_get_configuration_command(
if endpoint.stream is None:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)

return endpoint.stream.on_get_configuration_command()
return await endpoint.stream.on_get_configuration_command()

def on_reconfigure_command(self, command: Reconfigure_Command) -> Message | None:
async def on_reconfigure_command(
self, command: Reconfigure_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)

result = endpoint.stream.on_reconfigure_command(command.capabilities)
result = await endpoint.stream.on_reconfigure_command(command.capabilities)
return result or Reconfigure_Response()

def on_open_command(self, command: Open_Command) -> Message | None:
async def on_open_command(self, command: Open_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Open_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Open_Reject(AVDTP_BAD_STATE_ERROR)

result = endpoint.stream.on_open_command()
result = await endpoint.stream.on_open_command()
return result or Open_Response()

def on_start_command(self, command: Start_Command) -> Message | None:
async def on_start_command(self, command: Start_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
Expand All @@ -1712,12 +1729,12 @@ def on_start_command(self, command: Start_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_start_command()) is not None:
if (result := await endpoint.stream.on_start_command()) is not None:
return result

return Start_Response()

def on_suspend_command(self, command: Suspend_Command) -> Message | None:
async def on_suspend_command(self, command: Suspend_Command) -> Message | None:
for seid in command.acp_seids:
endpoint = self.get_local_endpoint_by_seid(seid)
if endpoint is None:
Expand All @@ -1731,45 +1748,47 @@ def on_suspend_command(self, command: Suspend_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(seid)
if not endpoint or not endpoint.stream:
raise InvalidStateError("Should already be checked!")
if (result := endpoint.stream.on_suspend_command()) is not None:
if (result := await endpoint.stream.on_suspend_command()) is not None:
return result

return Suspend_Response()

def on_close_command(self, command: Close_Command) -> Message | None:
async def on_close_command(self, command: Close_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Close_Reject(AVDTP_BAD_ACP_SEID_ERROR)
if endpoint.stream is None:
return Close_Reject(AVDTP_BAD_STATE_ERROR)

result = endpoint.stream.on_close_command()
result = await endpoint.stream.on_close_command()
return result or Close_Response()

def on_abort_command(self, command: Abort_Command) -> Message | None:
async def on_abort_command(self, command: Abort_Command) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None or endpoint.stream is None:
return Abort_Response()

endpoint.stream.on_abort_command()
await endpoint.stream.on_abort_command()
return Abort_Response()

def on_security_control_command(
async def on_security_control_command(
self, command: Security_Control_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return Security_Control_Reject(AVDTP_BAD_ACP_SEID_ERROR)

result = endpoint.on_security_control_command(command.data)
result = await endpoint.on_security_control_command(command.data)
return result or Security_Control_Response()

def on_delayreport_command(self, command: DelayReport_Command) -> Message | None:
async def on_delayreport_command(
self, command: DelayReport_Command
) -> Message | None:
endpoint = self.get_local_endpoint_by_seid(command.acp_seid)
if endpoint is None:
return DelayReport_Reject(AVDTP_BAD_ACP_SEID_ERROR)

result = endpoint.on_delayreport_command(command.delay)
result = await endpoint.on_delayreport_command(command.delay)
return result or DelayReport_Response()


Expand Down Expand Up @@ -1932,46 +1951,46 @@ async def close(self) -> None:

self.change_state(State.IDLE)

def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.IDLE:
return Set_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_set_configuration_command(configuration)
result = await self.local_endpoint.on_set_configuration_command(configuration)
if result is not None:
return result

self.change_state(State.CONFIGURED)
return None

def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
if self.state not in (
State.CONFIGURED,
State.OPEN,
State.STREAMING,
):
return Get_Configuration_Reject(error_code=AVDTP_BAD_STATE_ERROR)

return self.local_endpoint.on_get_configuration_command()
return await self.local_endpoint.on_get_configuration_command()

def on_reconfigure_command(
async def on_reconfigure_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
if self.state != State.OPEN:
return Reconfigure_Reject(error_code=AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_reconfigure_command(configuration)
result = await self.local_endpoint.on_reconfigure_command(configuration)
if result is not None:
return result

return None

def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
if self.state != State.CONFIGURED:
return Open_Reject(AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_open_command()
result = await self.local_endpoint.on_open_command()
if result is not None:
return result

Expand All @@ -1981,7 +2000,7 @@ def on_open_command(self) -> Message | None:
self.change_state(State.OPEN)
return None

def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
if self.state != State.OPEN:
return Open_Reject(AVDTP_BAD_STATE_ERROR)

Expand All @@ -1990,29 +2009,29 @@ def on_start_command(self) -> Message | None:
logger.warning('received start command before RTP channel establishment')
return Open_Reject(AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_start_command()
result = await self.local_endpoint.on_start_command()
if result is not None:
return result

self.change_state(State.STREAMING)
return None

def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
if self.state != State.STREAMING:
return Open_Reject(AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_suspend_command()
result = await self.local_endpoint.on_suspend_command()
if result is not None:
return result

self.change_state(State.OPEN)
return None

def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
if self.state not in (State.OPEN, State.STREAMING):
return Open_Reject(AVDTP_BAD_STATE_ERROR)

result = self.local_endpoint.on_close_command()
result = await self.local_endpoint.on_close_command()
if result is not None:
return result

Expand All @@ -2027,7 +2046,8 @@ def on_close_command(self) -> Message | None:

return None

def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
await self.local_endpoint.on_abort_command()
if self.rtp_channel is None:
# No need to wait
self.change_state(State.IDLE)
Expand Down Expand Up @@ -2179,13 +2199,13 @@ async def stop(self) -> None:
async def close(self) -> None:
"""[Source Only] Handles when receiving close command."""

def on_reconfigure_command(
async def on_reconfigure_command(
self, command: Iterable[ServiceCapabilities]
) -> Message | None:
del command # unused.
return None

def on_set_configuration_command(
async def on_set_configuration_command(
self, configuration: Iterable[ServiceCapabilities]
) -> Message | None:
logger.debug(
Expand All @@ -2196,34 +2216,34 @@ def on_set_configuration_command(
self.emit(self.EVENT_CONFIGURATION)
return None

def on_get_configuration_command(self) -> Message | None:
async def on_get_configuration_command(self) -> Message | None:
return Get_Configuration_Response(self.configuration)

def on_open_command(self) -> Message | None:
async def on_open_command(self) -> Message | None:
self.emit(self.EVENT_OPEN)
return None

def on_start_command(self) -> Message | None:
async def on_start_command(self) -> Message | None:
self.emit(self.EVENT_START)
return None

def on_suspend_command(self) -> Message | None:
async def on_suspend_command(self) -> Message | None:
self.emit(self.EVENT_SUSPEND)
return None

def on_close_command(self) -> Message | None:
async def on_close_command(self) -> Message | None:
self.emit(self.EVENT_CLOSE)
return None

def on_abort_command(self) -> Message | None:
async def on_abort_command(self) -> Message | None:
self.emit(self.EVENT_ABORT)
return None

def on_delayreport_command(self, delay: int) -> Message | None:
async def on_delayreport_command(self, delay: int) -> Message | None:
self.emit(self.EVENT_DELAY_REPORT, delay)
return None

def on_security_control_command(self, data: bytes) -> Message | None:
async def on_security_control_command(self, data: bytes) -> Message | None:
self.emit(self.EVENT_SECURITY_CONTROL, data)
return None

Expand Down Expand Up @@ -2275,13 +2295,13 @@ async def stop(self) -> None:
self.emit(self.EVENT_STOP)

@override
def on_start_command(self) -> Message | None:
asyncio.create_task(self.start())
async def on_start_command(self) -> Message | None:
await self.start()
return None

@override
def on_suspend_command(self) -> Message | None:
asyncio.create_task(self.stop())
async def on_suspend_command(self) -> Message | None:
await self.stop()
return None


Expand Down
Loading