Skip to content

Commit

Permalink
Unified how plots declare plot methods across bokeh and matplotlib
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jun 16, 2016
1 parent 2763871 commit 37ac89d
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 60 deletions.
4 changes: 2 additions & 2 deletions holoviews/plotting/bokeh/annotation.py
Expand Up @@ -8,7 +8,7 @@
class TextPlot(ElementPlot):

style_opts = text_properties
_plot_method = 'text'
_plot_methods = dict(single='text')

def get_data(self, element, ranges=None, empty=False):
mapping = dict(x='x', y='y', text='text')
Expand Down Expand Up @@ -62,7 +62,7 @@ class SplinePlot(ElementPlot):
"""

style_opts = line_properties
_plot_method = 'bezier'
_plot_methods = dict(single='bezier')

def get_data(self, element, ranges=None, empty=False):
data_attrs = ['x0', 'y0', 'x1', 'y1',
Expand Down
13 changes: 5 additions & 8 deletions holoviews/plotting/bokeh/chart.py
Expand Up @@ -46,9 +46,7 @@ class PointPlot(ElementPlot):
'unselected_color'] +
line_properties + fill_properties)

_plot_method = 'scatter'
_batched_plot_method = 'scatter'
_batched = True
_plot_methods = dict(single='scatter', batched='scatter')

def get_data(self, element, ranges=None, empty=False):
style = self.style[self.cyclic_index]
Expand Down Expand Up @@ -124,17 +122,16 @@ def _init_glyph(self, plot, mapping, properties):
renderer = plot.add_glyph(source, selected, selection_glyph=selected,
nonselection_glyph=unselected)
else:
renderer = getattr(plot, self._plot_method)(**dict(properties, **mapping))
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph


class CurvePlot(ElementPlot):

style_opts = ['color'] + line_properties
_plot_method = 'line'
_batched_plot_method = 'multi_line'
_plot_methods = dict(single='line', batched='multi_line')
_mapping = {p: p for p in ['xs', 'ys', 'color', 'line_alpha']}
_batched = True

def get_data(self, element, ranges=None, empty=False):
x = element.get_dimension(0).name
Expand Down Expand Up @@ -214,7 +211,7 @@ def get_data(self, element, ranges=None, empty=None):
class HistogramPlot(ElementPlot):

style_opts = ['color'] + line_properties + fill_properties
_plot_method = 'quad'
_plot_methods = dict(single='quad')

def get_data(self, element, ranges=None, empty=None):
mapping = dict(top='top', bottom=0, left='left', right='right')
Expand Down
11 changes: 6 additions & 5 deletions holoviews/plotting/bokeh/element.py
Expand Up @@ -137,10 +137,11 @@ class ElementPlot(BokehPlot, GenericElementPlot):
tick locations or bokeh Ticker object. If set to None
default bokeh ticking behavior is applied.""")

# A string corresponding to the glyph being drawn by the
# ElementPlot
_plot_method = None
_batched = False
# A dictionary mapping of the plot methods used to draw the
# glyphs corresponding to the ElementPlot, can support two
# keyword arguments a 'single' implementation to draw an individual
# plot and a 'batched' method to draw multiple Elements at once
_plot_methods = {}

# The plot objects to be updated on each frame
# Any entries should be existing keys in the handles
Expand Down Expand Up @@ -398,7 +399,7 @@ def _init_glyph(self, plot, mapping, properties):
Returns a Bokeh glyph object.
"""
properties = mpl_to_bokeh(properties)
plot_method = self._batched_plot_method if self.batched else self._plot_method
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
renderer = getattr(plot, plot_method)(**dict(properties, **mapping))
return renderer, renderer.glyph

Expand Down
6 changes: 2 additions & 4 deletions holoviews/plotting/bokeh/path.py
Expand Up @@ -15,7 +15,7 @@ class PathPlot(ElementPlot):
Whether to show legend for the plot.""")

style_opts = ['color'] + line_properties
_plot_method = 'multi_line'
_plot_methods = dict(single='multi_line')
_mapping = dict(xs='xs', ys='ys')

def get_data(self, element, ranges=None, empty=False):
Expand All @@ -27,9 +27,7 @@ def get_data(self, element, ranges=None, empty=False):
class PolygonPlot(PathPlot):

style_opts = ['color', 'cmap', 'palette'] + line_properties + fill_properties
_plot_method = 'patches'
_batched_plot_method = 'patches'
_batched = True
_plot_methods = dict(single='patches', batched='patches')

