From 37ac89d664ad434e5c9a2344bc06a920c20c871b Mon Sep 17 00:00:00 2001
From: Philipp Rudiger
Date: Thu, 16 Jun 2016 13:34:51 +0200
Subject: [PATCH] Unified how plots declare plot methods across bokeh and
matplotlib
---
holoviews/plotting/bokeh/annotation.py | 4 ++--
holoviews/plotting/bokeh/chart.py | 13 +++++------
holoviews/plotting/bokeh/element.py | 11 +++++-----
holoviews/plotting/bokeh/path.py | 6 ++----
holoviews/plotting/bokeh/raster.py | 8 +++----
holoviews/plotting/mpl/chart.py | 30 +++++++++++---------------
holoviews/plotting/mpl/chart3d.py | 12 ++++-------
holoviews/plotting/mpl/element.py | 12 +++++++++++
holoviews/plotting/mpl/raster.py | 9 +++-----
holoviews/plotting/plot.py | 7 ++----
10 files changed, 52 insertions(+), 60 deletions(-)
diff --git a/holoviews/plotting/bokeh/annotation.py b/holoviews/plotting/bokeh/annotation.py
index 5e5cee9780..9bdeee52a4 100644
--- a/holoviews/plotting/bokeh/annotation.py
+++ b/holoviews/plotting/bokeh/annotation.py
@@ -8,7 +8,7 @@
class TextPlot(ElementPlot):
style_opts = text_properties
- _plot_method = 'text'
+ _plot_methods = dict(single='text')
def get_data(self, element, ranges=None, empty=False):
mapping = dict(x='x', y='y', text='text')
@@ -62,7 +62,7 @@ class SplinePlot(ElementPlot):
"""
style_opts = line_properties
- _plot_method = 'bezier'
+ _plot_methods = dict(single='bezier')
def get_data(self, element, ranges=None, empty=False):
data_attrs = ['x0', 'y0', 'x1', 'y1',
diff --git a/holoviews/plotting/bokeh/chart.py b/holoviews/plotting/bokeh/chart.py
index a300b7fadf..8d3b95d969 100644
--- a/holoviews/plotting/bokeh/chart.py
+++ b/holoviews/plotting/bokeh/chart.py
@@ -46,9 +46,7 @@ class PointPlot(ElementPlot):
'unselected_color'] +
line_properties + fill_properties)
- _plot_method = 'scatter'
- _batched_plot_method = 'scatter'
- _batched = True
+ _plot_methods = dict(single='scatter', batched='scatter')
def get_data(self, element, ranges=None, empty=False):
style = self.style[self.cyclic_index]
@@ -124,17 +122,16 @@ def _init_glyph(self, plot, mapping, properties):
renderer = plot.add_glyph(source, selected, selection_glyph=selected,
nonselection_glyph=unselected)
else:
- renderer = getattr(plot, self._plot_method)(**dict(properties, **mapping))
+ plot_method = self._plot_methods.get('batched' if self.batched else 'single')
+ renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph
class CurvePlot(ElementPlot):
style_opts = ['color'] + line_properties
- _plot_method = 'line'
- _batched_plot_method = 'multi_line'
+ _plot_methods = dict(single='line', batched='multi_line')
_mapping = {p: p for p in ['xs', 'ys', 'color', 'line_alpha']}
- _batched = True
def get_data(self, element, ranges=None, empty=False):
x = element.get_dimension(0).name
@@ -214,7 +211,7 @@ def get_data(self, element, ranges=None, empty=None):
class HistogramPlot(ElementPlot):
style_opts = ['color'] + line_properties + fill_properties
- _plot_method = 'quad'
+ _plot_methods = dict(single='quad')
def get_data(self, element, ranges=None, empty=None):
mapping = dict(top='top', bottom=0, left='left', right='right')
diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py
index 5edeee0f02..8a25774daf 100644
--- a/holoviews/plotting/bokeh/element.py
+++ b/holoviews/plotting/bokeh/element.py
@@ -137,10 +137,11 @@ class ElementPlot(BokehPlot, GenericElementPlot):
tick locations or bokeh Ticker object. If set to None
default bokeh ticking behavior is applied.""")
- # A string corresponding to the glyph being drawn by the
- # ElementPlot
- _plot_method = None
- _batched = False
+ # A dictionary mapping of the plot methods used to draw the
+ # glyphs corresponding to the ElementPlot, can support two
+ # keyword arguments a 'single' implementation to draw an individual
+ # plot and a 'batched' method to draw multiple Elements at once
+ _plot_methods = {}
# The plot objects to be updated on each frame
# Any entries should be existing keys in the handles
@@ -398,7 +399,7 @@ def _init_glyph(self, plot, mapping, properties):
Returns a Bokeh glyph object.
"""
properties = mpl_to_bokeh(properties)
- plot_method = self._batched_plot_method if self.batched else self._plot_method
+ plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph
diff --git a/holoviews/plotting/bokeh/path.py b/holoviews/plotting/bokeh/path.py
index 6566733e14..219b810f09 100644
--- a/holoviews/plotting/bokeh/path.py
+++ b/holoviews/plotting/bokeh/path.py
@@ -15,7 +15,7 @@ class PathPlot(ElementPlot):
Whether to show legend for the plot.""")
style_opts = ['color'] + line_properties
- _plot_method = 'multi_line'
+ _plot_methods = dict(single='multi_line')
_mapping = dict(xs='xs', ys='ys')
def get_data(self, element, ranges=None, empty=False):
@@ -27,9 +27,7 @@ def get_data(self, element, ranges=None, empty=False):
class PolygonPlot(PathPlot):
style_opts = ['color', 'cmap', 'palette'] + line_properties + fill_properties
- _plot_method = 'patches'
- _batched_plot_method = 'patches'
- _batched = True
+ _plot_methods = dict(single='patches', batched='patches')
def get_data(self, element, ranges=None, empty=False):
xs = [] if empty else [path[:, 0] for path in element.data]
diff --git a/holoviews/plotting/bokeh/raster.py b/holoviews/plotting/bokeh/raster.py
index e3bfac5317..71bbcfc94a 100644
--- a/holoviews/plotting/bokeh/raster.py
+++ b/holoviews/plotting/bokeh/raster.py
@@ -17,7 +17,7 @@ class RasterPlot(ElementPlot):
Whether to show legend for the plot.""")
style_opts = ['cmap']
- _plot_method = 'image'
+ _plot_methods = dict(single='image')
_update_handles = ['color_mapper', 'source', 'glyph']
def __init__(self, *args, **kwargs):
@@ -74,7 +74,7 @@ def _update_glyph(self, glyph, properties, mapping):
class RGBPlot(RasterPlot):
style_opts = []
- _plot_method = 'image_rgba'
+ _plot_methods = dict(single='image_rgba')
def get_data(self, element, ranges=None, empty=False):
data, mapping = super(RGBPlot, self).get_data(element, ranges, empty)
@@ -113,7 +113,7 @@ class HeatmapPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")
- _plot_method = 'rect'
+ _plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties
def _axes_props(self, plots, subplots, element, ranges):
@@ -148,7 +148,7 @@ class QuadMeshPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")
- _plot_method = 'rect'
+ _plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties
def get_data(self, element, ranges=None, empty=False):
diff --git a/holoviews/plotting/mpl/chart.py b/holoviews/plotting/mpl/chart.py
index 28fc7e30cc..3ee4d14119 100644
--- a/holoviews/plotting/mpl/chart.py
+++ b/holoviews/plotting/mpl/chart.py
@@ -54,9 +54,7 @@ class CurvePlot(ChartPlot):
style_opts = ['alpha', 'color', 'visible', 'linewidth', 'linestyle', 'marker']
- def init_artists(self, ax, plot_data, plot_kwargs):
- return {'artist': ax.plot(*plot_data, **plot_kwargs)[0]}
-
+ _plot_methods = dict(single='plot')
def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
@@ -87,6 +85,8 @@ class ErrorPlot(ChartPlot):
'markerfacecolor', 'markersize', 'solid_capstyle',
'solid_joinstyle', 'dashes', 'color']
+ _plot_methods = dict(single='errorbar')
+
def init_artists(self, ax, plot_data, plot_kwargs):
_, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs)
return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]}
@@ -143,6 +143,8 @@ class AreaPlot(ChartPlot):
'hatch', 'linestyle', 'joinstyle',
'fill', 'capstyle', 'interpolate']
+ _plot_methods = dict(single='fill_between')
+
def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
ys = [element.dimension_values(vdim) for vdim in element.vdims]
@@ -455,10 +457,7 @@ class PointPlot(ChartPlot, ColorbarPlot):
'cmap', 'vmin', 'vmax']
_disabled_opts = ['size']
-
- def init_artists(self, ax, plot_args, plot_kwargs):
- return {'artist': ax.scatter(*plot_args, **plot_kwargs)}
-
+ _plot_methods = dict(single='scatter')
def get_data(self, element, ranges, style):
xs, ys = (element.dimension_values(i) for i in range(2))
@@ -546,6 +545,8 @@ class VectorFieldPlot(ColorbarPlot):
'scale', 'headlength', 'headaxislength', 'pivot',
'width','headwidth']
+ _plot_methods = dict(single='quiver')
+
def __init__(self, *args, **params):
super(VectorFieldPlot, self).__init__(*args, **params)
self._min_dist = self._get_map_info(self.hmap)
@@ -598,16 +599,12 @@ def get_data(self, element, ranges, style):
if 'pivot' not in style: style['pivot'] = 'mid'
if not self.arrow_heads:
style['headaxislength'] = 0
- style.update(dict(scale=input_scale, angles=angles))
+ style.update(dict(scale=scale, angles=angles,
+ units='x', scale_units='x'))
return args, style, {}
- def init_artists(self, ax, plot_args, plot_kwargs):
- quiver = ax.quiver(*plot_args, units='x', scale_units='x', **plot_kwargs)
- return {'artist': quiver}
-
-
def update_handles(self, key, axis, element, ranges, style):
args, style, axis_kwargs = self.get_data(element, ranges, style)
@@ -960,6 +957,8 @@ class BoxPlot(ChartPlot):
'whiskerprops', 'capprops', 'flierprops',
'medianprops', 'meanprops', 'meanline']
+ _plot_methods = dict(single='boxplot')
+
def get_extents(self, element, ranges):
return (np.NaN,)*4
@@ -987,11 +986,6 @@ def get_data(self, element, ranges, style):
element.vdims[0]]}
- def init_artists(self, ax, plot_args, plot_kwargs):
- boxplot = ax.boxplot(*plot_args, **plot_kwargs)
- return {'artist': boxplot}
-
-
def teardown_handles(self):
for group in self.handles['artist'].values():
for v in group:
diff --git a/holoviews/plotting/mpl/chart3d.py b/holoviews/plotting/mpl/chart3d.py
index 3c0623545e..342c78be23 100644
--- a/holoviews/plotting/mpl/chart3d.py
+++ b/holoviews/plotting/mpl/chart3d.py
@@ -115,6 +115,8 @@ class Scatter3DPlot(Plot3D, PointPlot):
allow_None=True, doc="""
Index of the dimension from which the sizes will the drawn.""")
+ _plot_methods = dict(single='scatter')
+
def get_data(self, element, ranges, style):
xs, ys, zs = (element.dimension_values(i) for i in range(3))
self._compute_styles(element, ranges, style)
@@ -127,11 +129,6 @@ def get_data(self, element, ranges, style):
style['facecolors'] = color
return (xs, ys, zs), style, {}
- def init_artists(self, ax, plot_data, plot_kwargs):
- scatterplot = ax.scatter(*plot_data, **plot_kwargs)
- ax.add_collection(scatterplot)
- return {'artist': scatterplot}
-
def update_handles(self, key, axis, element, ranges, style):
artist = self.handles['artist']
artist._offsets3d, style, _ = self.get_data(element, ranges, style)
@@ -195,11 +192,10 @@ class TrisurfacePlot(Plot3D):
style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor']
+ _plot_methods = dict(single='plot_trisurf')
+
def get_data(self, element, ranges, style):
dims = element.dimensions()
self._norm_kwargs(element, ranges, style, dims[2])
x, y, z = [element.dimension_values(d) for d in dims]
return (x, y, z), style, {}
-
- def init_artists(self, ax, plot_data, plot_kwargs):
- return {'artist': ax.plot_trisurf(*plot_data, **plot_kwargs)}
diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py
index 7e13829a06..a481e17dd3 100644
--- a/holoviews/plotting/mpl/element.py
+++ b/holoviews/plotting/mpl/element.py
@@ -468,6 +468,18 @@ def initialize_plot(self, ranges=None):
return self._finalize_axis(self.keys[-1], ranges=ranges, **axis_kwargs)
+ def init_artists(self, ax, plot_args, plot_kwargs):
+ """
+ Initializes the artist based on the plot method declared on
+ the plot.
+ """
+ plot_method = self._plot_methods.get('batched' if self.batched else 'single')
+ plot_fn = getattr(ax, plot_method)
+ artist = plot_fn(*plot_args, **plot_kwargs)
+ return {'artist': artist[0] if isinstance(artist, list) and
+ len(artist) == 1 else artist}
+
+
def update_handles(self, key, axis, element, ranges, style):
"""
Update the elements of the plot.
diff --git a/holoviews/plotting/mpl/raster.py b/holoviews/plotting/mpl/raster.py
index 2ece3309de..2e59a54d4d 100644
--- a/holoviews/plotting/mpl/raster.py
+++ b/holoviews/plotting/mpl/raster.py
@@ -30,6 +30,7 @@ class RasterPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'interpolation', 'visible',
'filterrad', 'clims', 'norm']
+ _plot_methods = dict(single='imshow')
def __init__(self, *args, **kwargs):
super(RasterPlot, self).__init__(*args, **kwargs)
@@ -74,12 +75,6 @@ def get_data(self, element, ranges, style):
return [data], style, {'xticks': xticks, 'yticks': yticks}
-
- def init_artists(self, ax, plot_args, plot_kwargs):
- im = ax.imshow(*plot_args, **plot_kwargs)
- return {'artist': im}
-
-
def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
@@ -192,6 +187,8 @@ class QuadMeshPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'clim', 'edgecolors', 'norm', 'shading',
'linestyles', 'linewidths', 'hatch', 'visible']
+ _plot_methods = dict(single='pcolormesh')
+
def get_data(self, element, ranges, style):
data = np.ma.array(element.data[2],
mask=np.logical_not(np.isfinite(element.data[2])))
diff --git a/holoviews/plotting/plot.py b/holoviews/plotting/plot.py
index 4a91f21c2b..a032ade413 100644
--- a/holoviews/plotting/plot.py
+++ b/holoviews/plotting/plot.py
@@ -483,9 +483,6 @@ class GenericElementPlot(DimensionedPlot):
apply_extents = param.Boolean(default=True, doc="""
Whether to apply extent overrides on the Elements""")
- # Whether the plotting class supports batched plotting
- _batched = False
-
def __init__(self, element, keys=None, ranges=None, dimensions=None,
batched=False, overlaid=0, cyclic_index=0, zorder=0, style=None,
overlay_dims={}, **params):
@@ -509,7 +506,7 @@ def __init__(self, element, keys=None, ranges=None, dimensions=None,
super(GenericElementPlot, self).__init__(keys=keys, dimensions=dimensions,
dynamic=dynamic,
**dict(params, **plot_opts))
- if self.batched and self._batched:
+ if self.batched:
self.ordering = util.layer_sort(self.hmap)
self.style = self.lookup_options(self.hmap.last.last, 'style').max_cycles(len(self.ordering))
else:
@@ -723,7 +720,7 @@ def _create_subplots(self, ranges):
batched = self.batched and type(self.hmap.last) is NdOverlay
if batched:
batchedplot = registry.get(type(self.hmap.last.last))
- if (batched and batchedplot and batchedplot._batched and
+ if (batched and batchedplot and 'batched' in batchedplot._plot_methods and
(not self.show_legend or len(ordering) > self.legend_limit)):
self.batched = True
keys, vmaps = [()], [self.hmap]