In [None]:
from matplotlib import pyplot as plt
from Bio import Phylo
import pandas as pd
import numpy as np
from copy import deepcopy
from io import StringIO
import matplotlib as mpl
from scripts.lib.plotting import load_style

loaded_style = load_style('paper')
savefig = loaded_style['savefig']
fullwidth = loaded_style['fullwidth']
halfwidth = loaded_style['halfwidth']

def annotate_clade_with_plotting_metadata(clade, left=0, bottom=0):
    branch_length = clade.branch_length
    if not branch_length:
        branch_length = 0
    clade.left = left
    clade.right = left + branch_length
    if clade.is_terminal():
        clade.y = bottom
        return 0, clade.y  # total_offset, root_pos
    else:
        roots = []
        total_offset = 0
        for c in clade.clades:
            offset, root_pos = annotate_clade_with_plotting_metadata(
                    c, left=clade.right, bottom=bottom + total_offset)
            roots.append(root_pos)
            total_offset += offset + 1
        clade.y = (min(roots) + max(roots)) / 2.0
        return total_offset - 1, clade.y

def plot_clade(clade, ax,
               tax_plotter=lambda c, ax: (),
               node_plotter=lambda c, ax: (),
               lw=1):
    ax.hlines(clade.y, clade.left, clade.right, lw=lw)
    children_ys = [plot_clade(c, ax,
                              tax_plotter=tax_plotter,
                              node_plotter=node_plotter,
                             )
                   for c in clade.clades]
    if children_ys:
        ymin = min(children_ys)
        ymax = max(children_ys)
        ax.vlines(clade.right, ymin, ymax, lw=lw)
    if clade.is_terminal():
        tax_plotter(clade, ax)
    node_plotter(clade, ax)
    return clade.y

def plot_tree(tree, ax, **kwargs):
    tree = deepcopy(tree)
    clade = getattr(tree, 'clade', tree)
    annotate_clade_with_plotting_metadata(clade)
    plot_clade(clade, ax, **kwargs)
    return tree
    
def fold_by_group(clade, key):
    if clade.is_terminal():
        clade.group = key(clade)
        clade.minbranch = 0
        clade.maxbranch = 0
        clade.count = 1
        return clade.group, clade.branch_length, clade.branch_length, 1
    else:
        child_group, child_minbranch, child_maxbranch, child_count = \
                zip(*[fold_by_group(c, key) for c in clade.clades])
        unique_groups = set(child_group)
        if (None in unique_groups) or len(unique_groups) > 1:
            # Children are not of one group
            # don't collapse anything and pass on null values
            return None, 0, 0, None
        else:
            # It IS a unique group
            # so drop the children
            clade.count = sum(child_count)
            clade.clades = []
            # and pass on the group and the branch lengths
            clade.group = unique_groups.pop()
            assert not unique_groups, "After popping, I expect unique_groups to be an empty set."
            clade.minbranch = min(child_minbranch)
            clade.maxbranch = max(child_maxbranch)
            return (clade.group,
                    clade.minbranch + clade.branch_length,
                    clade.maxbranch + clade.branch_length,
                    clade.count)

In [None]:
def _tax_plotter(clade, ax):
    if clade.name:
        label = clade.name
    else:
        label = '{} ({})'.format(clade.group, clade.count)
    if hasattr(clade, 'minbranch'):
        x = [clade.right,
             clade.right + clade.minbranch,
             clade.right + clade.maxbranch,
             clade.right
            ]
        y = [clade.y + 0.45] * 2 + [clade.y - 0.45] * 2
        ax.add_patch(plt.Polygon(xy=list(zip(x, y)), alpha=0.25))
    ax.annotate(label, xy=(clade.right + 0.0075, clade.y), ha='left', va='center', weight='bold')

test_tree_string = '((D:0.723274,((F:0.567784,Z:0.3):0.3, Q:0.4):0.3)1.000000:0.167192,(B:0.279326,H:0.756049)1.000000:0.807788);'
tree = Phylo.read(StringIO(test_tree_string), 'newick')
fig, ax = plt.subplots()
plot_tree(tree, ax, tax_plotter=_tax_plotter)


fold_by_group(tree.clade, lambda c: {'H': 'Group 0', 'B': 'Group 0', 'Q': 'Group 1', 'Z': 'Group 1', 'F': 'Group 1', 'D': 'Group 0'}[c.name])
fig, ax = plt.subplots()
plot_tree(tree, ax, tax_plotter=_tax_plotter)

In [None]:
get_all_node_values = lambda clade, attr: [getattr(t, attr)
                               for t
                               in clade.get_terminals() + clade_of_interest.get_nonterminals()]

In [None]:
rename_taxon = { 'KR364784.1': 'Muribaculum intestinale'
               , 'Otu000001.1': 'OTU-1.1'
               , 'Otu000001.2': 'OTU-1.2'
               , 'Otu000001.3': 'OTU-1.3'
               , 'Otu000001.4': 'OTU-1.4'
               , 'Otu000004.1': 'OTU-4.1'
               , 'Otu000004.2': 'OTU-4.2'
               , 'Otu000001': 'OTU-1'
               , 'Otu000004': 'OTU-4'
               , 'Otu0001': 'OTU-1'
               , 'Otu0004': 'OTU-4'
               , 'Bacteria|Bacteroidetes|Bacteroidia|Bacteroidales|Muribaculaceae||AJ400263.Unc18892': 'S24-7 (clone)'
               , 'Bacteria|Bacteroidetes|Bacteroidia|Bacteroidales|S24_7||AJ400267.Unc29200': 'S24-7 (AJ400267)'
               , 'Bacteria|Bacteroidetes|Bacteroidia|Bacteroidales|S24_7||AJ400235.Unc29190': 'S24-7 (AJ400235)'
               } 
                 
