Skip to content

Commit

Permalink
Merge 6a6baef into c4bbbb7
Browse files Browse the repository at this point in the history
  • Loading branch information
rkern committed Sep 19, 2014
2 parents c4bbbb7 + 6a6baef commit d453a3f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES.txt
Expand Up @@ -22,6 +22,7 @@ Fixes
* Make the `alpha` attribute of scatter plots work as intended (PR#164).
* Resume running test for empty container that has been skipped (PR#190).
* Improved handling of missing or empty data in scatterplot (PR#210)
* Avoid a numpy crash when using a data source of a unicode array (PR#213).

Release 4.4.1
-------------
Expand Down
17 changes: 14 additions & 3 deletions chaco/array_data_source.py
@@ -1,7 +1,8 @@
""" Defines the ArrayDataSource class."""

# Major library imports
from numpy import array, isfinite, ones, nanargmin, nanargmax, ndarray
from numpy import array, isfinite, ones, ndarray
import numpy as np

# Enthought library imports
from traits.api import Any, Constant, Int, Tuple
Expand All @@ -19,7 +20,12 @@ def bounded_nanargmin(arr):
# Different versions of numpy behave differently in the all-NaN case, so we
# catch this condition in two different ways.
try:
min = nanargmin(arr)
if np.issubdtype(arr.dtype, np.floating):
min = np.nanargmin(arr)
elif np.issubdtype(arr.dtype, np.number):
min = np.argmin(arr)
else:
min = 0
except ValueError:
return 0
if isfinite(min):
Expand All @@ -33,7 +39,12 @@ def bounded_nanargmax(arr):
If all NaNs, return -1.
"""
try:
max = nanargmax(arr)
if np.issubdtype(arr.dtype, np.floating):
max = np.nanargmax(arr)
elif np.issubdtype(arr.dtype, np.number):
max = np.argmax(arr)
else:
max = -1
except ValueError:
return -1
if isfinite(max):
Expand Down
8 changes: 8 additions & 0 deletions chaco/tests/arraydatasource_test_case.py
Expand Up @@ -5,6 +5,8 @@
import unittest

from numpy import arange, array, allclose, empty, isnan, nan
import numpy as np

from chaco.api import ArrayDataSource, PointDataSource


Expand Down Expand Up @@ -51,6 +53,12 @@ def test_bounds_all_nans(self):
self.assertTrue(isnan(bounds[0]))
self.assertTrue(isnan(bounds[1]))

def test_bounds_non_numeric(self):
myarray = np.array([u'abc', u'foo', u'bar', u'def'], dtype=unicode)
sd = ArrayDataSource(myarray)
bounds = sd.get_bounds()
self.assertEqual(bounds, (u'abc', u'def'))


class PointDataTestCase(unittest.TestCase):
# Since PointData is mostly the same as ScalarData, the key things to
Expand Down

0 comments on commit d453a3f

Please sign in to comment.