# Generate interactive html visualizations of Ball Mapper graphs.
Allows for switching between multiple coloring functions using a dropdown menu

In [None]:
import networkx as nx

import numpy as np
import pandas as pd

from bokeh.io import show, save
from bokeh.models import (Plot, Range1d, MultiLine, Circle, TapTool, OpenURL, HoverTool, 
                          CustomJS, Slider, Column, StaticLayoutProvider, TapTool, 
                          WheelZoomTool, PanTool, ResetTool, SaveTool, FixedTicker, 
                          LinearColorMapper, LogColorMapper, ColorBar, BasicTicker, LogTicker,
                          Dropdown,RadioButtonGroup)
from bokeh.plotting import figure, from_networkx

from matplotlib import cm
from matplotlib.colors import to_hex

In [None]:
from pyBallMapper_Bokeh import graph_GUI, read_graph_from_list

In [None]:
# to deal with large csv
import csv
import sys

maxInt = sys.maxsize
decrement = True

while decrement:
    # decrease the maxInt value by factor 10
    # as long as the OverflowError occurs.
    decrement = False
    try:
        csv.field_size_limit(maxInt)
    except OverflowError:
        maxInt = int(maxInt/10)
        decrement = True

In [None]:
from tqdm.notebook import tqdm

In [None]:
coloring_df = pd.read_csv('BM/jones_upto_15_SYMM/jones_upto_15_MIRRORS_colors.csv', sep=' ')

coloring_df.index = range(1, len(coloring_df)+1)
coloring_df['signature_mod4'] = coloring_df.signature % 4

In [None]:
EPSILON = 50

# adj lists path
GRAPH1_PATH = 'BM/jones_upto_15_SYMM/{}_edges'.format(EPSILON)

# point covered by each node path
GRAPH1_POINTS_PATH = 'BM/jones_upto_15_SYMM/{}_points_covered_by_landmarks'.format(EPSILON)

TITLE = 'Jones up to 15 crossing - epsilon={}'.format(EPSILON)


###########
# GRAPH 1 #
###########

# read graph
# ASSUME NODES ARE NUMBERED FROM 1 TO N
G = read_graph_from_list(GRAPH1_PATH, GRAPH1_POINTS_PATH,
                          coloring_df[['number_of_crossings',
                                       'signature',
                                       'signature_mod4']],
                          add_points_covered=False,
                          MIN_SCALE = 10,
                          MAX_SCALE = 25
                          )

In [None]:
# read list of points covered by each node
# ASSUME NODES ARE NUMBERED FROM 1 TO N
csv_file = open(GRAPH1_POINTS_PATH)
reader = csv.reader(csv_file)

points_covered = {}
MAX_NODE_SIZE = 0
print('loading points covered')
for i, line_list in tqdm(enumerate(reader)):
    points_covered[i+1] = [int(node) for node in line_list[0].split(' ')]
    if len(points_covered[i+1]) > MAX_NODE_SIZE:
        MAX_NODE_SIZE = len(points_covered[i+1])
print('done')
        

for node in tqdm(G.nodes):
    G.nodes[node]['signature_STD'] = coloring_df.loc[points_covered[node]].signature.std(ddof=0)

In [None]:
## compute all colors
coloring_variables_dict = dict()
for var in G.nodes[1].keys() - ['size rescaled',]:
    coloring_variables_dict[var] = dict()

In [None]:
coloring_variables_dict

In [None]:
# manually set each variable palette

#Here we adopt standard colour palette
# my_palette = cm.get_cmap(name='jet')
# my_red_palette = cm.get_cmap(name='Reds')

coloring_variables_dict['size']['palette'] = cm.get_cmap(name='Reds')
coloring_variables_dict['size']['style'] = 'log'

coloring_variables_dict['number_of_crossings']['palette'] = cm.get_cmap(name='Reds')
coloring_variables_dict['number_of_crossings']['style'] = 'continuous'

coloring_variables_dict['signature']['palette'] = cm.get_cmap(name='jet')
coloring_variables_dict['signature']['style'] = 'discrete'

coloring_variables_dict['signature_mod4']['palette'] = cm.get_cmap(name='Reds')
coloring_variables_dict['signature_mod4']['style'] = None

coloring_variables_dict['signature_STD']['palette'] = cm.get_cmap(name='Reds')
coloring_variables_dict['signature_STD']['style'] = 'continuous'

In [None]:
for var in coloring_variables_dict:
    MIN_VALUE = 10000
    MAX_VALUE = -10000
    
    if var == 'signature':
        MIN_VALUE = -12
        MAX_VALUE = +12
    
    for node in G.nodes:
        if G.nodes[node][var] > MAX_VALUE:
            MAX_VALUE = G.nodes[node][var]
        if G.nodes[node][var] < MIN_VALUE:
            MIN_VALUE = G.nodes[node][var]

    coloring_variables_dict[var]['max'] = MAX_VALUE
    coloring_variables_dict[var]['min'] = MIN_VALUE
            
    for node in G.nodes:
        if not pd.isna(G.nodes[node][var]):
            color_id = (G.nodes[node][var] - MIN_VALUE) / (MAX_VALUE - MIN_VALUE)
            if coloring_variables_dict[var]['style'] == 'log':
                color_id = (np.log10(G.nodes[node][var]) - np.log10(MIN_VALUE)) \
                            / (np.log10(MAX_VALUE) - np.log10(MIN_VALUE))
            G.nodes[node]['{}_color'.format(var)] = to_hex(coloring_variables_dict[var]['palette'](color_id))
        else:
            G.nodes[node]['{}_color'] = 'black'