def get_data(self, element, ranges=None, empty=False):
xs = [] if empty else [path[:, 0] for path in element.data]
Expand Down
8 changes: 4 additions & 4 deletions holoviews/plotting/bokeh/raster.py
Expand Up @@ -17,7 +17,7 @@ class RasterPlot(ElementPlot):
Whether to show legend for the plot.""")

style_opts = ['cmap']
_plot_method = 'image'
_plot_methods = dict(single='image')
_update_handles = ['color_mapper', 'source', 'glyph']

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -74,7 +74,7 @@ def _update_glyph(self, glyph, properties, mapping):
class RGBPlot(RasterPlot):

style_opts = []
_plot_method = 'image_rgba'
_plot_methods = dict(single='image_rgba')

def get_data(self, element, ranges=None, empty=False):
data, mapping = super(RGBPlot, self).get_data(element, ranges, empty)
Expand Down Expand Up @@ -113,7 +113,7 @@ class HeatmapPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")

_plot_method = 'rect'
_plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties

def _axes_props(self, plots, subplots, element, ranges):
Expand Down Expand Up @@ -148,7 +148,7 @@ class QuadMeshPlot(ElementPlot):
show_legend = param.Boolean(default=False, doc="""
Whether to show legend for the plot.""")

_plot_method = 'rect'
_plot_methods = dict(single='rect')
style_opts = ['cmap', 'color'] + line_properties + fill_properties

def get_data(self, element, ranges=None, empty=False):
Expand Down
30 changes: 12 additions & 18 deletions holoviews/plotting/mpl/chart.py
Expand Up @@ -54,9 +54,7 @@ class CurvePlot(ChartPlot):

style_opts = ['alpha', 'color', 'visible', 'linewidth', 'linestyle', 'marker']

def init_artists(self, ax, plot_data, plot_kwargs):
return {'artist': ax.plot(*plot_data, **plot_kwargs)[0]}

_plot_methods = dict(single='plot')

def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
Expand Down Expand Up @@ -87,6 +85,8 @@ class ErrorPlot(ChartPlot):
'markerfacecolor', 'markersize', 'solid_capstyle',
'solid_joinstyle', 'dashes', 'color']

_plot_methods = dict(single='errorbar')

def init_artists(self, ax, plot_data, plot_kwargs):
_, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs)
return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]}
Expand Down Expand Up @@ -143,6 +143,8 @@ class AreaPlot(ChartPlot):
'hatch', 'linestyle', 'joinstyle',
'fill', 'capstyle', 'interpolate']

_plot_methods = dict(single='fill_between')

def get_data(self, element, ranges, style):
xs = element.dimension_values(0)
ys = [element.dimension_values(vdim) for vdim in element.vdims]
Expand Down Expand Up @@ -455,10 +457,7 @@ class PointPlot(ChartPlot, ColorbarPlot):
'cmap', 'vmin', 'vmax']

_disabled_opts = ['size']

def init_artists(self, ax, plot_args, plot_kwargs):
return {'artist': ax.scatter(*plot_args, **plot_kwargs)}

_plot_methods = dict(single='scatter')

def get_data(self, element, ranges, style):
xs, ys = (element.dimension_values(i) for i in range(2))
Expand Down Expand Up @@ -546,6 +545,8 @@ class VectorFieldPlot(ColorbarPlot):
'scale', 'headlength', 'headaxislength', 'pivot',
'width','headwidth']

_plot_methods = dict(single='quiver')

def __init__(self, *args, **params):
super(VectorFieldPlot, self).__init__(*args, **params)
self._min_dist = self._get_map_info(self.hmap)
Expand Down Expand Up @@ -598,16 +599,12 @@ def get_data(self, element, ranges, style):
if 'pivot' not in style: style['pivot'] = 'mid'
if not self.arrow_heads:
style['headaxislength'] = 0
style.update(dict(scale=input_scale, angles=angles))
style.update(dict(scale=scale, angles=angles,
units='x', scale_units='x'))

return args, style, {}


def init_artists(self, ax, plot_args, plot_kwargs):
quiver = ax.quiver(*plot_args, units='x', scale_units='x', **plot_kwargs)
return {'artist': quiver}


def update_handles(self, key, axis, element, ranges, style):
args, style, axis_kwargs = self.get_data(element, ranges, style)

Expand Down Expand Up @@ -960,6 +957,8 @@ class BoxPlot(ChartPlot):
'whiskerprops', 'capprops', 'flierprops',
'medianprops', 'meanprops', 'meanline']

_plot_methods = dict(single='boxplot')

def get_extents(self, element, ranges):
return (np.NaN,)*4

Expand Down Expand Up @@ -987,11 +986,6 @@ def get_data(self, element, ranges, style):
element.vdims[0]]}


def init_artists(self, ax, plot_args, plot_kwargs):
boxplot = ax.boxplot(*plot_args, **plot_kwargs)
return {'artist': boxplot}


def teardown_handles(self):
for group in self.handles['artist'].values():
for v in group:
Expand Down
12 changes: 4 additions & 8 deletions holoviews/plotting/mpl/chart3d.py
Expand Up @@ -115,6 +115,8 @@ class Scatter3DPlot(Plot3D, PointPlot):
allow_None=True, doc="""
Index of the dimension from which the sizes will the drawn.""")

_plot_methods = dict(single='scatter')

def get_data(self, element, ranges, style):
xs, ys, zs = (element.dimension_values(i) for i in range(3))
self._compute_styles(element, ranges, style)
Expand All @@ -127,11 +129,6 @@ def get_data(self, element, ranges, style):
style['facecolors'] = color
return (xs, ys, zs), style, {}

def init_artists(self, ax, plot_data, plot_kwargs):
scatterplot = ax.scatter(*plot_data, **plot_kwargs)
ax.add_collection(scatterplot)
return {'artist': scatterplot}

def update_handles(self, key, axis, element, ranges, style):
artist = self.handles['artist']
artist._offsets3d, style, _ = self.get_data(element, ranges, style)
Expand Down Expand Up @@ -195,11 +192,10 @@ class TrisurfacePlot(Plot3D):

style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor']

_plot_methods = dict(single='plot_trisurf')

def get_data(self, element, ranges, style):
dims = element.dimensions()
self._norm_kwargs(element, ranges, style, dims[2])
x, y, z = [element.dimension_values(d) for d in dims]
return (x, y, z), style, {}

def init_artists(self, ax, plot_data, plot_kwargs):
return {'artist': ax.plot_trisurf(*plot_data, **plot_kwargs)}
12 changes: 12 additions & 0 deletions holoviews/plotting/mpl/element.py
Expand Up @@ -468,6 +468,18 @@ def initialize_plot(self, ranges=None):
return self._finalize_axis(self.keys[-1], ranges=ranges, **axis_kwargs)


def init_artists(self, ax, plot_args, plot_kwargs):
"""
Initializes the artist based on the plot method declared on
the plot.
"""
plot_method = self._plot_methods.get('batched' if self.batched else 'single')
plot_fn = getattr(ax, plot_method)
artist = plot_fn(*plot_args, **plot_kwargs)
return {'artist': artist[0] if isinstance(artist, list) and
len(artist) == 1 else artist}


def update_handles(self, key, axis, element, ranges, style):
"""
Update the elements of the plot.
Expand Down
9 changes: 3 additions & 6 deletions holoviews/plotting/mpl/raster.py
Expand Up @@ -30,6 +30,7 @@ class RasterPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'interpolation', 'visible',
'filterrad', 'clims', 'norm']

