-
Notifications
You must be signed in to change notification settings - Fork 99
/
array_data_source.py
322 lines (265 loc) · 10.4 KB
/
array_data_source.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!
""" Defines the ArrayDataSource class."""
# Major library imports
from numpy import array, empty, isfinite, ones, ndarray
import numpy as np
# Enthought library imports
from traits.api import Any, Constant, Int, Tuple
# Chaco imports
from .base import NumericalSequenceTrait, reverse_map_1d, SortOrderTrait
from .abstract_data_source import AbstractDataSource
def bounded_nanargmin(arr):
"""Find the index of the minimum value, ignoring NaNs.
If all NaNs, return 0.
"""
# Different versions of numpy behave differently in the all-NaN case, so we
# catch this condition in two different ways.
try:
if np.issubdtype(arr.dtype, np.floating):
min = np.nanargmin(arr)
elif np.issubdtype(arr.dtype, np.number):
min = np.argmin(arr)
else:
min = 0
except ValueError:
return 0
if isfinite(min):
return min
else:
return 0
def bounded_nanargmax(arr):
"""Find the index of the maximum value, ignoring NaNs.
If all NaNs, return -1.
"""
try:
if np.issubdtype(arr.dtype, np.floating):
max = np.nanargmax(arr)
elif np.issubdtype(arr.dtype, np.number):
max = np.argmax(arr)
else:
max = -1
except ValueError:
return -1
if isfinite(max):
return max
else:
return -1
class ArrayDataSource(AbstractDataSource):
"""A data source representing a single, continuous array of numerical data.
This class does not listen to the array for value changes; if you need that
behavior, create a subclass that hooks up the appropriate listeners.
"""
# ------------------------------------------------------------------------
# AbstractDataSource traits
# ------------------------------------------------------------------------
#: The dimensionality of the indices into this data source (overrides
#: AbstractDataSource).
index_dimension = Constant("scalar")
#: The dimensionality of the value at each index point (overrides
#: AbstractDataSource).
value_dimension = Constant("scalar")
#: The sort order of the data.
#: This is a specialized optimization for 1-D arrays, but it's an important
#: one that's used everywhere.
sort_order = SortOrderTrait
# ------------------------------------------------------------------------
# Private traits
# ------------------------------------------------------------------------
# The data array itself.
_data = NumericalSequenceTrait
# Cached values of min and max as long as **_data** doesn't change.
_cached_bounds = Tuple
# Not necessary, since this is not a filter, but provided for convenience.
_cached_mask = Any
# The index of the (first) minimum value in self._data
# FIXME: This is an Any instead of an Int trait because of how Traits
# typechecks numpy.int64 on 64-bit Windows systems.
_min_index = Any
# The index of the (first) maximum value in self._data
# FIXME: This is an Any instead of an Int trait because of how Traits
# typechecks numpy.int64 on 64-bit Windows systems.
_max_index = Any
# ------------------------------------------------------------------------
# Public methods
# ------------------------------------------------------------------------
def __init__(self, data=array([]), sort_order="none", **kw):
AbstractDataSource.__init__(self, **kw)
self.set_data(data, sort_order)
def set_data(self, newdata, sort_order=None):
"""Sets the data, and optionally the sort order, for this data source.
Parameters
----------
newdata : array
The data to use.
sort_order : SortOrderTrait
The sort order of the data
"""
self._data = newdata
if sort_order is not None:
self.sort_order = sort_order
self._compute_bounds()
self.data_changed = True
def set_mask(self, mask):
"""Sets the mask for this data source."""
self._cached_mask = mask
self.data_changed = True
def remove_mask(self):
"""Removes the mask on this data source."""
self._cached_mask = None
self.data_changed = True
# ------------------------------------------------------------------------
# AbstractDataSource interface
# ------------------------------------------------------------------------
def get_data(self):
"""Returns the data for this data source, or 0.0 if it has no data.
Implements AbstractDataSource.
"""
if self._data is not None:
return self._data
else:
return empty(shape=(0,))
def get_data_mask(self):
"""get_data_mask() -> (data_array, mask_array)
Implements AbstractDataSource.
"""
if self._cached_mask is None:
if self._data is None:
return self._data, ones(0, dtype=bool)
else:
return self._data, ones(len(self._data), dtype=bool)
else:
return self._data, self._cached_mask
def is_masked(self):
"""is_masked() -> bool
Implements AbstractDataSource.
"""
if self._cached_mask is not None:
return True
else:
return False
def get_size(self):
"""get_size() -> int
Implements AbstractDataSource.
"""
if self._data is not None:
return len(self._data)
else:
return 0
def get_bounds(self):
"""Returns the minimum and maximum values of the data source's data.
Implements AbstractDataSource.
"""
if (
self._cached_bounds is None
or self._cached_bounds == ()
or self._cached_bounds == 0.0
):
self._compute_bounds()
return self._cached_bounds
def reverse_map(self, pt, index=0, outside_returns_none=True):
"""Returns the index of *pt* in the data source.
Parameters
----------
pt : scalar value
value to find
index
ignored for data series with 1-D indices
outside_returns_none : Boolean
Whether the method returns None if *pt* is outside the range of
the data source; if False, the method returns the value of the
bound that *pt* is outside of.
"""
if self.sort_order == "none":
raise NotImplementedError
# index is ignored for dataseries with 1-dimensional indices
minval, maxval = self._cached_bounds
if pt < minval:
if outside_returns_none:
return None
else:
return self._min_index
elif pt > maxval:
if outside_returns_none:
return None
else:
return self._max_index
else:
return reverse_map_1d(self._data, pt, self.sort_order)
# ------------------------------------------------------------------------
# Private methods
# ------------------------------------------------------------------------
def _compute_bounds(self, data=None):
"""Computes the minimum and maximum values of self._data.
If a data array is passed in, then that is used instead of self._data.
This behavior is useful for subclasses.
"""
# TODO: as an optimization, perhaps create and cache a sorted
# version of the dataset?
if data is None:
data = self.get_data()
data_len = len(data)
if data_len == 0:
self._min_index = 0
self._max_index = 0
self._cached_bounds = (0.0, 0.0)
elif data_len == 1:
self._min_index = 0
self._max_index = 0
self._cached_bounds = (data[0], data[0])
else:
if self.sort_order == "ascending":
self._min_index = 0
self._max_index = -1
elif self.sort_order == "descending":
self._min_index = -1
self._max_index = 0
else:
# ignore NaN values. This is probably a little slower,
# but also much safer.
# data might be an array of strings or objects that
# can't have argmin calculated on them.
try:
# the data may be in a subclass of numpy.array, viewing
# the data as a ndarray will remove side effects of
# the subclasses, such as different operator behaviors
self._min_index = bounded_nanargmin(data.view(ndarray))
self._max_index = bounded_nanargmax(data.view(ndarray))
except (TypeError, IndexError, NotImplementedError):
# For strings and objects, we punt... These show up in
# label-ish data sources.
self._cached_bounds = (0.0, 0.0)
self._cached_bounds = (
data[self._min_index],
data[self._max_index],
)
# ------------------------------------------------------------------------
# Event handlers
# ------------------------------------------------------------------------
def _metadata_changed(self, event):
self.metadata_changed = True
def _metadata_items_changed(self, event):
self.metadata_changed = True
# ------------------------------------------------------------------------
# Persistence-related methods
# ------------------------------------------------------------------------
def __getstate__(self):
state = super().__getstate__()
if not self.persist_data:
state.pop("_data", None)
state.pop("_cached_mask", None)
state.pop("_cached_bounds", None)
state.pop("_min_index", None)
state.pop("_max_index", None)
return state
def _post_load(self):
super()._post_load()
self._cached_bounds = ()
self._cached_mask = None