## phylo tree view

* https://dash.plot.ly/cytoscape/biopython
* http://chuckpr.github.io/blog/trees2.html
* https://biopython.org/wiki/Phylo
* https://besjournals.onlinelibrary.wiley.com/doi/full/10.1111/2041-210X.13313

In [None]:
import os, sys, io, random, subprocess, time
import string
import math
import numpy as np
import pandas as pd
pd.set_option('display.width', 200)
pd.set_option('display.max_colwidth', 100)
from importlib import reload
%matplotlib inline

from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Align import MultipleSeqAlignment
from Bio import AlignIO, SeqIO
from Bio import Phylo

from IPython.display import HTML

from bokeh.plotting import figure
from bokeh.models import (ColumnDataSource, Plot, LinearAxis, Grid, Range1d,CustomJS, Slider,
                          HoverTool, NumeralTickFormatter, Label, Range1d)
from bokeh.models.glyphs import Text, Rect, Segment, Circle
from bokeh.layouts import gridplot, column
import panel as pn
import panel.widgets as pnw
pn.extension()
from pybioviz import utils, plotters

In [None]:
tree_file = os.path.join(utils.datadir, 'test.dnd')
tree = Phylo.read(tree_file, 'newick')
print (tree)
Phylo.draw_ascii(tree)

In [None]:
taxa = tree.get_terminals()
positions = dict((taxon, 2 * idx) for idx, taxon in enumerate(taxa))
positions

