diff --git a/labgrid/driver/consoleexpectmixin.py b/labgrid/driver/consoleexpectmixin.py index 7cdc0c0ce..536b12227 100644 --- a/labgrid/driver/consoleexpectmixin.py +++ b/labgrid/driver/consoleexpectmixin.py @@ -20,10 +20,14 @@ def __attrs_post_init__(self): @Driver.check_active @step(result=True, tag='console') - def read(self, size=1, timeout=0.0): - res = self._read(size=size, timeout=timeout) - self.logger.debug("Read %i bytes: %s, timeout %.2f, requested size %i", - len(res), res, timeout, size) + def read(self, size=1, timeout=0.0, max_size=None): + res = self._read(size=size, timeout=timeout, max_size=max_size) + if max_size: + self.logger.debug("Read %i bytes: %s, timeout %.2f, requested size %i, max size %i", + len(res), res, timeout, size, max_size) + else: + self.logger.debug("Read %i bytes: %s, timeout %.2f, requested size %i", + len(res), res, timeout, size) return res @Driver.check_active diff --git a/labgrid/driver/externalconsoledriver.py b/labgrid/driver/externalconsoledriver.py index 1177e2244..ad9619fed 100644 --- a/labgrid/driver/externalconsoledriver.py +++ b/labgrid/driver/externalconsoledriver.py @@ -70,17 +70,23 @@ def close(self): self._child = None self.status = 0 - def _read(self, size: int = 1024, timeout: int = 0): + def _read(self, size: int = 1024, timeout: int = 0, max_size: int = None): """ Reads 'size' bytes from the serialport Keyword Arguments: size -- amount of bytes to read, defaults to 1024 + max_size -- maximal amount of bytes to read """ + if max_size: + read_size = min(size, max_size) + else: + read_size = size + if self._child.poll() is not None: raise ExecutionError("child has vanished") if self._poll.poll(timeout): - return self._child.stdout.read(size) + return self._child.stdout.read(read_size) return b'' diff --git a/labgrid/driver/serialdriver.py b/labgrid/driver/serialdriver.py index 9f3f45d4a..7f3360db8 100644 --- a/labgrid/driver/serialdriver.py +++ b/labgrid/driver/serialdriver.py @@ -83,14 +83,19 @@ def get_export_vars(self): vars["protocol"] = self.port.protocol return vars - def _read(self, size: int = 1, timeout: float = 0.0): + def _read(self, size: int = 1, timeout: float = 0.0, max_size: int = None): """ Reads 'size' or more bytes from the serialport Keyword Arguments: size -- amount of bytes to read, defaults to 1 + max_size -- maximal amount of bytes to read, values 'None' or '0' do not restrict the read + length, defaults to None + if size == max_size: read and return exactly size = max_size bytes """ reading = max(size, self.serial.in_waiting) + if max_size: # limit reading to max_size if provided + reading = min(reading, max_size) self.serial.timeout = timeout res = self.serial.read(reading) if not res: diff --git a/tests/test_externalconsoledriver.py b/tests/test_externalconsoledriver.py index 4878ed60d..e67ce0e76 100644 --- a/tests/test_externalconsoledriver.py +++ b/tests/test_externalconsoledriver.py @@ -14,5 +14,8 @@ def test_communicate(self, target): target.activate(d) d.write(data) time.sleep(0.1) - assert d.read(1024) == data + assert d.read(1024) == data # assert written data is read + d.write(data) + time.sleep(0.1) + assert d.read(5, max_size=5) == data[:5] # assert max_size limits read bytes d.close() diff --git a/tests/test_serialdriver.py b/tests/test_serialdriver.py index c5ee88773..d9b7d0387 100644 --- a/tests/test_serialdriver.py +++ b/tests/test_serialdriver.py @@ -23,14 +23,18 @@ def test_write(self, target, serial_port, mocker): serial_mock.return_value.open.assert_called_once_with() serial_mock.return_value.write.assert_called_once_with(b"testdata") - def test_read(self, target, serial_port, mocker): + @pytest.mark.parametrize("param", [[1, 1, None, 1], # old test case + [3, 2, None, 3], [3, 2, 5, 3], [3, 2, 1, 1], [1, 2, 1, 1]]) + # param = [size, in_waiting, max_size, out] + def test_read(self, target, serial_port, mocker, param): serial_mock = mocker.patch('serial.Serial') - serial_mock.return_value.in_waiting = 0 + serial_mock.return_value.in_waiting = param[1] s = SerialDriver(target, "serial") target.activate(s) - s.read() + s.read(size=param[0], max_size=param[2]) serial_mock.return_value.open.assert_called_once_with() - serial_mock.return_value.read.assert_called_once_with(1) + # assert 'read' called once with correct return: + serial_mock.return_value.read.assert_called_once_with(param[3]) def test_close(self, target, serial_port, mocker): serial_mock = mocker.patch('serial.Serial')