Skip to content

Commit

Permalink
Merge pull request #169 from enthought/fix-datasource
Browse files Browse the repository at this point in the history
Fix UnboundLocalError when a datasource of the wrong shape is looked for
  • Loading branch information
tonysyu committed Jun 10, 2014
2 parents 930d51b + 6028af2 commit 032f21a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 28 deletions.
51 changes: 23 additions & 28 deletions chaco/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def plot(self, data, type="line", name=None, index_scale="linear",
name = self._make_new_plot_name()
if origin is None:
origin = self.default_origin

if plot_type in ("line", "scatter", "polygon", "bar", "filled_line"):
# Tie data to the index range
if len(data) == 1:
Expand Down Expand Up @@ -378,44 +378,44 @@ def plot(self, data, type="line", name=None, index_scale="linear",
orientation=self.orientation,
origin = origin,
**styles)

self.add(plot)
new_plots.append(plot)

if plot_type == 'bar':
# For bar plots, compute the ranges from the data to make the
# plot look clean.
# For bar plots, compute the ranges from the data to make the
# plot look clean.

def custom_index_func(data_low, data_high, margin, tight_bounds):
""" Compute custom bounds of the plot along index (in
""" Compute custom bounds of the plot along index (in
data space).
"""
bar_width = styles.get('bar_width', cls().bar_width)
plot_low = data_low - bar_width
plot_high = data_high + bar_width
return plot_low, plot_high

if self.index_range.bounds_func is None:
self.index_range.bounds_func = custom_index_func

def custom_value_func(data_low, data_high, margin, tight_bounds):
""" Compute custom bounds of the plot along value (in
""" Compute custom bounds of the plot along value (in
data space).
"""
plot_low = data_low - (data_high-data_low)*0.1
plot_high = data_high + (data_high-data_low)*0.1
return plot_low, plot_high
if self.value_range.bounds_func is None:

if self.value_range.bounds_func is None:
self.value_range.bounds_func = custom_value_func

self.index_range.tight_bounds = False
self.value_range.tight_bounds = False
self.index_range.refresh()
self.value_range.refresh()

self.plots[name] = new_plots

elif plot_type == "cmap_scatter":
if len(data) != 3:
raise ValueError("Colormapped scatter plots require (index, value, color) data")
Expand Down Expand Up @@ -908,7 +908,7 @@ def quiverplot(self, data, name=None, origin=None,
)
self.add(plot)
self.plots[name] = [plot]
return [plot]
return [plot]

def delplot(self, *names):
""" Removes the named sub-plots. """
Expand Down Expand Up @@ -1008,19 +1008,16 @@ def _get_or_create_datasource(self, name):
ds = ArrayDataSource(data, sort_order="none")
elif len(data.shape) == 2:
ds = ImageData(data=data, value_depth=1)
elif len(data.shape) == 3:
if data.shape[2] in (3,4):
ds = ImageData(data=data, value_depth=int(data.shape[2]))
else:
raise ValueError("Unhandled array shape in creating new plot: " \
+ str(data.shape))

elif len(data.shape) == 3 and data.shape[2] in (3,4):
ds = ImageData(data=data, value_depth=int(data.shape[2]))
else:
raise ValueError("Unhandled array shape in creating new "
"plot: %s" % str(data.shape))
elif isinstance(data, AbstractDataSource):
ds = data

else:
raise ValueError("Couldn't create datasource for data of type " + \
str(type(data)))
raise ValueError("Couldn't create datasource for data of "
"type %s" % type(data))

self.datasources[name] = ds

Expand Down Expand Up @@ -1059,7 +1056,7 @@ def _data_update_handler(self, name, event):
if name in self.datasources:
source = self.datasources[name]
source.set_data(self.data.get_data(name))

def _plots_items_changed(self, event):
if self.legend:
self.legend.plots = self.plots
Expand Down Expand Up @@ -1185,5 +1182,3 @@ def _set_title_font(self, font):

def _get_title_font(self):
return self._title.font


22 changes: 22 additions & 0 deletions chaco/tests/plot_test_case.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest

from numpy import arange

# Chaco imports
from chaco.api import ArrayPlotData, Plot


class PlotTestCase(unittest.TestCase):

def test_plot_from_unsupported_array_shape(self):
arr = arange(8).reshape(2, 2, 2)
data = ArrayPlotData(x=arr, y=arr)
plot = Plot(data)
self.assertRaises(ValueError, plot.plot, ("x", "y"))

arr = arange(16).reshape(2, 2, 2, 2)
data.update_data(x=arr, y=arr)
self.assertRaises(ValueError, plot.plot, ("x", "y"))

if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions docs/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ Enhancements
* AbstractPlotData and subclasses now implement an update_data() method which
updates the data whilst firing only one update event.

Fixes
-----

* Fixed UnboundLocalError in the Plot class when attempting to create a data
source from an array of unsupported shape.

Chaco 4.1.0
===========
Expand Down

0 comments on commit 032f21a

Please sign in to comment.