diff --git a/holoviews/plotting/bokeh/annotation.py b/holoviews/plotting/bokeh/annotation.py index f738b3fb26..22a0d9ab11 100644 --- a/holoviews/plotting/bokeh/annotation.py +++ b/holoviews/plotting/bokeh/annotation.py @@ -22,33 +22,29 @@ class TextPlot(ElementPlot): style_opts = text_properties+['color'] _plot_methods = dict(single='text', batched='text') - def _glyph_properties(self, plot, element, source, ranges): - props = super(TextPlot, self)._glyph_properties(plot, element, source, ranges) - props['text_align'] = element.halign - props['text_baseline'] = 'middle' if element.valign == 'center' else element.valign - if 'color' in props: - props['text_color'] = props.pop('color') - return props - - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): mapping = dict(x='x', y='y', text='text') if self.static_source: - return dict(x=[], y=[], text=[]), mapping + return dict(x=[], y=[], text=[]), mapping, style if self.invert_axes: data = dict(x=[element.y], y=[element.x]) else: data = dict(x=[element.x], y=[element.y]) self._categorize_data(data, ('x', 'y'), element.dimensions()) data['text'] = [element.text] - return (data, mapping) + style['text_align'] = element.halign + style['text_baseline'] = 'middle' if element.valign == 'center' else element.valign + if 'color' in style: + style['text_color'] = style.pop('color') + return (data, mapping, style) def get_batched_data(self, element, ranges=None): data = defaultdict(list) for key, el in element.data.items(): - eldata, elmapping = self.get_data(el, ranges) + eldata, elmapping, style = self.get_data(el, ranges) for k, eld in eldata.items(): data[k].extend(eld) - return data, elmapping + return data, elmapping, style def get_extents(self, element, ranges=None): return None, None, None, None @@ -63,7 +59,7 @@ class LineAnnotationPlot(ElementPlot): _plot_methods = dict(single='Span') - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): data, mapping = {}, {} dim = 'width' if isinstance(element, HLine) else 'height' if self.invert_axes: @@ -73,7 +69,7 @@ def get_data(self, element, ranges=None): if isinstance(loc, datetime_types): loc = date_to_integer(loc) mapping['location'] = loc - return (data, mapping) + return (data, mapping, style) def _init_glyph(self, plot, mapping, properties): """ @@ -97,7 +93,7 @@ class SplinePlot(ElementPlot): style_opts = line_properties _plot_methods = dict(single='bezier') - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if self.invert_axes: data_attrs = ['y0', 'x0', 'cy0', 'cx0', 'cy1', 'cx1', 'y1', 'x1'] else: @@ -117,7 +113,7 @@ def get_data(self, element, ranges=None): self.warning('Bokeh SplitPlot only support cubic splines, ' 'unsupported splines were skipped during plotting.') data = {da: data[da] for da in data_attrs} - return (data, dict(zip(data_attrs, data_attrs))) + return (data, dict(zip(data_attrs, data_attrs)), style) @@ -131,7 +127,7 @@ class ArrowPlot(CompositeElementPlot): _plot_methods = dict(single='text') - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): plot = self.state label_mapping = dict(x='x', y='y', text='text') @@ -167,7 +163,7 @@ def get_data(self, element, ranges=None): label_data = dict(x=[x2], y=[y2]) label_data['text'] = [element.text] return ({'label': label_data}, - {'arrow': arrow_opts, 'label': label_mapping}) + {'arrow': arrow_opts, 'label': label_mapping}, style) def _init_glyph(self, plot, mapping, properties, key): """ diff --git a/holoviews/plotting/bokeh/chart.py b/holoviews/plotting/bokeh/chart.py index 02855aa5f5..779bbe273a 100644 --- a/holoviews/plotting/bokeh/chart.py +++ b/holoviews/plotting/bokeh/chart.py @@ -71,8 +71,7 @@ def _get_size_data(self, element, ranges, style): return data, mapping - def get_data(self, element, ranges=None): - style = self.style[self.cyclic_index] + def get_data(self, element, ranges, style): dims = element.dimensions(label=True) xidx, yidx = (1, 0) if self.invert_axes else (0, 1) @@ -94,17 +93,17 @@ def get_data(self, element, ranges=None): mapping.update(smapping) self._get_hover_data(data, element) - return data, mapping + return data, mapping, style - def get_batched_data(self, element, ranges=None): + def get_batched_data(self, element, ranges): data = defaultdict(list) zorders = self._updated_zorders(element) - styles = self.lookup_options(element.last, 'style') - styles = styles.max_cycles(len(self.ordering)) for (key, el), zorder in zip(element.data.items(), zorders): self.set_param(**self.lookup_options(el, 'plot').options) - eldata, elmapping = self.get_data(el, ranges) + style = self.lookup_options(element.last, 'style') + style = style.max_cycles(len(self.ordering))[zorder] + eldata, elmapping, style = self.get_data(el, ranges, style) for k, eld in eldata.items(): data[k].append(eld) @@ -114,7 +113,6 @@ def get_batched_data(self, element, ranges=None): # Apply static styles nvals = len(list(eldata.values())[0]) - style = styles[zorder] sdata, smapping = expand_batched_style(style, self._batched_style_opts, elmapping, nvals) elmapping.update(smapping) @@ -127,7 +125,7 @@ def get_batched_data(self, element, ranges=None): data[sanitized].append([k]*nvals) data = {k: np.concatenate(v) for k, v in data.items()} - return data, elmapping + return data, elmapping, style @@ -185,8 +183,7 @@ def _glyph_properties(self, *args): return properties - def get_data(self, element, ranges=None): - style = self.style[self.cyclic_index] + def get_data(self, element, ranges, style): input_scale = style.pop('scale', 1.0) # Get x, y, angle, magnitude and color data @@ -242,7 +239,7 @@ def get_data(self, element, ranges=None): data[cdim.name] = color mapping.update(cmapping) - return (data, mapping) + return (data, mapping, style) @@ -259,12 +256,12 @@ class CurvePlot(ElementPlot): _plot_methods = dict(single='line', batched='multi_line') _batched_style_opts = line_properties - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) x = element.get_dimension(xidx).name y = element.get_dimension(yidx).name if self.static_source: - return {}, dict(x=x, y=y) + return {}, dict(x=x, y=y), style if 'steps' in self.interpolation: element = interpolate_curve(element, interpolation=self.interpolation) @@ -272,7 +269,7 @@ def get_data(self, element, ranges=None): y: element.dimension_values(yidx)} self._get_hover_data(data, element) self._categorize_data(data, (x, y), element.dimensions()) - return (data, dict(x=x, y=y)) + return (data, dict(x=x, y=y), style) def _hover_opts(self, element): if self.batched: @@ -283,16 +280,15 @@ def _hover_opts(self, element): line_policy = 'nearest' return dims, dict(line_policy=line_policy) - def get_batched_data(self, overlay, ranges=None): + def get_batched_data(self, overlay, ranges): data = defaultdict(list) zorders = self._updated_zorders(overlay) - styles = self.lookup_options(overlay.last, 'style') - styles = styles.max_cycles(len(self.ordering)) - for (key, el), zorder in zip(overlay.data.items(), zorders): self.set_param(**self.lookup_options(el, 'plot').options) - eldata, elmapping = self.get_data(el, ranges) + style = self.lookup_options(el, 'style') + style = style.max_cycles(len(self.ordering))[zorder] + eldata, elmapping, style = self.get_data(el, ranges, style) # Skip if data empty if not eldata: @@ -302,7 +298,6 @@ def get_batched_data(self, overlay, ranges=None): data[k].append(eld) # Apply static styles - style = styles[zorder] sdata, smapping = expand_batched_style(style, self._batched_style_opts, elmapping, nvals=1) elmapping.update(smapping) @@ -316,7 +311,7 @@ def get_batched_data(self, overlay, ranges=None): if not any(v is None for v in vals)} mapping = {{'x': 'xs', 'y': 'ys'}.get(k, k): v for k, v in elmapping.items()} - return data, mapping + return data, mapping, style @@ -325,7 +320,7 @@ class HistogramPlot(ElementPlot): style_opts = line_properties + fill_properties _plot_methods = dict(single='quad') - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if self.invert_axes: mapping = dict(top='left', bottom='right', left=0, right='top') else: @@ -336,7 +331,7 @@ def get_data(self, element, ranges=None): data = dict(top=element.values, left=element.edges[:-1], right=element.edges[1:]) self._get_hover_data(data, element) - return (data, mapping) + return (data, mapping, style) def get_extents(self, element, ranges): x0, y0, x1, y1 = super(HistogramPlot, self).get_extents(element, ranges) @@ -369,7 +364,7 @@ class SideHistogramPlot(ColorbarPlot, HistogramPlot): main_source.trigger('change') """ - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if self.invert_axes: mapping = dict(top='right', bottom='left', left=0, right='top') else: @@ -389,7 +384,7 @@ def get_data(self, element, ranges=None): mapping['fill_color'] = {'field': dim.name, 'transform': cmapper} self._get_hover_data(data, element) - return (data, mapping) + return (data, mapping, style) def _init_glyph(self, plot, mapping, properties): @@ -428,10 +423,10 @@ class ErrorPlot(ElementPlot): _plot_methods = dict(single=Whisker) - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): mapping = dict(self._mapping) if self.static_source: - return {}, mapping + return {}, mapping, style base = element.dimension_values(0) ys = element.dimension_values(1) @@ -448,7 +443,7 @@ def get_data(self, element, ranges=None): else: mapping['dimension'] = 'height' self._categorize_data(data, ('base',), element.dimensions()) - return (data, mapping) + return (data, mapping, style) def _init_glyph(self, plot, mapping, properties): @@ -493,10 +488,10 @@ def get_extents(self, element, ranges): ranges[vdim] = (np.nanmin([0, ranges[vdim][0]]), ranges[vdim][1]) return super(AreaPlot, self).get_extents(element, ranges) - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): mapping = dict(self._mapping) if self.static_source: - return {}, mapping + return {}, mapping, style xs = element.dimension_values(0) if len(element.vdims) > 1: @@ -510,7 +505,7 @@ def get_data(self, element, ranges=None): mapping['dimension'] = 'width' else: mapping['dimension'] = 'height' - return data, mapping + return data, mapping, style @@ -554,8 +549,7 @@ def get_extents(self, element, ranges): t = np.nanmax([0, t]) return l, b, r, t - def get_data(self, element, ranges=None): - style = self.style[self.cyclic_index] + def get_data(self, element, ranges, style): dims = element.dimensions(label=True) pos = self.position @@ -583,7 +577,7 @@ def get_data(self, element, ranges=None): for d in dims: data[dimension_sanitizer(d)] = element.dimension_values(d) - return data, mapping + return data, mapping, style class SideSpikesPlot(SpikesPlot): @@ -639,7 +633,7 @@ class BarPlot(ColorbarPlot, LegendPlot): style_opts = line_properties + fill_properties + ['width', 'cmap'] - _plot_methods = dict(single=('vbar', 'hbar'), batched=('vbar', 'hbar')) + _plot_methods = dict(single=('vbar', 'hbar')) # Declare that y-range should auto-range if not bounded _y_range_type = DataRange1d @@ -769,7 +763,7 @@ def _glyph_properties(self, *args): del props['width'] return props - def get_data(self, element, ranges): + def get_data(self, element, ranges, style): # Get x, y, group, stack and color dimensions grouping = None group_dim = element.get_dimension(self.group_index) @@ -793,7 +787,6 @@ def get_data(self, element, ranges): self.color_index = color_dim.name # Define style information - style = self.style[self.cyclic_index] width = style.get('width', 1) cmap = style.get('cmap') hover = any(t == 'hover' or isinstance(t, HoverTool) @@ -946,13 +939,7 @@ def get_data(self, element, ranges): mapping.update({'y': mapping.pop('x'), 'left': mapping.pop('bottom'), 'right': mapping.pop('top'), 'height': mapping.pop('width')}) - return sanitized_data, mapping - - def get_batched_data(self, element, ranges): - el = element.last - collapsed = Bars(element.table(), kdims=el.kdims+element.kdims, - vdims=el.vdims) - return self.get_data(collapsed, ranges) + return sanitized_data, mapping, style @@ -1001,8 +988,8 @@ def _get_axis_labels(self, *args, **kwargs): ylabel = element.vdims[0].pprint_label return xlabel, ylabel, None - def _glyph_properties(self, plot, element, source, ranges): - properties = dict(self.style[self.cyclic_index], source=source) + def _glyph_properties(self, plot, element, source, ranges, style): + properties = dict(style, source=source) if self.show_legend and not element.kdims: properties['legend'] = element.label return properties @@ -1025,12 +1012,11 @@ def _get_factors(self, element): xfactors, yfactors = factors, [] return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors) - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if element.kdims: groups = element.groupby(element.kdims).data else: groups = dict([(element.label, element)]) - style = self.style[self.cyclic_index] vdim = dimension_sanitizer(element.vdims[0].name) # Define CDS data @@ -1131,7 +1117,7 @@ def get_data(self, element, ranges=None): # Return if not grouped if not element.kdims: - return data, mapping + return data, mapping, style # Define color dimension and data if cidx is None or cidx>=element.ndims: @@ -1156,5 +1142,5 @@ def get_data(self, element, ranges=None): vbar2_map['fill_color'] = {'field': cname, 'transform': mapper} vbar_map['legend'] = cdim.name - return data, mapping + return data, mapping, style diff --git a/holoviews/plotting/bokeh/element.py b/holoviews/plotting/bokeh/element.py index e83f440cd6..57d0a2d4a0 100644 --- a/holoviews/plotting/bokeh/element.py +++ b/holoviews/plotting/bokeh/element.py @@ -656,9 +656,8 @@ def _init_glyph(self, plot, mapping, properties): return renderer, renderer.glyph - def _glyph_properties(self, plot, element, source, ranges): - properties = self.style[self.cyclic_index] - + def _glyph_properties(self, plot, element, source, ranges, style): + properties = dict(style, source=source) if self.show_legend: if self.overlay_dims: legend = ', '.join([d.pprint_value(v) for d, v in @@ -666,7 +665,6 @@ def _glyph_properties(self, plot, element, source, ranges): else: legend = element.label properties['legend'] = legend - properties['source'] = source return properties def _update_glyph(self, renderer, properties, mapping, glyph): @@ -737,16 +735,17 @@ def _init_glyphs(self, plot, element, ranges, source): # Get data and initialize data source if self.batched: current_id = tuple(element.traverse(lambda x: x._plot_id, [Element])) - data, mapping = self.get_batched_data(element, ranges) + data, mapping, style = self.get_batched_data(element, ranges) else: - data, mapping = self.get_data(element, ranges) + style = self.style[self.cyclic_index] + data, mapping, style = self.get_data(element, ranges, style) current_id = element._plot_id if source is None: source = self._init_datasource(data) self.handles['previous_id'] = current_id self.handles['source'] = source - properties = self._glyph_properties(plot, style_element, source, ranges) + properties = self._glyph_properties(plot, style_element, source, ranges, style) with abbreviated_exception(): renderer, glyph = self._init_glyph(plot, mapping, properties) self.handles['glyph'] = glyph @@ -819,16 +818,17 @@ def _update_glyphs(self, element, ranges): current_id = element._plot_id self.handles['previous_id'] = current_id self.static_source = (self.dynamic and (current_id == previous_id)) + style = self.style[self.cyclic_index] if self.batched: - data, mapping = self.get_batched_data(element, ranges) + data, mapping, style = self.get_batched_data(element, ranges) else: - data, mapping = self.get_data(element, ranges) + data, mapping, style = self.get_data(element, ranges, style) if not self.static_source: self._update_datasource(source, data) if glyph: - properties = self._glyph_properties(plot, element, source, ranges) + properties = self._glyph_properties(plot, element, source, ranges, style) renderer = self.handles.get('glyph_renderer') with abbreviated_exception(): self._update_glyph(renderer, properties, mapping, glyph) @@ -964,18 +964,15 @@ class CompositeElementPlot(ElementPlot): def _init_glyphs(self, plot, element, ranges, source): # Get data and initialize data source - if self.batched: - current_id = tuple(element.traverse(lambda x: x._plot_id, [Element])) - data, mapping = self.get_batched_data(element, ranges) - else: - data, mapping = self.get_data(element, ranges) - current_id = element._plot_id + style = self.style[self.cyclic_index] + data, mapping, style = self.get_data(element, ranges, style) + current_id = element._plot_id self.handles['previous_id'] = current_id for key in dict(mapping, **data): source = self._init_datasource(data.get(key, {})) self.handles[key+'_source'] = source - properties = self._glyph_properties(plot, element, source, ranges) + properties = self._glyph_properties(plot, element, source, ranges, style) properties = self._process_properties(key, properties) with abbreviated_exception(): renderer, glyph = self._init_glyph(plot, mapping.get(key, {}), properties, key) @@ -1015,7 +1012,8 @@ def _update_glyphs(self, element, ranges): current_id = element._plot_id self.handles['previous_id'] = current_id self.static_source = (self.dynamic and (current_id == previous_id)) - data, mapping = self.get_data(element, ranges) + style = self.style[self.cyclic_index] + data, mapping, style = self.get_data(element, ranges, style) for key in dict(mapping, **data): gdata = data[key] @@ -1025,7 +1023,7 @@ def _update_glyphs(self, element, ranges): self._update_datasource(source, gdata) if glyph: - properties = self._glyph_properties(plot, element, source, ranges) + properties = self._glyph_properties(plot, element, source, ranges, style) properties = self._process_properties(key, properties) renderer = self.handles.get(key+'_glyph_renderer') with abbreviated_exception(): diff --git a/holoviews/plotting/bokeh/graphs.py b/holoviews/plotting/bokeh/graphs.py index 505213aece..b633d442ea 100644 --- a/holoviews/plotting/bokeh/graphs.py +++ b/holoviews/plotting/bokeh/graphs.py @@ -84,8 +84,7 @@ def _get_axis_labels(self, *args, **kwargs): xlabel, ylabel = [kd.pprint_label for kd in element.nodes.kdims[:2]] return xlabel, ylabel, None - def get_data(self, element, ranges=None): - style = self.style[self.cyclic_index] + def get_data(self, element, ranges, style): xidx, yidx = (1, 0) if self.invert_axes else (0, 1) # Get node data @@ -135,7 +134,7 @@ def get_data(self, element, ranges=None): data = {'scatter_1': point_data, 'multi_line_1': path_data, 'layout': layout} mapping = {'scatter_1': point_mapping, 'multi_line_1': {}} - return data, mapping + return data, mapping, style def _update_datasource(self, source, data): @@ -150,14 +149,15 @@ def _update_datasource(self, source, data): def _init_glyphs(self, plot, element, ranges, source): # Get data and initialize data source - data, mapping = self.get_data(element, ranges) + style = self.style[self.cyclic_index] + data, mapping, style = self.get_data(element, ranges, style) self.handles['previous_id'] = element._plot_id properties = {} mappings = {} for key in mapping: source = self._init_datasource(data.get(key, {})) self.handles[key+'_source'] = source - glyph_props = self._glyph_properties(plot, element, source, ranges) + glyph_props = self._glyph_properties(plot, element, source, ranges, style) properties.update(glyph_props) mappings.update(mapping.get(key, {})) properties = {p: v for p, v in properties.items() if p not in ('legend', 'source')} diff --git a/holoviews/plotting/bokeh/path.py b/holoviews/plotting/bokeh/path.py index 804edb1895..805202b73a 100644 --- a/holoviews/plotting/bokeh/path.py +++ b/holoviews/plotting/bokeh/path.py @@ -41,7 +41,7 @@ def _get_hover_data(self, data, element): data[dim] = [v for _ in range(len(list(data.values())[0]))] - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if self.static_source: data = {} else: @@ -50,7 +50,7 @@ def get_data(self, element, ranges=None): xs, ys = ([path[:, idx] for path in paths] for idx in [xidx, yidx]) data = dict(xs=xs, ys=ys) self._get_hover_data(data, element) - return data, dict(self._mapping) + return data, dict(self._mapping), style def _categorize_data(self, data, cols, dims): @@ -75,13 +75,12 @@ def get_batched_data(self, element, ranges=None): data = defaultdict(list) zorders = self._updated_zorders(element) - styles = self.lookup_options(element.last, 'style') - styles = styles.max_cycles(len(self.ordering)) - for (key, el), zorder in zip(element.data.items(), zorders): self.set_param(**self.lookup_options(el, 'plot').options) + style = self.lookup_options(el, 'style') + style = style.max_cycles(len(self.ordering))[zorder] self.overlay_dims = dict(zip(element.kdims, key)) - eldata, elmapping = self.get_data(el, ranges) + eldata, elmapping, style = self.get_data(el, ranges, style) for k, eld in eldata.items(): data[k].extend(eld) @@ -91,24 +90,22 @@ def get_batched_data(self, element, ranges=None): # Apply static styles nvals = len(list(eldata.values())[0]) - style = styles[zorder] sdata, smapping = expand_batched_style(style, self._batched_style_opts, elmapping, nvals) elmapping.update({k: v for k, v in smapping.items() if k not in elmapping}) for k, v in sdata.items(): data[k].extend(list(v)) - return data, elmapping + return data, elmapping, style class ContourPlot(ColorbarPlot, PathPlot): style_opts = line_properties + ['cmap'] - def get_data(self, element, ranges=None): - data, mapping = super(ContourPlot, self).get_data(element, ranges) + def get_data(self, element, ranges, style): + data, mapping, style = super(ContourPlot, self).get_data(element, ranges, style) ncontours = len(list(data.values())[0]) - style = self.style[self.cyclic_index] if element.vdims and element.level is not None: cdim = element.vdims[0] dim_name = util.dimension_sanitizer(cdim.name) @@ -117,7 +114,7 @@ def get_data(self, element, ranges=None): if 'cmap' in style: cmapper = self._get_colormapper(cdim, element, ranges, style) mapping['line_color'] = {'field': dim_name, 'transform': cmapper} - return data, mapping + return data, mapping, style class PolygonPlot(ColorbarPlot, PathPlot): @@ -135,7 +132,7 @@ def _hover_opts(self, element): dims += element.vdims return dims, {} - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): if self.static_source: data = {} else: @@ -144,7 +141,6 @@ def get_data(self, element, ranges=None): ys = [path[:, 1] for path in paths] data = dict(xs=ys, ys=xs) if self.invert_axes else dict(xs=xs, ys=ys) - style = self.style[self.cyclic_index] mapping = dict(self._mapping) if element.vdims and element.level is not None: cdim = element.vdims[0] @@ -161,4 +157,4 @@ def get_data(self, element, ranges=None): data[dim] = [v for _ in range(len(xs))] data[dim_name] = [element.level for _ in range(len(xs))] - return data, mapping + return data, mapping, style diff --git a/holoviews/plotting/bokeh/plot.py b/holoviews/plotting/bokeh/plot.py index 140906f2fe..4bfe995c0e 100644 --- a/holoviews/plotting/bokeh/plot.py +++ b/holoviews/plotting/bokeh/plot.py @@ -94,7 +94,7 @@ def __init__(self, *args, **params): self.root = None - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): """ Returns the data from an element in the appropriate format for initializing or updating a ColumnDataSource and a dictionary diff --git a/holoviews/plotting/bokeh/raster.py b/holoviews/plotting/bokeh/raster.py index 4cb822f7b6..deab46c6b8 100644 --- a/holoviews/plotting/bokeh/raster.py +++ b/holoviews/plotting/bokeh/raster.py @@ -21,21 +21,13 @@ def __init__(self, *args, **kwargs): if self.hmap.type == Raster: self.invert_yaxis = not self.invert_yaxis - - def _glyph_properties(self, plot, element, source, ranges): - properties = super(RasterPlot, self)._glyph_properties(plot, element, - source, ranges) - properties = {k: v for k, v in properties.items()} + def get_data(self, element, ranges, style): + mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh') val_dim = [d for d in element.vdims][0] - properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges, - properties) - return properties - + style['color_mapper'] = self._get_colormapper(val_dim, element, ranges, style) - def get_data(self, element, ranges=None): - mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh') if self.static_source: - return {}, mapping + return {}, mapping, style img = element.dimension_values(2, flat=False) if img.dtype.kind == 'b': @@ -60,9 +52,9 @@ def get_data(self, element, ranges=None): img = img[::-1] b, t = t, b dh, dw = t-b, r-l - + data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh]) - return (data, mapping) + return (data, mapping, style) @@ -71,10 +63,10 @@ class RGBPlot(RasterPlot): style_opts = [] _plot_methods = dict(single='image_rgba') - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh') if self.static_source: - return {}, mapping + return {}, mapping, style img = np.dstack([element.dimension_values(d, flat=False) for d in element.vdims]) @@ -104,16 +96,16 @@ def get_data(self, element, ranges=None): dh, dw = t-b, r-l data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh]) - return (data, mapping) + return (data, mapping, style) - def _glyph_properties(self, plot, element, source, ranges): + def _glyph_properties(self, plot, element, source, ranges, style): return ElementPlot._glyph_properties(self, plot, element, - source, ranges) + source, ranges, style) class HSVPlot(RGBPlot): - def get_data(self, element, ranges=None): - return super(HSVPlot, self).get_data(element.rgb, ranges) + def get_data(self, element, ranges, style): + return super(HSVPlot, self).get_data(element.rgb, ranges, style) class HeatMapPlot(ColorbarPlot): @@ -138,13 +130,12 @@ class HeatMapPlot(ColorbarPlot): def _get_factors(self, element): return super(HeatMapPlot, self)._get_factors(element.gridded) - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]] if self.invert_axes: x, y = y, x - style = self.style[self.cyclic_index] cmapper = self._get_colormapper(element.vdims[0], element, ranges, style) if self.static_source: - return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}} + return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}, style aggregate = element.gridded xdim, ydim = aggregate.dimensions()[:2] @@ -168,7 +159,7 @@ def get_data(self, element, ranges=None): data[sanitized] = ['-' if is_nan(v) else vdim.pprint_value(v) for v in aggregate.dimension_values(vdim)] return (data, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}, - 'height': 1, 'width': 1}) + 'height': 1, 'width': 1}, style) class QuadMeshPlot(ColorbarPlot): @@ -179,13 +170,12 @@ class QuadMeshPlot(ColorbarPlot): _plot_methods = dict(single='rect') style_opts = ['cmap', 'color'] + line_properties + fill_properties - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): x, y, z = element.dimensions(label=True) if self.invert_axes: x, y = y, x - style = self.style[self.cyclic_index] cmapper = self._get_colormapper(element.vdims[0], element, ranges, style) if self.static_source: - return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}} + return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}}, style if len(set(v.shape for v in element.data)) == 1: raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes") @@ -204,4 +194,4 @@ def get_data(self, element, ranges=None): data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs} return (data, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}, - 'height': 'heights', 'width': 'widths'}) + 'height': 'heights', 'width': 'widths'}, style) diff --git a/holoviews/plotting/bokeh/tabular.py b/holoviews/plotting/bokeh/tabular.py index e20aa275de..4e1e7fe3c9 100644 --- a/holoviews/plotting/bokeh/tabular.py +++ b/holoviews/plotting/bokeh/tabular.py @@ -45,13 +45,13 @@ def _execute_hooks(self, element): self.warning("Plotting hook %r could not be applied:\n\n %s" % (hook, e)) - def get_data(self, element, ranges=None): + def get_data(self, element, ranges, style): dims = element.dimensions() mapping = {d.name: d.name for d in dims} data = {d: element.dimension_values(d) for d in dims} data = {d.name: values if values.dtype.kind in "if" else list(map(d.pprint_value, values)) for d, values in data.items()} - return data, mapping + return data, mapping, style def initialize_plot(self, ranges=None, plot=None, plots=None, source=None): @@ -64,18 +64,18 @@ def initialize_plot(self, ranges=None, plot=None, plots=None, source=None): self.current_frame = element self.current_key = key - data, _ = self.get_data(element, ranges) + style = self.lookup_options(element, 'style')[self.cyclic_index] + data, _, style = self.get_data(element, ranges, style) if source is None: source = self._init_datasource(data) self.handles['source'] = source dims = element.dimensions() columns = [TableColumn(field=d.name, title=d.pprint_label) for d in dims] - properties = self.lookup_options(element, 'style')[self.cyclic_index] if bokeh_version > '0.12.7': - properties['reorderable'] = False + style['reorderable'] = False table = DataTable(source=source, columns=columns, height=self.height, - width=self.width, **properties) + width=self.width, **style) self.handles['plot'] = table self.handles['glyph_renderer'] = table self._execute_hooks(element) @@ -118,5 +118,6 @@ def update_frame(self, key, ranges=None, plot=None): if self.static_source: return source = self.handles['source'] - data, _ = self.get_data(element, ranges) + style = self.lookup_options(element, 'style')[self.cyclic_index] + data, _, style = self.get_data(element, ranges, style) self._update_datasource(source, data)