diff --git a/pslab/sciencelab.py b/pslab/sciencelab.py index 32b40fa9..82e780c5 100644 --- a/pslab/sciencelab.py +++ b/pslab/sciencelab.py @@ -32,8 +32,13 @@ class ScienceLab(SerialHandler): nrf : pslab.peripherals.NRF24L01 """ - def __init__(self): - super().__init__() + def __init__( + self, + port: str = None, + baudrate: int = 1000000, + timeout: float = 1.0, + ): + super().__init__(port, baudrate, timeout) self.logic_analyzer = LogicAnalyzer(device=self) self.oscilloscope = Oscilloscope(device=self) self.waveform_generator = WaveformGenerator(device=self) @@ -210,10 +215,10 @@ def _read_program_address(self, address: int): return data def _device_id(self): - a = self.read_program_address(0x800FF8) - b = self.read_program_address(0x800FFA) - c = self.read_program_address(0x800FFC) - d = self.read_program_address(0x800FFE) + a = self._read_program_address(0x800FF8) + b = self._read_program_address(0x800FFA) + c = self._read_program_address(0x800FFC) + d = self._read_program_address(0x800FFE) val = d | (c << 16) | (b << 32) | (a << 48) return val diff --git a/pslab/serial_handler.py b/pslab/serial_handler.py index 84c52fc4..18ac90ee 100755 --- a/pslab/serial_handler.py +++ b/pslab/serial_handler.py @@ -23,6 +23,40 @@ logger = logging.getLogger(__name__) +def detect(): + """Detect connected PSLab devices. + + Returns + ------- + devices : dict of str: str + Dictionary containing port name as keys and device version on that + port as values. + """ + regex = [] + + for vid, pid in zip(SerialHandler._USB_VID, SerialHandler._USB_PID): + regex.append(f"{vid:04x}:{pid:04x}") + + regex = "(" + "|".join(regex) + ")" + port_info_generator = list_ports.grep(regex) + pslab_devices = {} + + for port_info in port_info_generator: + version = _get_version(port_info.device) + if any(expected in version for expected in ["PSLab", "CSpark"]): + pslab_devices[port_info.device] = version + + return pslab_devices + + +def _get_version(port: str) -> str: + interface = serial.Serial(port=port, baudrate=1e6, timeout=1) + interface.write(CP.COMMON) + interface.write(CP.GET_VERSION) + version = interface.readline() + return version.decode("utf-8") + + class SerialHandler: """Provides methods for communicating with the PSLab hardware. @@ -98,9 +132,11 @@ def connect( Parameters ---------- port : str, optional - The name of the port to which the PSLab is connected as a string. On - Posix this is a path, e.g. "/dev/ttyACM0". On Windows, it's a numbered - COM port, e.g. "COM5". Will be autodetected if not specified. + The name of the port to which the PSLab is connected as a string. + On Posix this is a path, e.g. "/dev/ttyACM0". On Windows, it's a + numbered COM port, e.g. "COM5". Will be autodetected if not + specified. If multiple PSLab devices are connected, port must be + specified. baudrate : int, optional Symbol rate in bit/s. The default value is 1000000. timeout : float, optional @@ -111,6 +147,8 @@ def connect( ------ SerialException If connection could not be established. + RuntimeError + If ultiple devices are connected and no port was specified. """ # serial.Serial opens automatically if port is not None. self.interface = serial.Serial( @@ -119,28 +157,31 @@ def connect( timeout=timeout, write_timeout=timeout, ) + pslab_devices = detect() if self.interface.is_open: # User specified a port. version = self.get_version() else: - regex = [] - for vid, pid in zip(self._USB_VID, self._USB_PID): - regex.append(f"{vid:04x}:{pid:04x}") - - regex = "(" + "|".join(regex) + ")" - port_info_generator = list_ports.grep(regex) - - for port_info in port_info_generator: - self.interface.port = port_info.device + if len(pslab_devices) == 1: + self.interface.port = list(pslab_devices.keys())[0] self.interface.open() version = self.get_version() - if any(expected in version for expected in ["PSLab", "CSpark"]): - break + elif len(pslab_devices) > 1: + found = "" + + for port, version in pslab_devices.items(): + found += f"{port}: {version}" + + raise RuntimeError( + "Multiple PSLab devices found:\n" + f"{found}" + "Please choose a device by specifying a port." + ) else: version = "" - if any(expected in version for expected in ["PSLab", "CSpark"]): + if self.interface.port in pslab_devices: self.version = version logger.info(f"Connected to {self.version} on {self.interface.port}.") else: @@ -174,13 +215,11 @@ def reconnect( port = self.interface.port if port is None else port timeout = self.interface.timeout if timeout is None else timeout - self.interface = serial.Serial( + self.connect( port=port, baudrate=baudrate, timeout=timeout, - write_timeout=timeout, ) - self.connect() def get_version(self) -> str: """Query PSLab for its version and return it as a decoded string. diff --git a/tests/test_serial_handler.py b/tests/test_serial_handler.py index 84a1f93d..7c391d65 100644 --- a/tests/test_serial_handler.py +++ b/tests/test_serial_handler.py @@ -3,15 +3,19 @@ from serial.tools.list_ports_common import ListPortInfo import pslab.protocol as CP -from pslab.serial_handler import SerialHandler +from pslab.serial_handler import detect, SerialHandler VERSION = "PSLab vMOCK\n" PORT = "mock_port" +PORT2 = "mock_port_2" -def mock_ListPortInfo(found=True): +def mock_ListPortInfo(found=True, multiple=False): if found: - yield ListPortInfo(device=PORT) + if multiple: + yield from [ListPortInfo(device=PORT), ListPortInfo(device=PORT2)] + else: + yield ListPortInfo(device=PORT) else: return @@ -20,12 +24,14 @@ def mock_ListPortInfo(found=True): def mock_serial(mocker): serial_patch = mocker.patch("pslab.serial_handler.serial.Serial") serial_patch().readline.return_value = VERSION.encode() + serial_patch().is_open = False return serial_patch @pytest.fixture -def mock_handler(mocker, mock_serial): +def mock_handler(mocker, mock_serial, mock_list_ports): mocker.patch("pslab.serial_handler.SerialHandler._check_udev") + mock_list_ports.grep.return_value = mock_ListPortInfo() return SerialHandler() @@ -34,8 +40,12 @@ def mock_list_ports(mocker): return mocker.patch("pslab.serial_handler.list_ports") +def test_detect(mocker, mock_serial, mock_list_ports): + mock_list_ports.grep.return_value = mock_ListPortInfo(multiple=True) + assert len(detect()) == 2 + + def test_connect_scan_port(mocker, mock_serial, mock_list_ports): - mock_serial().is_open = False mock_list_ports.grep.return_value = mock_ListPortInfo() mocker.patch("pslab.serial_handler.SerialHandler._check_udev") SerialHandler() @@ -43,19 +53,26 @@ def test_connect_scan_port(mocker, mock_serial, mock_list_ports): def test_connect_scan_failure(mocker, mock_serial, mock_list_ports): - mock_serial().is_open = False mock_list_ports.grep.return_value = mock_ListPortInfo(found=False) mocker.patch("pslab.serial_handler.SerialHandler._check_udev") with pytest.raises(SerialException): SerialHandler() +def test_connect_multiple_connected(mocker, mock_serial, mock_list_ports): + mock_list_ports.grep.return_value = mock_ListPortInfo(multiple=True) + mocker.patch("pslab.serial_handler.SerialHandler._check_udev") + with pytest.raises(RuntimeError): + SerialHandler() + + def test_disconnect(mock_serial, mock_handler): mock_handler.disconnect() mock_serial().close.assert_called() -def test_reconnect(mock_serial, mock_handler): +def test_reconnect(mock_serial, mock_handler, mock_list_ports): + mock_list_ports.grep.return_value = mock_ListPortInfo() mock_handler.reconnect() mock_serial().close.assert_called() @@ -67,10 +84,9 @@ def test_get_version(mock_serial, mock_handler): def test_get_ack_success(mock_serial, mock_handler): - H = SerialHandler() success = 1 mock_serial().read.return_value = CP.Byte.pack(success) - assert H.get_ack() == success + assert mock_handler.get_ack() == success def test_get_ack_failure(mock_serial, mock_handler):