for node in G.nodes:
    G.nodes[node]['current_color'] = 'white'

In [None]:
coloring_variables_dict

In [None]:
def create_colorbar(style, palette, low, high):
    
    if style == 'continuous':
        # continuous colorbar 
        num_ticks = 100
        color_mapper = LinearColorMapper(palette=[to_hex(palette(color_id)) 
                                                  for color_id in np.linspace(0, 1, num_ticks)], 
                                         low=low, high=high)


        return ColorBar(color_mapper=color_mapper, 
                             major_label_text_font_size='14pt',
                             label_standoff=12,
                       )

    elif style == 'log':
        # log colorbar
        num_ticks = 100
        color_mapper = LogColorMapper(palette=[to_hex(palette(color_id)) 
                                                  for color_id in np.linspace(0, 1, num_ticks)], 
                                      low=low, high=high)

        log_ticks = LogTicker(mantissas=[1,2,3,4,5], desired_num_ticks=10)

        return ColorBar(color_mapper=color_mapper, 
                             major_label_text_font_size='14pt',
                             label_standoff=12,
                             ticker=log_ticks
                       )
    
    elif style == 'discrete':
        # discrete colorbar
        num_ticks = 13
        low = -13
        high = 13
        color_mapper = LinearColorMapper(palette=[to_hex(palette(color_id)) 
                                                  for color_id in np.linspace(0, 1, num_ticks)], 
                                         low=low, high=high)

        ticks = [i for i in range(-12, 13, 2)]
        color_ticks = FixedTicker(ticks=ticks)

        return ColorBar(color_mapper=color_mapper, 
                             major_label_text_font_size='14pt',
                             label_standoff=12,
                             ticker=color_ticks,
                            )

In [None]:
plot = Plot(#plot_width=700, plot_height=700,
            x_range=Range1d(-2, 2), y_range=Range1d(-2, 2),
            sizing_mode="stretch_both",
            toolbar_location = 'right',
            title=TITLE)

tooltips=[("index", "@index"), ("size", "@size")]

tooltips += [(name.replace('_', ' '), '@{}'.format(name)) for name in coloring_variables_dict]

node_hover_tool = HoverTool(tooltips=tooltips)
zoom_tool = WheelZoomTool()
plot.add_tools(PanTool(), node_hover_tool, zoom_tool,
                    ResetTool(), SaveTool())
plot.toolbar.active_scroll = zoom_tool

graph_renderer = from_networkx(G, nx.spring_layout,
                               seed=42, scale=1, center=(0, 0),
                               k= 10/np.sqrt(len(G.nodes)),
                               iterations=2000)

# nodes
graph_renderer.node_renderer.glyph = Circle(size='size rescaled',
                                              fill_color='current_color',
                                              fill_alpha=0.8)

# edges
graph_renderer.edge_renderer.glyph = MultiLine(line_color='black',
                                                 line_alpha=0.8, line_width=1)

plot.renderers.append(graph_renderer)


# colorbars

color_bar_dict = {}

for var in coloring_variables_dict:
    if coloring_variables_dict[var]['style']:
        color_bar_dict[var+'_color'] = create_colorbar(style=coloring_variables_dict[var]['style'],
                                                       palette=coloring_variables_dict[var]['palette'],
                                                       low=coloring_variables_dict[var]['min'],
                                                       high=coloring_variables_dict[var]['max']              
                                                      )
        color_bar_dict[var+'_color'].visible = False
        color_bar_dict[var+'_color'].title = var.replace('_', ' ')
        color_bar_dict[var+'_color'].title_text_font_size = '14pt'

for key in color_bar_dict:
    plot.add_layout(color_bar_dict[key], 'right')

# dropdown menu
code = """ 
        
        var node_data = graph_renderer.node_renderer.data_source.data;
        var edge_data = graph_renderer.edge_renderer.data_source.data;
        for (var i = 0; i < node_data['size'].length; i++) {
            
            graph_renderer.node_renderer.data_source.data['current_color'][i] = node_data[this.item][i];
        }
        
        
        for (var key in color_bar_dict){
            color_bar_dict[key].visible = false;
        }
        
        if (this.item in color_bar_dict) {
            color_bar_dict[this.item].visible = true;

        }
        
        graph_renderer.node_renderer.data_source.change.emit();
        graph_renderer.edge_renderer.data_source.change.emit();
        
        for (var key in color_bar_dict){
            color_bar_dict[key].change.emit();
        }


    """


callback = CustomJS(args = dict(graph_renderer = graph_renderer,
                                color_bar_dict = color_bar_dict   ),
                    code = code)

menu = [(var.replace('_', ' '), var+'_color') for var in coloring_variables_dict]

dropdown = Dropdown(label="Select a coloring function", button_type="default", menu=menu)
dropdown.js_on_event("menu_item_click", callback)



layout = Column(dropdown, plot, sizing_mode="scale_both")

In [None]:
## save to HTML file
OUTPUT_PATH = 'Jones_upto_15'
save(layout,'output/{}.html'.format(OUTPUT_PATH))

In [None]:
# show on browser
show(layout) 