In [1]:
import numpy as np
import networkx as nx
from math import sqrt
import csv
import sys

# to deal with large csv
maxInt = sys.maxsize
decrement = True

while decrement:
    # decrease the maxInt value by factor 10 
    # as long as the OverflowError occurs.
    decrement = False
    try:
        csv.field_size_limit(maxInt)
    except OverflowError:
        maxInt = int(maxInt/10)
        decrement = True

In [2]:
from bokeh.io import show
from bokeh.plotting import figure

from bokeh.layouts import layout, column, row, grid
from bokeh.models import (BoxZoomTool, Circle, HoverTool,
                          MultiLine, Plot, Range1d, ResetTool,
                          ColumnDataSource, LabelSet,
                          TapTool, WheelZoomTool, PanTool,
                          ColorBar, LinearColorMapper, BasicTicker,
                          Button, TextInput,
                          CustomJS, MultiChoice,
                          SaveTool)

from bokeh.events import Tap

from bokeh.plotting import from_networkx

from matplotlib import cm
from matplotlib.colors import to_hex

## Read the graphs

Each graph must be rapresented by an adjecency list (space separated)  
We assume nodes are numbered from 1 to N  
  
The list of points covereb by each node is a file with N lines, each line contains the points id (space separated) 

In [3]:
# define the color palette
my_red_palette = cm.get_cmap(name='Reds')

# create and array with 101 colors
color_list = [to_hex(my_red_palette(i/100)) for i in range(101)]  

In [4]:
def read_graph_from_list(GRAPH_ADJ_PATH, GRAPH_POINTS_PATH, color_list):
    # read graph adjecency list
    # G_dummy is needed because I want the nodes to be ordered
    # ASSUME NODES ARE NUMBERED FROM 1 TO N
    G_dummy = nx.read_adjlist(GRAPH_ADJ_PATH, nodetype = int)
        
    # read list of points covered by each node
    # ASSUME NODES ARE NUMBERED FROM 1 TO N
    csv_file = open(GRAPH_POINTS_PATH)
    reader = csv.reader(csv_file)

    points_covered = {}

    MAX_NODE_SIZE = 0
    for i, line_list in enumerate(reader):
        points_covered[i+1] = [int(node) for node in line_list[0].split(' ')]
        if len(points_covered[i+1]) > MAX_NODE_SIZE:
            MAX_NODE_SIZE = len(points_covered[i+1])
        
    # add the nodes that are not in the edgelist
    G = nx.Graph()
    G.add_nodes_from( range(1, len(points_covered) + 1) )
    G.add_edges_from(G_dummy.edges)

    MIN_SCALE = 10
    MAX_SCALE = 25
    
    for node in G.nodes:
        G.nodes[node]['points covered'] = points_covered[node]
        G.nodes[node]['size'] = len(G.nodes[node]['points covered'])
        # rescale the size for display
        G.nodes[node]['size rescaled'] = MAX_SCALE*G.nodes[node]['size']/MAX_NODE_SIZE + MIN_SCALE
        
        # color all nodes to white
        G.nodes[node]['color'] = color_list[0]
        
    return G

In [5]:
# Prepare Data

# adj lists path
GRAPH1_PATH = 'knots_BM/alexander15/50_edges'
GRAPH2_PATH = 'knots_BM/jones15/50_edges'

# point covered by each node path
GRAPH1_POINTS_PATH = 'knots_BM/alexander15/50_points_covered_by_landmarks'
GRAPH2_POINTS_PATH = 'knots_BM/jones15/50_points_covered_by_landmarks'

###########
# GRAPH 1 #
###########

# read graph
# ASSUME NODES ARE NUMBERED FROM 1 TO N
G1 = read_graph_from_list(GRAPH1_PATH, GRAPH1_POINTS_PATH, color_list)

###########
# GRAPH 2 #
###########
                          
G2 = read_graph_from_list(GRAPH2_PATH, GRAPH2_POINTS_PATH, color_list)
for node in G2.nodes:
    G2.nodes[node]['coverage'] = 0



## UI

