Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFCOMM: Slightly refactor and correct constants #418

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
139 changes: 41 additions & 98 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import logging
import asyncio
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING

Expand Down Expand Up @@ -60,27 +61,18 @@

RFCOMM_PSM = 0x0003

class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first

# Frame types
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum):
PN = 0x20
MSC = 0x38

RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}

# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38

# FCS CRC
CRC_TABLE = bytes([
Expand Down Expand Up @@ -118,7 +110,7 @@
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])

RFCOMM_DEFAULT_WINDOW_SIZE = 16
RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000

RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
Expand Down Expand Up @@ -183,7 +175,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame:
def __init__(
self,
frame_type: int,
frame_type: FrameType,
c_r: int,
dlci: int,
p_f: int,
Expand All @@ -206,14 +198,11 @@ def __init__(
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
if frame_type == FrameType.UIH:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)

def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]

@staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2
Expand All @@ -237,32 +226,32 @@ def make_mcc(mcc_type: int, c_r: int, data: bytes) -> bytes:

@staticmethod
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)

@staticmethod
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)

@staticmethod
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)

@staticmethod
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)

@staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)

@staticmethod
def from_bytes(data: bytes) -> RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = data[1] & 0xEF
frame_type = FrameType(data[1] & 0xEF)
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
Expand Down Expand Up @@ -291,7 +280,7 @@ def __bytes__(self) -> bytes:

def __str__(self) -> str:
return (
f'{color(self.type_name(), "yellow")}'
f'{color(self.type.name, "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
Expand All @@ -301,6 +290,7 @@ def __str__(self) -> str:


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN:
dlci: int
cl: int
Expand All @@ -310,23 +300,11 @@ class RFCOMM_MCC_PN:
max_retransmissions: int
window_size: int

def __init__(
self,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
def __post_init__(self) -> None:
if self.window_size < 1 or self.window_size > 7:
logger.warning(
f'Error Recovery Window size {self.window_size} is out of range [1, 7].'
)

@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
Expand All @@ -337,7 +315,7 @@ def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7],
window_size=data[7] & 0x07,
)

def __bytes__(self) -> bytes:
Expand All @@ -350,23 +328,14 @@ def __bytes__(self) -> bytes:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF,
# Only 3 bits are meaningful.
self.window_size & 0x07,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser on line 312 should also me modified to have window_size=data[7] & 0x07. Maybe even have a check in the constructor to raise ValueError if window_size is not > 0 and < 8.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

]
)

def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC:
dlci: int
fc: int
Expand All @@ -375,16 +344,6 @@ class RFCOMM_MCC_MSC:
ic: int
dv: int

def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv

@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC(
Expand All @@ -409,16 +368,6 @@ def __bytes__(self) -> bytes:
]
)

def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)


# -----------------------------------------------------------------------------
class DLC(EventEmitter):
Expand Down Expand Up @@ -471,7 +420,7 @@ def send_frame(self, frame: RFCOMM_Frame) -> None:
self.multiplexer.send_frame(frame)

def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)

def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
Expand All @@ -485,9 +434,7 @@ def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:

# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))

Expand All @@ -503,9 +450,7 @@ def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:

# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))

Expand Down Expand Up @@ -559,9 +504,7 @@ def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
Expand Down Expand Up @@ -589,7 +532,7 @@ def accept(self) -> None:
max_retransmissions=0,
window_size=self.window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING)
Expand Down Expand Up @@ -711,7 +654,7 @@ def on_pdu(self, pdu: bytes) -> None:
if frame.dlci == 0:
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
if frame.type == FrameType.DM:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
Expand All @@ -724,7 +667,7 @@ def on_pdu(self, pdu: bytes) -> None:
dlc.on_frame(frame)

def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler(frame)

def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
Expand Down Expand Up @@ -772,10 +715,10 @@ def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)

if mcc_type == RFCOMM_MCC_PN_TYPE:
if mcc_type == MccType.PN:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
elif mcc_type == MccType.MSC:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)

Expand Down Expand Up @@ -871,7 +814,7 @@ async def open_dlc(
max_retransmissions=0,
window_size=window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING)
Expand Down
29 changes: 28 additions & 1 deletion tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.rfcomm import RFCOMM_Frame
import asyncio
import pytest

from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -43,6 +47,29 @@ def test_frames():
basic_frame_check(frame)


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
devices = test_utils.TwoDevices()
await devices.setup_connection()

accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)

multiplexer = await Client(devices.connections[1]).start()
dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel))

queues = [asyncio.Queue(), asyncio.Queue()]
for dlc, queue in zip(dlcs, queues):
dlc.sink = queue.put_nowait

dlcs[0].write(b'The quick brown fox jumps over the lazy dog')
assert await queues[1].get() == b'The quick brown fox jumps over the lazy dog'

dlcs[1].write(b'Lorem ipsum dolor sit amet')
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()