From 37ac89d664ad434e5c9a2344bc06a920c20c871b Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Thu, 16 Jun 2016 13:34:51 +0200 Subject: [PATCH] Unified how plots declare plot methods across bokeh and matplotlib --- holoviews/plotting/bokeh/annotation.py | 4 ++-- holoviews/plotting/bokeh/chart.py | 13 +++++------ holoviews/plotting/bokeh/element.py | 11 +++++----- holoviews/plotting/bokeh/path.py | 6 ++---- holoviews/plotting/bokeh/raster.py | 8 +++---- holoviews/plotting/mpl/chart.py | 30 +++++++++++--------------- holoviews/plotting/mpl/chart3d.py | 12 ++++------- holoviews/plotting/mpl/element.py | 12 +++++++++++ holoviews/plotting/mpl/raster.py | 9 +++----- holoviews/plotting/plot.py | 7 ++---- 10 files changed, 52 insertions(+), 60 deletions(-) diff --git a/holoviews/plotting/bokeh/annotation.py b/holoviews/plotting/bokeh/annotation.py index 5e5cee9780..9bdeee52a4 100644 --- a/holoviews/plotting/bokeh/annotation.py +++ b/holoviews/plotting/bokeh/annotation.py @@ -8,7 +8,7 @@ class TextPlot(ElementPlot): style_opts = text_properties - _plot_method = 'text' + _plot_methods = dict(single='text') def get_data(self, element, ranges=None, empty=False): mapping = dict(x='x', y='y', text='text') @@ -62,7 +62,7 @@ class SplinePlot(ElementPlot): """ style_opts = line_properties - _plot_method = 'bezier' + _plot_methods = dict(single='bezier') def get_data(self, element, ranges=None, empty=False): data_attrs = ['x0', 'y0', 'x1', 'y1', diff --git a/holoviews/plotting/bokeh/chart.py b/holoviews/plotting/bokeh/chart.py index a300b7fadf..8d3b95d969 100644 --- a/holoviews/plotting/bokeh/chart.py +++ b/holoviews/plotting/bokeh/chart.py @@ -46,9 +46,7 @@ class PointPlot(ElementPlot): 'unselected_color'] + line_properties + fill_properties) - _plot_method = 'scatter' - _batched_plot_method = 'scatter' - _batched = True + _plot_methods = dict(single='scatter', batched='scatter') def get_data(self, element, ranges=None, empty=False): style = self.style[self.cyclic_index] @@ -124,17 +122,16 @@ def _init_glyph(self, plot, mapping, properties): renderer = plot.add_glyph(source, selected, selection_glyph=selected, nonselection_glyph=unselected) else: - renderer = getattr(plot, self._plot_method)(**dict(properties, **mapping)) + plot_method = self._plot_methods.get('batched' if self.batched else 'single') + renderer = getattr(plot, plot_method)(**dict(properties, **mapping)) return renderer, renderer.glyph class CurvePlot(ElementPlot): style_opts = ['color'] + line_properties - _plot_method = 'line' - _batched_plot_method = 'multi_line' + _plot_methods = dict(single='line', batched='multi_line') _mapping = {p: p for p in ['xs', 'ys', 'color', 'line_alpha']} - _batched = True def get_data(self, element, ranges=None, empty=False): x = element.get_dimension(0).name @@ -214,7 +211,7 @@ def get_data(self, element, ranges=None, empty=None): class HistogramPlot(ElementPlot): style_opts = ['color'] + line_properties + fill_properties - _plot_method = 'quad' + _plot_methods = dict(single='quad') def get_data(self, element, ranges=None, empty=None): mapping = dict(top='top', bottom=0, left='left', right='right') diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index 5edeee0f02..8a25774daf 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -137,10 +137,11 @@ class ElementPlot(BokehPlot, GenericElementPlot): tick locations or bokeh Ticker object. If set to None default bokeh ticking behavior is applied.""") - # A string corresponding to the glyph being drawn by the - # ElementPlot - _plot_method = None - _batched = False + # A dictionary mapping of the plot methods used to draw the + # glyphs corresponding to the ElementPlot, can support two + # keyword arguments a 'single' implementation to draw an individual + # plot and a 'batched' method to draw multiple Elements at once + _plot_methods = {} # The plot objects to be updated on each frame # Any entries should be existing keys in the handles @@ -398,7 +399,7 @@ def _init_glyph(self, plot, mapping, properties): Returns a Bokeh glyph object. """ properties = mpl_to_bokeh(properties) - plot_method = self._batched_plot_method if self.batched else self._plot_method + plot_method = self._plot_methods.get('batched' if self.batched else 'single') renderer = getattr(plot, plot_method)(**dict(properties, **mapping)) return renderer, renderer.glyph diff --git a/holoviews/plotting/bokeh/path.py b/holoviews/plotting/bokeh/path.py index 6566733e14..219b810f09 100644 --- a/holoviews/plotting/bokeh/path.py +++ b/holoviews/plotting/bokeh/path.py @@ -15,7 +15,7 @@ class PathPlot(ElementPlot): Whether to show legend for the plot.""") style_opts = ['color'] + line_properties - _plot_method = 'multi_line' + _plot_methods = dict(single='multi_line') _mapping = dict(xs='xs', ys='ys') def get_data(self, element, ranges=None, empty=False): @@ -27,9 +27,7 @@ def get_data(self, element, ranges=None, empty=False): class PolygonPlot(PathPlot): style_opts = ['color', 'cmap', 'palette'] + line_properties + fill_properties - _plot_method = 'patches' - _batched_plot_method = 'patches' - _batched = True + _plot_methods = dict(single='patches', batched='patches') def get_data(self, element, ranges=None, empty=False): xs = [] if empty else [path[:, 0] for path in element.data] diff --git a/holoviews/plotting/bokeh/raster.py b/holoviews/plotting/bokeh/raster.py index e3bfac5317..71bbcfc94a 100644 --- a/holoviews/plotting/bokeh/raster.py +++ b/holoviews/plotting/bokeh/raster.py @@ -17,7 +17,7 @@ class RasterPlot(ElementPlot): Whether to show legend for the plot.""") style_opts = ['cmap'] - _plot_method = 'image' + _plot_methods = dict(single='image') _update_handles = ['color_mapper', 'source', 'glyph'] def __init__(self, *args, **kwargs): @@ -74,7 +74,7 @@ def _update_glyph(self, glyph, properties, mapping): class RGBPlot(RasterPlot): style_opts = [] - _plot_method = 'image_rgba' + _plot_methods = dict(single='image_rgba') def get_data(self, element, ranges=None, empty=False): data, mapping = super(RGBPlot, self).get_data(element, ranges, empty) @@ -113,7 +113,7 @@ class HeatmapPlot(ElementPlot): show_legend = param.Boolean(default=False, doc=""" Whether to show legend for the plot.""") - _plot_method = 'rect' + _plot_methods = dict(single='rect') style_opts = ['cmap', 'color'] + line_properties + fill_properties def _axes_props(self, plots, subplots, element, ranges): @@ -148,7 +148,7 @@ class QuadMeshPlot(ElementPlot): show_legend = param.Boolean(default=False, doc=""" Whether to show legend for the plot.""") - _plot_method = 'rect' + _plot_methods = dict(single='rect') style_opts = ['cmap', 'color'] + line_properties + fill_properties def get_data(self, element, ranges=None, empty=False): diff --git a/holoviews/plotting/mpl/chart.py b/holoviews/plotting/mpl/chart.py index 28fc7e30cc..3ee4d14119 100644 --- a/holoviews/plotting/mpl/chart.py +++ b/holoviews/plotting/mpl/chart.py @@ -54,9 +54,7 @@ class CurvePlot(ChartPlot): style_opts = ['alpha', 'color', 'visible', 'linewidth', 'linestyle', 'marker'] - def init_artists(self, ax, plot_data, plot_kwargs): - return {'artist': ax.plot(*plot_data, **plot_kwargs)[0]} - + _plot_methods = dict(single='plot') def get_data(self, element, ranges, style): xs = element.dimension_values(0) @@ -87,6 +85,8 @@ class ErrorPlot(ChartPlot): 'markerfacecolor', 'markersize', 'solid_capstyle', 'solid_joinstyle', 'dashes', 'color'] + _plot_methods = dict(single='errorbar') + def init_artists(self, ax, plot_data, plot_kwargs): _, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs) return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]} @@ -143,6 +143,8 @@ class AreaPlot(ChartPlot): 'hatch', 'linestyle', 'joinstyle', 'fill', 'capstyle', 'interpolate'] + _plot_methods = dict(single='fill_between') + def get_data(self, element, ranges, style): xs = element.dimension_values(0) ys = [element.dimension_values(vdim) for vdim in element.vdims] @@ -455,10 +457,7 @@ class PointPlot(ChartPlot, ColorbarPlot): 'cmap', 'vmin', 'vmax'] _disabled_opts = ['size'] - - def init_artists(self, ax, plot_args, plot_kwargs): - return {'artist': ax.scatter(*plot_args, **plot_kwargs)} - + _plot_methods = dict(single='scatter') def get_data(self, element, ranges, style): xs, ys = (element.dimension_values(i) for i in range(2)) @@ -546,6 +545,8 @@ class VectorFieldPlot(ColorbarPlot): 'scale', 'headlength', 'headaxislength', 'pivot', 'width','headwidth'] + _plot_methods = dict(single='quiver') + def __init__(self, *args, **params): super(VectorFieldPlot, self).__init__(*args, **params) self._min_dist = self._get_map_info(self.hmap) @@ -598,16 +599,12 @@ def get_data(self, element, ranges, style): if 'pivot' not in style: style['pivot'] = 'mid' if not self.arrow_heads: style['headaxislength'] = 0 - style.update(dict(scale=input_scale, angles=angles)) + style.update(dict(scale=scale, angles=angles, + units='x', scale_units='x')) return args, style, {} - def init_artists(self, ax, plot_args, plot_kwargs): - quiver = ax.quiver(*plot_args, units='x', scale_units='x', **plot_kwargs) - return {'artist': quiver} - - def update_handles(self, key, axis, element, ranges, style): args, style, axis_kwargs = self.get_data(element, ranges, style) @@ -960,6 +957,8 @@ class BoxPlot(ChartPlot): 'whiskerprops', 'capprops', 'flierprops', 'medianprops', 'meanprops', 'meanline'] + _plot_methods = dict(single='boxplot') + def get_extents(self, element, ranges): return (np.NaN,)*4 @@ -987,11 +986,6 @@ def get_data(self, element, ranges, style): element.vdims[0]]} - def init_artists(self, ax, plot_args, plot_kwargs): - boxplot = ax.boxplot(*plot_args, **plot_kwargs) - return {'artist': boxplot} - - def teardown_handles(self): for group in self.handles['artist'].values(): for v in group: diff --git a/holoviews/plotting/mpl/chart3d.py b/holoviews/plotting/mpl/chart3d.py index 3c0623545e..342c78be23 100644 --- a/holoviews/plotting/mpl/chart3d.py +++ b/holoviews/plotting/mpl/chart3d.py @@ -115,6 +115,8 @@ class Scatter3DPlot(Plot3D, PointPlot): allow_None=True, doc=""" Index of the dimension from which the sizes will the drawn.""") + _plot_methods = dict(single='scatter') + def get_data(self, element, ranges, style): xs, ys, zs = (element.dimension_values(i) for i in range(3)) self._compute_styles(element, ranges, style) @@ -127,11 +129,6 @@ def get_data(self, element, ranges, style): style['facecolors'] = color return (xs, ys, zs), style, {} - def init_artists(self, ax, plot_data, plot_kwargs): - scatterplot = ax.scatter(*plot_data, **plot_kwargs) - ax.add_collection(scatterplot) - return {'artist': scatterplot} - def update_handles(self, key, axis, element, ranges, style): artist = self.handles['artist'] artist._offsets3d, style, _ = self.get_data(element, ranges, style) @@ -195,11 +192,10 @@ class TrisurfacePlot(Plot3D): style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor'] + _plot_methods = dict(single='plot_trisurf') + def get_data(self, element, ranges, style): dims = element.dimensions() self._norm_kwargs(element, ranges, style, dims[2]) x, y, z = [element.dimension_values(d) for d in dims] return (x, y, z), style, {} - - def init_artists(self, ax, plot_data, plot_kwargs): - return {'artist': ax.plot_trisurf(*plot_data, **plot_kwargs)} diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index 7e13829a06..a481e17dd3 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -468,6 +468,18 @@ def initialize_plot(self, ranges=None): return self._finalize_axis(self.keys[-1], ranges=ranges, **axis_kwargs) + def init_artists(self, ax, plot_args, plot_kwargs): + """ + Initializes the artist based on the plot method declared on + the plot. + """ + plot_method = self._plot_methods.get('batched' if self.batched else 'single') + plot_fn = getattr(ax, plot_method) + artist = plot_fn(*plot_args, **plot_kwargs) + return {'artist': artist[0] if isinstance(artist, list) and + len(artist) == 1 else artist} + + def update_handles(self, key, axis, element, ranges, style): """ Update the elements of the plot. diff --git a/holoviews/plotting/mpl/raster.py b/holoviews/plotting/mpl/raster.py index 2ece3309de..2e59a54d4d 100644 --- a/holoviews/plotting/mpl/raster.py +++ b/holoviews/plotting/mpl/raster.py @@ -30,6 +30,7 @@ class RasterPlot(ColorbarPlot): style_opts = ['alpha', 'cmap', 'interpolation', 'visible', 'filterrad', 'clims', 'norm'] + _plot_methods = dict(single='imshow') def __init__(self, *args, **kwargs): super(RasterPlot, self).__init__(*args, **kwargs) @@ -74,12 +75,6 @@ def get_data(self, element, ranges, style): return [data], style, {'xticks': xticks, 'yticks': yticks} - - def init_artists(self, ax, plot_args, plot_kwargs): - im = ax.imshow(*plot_args, **plot_kwargs) - return {'artist': im} - - def update_handles(self, key, axis, element, ranges, style): im = self.handles['artist'] data, style, axis_kwargs = self.get_data(element, ranges, style) @@ -192,6 +187,8 @@ class QuadMeshPlot(ColorbarPlot): style_opts = ['alpha', 'cmap', 'clim', 'edgecolors', 'norm', 'shading', 'linestyles', 'linewidths', 'hatch', 'visible'] + _plot_methods = dict(single='pcolormesh') + def get_data(self, element, ranges, style): data = np.ma.array(element.data[2], mask=np.logical_not(np.isfinite(element.data[2]))) diff --git a/holoviews/plotting/plot.py b/holoviews/plotting/plot.py index 4a91f21c2b..a032ade413 100644 --- a/holoviews/plotting/plot.py +++ b/holoviews/plotting/plot.py @@ -483,9 +483,6 @@ class GenericElementPlot(DimensionedPlot): apply_extents = param.Boolean(default=True, doc=""" Whether to apply extent overrides on the Elements""") - # Whether the plotting class supports batched plotting - _batched = False - def __init__(self, element, keys=None, ranges=None, dimensions=None, batched=False, overlaid=0, cyclic_index=0, zorder=0, style=None, overlay_dims={}, **params): @@ -509,7 +506,7 @@ def __init__(self, element, keys=None, ranges=None, dimensions=None, super(GenericElementPlot, self).__init__(keys=keys, dimensions=dimensions, dynamic=dynamic, **dict(params, **plot_opts)) - if self.batched and self._batched: + if self.batched: self.ordering = util.layer_sort(self.hmap) self.style = self.lookup_options(self.hmap.last.last, 'style').max_cycles(len(self.ordering)) else: @@ -723,7 +720,7 @@ def _create_subplots(self, ranges): batched = self.batched and type(self.hmap.last) is NdOverlay if batched: batchedplot = registry.get(type(self.hmap.last.last)) - if (batched and batchedplot and batchedplot._batched and + if (batched and batchedplot and 'batched' in batchedplot._plot_methods and (not self.show_legend or len(ordering) > self.legend_limit)): self.batched = True keys, vmaps = [()], [self.hmap]