<a href="https://colab.research.google.com/github/nathan-mahynski/nathan-mahynski.github.io/blob/public/_notes/bokeh/bokeh_notes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install watermark
# !pip install bokeh==3.0.3

In [None]:
import bokeh
import scipy
import watermark

import numpy as np
import pandas as pd

from bokeh.plotting import figure, show
from bokeh.io import output_notebook

from bokeh.layouts import column, layout
from bokeh.models import (ColumnDataSource, DataTable, HoverTool, CrosshairTool, FileInput, CheckboxGroup, Paragraph,
                          SelectEditor, CustomJS, Segment, VBar, Rect, Button, TextInput, RadioButtonGroup, Span, 
                          StringFormatter, TableColumn, RangeSlider, Slider, Select, CDSView, IndexFilter)

In [None]:
output_notebook()

In [None]:
%load_ext autoreload
%autoreload 2

%load_ext watermark
%watermark -t -m -h -v --iversions

---

The [Bokeh gallery](https://docs.bokeh.org/en/latest/docs/gallery.html) has a lot of great examples to get started quickly - refer to this as a good reference.

Below is an example that walks through some more complex functions I have found useful.

# Create Example Data

In [None]:
# This is meant to emulate GCMS data where peaks have been assigned to compounds by an algorithm, but the user may need to manually reassign them.
import hashlib

np.random.seed(42)
retention_times = [np.random.normal(loc=10, scale=1.0) for i in range(5)] + \
  [np.random.normal(loc=12, scale=1.0) for i in range(10)] + \
  [np.random.normal(loc=14, scale=1.0) for i in range(15)]

quality = [np.random.random()*100 for i in range(30)]
hit_name = ['A']*5 + ['B']*10 + ['C']*15
suggested_name = hit_name
new_name = hit_name

df = pd.DataFrame()
df['Retention_Time'] = retention_times
df['Quality'] = quality 
df['Hit_Name'] = hit_name 
df['Suggested_Name'] = suggested_name 
df['Assigned_Name'] = new_name 
df['Hash'] = [hashlib.sha1(str(x).encode()).hexdigest() for x in np.random.random(30)]

# Create a Data Table

In [None]:
# 1. Create a source for the table
total_source = ColumnDataSource(df)

# 2. Let's also create a view to only show certain parts of the data
view = CDSView(filter=IndexFilter(np.arange(len(df)).tolist()))

# 3. We are going to make it editable, but we only want to allow users to change names to be those we already know about
compound_names = sorted(df["Hit_Name"].unique())

columns = [
        TableColumn(field="Retention_Time", title="Retention Time",
                    ),
        TableColumn(field="Quality", title="Quality",
                    ),
        TableColumn(field="Hit_Name", title="Original Name", 
                    formatter=StringFormatter(font_style="bold")),
        TableColumn(field="Suggested_Name", title="Suggested Name", 
                    formatter=StringFormatter(font_style="bold")),
        TableColumn(field="Assigned_Name", title="Assigned Name", 
                    editor=SelectEditor(options=compound_names), # Can be modified based on user choice
                    formatter=StringFormatter(font_style="bold")),
          TableColumn(field="Hash", title="Hash", 
                    ),
    ]

data_table = DataTable(
    source=total_source,
    width=1000,
    view=view,
    columns=columns, 
    editable=True, 
    index_position=-1, 
    index_header="Index", 
    index_width=60)

In [None]:
# Note how "Assigned Name" can be modified via a dropdown!
show(data_table, sizing_mode='scale_width')

See [Bokeh.models](https://docs.bokeh.org/en/latest/docs/reference/models.html) for other types of editors besides dropdown menus.

# Plot Results in the DataTable

In [None]:
# Order the x-axis
ordered_cats = df.groupby('Hit_Name')['Retention_Time'].median().sort_values().index.tolist()

p = figure(
    background_fill_color="#efefef", 
    x_range=ordered_cats, 
    title="Retention Time Ranges",
    tools="pan,wheel_zoom,ybox_select,xbox_select,lasso_select,reset", 
    active_drag="pan",
    x_axis_label="",
    y_axis_label="Retention Time",
)
p.xaxis.major_label_orientation = "vertical"

In [None]:
# Datapoints
points = p.circle(
  x="Assigned_Name", 
  y="Retention_Time", 
  color="#F38630", 
  size=4, alpha=0.5, 
  source=total_source, 
  view=view
)
    
tooltips = [
        ("Original Name", "@Hit_Name"),
        ("Quality", "@Quality"),
        ("Hash", "@Hash")
    ]

hover_tool = HoverTool(renderers=[points], tooltips=tooltips)
p.add_tools(hover_tool)

crosshair_tool = CrosshairTool()
p.add_tools(crosshair_tool)

In [None]:
show(p)

# Add Some Interactions

Let's create 2 sliders which will filter the displayed data based on certain criteria

In [None]:
# Min observation slider (filter)
min_obs_slider = Slider(start=0, 
                        end=df.groupby('Hit_Name').count().max()['Quality'],
                        value=0, 
                        step=1, 
                        title="Minimum Observations within Quality Range")

# Minimum Quality slider (filter)
quality_slider = RangeSlider(start=0, end=100, step=1, value=(1, 99), title="Quality Range")

In [None]:
# JavaScript to compute the updated filter.
update_view_code = """
        var qvals = quality_slider.value;
        var t = total_source.data;
        
        var counts = {};
        for (var i = 0; i < t.index.length; i++) {  
            counts[t.Assigned_Name[i]] = 0;
        }
        for (var i = 0; i < t.index.length; i++) {  
            counts[t.Assigned_Name[i]] += 1;
        }

        // Update the visible data
        view.filter.indices = [];
        for (var i = 0; i < t.index.length; i++) { 
            if ((qvals[0] <= t.Quality[i]) && (t.Quality[i] <= qvals[1]) && (counts[t.Assigned_Name[i]] > min_obs_slider.value)) {
                view.filter.indices.push(i);
            }   
        }

        total_source.change.emit(); // Trigger the source this code is tied to - modifying view alone will not do this.
    """

In [None]:
update = CustomJS(args=dict(
        total_source = total_source, 
        view = view, 
        min_obs_slider = min_obs_slider, 
        quality_slider = quality_slider,
        ), 
        code=update_view_code
        )

# Recompute should be triggered by filters (quality and min_obs) to update view
quality_slider.js_on_change('value', update)
min_obs_slider.js_on_change('value', update)

In [None]:
show(column(quality_slider, min_obs_slider, p), sizing_mode='scale_width')

# Combining a Table and a Plot

All glyphs linked to a given source update when the source is modified.  This happens if you change a value, or if you force a "trigger" event.  Usually, the latter occurs in a callback function, as with the sliders above.  The former can sometimes be done directly, as illustrated next.

Let's combine the plot and table - the `Assigned Names` (x-axis) in the table can be edited. Changes to the table modify total_source, which the points are linked to, so they will change as you modify the table.

In [None]:
# Try changing the "Assigned Name" values - you will see that points get reassigned to different columns!

show(column(data_table, p), sizing_mode='scale_width')

# More Complex Combinations

Table edits directly modify total_source and sliders trigger a change signal, both of which update the plot and table. But what if we are plotting data that is INDIRECTLY a function of total_source; for example, a set of summary statistics? We can put those calculations inside the `update_view_code`, but that is only called when the sliders change values; table edits will change which points are plotted, but won't call the `update_view_code` function. We also need to trigger that recomputation when any part of the table is changed - This can be accomplished with this command.

~~~python
total_source.js_on_change('patching', update_view_code)
~~~

This way, when table edits change the `total_source` the `update_view_code` is also called. Let's do an example.

In [None]:
# Compute IQR for the visible (filtered) data

def get_quantiles_df(name_groups):
    """Python function to initialize IQR datatable."""
    q1 = name_groups['Retention_Time'].quantile(q=0.25)
    q2 = name_groups['Retention_Time'].quantile(q=0.5)
    q3 = name_groups['Retention_Time'].quantile(q=0.75)

    iqr = q3 - q1
    upper = q3 + 1.5*iqr # Do not bound by max/min, just report ranges
    lower = q1 - 1.5*iqr # Do not bound by max/min, just report ranges
    
    df = pd.concat([
        pd.DataFrame(q1).rename(columns={'Retention_Time':'q1'}),
        pd.DataFrame(q2).rename(columns={'Retention_Time':'q2'}),
        pd.DataFrame(q3).rename(columns={'Retention_Time':'q3'}),
        pd.DataFrame(upper).rename(columns={'Retention_Time':'upper'}),
        pd.DataFrame(lower).rename(columns={'Retention_Time':'lower'}),
    ], axis=1)
    
    df['Assigned_Name'] = df.index.copy()
    
    return df
    
df_iqr = get_quantiles_df(df.groupby('Hit_Name'))
iqr_source = ColumnDataSource(df_iqr)

In [None]:
calc_iqr_code = """
  // Create dictionary of lists, organized by name
  var iqr_data = {};
  for (var i = 0; i < view.filter.indices.length; i++) { 
    iqr_data[t.Assigned_Name[view.filter.indices[i]]] = [];
  }
  for (var i = 0; i < view.filter.indices.length; i++) {
    let idx = view.filter.indices[i];
    iqr_data[t.Assigned_Name[idx]].push(t.Retention_Time[idx]);
  }

  // From: https://www.geeksforgeeks.org/interquartile-range-iqr/
  function QUANTILE(data, q)
  {
    // R-7 method: https://en.wikipedia.org/wiki/Quantile#Estimating_quantiles_from_a_sample
    let values = data.slice().sort((a, b) => a - b); // copy and sort
    let h = (data.length - 1)*q;

    if (h % 1 === 0) {
      return values[h]
    } else {
      let base = Math.floor(h);
      return values[base] + (h - base)*(values[Math.ceil(h)] - values[base]);
    }
  }
  
  var iqr_ = iqr_source.data;
  iqr_.Assigned_Name = []
  iqr_.q1 = []
  iqr_.q2 = []
  iqr_.q3 = []
  iqr_.upper = []
  iqr_.lower = []

  var iqr_values = [0, 0, 0];
  for (const [Assigned_Name, values] of Object.entries(iqr_data)) {
    iqr_values = [QUANTILE(values, 0.25), QUANTILE(values, 0.5), QUANTILE(values, 0.75)];
    iqr_.Assigned_Name.push(Assigned_Name);
    iqr_.q1.push(iqr_values[0]);
    iqr_.q2.push(iqr_values[1]);
    iqr_.q3.push(iqr_values[2]);
    iqr_.lower.push(iqr_values[0] - 1.5*(iqr_values[2] - iqr_values[0]));
    iqr_.upper.push(iqr_values[2] + 1.5*(iqr_values[2] - iqr_values[0]));
  }

  iqr_source.change.emit();
"""

In [None]:
min_obs_slider = Slider(start=0, 
                        end=df.groupby('Hit_Name').count().max()['Quality'],
                        value=0, 
                        step=1, 
                        title="Minimum Observations within Quality Range")

quality_slider = RangeSlider(start=0, end=100, step=1, value=(1, 99), title="Quality Range")

update = CustomJS(args=dict(
        total_source = total_source, 
        iqr_source = iqr_source, 
        view = view, 
        min_obs_slider = min_obs_slider, 
        quality_slider = quality_slider,
        ), 
        code=update_view_code + calc_iqr_code # Sliders will now also trigger the re-calculation of IQR
        )

quality_slider.js_on_change('value', update)
min_obs_slider.js_on_change('value', update)
total_source.js_on_change('patching', update) # When the data table is changed, the re-calculation is also triggered

In [None]:
p = figure(
    background_fill_color="#efefef", 
    x_range=ordered_cats, 
    title="Retention Time Ranges",
    tools="pan,wheel_zoom,ybox_select,xbox_select,lasso_select,reset", 
    active_drag="pan",
    x_axis_label="",
    y_axis_label="Retention Time",
)
p.xaxis.major_label_orientation = "vertical"

from bokeh.transform import jitter
points = p.circle(
      x=jitter("Assigned_Name", 0.3, range=p.x_range),
      y="Retention_Time", 
      color="#F38630", 
      size=4, alpha=0.5, 
      source=total_source, 
      view=view
)
    
tooltips = [
        ("Original Name", "@Hit_Name"),
        ("Quality", "@Quality"),
    ]

hover_tool = HoverTool(renderers=[points], tooltips=tooltips)
p.add_tools(hover_tool)

crosshair_tool = CrosshairTool()
p.add_tools(crosshair_tool)

In [None]:
from bokeh.models import Whisker

# Plot the IQR with whiskers
error = Whisker(base="Assigned_Name", upper="upper", lower="lower", source=iqr_source,
                level="annotation", line_width=2)
error.upper_head.size=20
error.lower_head.size=20
p.add_layout(error)

In [None]:
# Put it all together
show(column(data_table, quality_slider, min_obs_slider, p), sizing_mode='scale_width')

# I/O with JavaScript

Another common task when building Bokeh applications is I/O.  This has to be accomplished with `CustomJS` but there are nice input/upload widgets that this code can be linked to.  Here is code illustrating the dumping/saving of the data table to a file, and the loading of a old file to update current values.

## Exporting Data

In [None]:
# Enter a filename to save to.
filename_export = TextInput(value="exported.txt", title="")

# This widget provides some logic to export the entire visible table, or just selected points.
radio = RadioButtonGroup(labels=['Selected Points Only', 'Entire Table'], active=0)

# Buttom to execute the export command.
export_button = Button(label='Export to Tab Delimited File', button_type='success') 
export_button.js_on_click(
    CustomJS( # https://github.com/surfaceowl-ai/python_visualizations/blob/main/notebooks/bokeh_save_export_data.py
        args=dict(total_source=total_source, filename_input=filename_export, radio=radio, min_obs_slider=min_obs_slider, quality_slider=quality_slider),
        code="""
            var inds;
            if (radio.active == 0) { // Only selected points is the first label option
                inds = total_source.selected.indices;
            } else { // All points in the table is the second label (only other option)
                inds = [];
                for (let i=0; i < total_source.get_length(); ++i) {
                    inds.push(i);
                }
            }

            // Create a tab-delimited string to write out
            function table_to_txt(source) {
                const columns = ['Retention_Time', 'Quality', 'Hit_Name', 'Suggested_Name', 'Assigned_Name', 'Hash'];
                const nrows = source.get_length();
                var lines = 'Minimum Observations within Quality Range\t' + min_obs_slider.value.toString() + '\tQuality Range\t' + quality_slider.value[0].toString() + '\t' + quality_slider.value[1].toString();
                lines += ['\\nRetention Time\tQuality\tHit Name\tSuggested Name\tAssigned Name\tHash'];

                for (let i = 0; i < inds.length; i++) {
                    lines += '\\n';
                    for (let j = 0; j < columns.length; j++) {
                        lines += source.data[columns[j]][inds[i]].toString();
                        if (j < columns.length-1) {
                            lines += '\\t';
                        }
                    }
                }
                return lines;
            }

            var file = new Blob([table_to_txt(total_source)], {type: 'text/plain'});
            var elem = window.document.createElement('a');
            elem.href = window.URL.createObjectURL(file);
            elem.download = filename_input.value;
            document.body.appendChild(elem);
            elem.click();
            document.body.removeChild(elem);
            """,
        )
    )

## Importing Data

In [None]:
# You can restrict the file suffix in the pop-up window.
filename_import = FileInput(accept=".txt,")

# The FileInput widget loads the file behind the scenes immediately, and saves the result in base64.
# You need to convert that using a Custom JS callback.
import_button = Button(label="Import", button_type='danger')
import_button.js_on_click(CustomJS( 
        args=dict(source=total_source, filename_import=filename_import, quality_slider=quality_slider, min_obs_slider=min_obs_slider),
        code="""
            var data = atob(filename_import.value); // Convert from base64
            var lines = data.split('\\n');        

            // These edits will trigger new visible data
            min_obs_slider.value = parseInt(lines[0].split('\\t')[1]);
            quality_slider.value = [parseInt(lines[0].split('\\t')[3]), parseInt(lines[0].split('\\t')[4])];

            for (let i = 1; i < lines.length; ++i) {
                let row = lines[i].split('\\t');
                let hash = row[row.length - 1];
                let index = -1;

                // Try to find entry in total source and update it (based on hash)
                for (let j = 0; j < source.get_length(); ++j) {
                    if (hash === source.data.hash[j]) {
                        index = j;
                        break;
                    }
                }
                if (index >= 0) {
                    source.data.Retention_Time[index] = parseFloat(row[0]);
                    source.data.Quality[index] = parseInt(row[1]);
                    source.data.Hit_Name[index] = row[2];
                    source.data.Suggested_Name[index] = row[3];
                    source.data.Assigned_Name[index] = row[4];
                }
            }

            // (*) Trigger anything the source is directly linked to
            source.change.emit();

            """,
        )
    )

# Also trigger a re-compute after loading the data.  This means we don't really
# need the (*) line above, but if we weren't using this update code then we should
# include the (*) line.
import_button.js_on_click(update) 

## Put it all together

In [None]:
# Experiment by importing and exporting!
show(
    layout([
        [filename_import, import_button],
        [data_table], 
        [min_obs_slider, quality_slider],
        [p],
        [filename_export, radio, export_button]
        ], 
        sizing_mode='scale_width'
        )
)