From 92fc6173d8f7236f9c12a520759ea97a31fb6542 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Sat, 8 Apr 2017 18:06:58 +0100 Subject: [PATCH] Correctly sync shared datasources --- holoviews/plotting/bokeh/plot.py | 15 ++++++++---- holoviews/plotting/bokeh/util.py | 26 ++++++++++++++++++++ tests/testplotinstantiation.py | 41 +++++++++++++++++++++++++++++++- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/holoviews/plotting/bokeh/plot.py b/holoviews/plotting/bokeh/plot.py index dd8e5a5865..5160309f7f 100644 --- a/holoviews/plotting/bokeh/plot.py +++ b/holoviews/plotting/bokeh/plot.py @@ -17,7 +17,7 @@ from ..util import get_dynamic_mode, initialize_sampled from .renderer import BokehRenderer from .util import (bokeh_version, layout_padding, pad_plots, - filter_toolboxes, make_axis) + filter_toolboxes, make_axis, update_shared_sources) if bokeh_version >= '0.12': from bokeh.layouts import gridplot @@ -153,6 +153,8 @@ def sync_sources(self): and 'source' in x.handles) data_sources = self.traverse(get_sources, [filter_fn]) grouped_sources = groupby(sorted(data_sources, key=lambda x: x[0]), lambda x: x[0]) + shared_sources = [] + source_cols = {} for _, group in grouped_sources: group = list(group) if len(group) > 1: @@ -169,6 +171,10 @@ def sync_sources(self): else: renderer.update(source=new_source) plot.handles['source'] = new_source + shared_sources.append(new_source) + source_cols[id(new_source)] = [c for c in new_source.data] + self.handles['shared_sources'] = shared_sources + self.handles['source_cols'] = source_cols @@ -441,7 +447,7 @@ def _make_axes(self, plot): plot = Column(*models) return plot - + @update_shared_sources def update_frame(self, key, ranges=None): """ Update the internal state of the Plot to represent the given @@ -450,7 +456,7 @@ def update_frame(self, key, ranges=None): """ ranges = self.compute_ranges(self.layout, key, ranges) for coord in self.layout.keys(full_grid=True): - subplot = self.subplots.get(coord, None) + subplot = self.subplots.get(wrap_tuple(coord), None) if subplot is not None: subplot.update_frame(key, ranges) title = self._get_title(key) @@ -692,13 +698,14 @@ def initialize_plot(self, plots=None, ranges=None): return self.handles['plot'] - + @update_shared_sources def update_frame(self, key, ranges=None): """ Update the internal state of the Plot to represent the given key tuple (where integers represent frames). Returns this state. """ + source_cols = self.handles.get('source_cols', {}) ranges = self.compute_ranges(self.layout, key, ranges) for r, c in self.coords: subplot = self.subplots.get((r, c), None) diff --git a/holoviews/plotting/bokeh/util.py b/holoviews/plotting/bokeh/util.py index ec7e993066..b2d9b87171 100644 --- a/holoviews/plotting/bokeh/util.py +++ b/holoviews/plotting/bokeh/util.py @@ -612,3 +612,29 @@ def filter_batched_data(data, mapping): del data[v] except: pass + + +def update_shared_sources(f): + """ + Context manager to ensures data sources shared between multiple + plots are cleared and updated appropriately avoiding warnings and + allowing empty frames on subplots. Expects a list of + shared_sources and a mapping of the columns expected columns for + each source in the plots handles. + """ + def wrapper(self, *args, **kwargs): + source_cols = self.handles.get('source_cols', {}) + shared_sources = self.handles.get('shared_sources', []) + for source in shared_sources: + source.data.clear() + + ret = f(self, *args, **kwargs) + + for source in shared_sources: + expected = source_cols[id(source)] + found = [c for c in expected if c in source.data] + empty = np.full_like(source.data[found[0]], np.NaN) if found else [] + patch = {c: empty for c in expected if c not in source.data} + source.data.update(patch) + return ret + return wrapper diff --git a/tests/testplotinstantiation.py b/tests/testplotinstantiation.py index 8d6e27f7d4..1f88080e12 100644 --- a/tests/testplotinstantiation.py +++ b/tests/testplotinstantiation.py @@ -11,7 +11,7 @@ import param import numpy as np -from holoviews import (Dimension, Overlay, DynamicMap, Store, +from holoviews import (Dimension, Overlay, DynamicMap, Store, Dataset, NdOverlay, GridSpace, HoloMap, Layout, Cycle) from holoviews.core.util import pd from holoviews.element import (Curve, Scatter, Image, VLine, Points, @@ -1224,6 +1224,45 @@ def test_shared_axes_disable(self): self.assertEqual((x_range.start, x_range.end), (-.5, .5)) self.assertEqual((y_range.start, y_range.end), (-.5, .5)) + def test_layout_shared_source_synced_update(self): + hmap = HoloMap({i: Dataset({chr(65+j): np.random.rand(i+2) + for j in range(4)}, kdims=['A', 'B', 'C', 'D']) + for i in range(3)}) + hmap1= hmap.map(lambda x: Points(x.clone(kdims=['A', 'B'])), Dataset) + hmap2 = hmap.map(lambda x: Points(x.clone(kdims=['D', 'C'])), Dataset) + hmap2.pop(1) + layout = (hmap1 + hmap2)(plot=dict(shared_datasource=True)) + plot = bokeh_renderer.get_plot(layout) + sources = plot.handles.get('shared_sources', []) + cols = plot.handles.get('source_cols', {}) + self.assertEqual(len(sources), 1) + data = sources[0].data + self.assertEqual(set(data.keys()), {'A', 'B', 'C', 'D'}) + plot.update((1,)) + self.assertEqual(data['A'], hmap1[1].dimension_values(0)) + self.assertEqual(data['B'], hmap1[1].dimension_values(1)) + self.assertEqual(data['C'], np.full_like(hmap1[1].dimension_values(0), np.NaN)) + self.assertEqual(data['D'], np.full_like(hmap1[1].dimension_values(0), np.NaN)) + + def test_grid_shared_source_synced_update(self): + hmap = HoloMap({i: Dataset({chr(65+j): np.random.rand(i+2) + for j in range(4)}, kdims=['A', 'B', 'C', 'D']) + for i in range(3)}) + hmap1= hmap.map(lambda x: Points(x.clone(kdims=['A', 'B'])), Dataset) + hmap2 = hmap.map(lambda x: Points(x.clone(kdims=['D', 'C'])), Dataset) + hmap2.pop(1) + grid = GridSpace({0: hmap1, 2: hmap2}, kdims=['X'])(plot=dict(shared_datasource=True)) + plot = bokeh_renderer.get_plot(grid) + sources = plot.handles.get('shared_sources', []) + cols = plot.handles.get('source_cols', {}) + self.assertEqual(len(sources), 1) + data = sources[0].data + self.assertEqual(set(data.keys()), {'A', 'B', 'C', 'D'}) + plot.update((1,)) + self.assertEqual(data['A'], hmap1[1].dimension_values(0)) + self.assertEqual(data['B'], hmap1[1].dimension_values(1)) + self.assertEqual(data['C'], np.full_like(hmap1[1].dimension_values(0), np.NaN)) + self.assertEqual(data['D'], np.full_like(hmap1[1].dimension_values(0), np.NaN)) class TestPlotlyPlotInstantiation(ComparisonTestCase):