color_map = dict(zip(['Bacteroidaceae',
                      'Barnesiellaceae',
                      'Dysgonamonadaceae',
                      'Marinifilaceae',
                      'Marinilabiliaceae',
                      'Paludibacteraceae',
                      'Porphyromonadaceae',
                      'Prevotellaceae',
                      'Rikenellaceae',
                      'Tannerellaceae'],
                     plt.cm.tab10.colors))

def _fold_key(clade):
    try:
        group = clade.name.split('|')[5]
        if group == '':
            group = clade.name
    except IndexError:
        group = clade.name
    return group

def _tax_plotter(clade, ax):
    if clade.name:
        if clade.name in rename_taxon:
            # Check if there's a renaming for a given clade.
            label = rename_taxon[clade.name]
        else:
            # Label singleton taxa with the accession number.
            label = '{}'.format(clade.group,
#                                     clade.name.split('|')[-1].split('.')[0]
                                    )
    else:
        # Label collapsed clades with the number of taxa included.
        label = '{} ({})'.format(clade.group, clade.count)
    
    try:
        family = family_map[clade.group]
        color = color_map[family]
    except KeyError:
        color = 'black'

    if clade.count > 1:
        x = [clade.right,
             clade.right + clade.minbranch,
             clade.right + clade.maxbranch,
             clade.right
            ]
        y = [clade.y + 0.4] * 2 + [clade.y - 0.4] * 2
#        y = [clade.y, clade.y + 0.45, clade.y - 0.45, clade.y]
        ax.add_patch(plt.Polygon(xy=list(zip(x, y)), alpha=0.3, color=color))
        ax.annotate(label, xy=(clade.right + 0.015, clade.y),
                    ha='left', va='center',
                    weight='bold', style='italic',
                    color=color, fontsize=8)
    else:
        ax.annotate(label, xy=(clade.right + 0.015, clade.y),
                ha='left', va='center',
                weight='bold', style='italic',
                color=color, fontsize=8)
    
def _node_plotter(clade, ax):
#    if clade.is_terminal():
#        return
    confidence = clade.confidence
    if confidence is None:
        return

    if confidence > 0.95:
        if clade.is_terminal():
            hpos = clade.right - 0.013
        else:
            hpos = clade.right
        color = 'k'
        ax.scatter([hpos], [clade.y], color=color,
                   marker='o', s=12, linewidths=0, edgecolors='k', zorder=10)
#    elif confidence > 0.9:
#        color = 'grey'
#    elif confidence > 0.8:
#        color = 'lightgrey'
#    else:
#        color = 'w'
#    ax.scatter([clade.right], [clade.y], color=color, s=20, linewidths=1, edgecolors='k', zorder=10)

    
    
tree = Phylo.read('res/s247_of_interest.wrefs.press.realign.mask.uniq.nwk', 'newick')
outgroup = 'Bacteria|Bacteroidetes|Bacteroidia|Flavobacteriales|Flavobacteriaceae|Flavobacterium||JQ800019.H6DSpec7'
family_map = pd.read_table('meta/misc/bacteroidales_family_designation.tsv', index_col='genus').ormerod_family
tree.root_with_outgroup(outgroup)
tree.prune(outgroup)
tree.clade.branch_length = 0.06
tree.clade.confidence = None
#for c in tree.get_nonterminals() + tree.get_terminals():
#    if c.branch_length:
#        c.branch_length = 1
fold_by_group(tree.clade, _fold_key)
tree.collapse_all(lambda c: c.confidence is not None and c.confidence < 0.7)
tree.ladderize()

fig, ax = plt.subplots(figsize=(halfwidth, 5))
tree_annotated = plot_tree(tree, ax,
                           tax_plotter=_tax_plotter,
                           node_plotter=_node_plotter)

for confidence_label, color in [('> 95%', 'black'),
#                                ('> 90%', 'grey'),
#                                ('> 80%', 'lightgrey'),
#                                ('< 80%', 'white')
#                                ('> 80%', 'white')
                               ]:
    ax.scatter([], [], color=color, label=confidence_label, s=50, linewidths=2, edgecolors='k')
ax.set_yticks([])
ax.set_xticks([])
#ax.legend(loc='upper left', title='confidence')
ax.hlines([1.75], [0.95], [1.15], lw=1)
ax.annotate('0.2', xy=(1.05, 2), ha='center', fontsize=8)

# Draw box around S24-7 clade
clade_of_interest = tree_annotated.common_ancestor(['Otu0001', 'Otu0004',
                                                    'KR364784.1',
                                                    'Bacteria|Bacteroidetes|Bacteroidia|Bacteroidales|Muribaculaceae||AJ400263.Unc18892'])
lefts = get_all_node_values(clade_of_interest, 'left')
rights = get_all_node_values(clade_of_interest, 'right')
ys = get_all_node_values(clade_of_interest, 'y')
box_x_coords = [min(lefts) + 0.06,
                max(rights) + 0.55,
                max(rights) + 0.55,
                min(lefts) + 0.06
               ]
box_y_coords = [max(ys) + 0.5,
                max(ys) + 0.5,
                min(ys) - 0.5,
                min(ys) - 0.5
               ]
print(box_x_coords, box_y_coords)
ax.add_patch(plt.Polygon(xy=list(zip(box_x_coords, box_y_coords)),
                         fill=False, edgecolor='grey', linestyle='--', ))


#ax.set_xlim(-0.1, 1.5)
ax.axis('off')
#fig.tight_layout()

savefig(fig, 'fig/s247_tree')