-
Notifications
You must be signed in to change notification settings - Fork 52
/
io.py
383 lines (317 loc) · 11.9 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# -*- coding: utf-8 -*-
# Copyright (c) oct2py developers.
# Distributed under the terms of the MIT License.
from __future__ import absolute_import, print_function, division
import inspect
import dis
import threading
import numpy as np
from scipy.io import loadmat, savemat
from scipy.io.matlab.mio5 import MatlabObject, MatlabFunction
from scipy.sparse import spmatrix
from .compat import PY2
from .dynamic import OctaveVariablePtr, OctaveUserClass, OctaveFunctionPtr
from .utils import Oct2PyError
_WRITE_LOCK = threading.Lock()
def read_file(path, session=None):
"""Read the data from the given file path.
"""
try:
data = loadmat(path, struct_as_record=True)
except UnicodeDecodeError as e:
raise Oct2PyError(str(e))
out = dict()
for (key, value) in data.items():
out[key] = _extract(value, session)
return out
def write_file(obj, path, oned_as='row', convert_to_float=True):
"""Save a Python object to an Octave file on the given path.
"""
data = _encode(obj, convert_to_float)
try:
# scipy.io.savemat is not thread-save.
# See https://github.com/scipy/scipy/issues/7260
with _WRITE_LOCK:
savemat(path, data, appendmat=False, oned_as=oned_as,
long_field_names=True)
except KeyError: # pragma: no cover
raise Exception('could not save mat file')
class Struct(dict):
"""
Octave style struct, enhanced.
Notes
=====
Supports dictionary and attribute style access. Can be pickled,
and supports code completion in a REPL.
Examples
========
>>> from pprint import pprint
>>> from oct2py import Struct
>>> a = Struct()
>>> a.b = 'spam' # a["b"] == 'spam'
>>> a.c["d"] = 'eggs' # a.c.d == 'eggs'
>>> pprint(a)
{'b': 'spam', 'c': {'d': 'eggs'}}
"""
def __getattr__(self, attr):
# Access the dictionary keys for unknown attributes.
try:
return self[attr]
except KeyError:
msg = "'Struct' object has no attribute %s" % attr
raise AttributeError(msg)
def __getitem__(self, attr):
# Get a dict value; create a Struct if requesting a Struct member.
# Do not create a key if the attribute starts with an underscore.
if attr in self.keys() or attr.startswith('_'):
return dict.__getitem__(self, attr)
frame = inspect.currentframe()
# step into the function that called us
if frame.f_back.f_back and self._is_allowed(frame.f_back.f_back):
dict.__setitem__(self, attr, Struct())
elif self._is_allowed(frame.f_back):
dict.__setitem__(self, attr, Struct())
return dict.__getitem__(self, attr)
def _is_allowed(self, frame):
# Check for allowed op code in the calling frame.
allowed = [dis.opmap['STORE_ATTR'], dis.opmap['LOAD_CONST'],
dis.opmap.get('STOP_CODE', 0)]
bytecode = frame.f_code.co_code
instruction = bytecode[frame.f_lasti + 3]
instruction = ord(instruction) if PY2 else instruction
return instruction in allowed
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
@property
def __dict__(self):
# Allow for code completion in a REPL.
return self.copy()
class StructArray(np.recarray):
"""A Python representation of an Octave structure array.
Notes
=====
Accessing a record returns a Cell containing the values.
This class is not meant to be directly created by the user. It is
created automatically for structure array values received from Octave.
The last axis is squeezed if it is of size 1 to simplify element access.
Examples
========
>>> from oct2py import octave
>>> # generate the struct array
>>> octave.eval('x = struct("y", {1, 2}, "z", {3, 4});')
>>> x = octave.pull('x')
>>> x.y # attribute access -> oct2py Cell
Cell([[1.0, 2.0]])
>>> x['z'] # item access -> oct2py Cell
Cell([[3.0, 4.0]])
>>> x[0, 0] # index access -> numpy record
(1.0, 3.0)
>>> x[0, 1].z
4.0
"""
def __new__(cls, value, session=None):
"""Create a struct array from a value and optional Octave session."""
value = np.asarray(value)
# Squeeze the last element if it is 1
if (value.shape[value.ndim - 1] == 1):
value = value.squeeze(axis=value.ndim - 1)
value = np.atleast_1d(value)
if not session:
return value.view(cls)
# Extract the values.
obj = np.empty(value.size, dtype=value.dtype).view(cls)
for (i, item) in enumerate(value.ravel()):
for name in value.dtype.names:
obj[i][name] = _extract(item[name], session)
return obj.reshape(value.shape)
@property
def fieldnames(self):
"""The field names of the struct array."""
return self.dtype.names
def __getattribute__(self, attr):
"""Return object arrays as cells and all other values unchanged.
"""
attr = np.recarray.__getattribute__(self, attr)
if isinstance(attr, np.ndarray) and attr.dtype.kind == 'O':
return Cell(attr)
return attr
def __getitem__(self, item):
"""Return object arrays as cells and all other values unchanged.
"""
item = np.recarray.__getitem__(self, item)
if isinstance(item, np.ndarray) and item.dtype.kind == 'O':
return Cell(item)
return item
def __repr__(self):
shape = self.shape
if len(shape) == 1:
shape = (shape[0], 1)
msg = 'x'.join(str(i) for i in shape)
msg += ' StructArray containing the fields:'
for key in self.fieldnames:
msg += '\n %s' % key
return msg
class Cell(np.ndarray):
"""A Python representation of an Octave cell array.
Notes
=====
This class is not meant to be directly created by the user. It is
created automatically for cell array values received from Octave.
The last axis is squeezed if it is of size 1 to simplify element access.
Examples
========
>>> from oct2py import octave
>>> # generate the struct array
>>> octave.eval("x = cell(2,2); x(:) = 1.0;")
>>> x = octave.pull('x')
>>> x
Cell([[1.0, 1.0],
[1.0, 1.0]])
>>> x[0]
Cell([1.0, 1.0])
>>> x[0].tolist()
[1.0, 1.0]
"""
def __new__(cls, value, session=None):
"""Create a cell array from a value and optional Octave session."""
value = np.asarray(value, dtype=object)
# Squeeze the last element if it is 1
if (value.shape[value.ndim - 1] == 1):
value = value.squeeze(axis=value.ndim - 1)
value = np.atleast_1d(value)
if not session:
return value.view(cls)
# Extract the values.
obj = np.empty(value.size, dtype=object).view(cls)
for (i, item) in enumerate(value.ravel()):
obj[i] = _extract(item, session)
return obj.reshape(value.shape)
def __repr__(self):
shape = self.shape
if len(shape) == 1:
shape = (shape[0], 1)
msg = self.view(np.ndarray).__repr__()
msg = msg.replace('array', 'Cell', 1)
return msg.replace(', dtype=object', '', 1)
def _extract(data, session=None):
"""Convert the Octave values to values suitable for Python.
"""
# Extract each item of a list.
if isinstance(data, list):
return [_extract(d, session) for d in data]
# Ignore leaf objects.
if not isinstance(data, np.ndarray):
return data
# Extract user defined classes.
if isinstance(data, MatlabObject):
cls = session._get_user_class(data.classname)
return cls.from_value(data)
# Extract struct data.
if data.dtype.names:
# Singular struct
if data.size == 1:
return _create_struct(data, session)
# Struct array
return StructArray(data, session)
# Extract cells.
if data.dtype.kind == 'O':
return Cell(data, session)
# Compress singleton values.
if data.size == 1:
return data.item()
# Compress empty values.
if data.size == 0:
if data.dtype.kind in 'US':
return ''
return []
# Return standard array.
return data
def _create_struct(data, session):
"""Create a struct from session data.
"""
out = Struct()
for name in data.dtype.names:
item = data[name]
# Extract values that are cells (they are doubly wrapped).
if isinstance(item, np.ndarray) and item.dtype.kind == 'O':
item = item.squeeze().tolist()
out[name] = _extract(item, session)
return out
def _encode(data, convert_to_float):
"""Convert the Python values to values suitable to send to Octave.
"""
ctf = convert_to_float
# Handle variable pointer.
if isinstance(data, (OctaveVariablePtr)):
return _encode(data.value, ctf)
# Handle a user defined object.
if isinstance(data, OctaveUserClass):
return _encode(OctaveUserClass.to_value(data), ctf)
# Handle a function pointer.
if isinstance(data, (OctaveFunctionPtr, MatlabFunction)):
raise Oct2PyError('Cannot write Octave functions')
# Handle matlab objects.
if isinstance(data, MatlabObject):
view = data.view(np.ndarray)
out = MatlabObject(data, data.classname)
for name in out.dtype.names:
out[name] = _encode(view[name], ctf)
return out
# Extract and encode values from dict-like objects.
if isinstance(data, dict):
out = dict()
for (key, value) in data.items():
out[key] = _encode(value, ctf)
return out
# Send None as nan.
if data is None:
return np.NaN
# Sets are treated like lists.
if isinstance(data, set):
return _encode(list(data), ctf)
# Lists can be interpreted as numeric arrays or cell arrays.
if isinstance(data, list):
if _is_simple_numeric(data):
return _encode(np.array(data), ctf)
return _encode(tuple(data), ctf)
# Tuples are handled as cells.
if isinstance(data, tuple):
obj = np.empty(len(data), dtype=object)
for (i, item) in enumerate(data):
obj[i] = _encode(item, ctf)
return obj
# Sparse data must be floating type.
if isinstance(data, spmatrix):
return data.astype(np.float64)
# Return other data types unchanged.
if not isinstance(data, np.ndarray):
return data
# Extract and encode data from object-like arrays.
if data.dtype.kind in 'OV':
out = np.empty(data.size, dtype=data.dtype)
for (i, item) in enumerate(data.ravel()):
if data.dtype.names:
for name in data.dtype.names:
out[i][name] = _encode(item[name], ctf)
else:
out[i] = _encode(item, ctf)
return out.reshape(data.shape)
# Complex 128 is the highest supported by savemat.
if data.dtype.name == 'complex256':
return data.astype(np.complex128)
# Convert to float if applicable.
if ctf and data.dtype.kind in 'ui':
return data.astype(np.float64)
# Return standard array.
return data
def _is_simple_numeric(data):
"""Test if a list contains simple numeric data."""
for item in data:
if isinstance(item, set):
item = list(item)
if isinstance(item, list):
if not _is_simple_numeric(item):
return False
elif not isinstance(item, (int, float, complex)):
return False
return True