In [None]:
from matplotlib import pyplot as plt
from Bio import Phylo
import pandas as pd
import numpy as np
from copy import deepcopy
from io import StringIO
import matplotlib as mpl
from scripts.lib.plotting import load_style

loaded_style = load_style('paper')
savefig = loaded_style['savefig']
fullwidth = loaded_style['fullwidth']
halfwidth = loaded_style['halfwidth']

def annotate_clade_with_plotting_metadata(clade, left=0, bottom=0):
    branch_length = clade.branch_length
    if not branch_length:
        branch_length = 0
    clade.left = left
    clade.right = left + branch_length
    if clade.is_terminal():
        clade.y = bottom
        return 0, clade.y  # total_offset, root_pos
    else:
        roots = []
        total_offset = 0
        for c in clade.clades:
            offset, root_pos = annotate_clade_with_plotting_metadata(
                    c, left=clade.right, bottom=bottom + total_offset)
            roots.append(root_pos)
            total_offset += offset + 1
        clade.y = (min(roots) + max(roots)) / 2.0
        return total_offset - 1, clade.y

def plot_clade(clade, ax,
               tax_plotter=lambda c, ax: (),
               node_plotter=lambda c, ax: (),
               lw=1):
    ax.hlines(clade.y, clade.left, clade.right, lw=lw)
    children_ys = [plot_clade(c, ax,
                              tax_plotter=tax_plotter,
                              node_plotter=node_plotter,
                             )
                   for c in clade.clades]
    if children_ys:
        ymin = min(children_ys)
        ymax = max(children_ys)
        ax.vlines(clade.right, ymin, ymax, lw=lw)
    if clade.is_terminal():
        tax_plotter(clade, ax)
    node_plotter(clade, ax)
    return clade.y

def plot_tree(tree, ax, **kwargs):
    tree = deepcopy(tree)
    clade = getattr(tree, 'clade', tree)
    annotate_clade_with_plotting_metadata(clade)
    plot_clade(clade, ax, **kwargs)
    return tree
    
def fold_by_group(clade, key):
    if clade.is_terminal():
        clade.group = key(clade)
        clade.minbranch = 0
        clade.maxbranch = 0
        clade.count = 1
        return clade.group, clade.branch_length, clade.branch_length, 1
    else:
        child_group, child_minbranch, child_maxbranch, child_count = \
                zip(*[fold_by_group(c, key) for c in clade.clades])
        unique_groups = set(child_group)
        if (None in unique_groups) or len(unique_groups) > 1:
            # Children are not of one group
            # don't collapse anything and pass on null values
            return None, 0, 0, None
        else:
            # It IS a unique group
            # so drop the children
            clade.count = sum(child_count)
            clade.clades = []
            # and pass on the group and the branch lengths
            clade.group = unique_groups.pop()
            assert not unique_groups, "After popping, I expect unique_groups to be an empty set."
            clade.minbranch = min(child_minbranch)
            clade.maxbranch = max(child_maxbranch)
            return (clade.group,
                    clade.minbranch + clade.branch_length,
                    clade.maxbranch + clade.branch_length,
                    clade.count)

In [None]:
rename_taxon = { 'Otu0001_vC': 'B1-A'
               , 'Otu0001_vB': 'B1-B'
               , 'Otu0007_vA': 'B2'
               , 'Otu0004_vA': 'B3'
               , 'Otu0005_vA': 'B4'
               , 'Otu0009_vA': 'B5'
               , 'Otu0017_vA': 'B6'
               , 'Otu0049_vA': 'B7'
               , 'Muribaculaceae_bacterium_DSM_100764': 'DSM-100764'
               , 'Muribaculaceae_bacterium_DSM_100720': 'DSM-100720'
               , 'Barnesiella_viscericola_DSM_18177': 'Bc'
               , 'Bacteroides_ovatus_ATCC_8483': 'Bo'
               , 'Bacteroides_thetaiotaomicron_VPI5482': 'Bt'
               , 'Porphyromonas_gingivalis_ATCC_33277': 'Pg'
               , 'Homeothermus_arabinoxylanisolvens': 'H. arabinoxylanisolvens'
               , 'Muribaculum_intestinale_yl27': 'M. intestinale'
               } 
                 
color_map = { 'starch': 'blue'
            , 'host': 'purple'
            , 'plant': 'green'
            }

italic_list = [ 'Barnesiella_viscericola_DSM_18177'
              , 'Bacteroides_ovatus_ATCC_8483'
              , 'Bacteroides_thetaiotaomicron_VPI5482'
              , 'Porphyromonas_gingivalis_ATCC_33277'
              , 'Homeothermus_arabinoxylanisolvens'
              , 'Muribaculum_intestinale_yl27'
              ]

mag = pd.read_table('meta/genome.tsv', index_col='genome_id')

def _tax_plotter(clade, ax):
    if clade.name in rename_taxon:
        # Check if there's a renaming for a given clade.
        label = rename_taxon[clade.name]
    else:
        label = clade.name
        
    if mag.genome_type[clade.name] == 'here':
        label = label + '*'
        
    if clade.name in italic_list:
        textstyle = 'italic'
    else:
        textstyle = 'normal'
    
    guild = mag.ormerod_guild[clade.name]
    if guild in color_map:
        color = color_map[guild]
    else:
        color = 'black'
    ax.annotate(label, xy=(clade.right + 0.015, clade.y),
                ha='left', va='center',
                weight='bold', style=textstyle,
                color=color, fontsize=6.5)
    
def _node_plotter(clade, ax):
    confidence = clade.confidence
    if confidence is None:
        return

    if confidence > 0.95:
        if clade.is_terminal():
            hpos = clade.right - 0.013
        else:
            hpos = clade.right
        color = 'k'
        ax.scatter([hpos], [clade.y], color=color,
                   marker='o', s=12, linewidths=0, edgecolors='k', zorder=10)

    
    
tree = Phylo.read('data/core.a.mags.muri.g.final.marker_genes.refine.gb.prot.nwk', 'newick')
outgroup = [ 'Barnesiella_viscericola_DSM_18177'
           , 'Porphyromonas_gingivalis_ATCC_33277'
           , 'Bacteroides_ovatus_ATCC_8483'
           , 'Bacteroides_thetaiotaomicron_VPI5482'
           ]
tree.root_with_outgroup(outgroup)
for t in outgroup:
    tree.prune(t)
tree.clade.branch_length = 0.06
#tree.clade.confidence = None
tree.collapse_all(lambda c: c.confidence is not None and c.confidence < 0.7)
#tree.ladderize(reverse=True)

fig, ax = plt.subplots(figsize=(halfwidth, 6))
tree_annotated = plot_tree(tree, ax,
                           tax_plotter=_tax_plotter,
                           node_plotter=_node_plotter)

#for confidence_label, color in [('> 95%', 'black'),
#                                ('> 90%', 'grey'),
#                                ('> 80%', 'lightgrey'),
#                                ('< 80%', 'white')
#                                ('> 80%', 'white')
#                               ]
#    ax.scatter([], [], color=color, label=confidence_label, s=50, linewidths=2, edgecolors='k')
ax.set_yticks([])
ax.set_xticks([])
#ax.legend(loc='upper left', title='confidence')

scale_xy = (0.5, 0.5)
scale_length = 0.2
_x, _y = scale_xy
ax.hlines([_y], [_x], [_x + scale_length], lw=1)
ax.annotate('{:.1}'.format(scale_length),
            xy=(_x + scale_length / 2, _y + 0.4),
            ha='center', fontsize=8)

ax.axis('off')
#fig.tight_layout()

savefig(fig, 'build/figure_tree')