Our app is defined by the following function  
see https://github.com/bokeh/bokeh/blob/2.2.3/examples/howto/server_embed/notebook_embed.ipynb  
for more info

In [85]:

##########
# PLOT 1 #
##########

plot1 = Plot(plot_width=800, plot_height=800,
            x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1),
            sizing_mode="stretch_both")

node_hover_tool = HoverTool(tooltips=[("index", "@index"), ("size", "@size")])
plot1.add_tools(PanTool(), node_hover_tool, BoxZoomTool(), WheelZoomTool(),
                ResetTool(), TapTool(), SaveTool())

graph_renderer_1 = from_networkx(G1, nx.spring_layout,
                                 seed=42, scale=1, center=(0, 0),
                                 k= 10/sqrt(len(G1.nodes)),
                                 iterations=1000)


## labels
# get the coordinates of each node
x_1, y_1 = zip(*graph_renderer_1.layout_provider.graph_layout.values())

# create a dictionary with each node position and the label
source_1 = ColumnDataSource({'x': x_1, 'y': y_1,
                             'node_id': [node for node in G1.nodes]})
labels_1 = LabelSet(x='x', y='y', text='node_id', source=source_1,
                    text_color='black', text_alpha=1, visible=False)

# nodes
graph_renderer_1.node_renderer.glyph = Circle(size='size rescaled',
                                            fill_color='color',
                                            fill_alpha=0.8)

# edges
graph_renderer_1.edge_renderer.glyph = MultiLine(line_color='black',
                                               line_alpha=0.8, line_width=1)

plot1.renderers.append(graph_renderer_1)
plot1.renderers.append(labels_1)

##########
# PLOT 2 #
##########

plot2 = Plot(plot_width=800, plot_height=800,
            x_range=Range1d(-1.1, 1.1), y_range=Range1d(-1.1, 1.1),
            sizing_mode="stretch_both")

node_hover_tool = HoverTool(tooltips=[("index", "@index"), ("size", "@size"),
                                      ("coverage", "@{coverage}{%0f}")])
plot2.add_tools(PanTool(), node_hover_tool, BoxZoomTool(), WheelZoomTool(),
                ResetTool(), SaveTool())

graph_renderer_2 = from_networkx(G2, nx.spring_layout,
                                 seed=42, scale=1, center=(0, 0),
                                 k= 10/sqrt(len(G2.nodes)),
                                 iterations=1000)

## labels
# get the coordinates of each node
x_2, y_2 = zip(*graph_renderer_2.layout_provider.graph_layout.values())

# create a dictionary with each node position and the label
source_2 = ColumnDataSource({'x': x_2, 'y': y_2,
                           'node_id': [node for node in G2.nodes]})
labels_2 = LabelSet(x='x', y='y', text='node_id', source=source_2,
                    text_color='black', text_alpha=1, visible=False)

# nodes
graph_renderer_2.node_renderer.glyph = Circle(size='size rescaled',
                                              fill_color='color',
                                              fill_alpha=0.8)

# edges
graph_renderer_2.edge_renderer.glyph = MultiLine(line_color='black',
                                                 line_alpha=0.8, line_width=1)

# color bar legend
color_mapper_2 = LinearColorMapper(palette=[to_hex(my_red_palette(color_id)) 
                                            for color_id in np.linspace(0, 1, 101)], 
                                   low=1, high=100)
color_bar_2 = ColorBar(color_mapper=color_mapper_2, ticker=BasicTicker(),
                       label_standoff=12, border_line_color=None, location=(0,0),
                       title='Percentage')

plot2.add_layout(color_bar_2, 'right')

plot2.renderers.append(graph_renderer_2)
plot2.renderers.append(labels_2)

################
# color button #
################

color_button = Button(label='COLOR',
                height_policy='fit',
                button_type="success")


################
# labels button #
################

labels_button = Button(label='SHOW LABELS',
                height_policy='fit',)
                #button_type="success")
    