_plot_methods = dict(single='imshow')

def __init__(self, *args, **kwargs):
super(RasterPlot, self).__init__(*args, **kwargs)
Expand Down Expand Up @@ -74,12 +75,6 @@ def get_data(self, element, ranges, style):

return [data], style, {'xticks': xticks, 'yticks': yticks}


def init_artists(self, ax, plot_args, plot_kwargs):
im = ax.imshow(*plot_args, **plot_kwargs)
return {'artist': im}


def update_handles(self, key, axis, element, ranges, style):
im = self.handles['artist']
data, style, axis_kwargs = self.get_data(element, ranges, style)
Expand Down Expand Up @@ -192,6 +187,8 @@ class QuadMeshPlot(ColorbarPlot):
style_opts = ['alpha', 'cmap', 'clim', 'edgecolors', 'norm', 'shading',
'linestyles', 'linewidths', 'hatch', 'visible']

_plot_methods = dict(single='pcolormesh')

def get_data(self, element, ranges, style):
data = np.ma.array(element.data[2],
mask=np.logical_not(np.isfinite(element.data[2])))
Expand Down
7 changes: 2 additions & 5 deletions holoviews/plotting/plot.py
Expand Up @@ -483,9 +483,6 @@ class GenericElementPlot(DimensionedPlot):
apply_extents = param.Boolean(default=True, doc="""
Whether to apply extent overrides on the Elements""")

# Whether the plotting class supports batched plotting
_batched = False

def __init__(self, element, keys=None, ranges=None, dimensions=None,
batched=False, overlaid=0, cyclic_index=0, zorder=0, style=None,
overlay_dims={}, **params):
Expand All @@ -509,7 +506,7 @@ def __init__(self, element, keys=None, ranges=None, dimensions=None,
super(GenericElementPlot, self).__init__(keys=keys, dimensions=dimensions,
dynamic=dynamic,
**dict(params, **plot_opts))
if self.batched and self._batched:
if self.batched:
self.ordering = util.layer_sort(self.hmap)
self.style = self.lookup_options(self.hmap.last.last, 'style').max_cycles(len(self.ordering))
else:
Expand Down Expand Up @@ -723,7 +720,7 @@ def _create_subplots(self, ranges):
batched = self.batched and type(self.hmap.last) is NdOverlay
if batched:
batchedplot = registry.get(type(self.hmap.last.last))
if (batched and batchedplot and batchedplot._batched and
if (batched and batchedplot and 'batched' in batchedplot._plot_methods and
(not self.show_legend or len(ordering) > self.legend_limit)):
self.batched = True
keys, vmaps = [()], [self.hmap]
Expand Down

0 comments on commit 37ac89d

Please sign in to comment.