Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Update compressed data check in response to code review

  • Loading branch information...
commit 4e252d844f67e32f21f2bbd1a96bed902c8599ce 1 parent e6f3ace
Max Bolingbroke authored DSG User committed
6 scipy/io/matlab/mio.py
View
@@ -99,6 +99,12 @@ def loadmat(file_name, mdict=None, appendmat=True, **kwargs):
False replicates the behavior of scipy version 0.7.x (returning
numpy object arrays). The default setting is True, because it
allows easier round-trip load and save of MATLAB files.
+ verify_compressed_data_integrity : bool, optional
+ Whether the length of compressed sequences in the MATLAB file
+ should be checked, to ensure that they are not longer than we expect.
+ It is advisable to enable this (the default) because overlong
+ compressed sequences in MATLAB files generally indicate that the
+ files have experienced some sort of corruption.
variable_names : None or sequence
If None (the default) - read all variables in file. Otherwise
`variable_names` should be a sequence of strings, giving names of the
6 scipy/io/matlab/mio5.py
View
@@ -136,6 +136,7 @@ def __init__(self,
chars_as_strings=True,
matlab_compatible=False,
struct_as_record=True,
+ verify_compressed_data_integrity=True,
uint16_codec=None
):
'''Initializer for matlab 5 file format reader
@@ -154,7 +155,8 @@ def __init__(self,
squeeze_me,
chars_as_strings,
matlab_compatible,
- struct_as_record
+ struct_as_record,
+ verify_compressed_data_integrity
)
# Set uint16 codec
if not uint16_codec:
@@ -220,7 +222,7 @@ def read_var_header(self):
# Make new stream from compressed data
stream = ZlibInputStream(self.mat_stream, byte_count)
self._matrix_reader.set_stream(stream)
- check_stream_limit = True
+ check_stream_limit = self.verify_compressed_data_integrity
mdtype, byte_count = self._matrix_reader.read_full_tag()
else:
check_stream_limit = False
15 scipy/io/matlab/mio5_utils.pyx
View
@@ -516,6 +516,13 @@ cdef class VarReader5:
''' Return matrix header for current stream position
Returns matrix headers at top level and sub levels
+
+ Parameters
+ ----------
+ check_stream_limit : if True, then if the returned header
+ is passed to array_from_header, it will be verified that
+ the length of the uncompressed data is not overlong (which
+ can indicate .mat file corruption)
'''
cdef:
cdef cnp.uint32_t u4s[2]
@@ -612,7 +619,7 @@ cdef class VarReader5:
return np.array([])
else:
return np.array([[]])
- header = self.read_header(0)
+ header = self.read_header(False)
return self.array_from_header(header, process)
cpdef array_from_header(self, VarHeader5 header, int process=1):
@@ -655,7 +662,7 @@ cdef class VarReader5:
elif mc == mxSPARSE_CLASS:
arr = self.read_sparse(header)
# no current processing makes sense for sparse
- process = 0
+ process = False
elif mc == mxCHAR_CLASS:
arr = self.read_char(header)
if process and self.chars_as_strings:
@@ -680,7 +687,9 @@ cdef class VarReader5:
process = 0
if header.check_stream_limit:
if not self.cstream.all_data_read():
- raise ValueError('Did not fully consume compressed contents of an miCOMPRESSED element. This can indicate that the .mat file is corrupted.')
+ raise ValueError('Did not fully consume compressed contents' +
+ ' of an miCOMPRESSED element. This can' +
+ ' indicate that the .mat file is corrupted.')
if process and self.squeeze_me:
return squeeze_element(arr)
return arr
4 scipy/io/matlab/miobase.py
View
@@ -344,7 +344,8 @@ def __init__(self, mat_stream,
squeeze_me=False,
chars_as_strings=True,
matlab_compatible=False,
- struct_as_record=True
+ struct_as_record=True,
+ verify_compressed_data_integrity=True
):
'''
Initializer for mat file reader
@@ -368,6 +369,7 @@ def __init__(self, mat_stream,
self.squeeze_me = squeeze_me
self.chars_as_strings = chars_as_strings
self.mat_dtype = mat_dtype
+ self.verify_compressed_data_integrity = verify_compressed_data_integrity
def set_matlab_compatible(self):
''' Sets options to return arrays as MATLAB loads them '''
12 scipy/io/matlab/streams.pyx
View
@@ -106,9 +106,8 @@ cdef class ZlibInputStream(GenericStream):
----------
stream : file-like
Stream to read compressed data from.
- max_length : int, optional
+ max_length : int
Maximum number of bytes to read from the stream.
- -1 if the length is unlimited.
Notes
-----
@@ -127,7 +126,7 @@ cdef class ZlibInputStream(GenericStream):
cdef size_t _total_position
cdef size_t _read_bytes
- def __init__(self, fobj, ssize_t max_length=-1):
+ def __init__(self, fobj, ssize_t max_length):
self.fobj = fobj
self._max_length = max_length
@@ -145,9 +144,7 @@ cdef class ZlibInputStream(GenericStream):
if self._buffer_position < self._buffer_size:
return
- read_size = BLOCK_SIZE
- if self._max_length >= 0:
- read_size = min(read_size, self._max_length - self._read_bytes)
+ read_size = min(BLOCK_SIZE, self._max_length - self._read_bytes)
block = self.fobj.read(read_size)
self._read_bytes += len(block)
@@ -200,7 +197,8 @@ cdef class ZlibInputStream(GenericStream):
return self.read_string(n_bytes, &p)
cpdef int all_data_read(self):
- return (self._max_length == self._read_bytes) and (self._buffer_size == self._buffer_position)
+ return (self._max_length == self._read_bytes) and \
+ (self._buffer_size == self._buffer_position)
cpdef long int tell(self):
return self._total_position
5 scipy/io/matlab/tests/test_mio.py
View
@@ -800,12 +800,9 @@ def test_empty_string():
def test_corrupted_data():
import zlib
for exc, fname in [(ValueError, 'corrupted_zlib_data.mat'), (zlib.error, 'corrupted_zlib_checksum.mat')]:
- fp = open(pjoin(test_data_path, fname), 'rb')
- try:
+ with open(pjoin(test_data_path, fname), 'rb') as fp:
rdr = MatFile5Reader(fp)
assert_raises(exc, rdr.get_variables)
- finally:
- fp.close()
def test_read_both_endian():
23 scipy/io/matlab/tests/test_streams.py
View
@@ -105,8 +105,9 @@ def test_read():
class TestZlibInputStream(object):
def _get_data(self, size):
data = np.random.randint(0, 256, size).astype(np.uint8).tostring()
- stream = BytesIO(zlib.compress(data))
- return stream, data
+ compressed_data = zlib.compress(data)
+ stream = BytesIO(compressed_data)
+ return stream, len(compressed_data), data
def test_read(self):
block_size = 131072
@@ -118,8 +119,8 @@ def test_read(self):
block_size, block_size+1]
def check(size, read_size):
- compressed_stream, data = self._get_data(size)
- stream = ZlibInputStream(compressed_stream)
+ compressed_stream, compressed_data_len, data = self._get_data(size)
+ stream = ZlibInputStream(compressed_stream, compressed_data_len)
data2 = b''
so_far = 0
while True:
@@ -148,9 +149,9 @@ def test_read_max_length(self):
assert_raises(IOError, stream.read, 1)
def test_seek(self):
- compressed_stream, data = self._get_data(1024)
+ compressed_stream, compressed_data_len, data = self._get_data(1024)
- stream = ZlibInputStream(compressed_stream)
+ stream = ZlibInputStream(compressed_stream, compressed_data_len)
stream.seek(123)
p = 123
@@ -177,5 +178,15 @@ def test_seek(self):
stream.seek(10000, 1)
assert_raises(IOError, stream.read, 12)
+ def test_all_data_read(self):
+ compressed_stream, compressed_data_len, data = self._get_data(1024)
+ stream = ZlibInputStream(compressed_stream, compressed_data_len)
+ assert_false(stream.all_data_read())
+ stream.seek(512)
+ assert_false(stream.all_data_read())
+ stream.seek(1024)
+ assert_true(stream.all_data_read())
+
+
if __name__ == "__main__":
run_module_suite()
Please sign in to comment.
Something went wrong with that request. Please try again.