# this code is called when the labels_button object is clicked
labels_button_code = """ 
                
        if (labels_button.label.localeCompare('SHOW LABELS') == 0) {
            console.log('SHOW LABELS');
            labels_button.label = 'HIDE LABELS';
            labels_1.visible = true;
            labels_2.visible = true;
        }
        
        else if (labels_button.label.localeCompare('HIDE LABELS') == 0) {
            console.log('HIDE LABELS');

            labels_button.label = 'SHOW LABELS';
            labels_1.visible = false;
            labels_2.visible = false;
        }
        

        
        labels_1.change.emit();
        labels_1.change.emit();
        labels_button.change.emit();


    """

labels_button.js_on_click(CustomJS(args=dict(labels_1 = labels_1,
                                             labels_2 = labels_2,
                                             labels_button = labels_button),
                                   code=labels_button_code ))

###################
# multichoice box #
###################

OPTIONS = [str(n) for n in G1.nodes]

multi_choice = MultiChoice(value=[], options=OPTIONS)

# this code is called when the color_button object is clicked
color_button_code = """ 
        
        var node_data_1 = gr_1.node_renderer.data_source.data;
        var node_data_2 = gr_2.node_renderer.data_source.data;
        
        var selected_nodes = multi_choice.value.map(Number);

        // color selected nodes of 1 to red, not selected to white
        for (var i = 0; i < node_data_1['index'].length; i++) {
            // node indices starts from 1
            if (selected_nodes.includes(i+1)) {
               gr_1.node_renderer.data_source.data['color'][i] = color_list[color_list.length - 1];
            } 
            else {
               gr_1.node_renderer.data_source.data['color'][i] = color_list[0];
            }
        }      
        
        // get list of points in selected_nodes
        var points_in_selected_nodes = new Array();
        for (var i = 0; i < selected_nodes.length; i++) {
            points_in_selected_nodes.push(...gr_1.node_renderer.data_source.data['points covered'][selected_nodes[i]-1]);
        }
                
        // select only unique values, convert it to a set
        points_in_selected_nodes = new Set(points_in_selected_nodes);        

        // color nodes in G2 according to the percentage of points that are in POINTS_IN_SELECTED_NODES
        for (var i = 0; i < node_data_2['index'].length; i++) {
            
            const points_2 = new Set(node_data_2['points covered'][i]);
            
            const intersection = new Set([...points_in_selected_nodes].filter(value=>points_2.has(value)));
  
            var coverage = intersection.size / points_2.size ;
            
            gr_2.node_renderer.data_source.data['coverage'][i] = coverage;
            gr_2.node_renderer.data_source.data['color'][i] = color_list[Math.round(coverage*100)];
        }      
        
        gr_1.node_renderer.data_source.change.emit();
        gr_2.node_renderer.data_source.change.emit();


    """

color_button.js_on_click(CustomJS(args=dict(gr_1 = graph_renderer_1,
                                            gr_2 = graph_renderer_2,
                                            multi_choice = multi_choice,
                                            color_list = color_list),
                                  code=color_button_code ))

taptool = plot1.select(type=TapTool)

# def update_node_highlight(event):
#     nodes_clicked_ints = graph_renderer_1.node_renderer.data_source.selected.indices
#     nodes_clicked_ints = [n+1 for n in nodes_clicked_ints]
#     nodes_clicked = list(map(str, nodes_clicked_ints))       
#     multi_choice.value += nodes_clicked


# this code is called when nodes on the left are selected
tap_code = """ 
        
        var nodes_clicked_int = gr_1.node_renderer.data_source.selected.indices;
        for (var i = 0; i < nodes_clicked_int.length; i++){
            nodes_clicked_int[i] += 1;
        }
                
        var nodes_clicked_str = new Array();
        for (var i = 0; i < nodes_clicked_int.length; i++){
            nodes_clicked_str.push(nodes_clicked_int[i].toString());
        }
        
        multi_choice.value = [...new Set([...multi_choice.value, ...nodes_clicked_str])];

    """

plot1.js_on_event(Tap, CustomJS(args=dict(gr_1 = graph_renderer_1,
                                          multi_choice = multi_choice),
                                code=tap_code ))

##########
# LAYOUT #
##########
layout = grid([[labels_button,color_button], [multi_choice], [plot1, plot2]])



In [86]:
show(layout) 