From 3b65566b57d7575d8523cb3c2bd287e8baf7dfcc Mon Sep 17 00:00:00 2001 From: Robert Kern Date: Fri, 19 Sep 2014 10:07:21 +0100 Subject: [PATCH] BUG: Avoid a numpy crash seen when using a data source of a unicode array. --- CHANGES.txt | 1 + chaco/array_data_source.py | 17 ++++++++++++++--- chaco/tests/arraydatasource_test_case.py | 8 ++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 8bd8083af..ab3e66bcd 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -21,6 +21,7 @@ Fixes PR#101 (PR#175). * Make the `alpha` attribute of scatter plots work as intended (PR#164). * Resume running test for empty container that has been skipped (PR#190). + * Avoid a numpy crash when using a data source of a unicode array (PR#213). Release 4.4.1 ------------- diff --git a/chaco/array_data_source.py b/chaco/array_data_source.py index 1905bcfd9..372f437cf 100644 --- a/chaco/array_data_source.py +++ b/chaco/array_data_source.py @@ -1,7 +1,8 @@ """ Defines the ArrayDataSource class.""" # Major library imports -from numpy import array, isfinite, ones, nanargmin, nanargmax, ndarray +from numpy import array, isfinite, ones, ndarray +import numpy as np # Enthought library imports from traits.api import Any, Constant, Int, Tuple @@ -19,7 +20,12 @@ def bounded_nanargmin(arr): # Different versions of numpy behave differently in the all-NaN case, so we # catch this condition in two different ways. try: - min = nanargmin(arr) + if np.issubdtype(arr.dtype, np.floating): + min = np.nanargmin(arr) + elif np.issubdtype(arr.dtype, np.number): + min = np.argmin(arr) + else: + min = 0 except ValueError: return 0 if isfinite(min): @@ -33,7 +39,12 @@ def bounded_nanargmax(arr): If all NaNs, return -1. """ try: - max = nanargmax(arr) + if np.issubdtype(arr.dtype, np.floating): + max = np.nanargmax(arr) + elif np.issubdtype(arr.dtype, np.number): + max = np.argmax(arr) + else: + max = -1 except ValueError: return -1 if isfinite(max): diff --git a/chaco/tests/arraydatasource_test_case.py b/chaco/tests/arraydatasource_test_case.py index da7f14dae..1ca889637 100644 --- a/chaco/tests/arraydatasource_test_case.py +++ b/chaco/tests/arraydatasource_test_case.py @@ -5,6 +5,8 @@ import unittest from numpy import arange, array, allclose, empty, isnan, nan +import numpy as np + from chaco.api import ArrayDataSource, PointDataSource @@ -51,6 +53,12 @@ def test_bounds_all_nans(self): self.assertTrue(isnan(bounds[0])) self.assertTrue(isnan(bounds[1])) + def test_bounds_non_numeric(self): + myarray = np.array([u'abc', u'foo', u'bar', u'def'], dtype=unicode) + sd = ArrayDataSource(myarray) + bounds = sd.get_bounds() + self.assertEqual(bounds, (u'abc', u'def')) + class PointDataTestCase(unittest.TestCase): # Since PointData is mostly the same as ScalarData, the key things to