Skip to content

Commit

Permalink
Improvements for bokeh GraphPlot
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Dec 8, 2017
1 parent 4bab10c commit ea5a1d2
Showing 1 changed file with 47 additions and 25 deletions.
72 changes: 47 additions & 25 deletions holoviews/plotting/bokeh/graphs.py
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.models import (StaticLayoutProvider, NodesAndLinkedEdges,
EdgesAndLinkedNodes, Patches)
EdgesAndLinkedNodes, Patches, Bezier)

from ...core.util import basestring, dimension_sanitizer, unique_array
from ...core.options import Cycle
Expand Down Expand Up @@ -33,7 +33,7 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):
A list of plugin tools to use on the plot.""")

# Map each glyph to a style group
_style_groups = {'scatter': 'node', 'multi_line': 'edge', 'patches': 'edge'}
_style_groups = {'scatter': 'node', 'multi_line': 'edge', 'patches': 'edge', 'bezier': 'edge'}

style_opts = (['edge_'+p for p in line_properties] +
['node_'+p for p in fill_properties+line_properties] +
Expand All @@ -42,6 +42,9 @@ class GraphPlot(CompositeElementPlot, ColorbarPlot, LegendPlot):
# Filled is only supported for subclasses
filled = False

# Bezier paths
bezier = False

# Declares which columns in the data refer to node indices
_node_indices = [0, 1]

Expand Down Expand Up @@ -114,12 +117,25 @@ def _get_edge_colors(self, element, ranges, edge_data, edge_mapping, style):
edge_mapping['edge_selection_'+color_type] = transform


def _get_edge_paths(self, element):
path_data = {}
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
if element._edgepaths:
edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims)
if len(edges) == len(element):
path_data['xs'] = [path[:, xidx] for path in edges]
path_data['ys'] = [path[:, yidx] for path in edges]
else:
self.warning('Graph edge paths do not match the number of abstract edges '
'and will be skipped')
return path_data, {'xs': 'xs', 'ys': 'ys'}


def get_data(self, element, ranges, style):
# Force static source to False
static = self.static_source
self.handles['static_source'] = static
self.static_source = False
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)

# Get node data
nodes = element.nodes.dimension_values(2)
Expand Down Expand Up @@ -152,7 +168,6 @@ def get_data(self, element, ranges, style):
['node_fill_color', 'node_nonselection_fill_color']}
point_mapping['node_nonselection_fill_color'] = point_mapping['node_fill_color']

# Get edge data
edge_mapping = {}
nan_node = index.max()+1
start, end = (element.dimension_values(i) for i in range(2))
Expand All @@ -163,14 +178,10 @@ def get_data(self, element, ranges, style):
end = np.array([node_indices.get(y, nan_node) for y in end], dtype=np.int32)
path_data = dict(start=start, end=end)
self._get_edge_colors(element, ranges, path_data, edge_mapping, style)
if element._edgepaths and not static:
edges = element._split_edgepaths.split(datatype='array', dimensions=element.edgepaths.kdims)
if len(edges) == len(start):
path_data['xs'] = [path[:, 0] for path in edges]
path_data['ys'] = [path[:, 1] for path in edges]
else:
self.warning('Graph edge paths do not match the number of abstract edges '
'and will be skipped')
if not static:
pdata, pmapping = self._get_edge_paths(element)
path_data.update(pdata)
edge_mapping.update(pmapping)

# Get hover data
if any(isinstance(t, HoverTool) for t in self.state.tools):
Expand All @@ -182,11 +193,19 @@ def get_data(self, element, ranges, style):
elif self.inspection_policy == 'edges':
for d in element.dimensions():
path_data[dimension_sanitizer(d.name)] = element.dimension_values(d)
edge_glyph = 'patches_1' if self.filled else 'multi_line_1'
data = {'scatter_1': point_data, edge_glyph: path_data, 'layout': layout}
mapping = {'scatter_1': point_mapping, edge_glyph: edge_mapping}
data = {'scatter_1': point_data, self.edge_glyph: path_data, 'layout': layout}
mapping = {'scatter_1': point_mapping, self.edge_glyph: edge_mapping}
return data, mapping, style

@property
def edge_glyph(self):
if self.filled:
edge_glyph = 'patches_1'
elif self.bezier:
edge_glyph = 'bezier_1'
else:
edge_glyph = 'multi_line_1'
return edge_glyph

def _update_datasource(self, source, data):
"""
Expand All @@ -205,13 +224,14 @@ def _init_glyphs(self, plot, element, ranges, source):
# Get data and initialize data source
style = self.style[self.cyclic_index]
data, mapping, style = self.get_data(element, ranges, style)
edge_mapping = {k: v for k, v in mapping[self.edge_glyph].items()
if 'color' not in k}
self.handles['previous_id'] = element._plot_id
edge_glyph = 'patches_1' if self.filled else 'multi_line_1'

properties = {}
mappings = {}
for key in list(mapping):
if not any(glyph in key for glyph in ('scatter_1', edge_glyph)):
if not any(glyph in key for glyph in ('scatter_1', self.edge_glyph)):
continue
source = self._init_datasource(data.pop(key, {}))
self.handles[key+'_source'] = source
Expand All @@ -229,7 +249,7 @@ def _init_glyphs(self, plot, element, ranges, source):
# Define static layout
layout = StaticLayoutProvider(graph_layout=layout)
node_source = self.handles['scatter_1_source']
edge_source = self.handles[edge_glyph+'_source']
edge_source = self.handles[self.edge_glyph+'_source']
renderer = plot.graph(node_source, edge_source, layout, **properties)

# Initialize GraphRenderer
Expand All @@ -250,19 +270,20 @@ def _init_glyphs(self, plot, element, ranges, source):
self.handles['layout_source'] = layout
self.handles['glyph_renderer'] = renderer
self.handles['scatter_1_glyph_renderer'] = renderer.node_renderer
self.handles[edge_glyph+'_glyph_renderer'] = renderer.edge_renderer
self.handles[self.edge_glyph+'_glyph_renderer'] = renderer.edge_renderer
self.handles['scatter_1_glyph'] = renderer.node_renderer.glyph
if self.filled:
allowed_properties = Patches.properties()
if self.filled or self.bezier:
glyph_model = Patches if self.filled else Bezier
allowed_properties = glyph_model.properties()
for glyph_type in ('', 'selection_', 'nonselection_', 'hover_', 'muted_'):
glyph = getattr(renderer.edge_renderer, glyph_type+'glyph', None)
if glyph is None:
continue
props = self._process_properties(edge_glyph, properties, mappings)
props = self._process_properties(self.edge_glyph, properties, mappings)
filtered = self._filter_properties(props, glyph_type, allowed_properties)
patches = Patches(**dict(filtered, xs='xs', ys='ys'))
glyph = setattr(renderer.edge_renderer, glyph_type+'glyph', patches)
self.handles[edge_glyph+'_glyph'] = renderer.edge_renderer.glyph
new_glyph = glyph_model(**dict(filtered, **edge_mapping))
setattr(renderer.edge_renderer, glyph_type+'glyph', new_glyph)
self.handles[self.edge_glyph+'_glyph'] = renderer.edge_renderer.glyph
if 'hover' in self.handles:
self.handles['hover'].renderers.append(renderer)

Expand Down Expand Up @@ -294,3 +315,4 @@ def get_data(self, element, ranges, style):
# Ensure the edgepaths for the triangles are generated
element.edgepaths
return super(TriMeshPlot, self).get_data(element, ranges, style)

0 comments on commit ea5a1d2

Please sign in to comment.