diff --git a/corsikaio/file.py b/corsikaio/file.py index 6d78d67..1a63c29 100644 --- a/corsikaio/file.py +++ b/corsikaio/file.py @@ -71,7 +71,10 @@ def run_end(self): return self._run_end def __next__(self): - block = next(self._block_iter) + try: + block = next(self._block_iter) + except StopIteration: + raise IOError("File seems to be truncated") if block[:4] == b'RUNE': self._run_end = parse_run_end(block) @@ -88,7 +91,11 @@ def __next__(self): data_bytes = bytearray() long_bytes = bytearray() - block = next(self._block_iter) + try: + block = next(self._block_iter) + except StopIteration: + raise IOError("File seems to be truncated") + while block[:4] != b'EVTE': if block[:4] == b'LONG': @@ -96,7 +103,10 @@ def __next__(self): else: data_bytes += block - block = next(self._block_iter) + try: + block = next(self._block_iter) + except StopIteration: + raise IOError("File seems to be truncated") if self.parse_blocks: event_end = parse_event_end(block)[0] diff --git a/corsikaio/io.py b/corsikaio/io.py index 81472a6..a060781 100644 --- a/corsikaio/io.py +++ b/corsikaio/io.py @@ -72,20 +72,31 @@ def iter_blocks(f): # for the fortran-chunked output, we need to read the record size if is_fortran_file: data = f.read(RECORD_MARKER.size) + if len(data) == 0: + return + if len(data) < RECORD_MARKER.size: raise IOError("Read less bytes than expected, file seems to be truncated") buffer_size, = RECORD_MARKER.unpack(data) data = f.read(buffer_size) + if is_fortran_file: + if len(data) < buffer_size: + raise IOError("Read less bytes than expected, file seems to be truncated") + + else: + if len(data) == 0: + return + + n_blocks, rest = divmod(len(data), BLOCK_SIZE_BYTES) + if rest != 0: + raise IOError("Read less bytes than expected, file seems to be truncated") - n_blocks = len(data) // BLOCK_SIZE_BYTES for block in range(n_blocks): start = block * BLOCK_SIZE_BYTES stop = start + BLOCK_SIZE_BYTES block = data[start:stop] - if len(block) < BLOCK_SIZE_BYTES: - raise IOError("Read less bytes than expected, file seems to be truncated") yield block # read trailing record marker diff --git a/tests/test_file.py b/tests/test_file.py index a998ffb..c695b14 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -1,6 +1,9 @@ import pytest import numpy as np +from corsikaio.constants import BLOCK_SIZE_BYTES +from corsikaio.io import RECORD_MARKER + def test_version(): from corsikaio import CorsikaFile @@ -88,7 +91,16 @@ def test_particle_no_parse(): -def test_truncated(tmp_path): +@pytest.mark.parametrize( + "size", + ( + RECORD_MARKER.size + 22932, + RECORD_MARKER.size + 2 * 22932, + RECORD_MARKER.size + 3 * 22932, + 2000, + ) +) +def test_truncated(tmp_path, size): '''Test we raise a meaningful error for a truncated file Truncated files might happen if corsika crashes or the disk is full. @@ -100,7 +112,7 @@ def test_truncated(tmp_path): with open("tests/resources/corsika757_particle", "rb") as f: with path.open("wb") as out: - out.write(f.read(273 * 10)) + out.write(f.read(size)) with pytest.raises(IOError, match="seems to be truncated"): with CorsikaParticleFile(path) as f: diff --git a/tests/test_io.py b/tests/test_io.py index 54a68df..676de3c 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -125,8 +125,20 @@ def test_iter_blocks_simple_file(dummy_file): assert block[:4] == b'RUNE' assert (np.frombuffer(block[4:], np.float32) == data[1:]).all() + with pytest.raises(StopIteration): + next(block_it) +def test_iter_blocks_all(dummy_file): + """Test for iterblocks for the case of no record markers""" + + with dummy_file.open('rb') as f: + n_blocks_read = 0 + for _ in iter_blocks(f): + n_blocks_read += 1 + + assert n_blocks_read == 27 + def test_versions(): from corsikaio.io import read_buffer_size, read_block from corsikaio.subblocks import get_version @@ -142,3 +154,17 @@ def test_versions(): block = read_block(f, buffer_size) assert get_version(block, EVTH_VERSION_POSITION) == version + + + +@pytest.mark.parametrize("size", (100, 1000, 5000)) +def test_iter_blocks_truncated(size, tmp_path, dummy_file): + path = tmp_path / f"test_truncated_{size}.dat" + + with path.open("wb") as out, dummy_file.open("rb") as infile: + out.write(infile.read(size)) + + with pytest.raises(IOError, match="file seems to be truncated"): + with path.open("rb") as f: + for _ in iter_blocks(f): + pass