diff --git a/corsikaio/io.py b/corsikaio/io.py index 4490883..a060781 100644 --- a/corsikaio/io.py +++ b/corsikaio/io.py @@ -81,19 +81,22 @@ def iter_blocks(f): buffer_size, = RECORD_MARKER.unpack(data) data = f.read(buffer_size) - if len(data) == 0: - if is_fortran_file: + if is_fortran_file: + if len(data) < buffer_size: raise IOError("Read less bytes than expected, file seems to be truncated") - else: + + else: + if len(data) == 0: return - n_blocks = len(data) // BLOCK_SIZE_BYTES + n_blocks, rest = divmod(len(data), BLOCK_SIZE_BYTES) + if rest != 0: + raise IOError("Read less bytes than expected, file seems to be truncated") + for block in range(n_blocks): start = block * BLOCK_SIZE_BYTES stop = start + BLOCK_SIZE_BYTES block = data[start:stop] - if len(block) < BLOCK_SIZE_BYTES: - raise IOError("Read less bytes than expected, file seems to be truncated") yield block # read trailing record marker diff --git a/tests/test_io.py b/tests/test_io.py index a235d13..676de3c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -154,3 +154,17 @@ def test_versions(): block = read_block(f, buffer_size) assert get_version(block, EVTH_VERSION_POSITION) == version + + + +@pytest.mark.parametrize("size", (100, 1000, 5000)) +def test_iter_blocks_truncated(size, tmp_path, dummy_file): + path = tmp_path / f"test_truncated_{size}.dat" + + with path.open("wb") as out, dummy_file.open("rb") as infile: + out.write(infile.read(size)) + + with pytest.raises(IOError, match="file seems to be truncated"): + with path.open("rb") as f: + for _ in iter_blocks(f): + pass