Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<p align="center">
<img src="docs/source/_static/bpod-core.svg" />
<img src="https://raw.githubusercontent.com/int-brain-lab/bpod-core/refs/heads/main/docs/source/_static/bpod-core.svg" />
</p>

# bpod-core
Expand Down
10 changes: 9 additions & 1 deletion bpod_core/bpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,20 +289,28 @@ class Bpod(AbstractBpod):
_zmq_service: DualChannelHost
_next_fsm_index: int = -1
_serial_buffer = bytearray() # buffer for TrialReader thread

serial0: ExtendedSerial
"""Primary serial device for communication with the Bpod."""

serial1: ExtendedSerial | None = None
"""Secondary serial device for communication with the Bpod."""

serial2: ExtendedSerial | None = None
"""Tertiary serial device for communication with the Bpod - used by Bpod 2+ only."""

inputs: NamedTuple
"""Available input channels."""

outputs: NamedTuple
"""Available output channels."""

modules: NamedTuple
"""Available modules."""

event_names: list[str]
"""List of event names."""

actions: list[str]
"""List of output actions."""

Expand Down Expand Up @@ -1390,7 +1398,7 @@ def read(self) -> bool:
bool
True if the input channel is active, False otherwise.
"""
return self._serial0.verify([b'I', self.index])
return self._serial0.verify(struct.pack('<cB', b'I', self.index))

def override(self, state: bool) -> None:
"""
Expand Down
110 changes: 18 additions & 92 deletions bpod_core/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,56 +2,18 @@

import logging
import struct
from collections.abc import Iterable
from typing import Any, TypeAlias
from typing import Any

import numpy as np
from serial import Serial
from serial.serialutil import to_bytes as serial_to_bytes # type: ignore[attr-defined]
from serial.threaded import Protocol
from typing_extensions import Buffer, Self

logger = logging.getLogger(__name__)

ByteLike: TypeAlias = (
Buffer | int | np.ndarray | np.generic | str | Iterable['ByteLike']
)
"""
A recursive type alias representing any data that can be converted to bytes for serial
communication.

Includes:

- Buffer: Any buffer-compatible object (e.g., bytes, bytearray, memoryview)
- int: Single integer values (interpreted as a single byte)
- np.ndarray, np.generic: NumPy arrays and scalars (converted via .tobytes())
- str: Strings (encoded as UTF-8)
- Iterable['ByteLike']: Nested iterables of ByteLike types (recursively flattened)
"""


class ExtendedSerial(Serial):
"""Enhances :class:`serial.Serial` with additional functionality."""

def write(self, data: ByteLike) -> int | None: # type: ignore[override]
"""
Write data to the serial port.

This method extends :meth:`serial.Serial.write` with support for NumPy types,
unsigned 8-bit integers, strings (interpreted as UTF-8) and iterables.

Parameters
----------
data : ByteLike
Data to be written to the serial port.

Returns
-------
int or None
Number of bytes written to the serial port.
"""
return super().write(to_bytes(data))

def write_struct(self, format_string: str, *data: Any) -> int | None: # noqa:ANN401
"""
Write structured data to the serial port.
Expand All @@ -74,9 +36,16 @@ def write_struct(self, format_string: str, *data: Any) -> int | None: # noqa:AN
int | None
The number of bytes written to the serial port, or None if the write
operation fails.

Raises
------
struct.error
Error occurred during packing of the data into binary format.
serial.SerialTimeoutException
In case a write timeout is configured for the port and the time is exceeded.
"""
buffer = struct.pack(format_string, *data)
return super().write(buffer)
return self.write(buffer)

def read_struct(self, format_string: str) -> tuple[Any, ...]:
"""
Expand All @@ -102,15 +71,16 @@ def read_struct(self, format_string: str) -> tuple[Any, ...]:
n_bytes = struct.calcsize(format_string)
return struct.unpack(format_string, super().read(n_bytes))

def query(self, query: ByteLike, size: int = 1) -> bytes:
def query(self, query: Buffer, size: int = 1) -> bytes:
r"""
Query data from the serial port.

This method is a combination of :meth:`write` and :meth:`~serial.Serial.read`.
This method is a combination of :meth:`~serial.Serial.write` and
:meth:`~serial.Serial.read`.

Parameters
----------
query : ByteLike
query : Buffer
Query to be sent to the serial port.
size : int, default: 1
The number of bytes to receive from the serial port.
Expand All @@ -125,7 +95,7 @@ def query(self, query: ByteLike, size: int = 1) -> bytes:

def query_struct(
self,
query: ByteLike,
query: Buffer,
format_string: str,
) -> tuple[Any, ...]:
"""
Expand All @@ -136,7 +106,7 @@ def query_struct(

Parameters
----------
query : ByteLike
query : Buffer
Query to be sent to the serial port.
format_string : str
A format string that specifies the layout of the data to be read. It should
Expand All @@ -153,7 +123,7 @@ def query_struct(
self.write(query)
return self.read_struct(format_string)

def verify(self, query: ByteLike, expected_response: bytes = b'\x01') -> bool:
def verify(self, query: Buffer = b'', expected_response: bytes = b'\x01') -> bool:
r"""
Verify the response of the serial port.

Expand All @@ -162,8 +132,8 @@ def verify(self, query: ByteLike, expected_response: bytes = b'\x01') -> bool:

Parameters
----------
query : ByteLike
The query to be sent to the serial port.
query : Buffer, optional
The query to be sent to the serial port. Defaults to an empty byte string.
expected_response : bytes, optional
The expected response from the serial port. Default: b'\x01'.

Expand Down Expand Up @@ -271,47 +241,3 @@ def process(self, data_chunk: bytearray) -> None:
data_chunk : bytearray
A contiguous slice of bytes of length `chunk_size`.
"""


def to_bytes(data: ByteLike) -> bytes: # noqa: PLR0911
"""
Convert data to a bytes object.

This function extends :func:`serial.to_bytes` with support for:
- NumPy arrays and scalars
- Unsigned 8-bit integers
- Strings (encoded as UTF-8)
- Arbitrary iterables of ByteLike

Parameters
----------
data : ByteLike
Data to be converted to a bytes object.

Returns
-------
bytes
Data converted to bytes.

Raises
------
TypeError
If the input type cannot be interpreted as bytes
ValueError
If an integer is out of the 0..255 range when coerced to a single byte.
"""
match data:
case bytes():
return data
case bytearray():
return bytes(data)
case memoryview() | np.ndarray() | np.generic():
return data.tobytes()
case int():
return bytes([data])
case str():
return data.encode('utf-8')
case _ if isinstance(data, Iterable):
return b''.join(to_bytes(item) for item in data)
case _:
return serial_to_bytes(data) # type: ignore[no-any-return]
Loading
Loading