In [None]:
def generate_elements(tree, xlen=30, ylen=30, grabbable=False):
    
    def get_col_positions(tree, column_width=80):
        """Create a mapping of each clade to its column position."""
        
        taxa = tree.get_terminals()

        # Some constants for the drawing calculations
        max_label_width = max(len(str(taxon)) for taxon in taxa)
        drawing_width = column_width - max_label_width - 1

        depths = tree.depths()
        # If there are no branch lengths, assume unit branch lengths
        if not max(depths.values()):
            depths = tree.depths(unit_branch_lengths=True)
            # Potential drawing overflow due to rounding -- 1 char per tree layer
        fudge_margin = int(math.ceil(math.log(len(taxa), 2)))
        cols_per_branch_unit = ((drawing_width - fudge_margin) /
                                float(max(depths.values())))
        return dict((clade, int(blen * cols_per_branch_unit + 1.0))
                    for clade, blen in depths.items())

    def get_row_positions(tree):
        taxa = tree.get_terminals()
        positions = dict((taxon, 2 * idx) for idx, taxon in enumerate(taxa))

        def calc_row(clade):
            for subclade in clade:
                if subclade not in positions:
                    calc_row(subclade)
            positions[clade] = ((positions[clade.clades[0]] +
                                 positions[clade.clades[-1]]) // 2)

        calc_row(tree.root)
        return positions

    def add_to_elements(clade, clade_id):
        children = clade.clades

        pos_x = col_positions[clade] * xlen
        pos_y = row_positions[clade] * ylen

        print (pos_x)
        cy_source = {
            "data": {"id": clade_id},
            'position': {'x': pos_x, 'y': pos_y},
            'classes': 'nonterminal',
            'grabbable': grabbable
        }
        nodes.append(cy_source)

        if clade.is_terminal():
            cy_source['data']['name'] = clade.name
            cy_source['classes'] = 'terminal'

        for n, child in enumerate(children):
            # The "support" node is on the same column as the parent clade,
            # and on the same row as the child clade. It is used to create the
            # 90 degree angle between the parent and the children.
            # Edge config: parent -> support -> child

            support_id = clade_id + 's' + str(n)
            child_id = clade_id + 'c' + str(n)
            pos_y_child = row_positions[child] * ylen

            cy_support_node = {
                'data': {'id': support_id},
                'position': {'x': pos_x, 'y': pos_y_child},
                'grabbable': grabbable,
                'classes': 'support'
            }

            cy_support_edge = {
                'data': {
                    'source': clade_id,
                    'target': support_id,
                    'sourceCladeId': clade_id
                },
            }

            cy_edge = {
                'data': {
                    'source': support_id,
                    'target': child_id,
                    'length': clade.branch_length,
                    'sourceCladeId': clade_id
                },
            }

            if clade.confidence and clade.confidence.value:
                cy_source['data']['confidence'] = clade.confidence.value

            nodes.append(cy_support_node)
            edges.extend([cy_support_edge, cy_edge])

            add_to_elements(child, child_id)

    col_positions = get_col_positions(tree)
    row_positions = get_row_positions(tree)
    nodes = []
    edges = []
    add_to_elements(tree.clade, 'r')

    return nodes, edges


In [None]:
nodes, edges = generate_elements(tree)

We plot edges not nodes. We want a dataframe where each row represents an edge and has all the information for plotting that edge. 
This information includes the beginning and ending coordinates of the edge, labels for terminal edges, and a column denoting whether the edge leads to a tip or not.

In [None]:
nodes[0]


In [None]:
testdf = {0:{'x0':4,'y0':3,'x1':6,'y1':4,'color':'g'}}
def fake_data(l=10,n=20):
    d={}
    for i in range(l):
        x0=5
        y0=x0+5
        x1=np.random.randint(n)
        y1=np.random.randint(n)
        d[i] = {'x0':x0,'y0':y0,'x1':x1,'y1':y1,'class':'terminal','label':'seq'+str(y1)}
    df = pd.DataFrame(d).T
    df['id'] = range(len(df))
    df['color'] = utils.random_colors(len(df))
    return df

df=fake_data(5)
#can add urls to df or any other data to be shown in tooltips
imgs = ['https://upload.wikimedia.org/wikipedia/commons/d/d4/CH_cow_2_cropped.jpg' for i in range(len(df))]
#df['imgs'] = imgs
#print (df)
tooltip = """
        <div>
            <img
                src="@imgs" height="52" alt="@imgs" width="52"
                style="float: left; margin: 0px 15px 15px 0px;"
                border="2"
            ></img>
        </div>x
"""        

In [None]:
def plot_tree(df, radius=.4):
    """Bokeh tree plot"""
        
    source = ColumnDataSource(df) 
    x_range = Range1d(0,10, bounds='auto')
    leaf_glyph = Circle(x="x1", y="y1", 
                        radius=radius, 
                        fill_color="color", 
                        name="circles", 
                        fill_alpha=0.6)
    tree_glyph = Segment(x0="x0", y0="y0", 
                         x1="x1", y1="y1", 
                         line_color="#151515",
                         line_alpha=0.75)

    tip_glyph = Segment(x0="x0", y0="y0", 
                         x1="x1", y1="y1", 
                         line_color="black",
                         line_width=1.65)
    
    tooltip = """
        <div>
            <span style="font-size: 16px; color: blue;">id: </span> 
            <span style="font-size: 16px; font-weight: bold;"> @id </span>
        </div>
        </div>@label</div>
    """    
    hover = HoverTool(
        tooltips=tooltip, point_policy='follow_mouse')   

    p = figure(title=None, plot_width=400, plot_height=400,
                    tools=[hover,"box_zoom,xwheel_zoom,reset,save"])
    
    p.add_glyph(source, leaf_glyph)
    p.add_glyph(source, tree_glyph)
    p.add_glyph(source, tip_glyph)
    p.xaxis.visible = False
    p.yaxis.visible = False
    p.grid.visible = False
    p.toolbar.logo = None   
    return p

In [None]:
plot=pn.pane.Bokeh()
edges_slider=pnw.IntSlider(name='edges',value=10,end=100,orientation='vertical',height=300)
b=pnw.Button(name='redraw')

def replot(event):
    df = fake_data(edges_slider.value)
    plot.object = plot_tree(df, radius=.4)
edges_slider.param.watch(replot,'value')
plot.object = plot_tree(df)
app=pn.Row(edges_slider,plot)
app