In [1]:
import matplotlib.pyplot as plt
import baltic as bt
import pickle
import seaborn as sns
%matplotlib inline

import matplotlib.patches as mpatches

In [2]:
annotated_tree = '../frequencies/source/annotated_tree.nexus'
tree = bt.loadNexus(annotated_tree, absoluteTime=False)

In [3]:
colors = pickle.load(open('./colors.p', 'rb'))

In [4]:
def plot_tree(tree, labels=False, colorby='genotype'):
    branchWidth=2 ## default branch width
    ll = tree

    plt.yticks(size=0)
    
    for k in ll.Objects: ## iterate over objects in tree
        x=k.x ## or use absolute time instead
        y=k.y ## get y position from .drawTree that was run earlier, but could be anything else

        xp=k.parent.x ## get x position of current object's parent
        if x==None: ## matplotlib won't plot Nones, like root
            x=0.0
        if xp==None:
            xp=x
        c = colors[k.traits[colorby]] if k.traits.has_key(colorby) else 'gray'
            
        if isinstance(k,bt.leaf) or k.branchType=='leaf': ## if leaf...
            s=50 ## tip size can be fixed
            s=50-30*k.height/ll.treeHeight
            
            if colorby in k.traits:
                label = k.traits[colorby]
            else:
                label = ''
            plt.scatter(x,y,s=s,facecolor=c,edgecolor='none',zorder=11, label=label) ## plot circle for every tip
            plt.scatter(x,y,s=s+0.8*s,facecolor='k',edgecolor='none',zorder=10) ## plot black circle underneath
                        
        elif isinstance(k,bt.node) or k.branchType=='node': ## if node...
            plt.plot([x,x],[k.children[-1].y,k.children[0].y],lw=branchWidth,color=c,ls='-',zorder=9)
            
        plt.plot([xp,x],[y,y],lw=branchWidth,color=c,ls='-',zorder=9)

    x0,x1 = plt.xlim()
    plt.xlim((x0, x1*1.1))

In [7]:
genotypes = set([k.traits['genotype'] for k in tree.Objects if 'genotype' in k.traits])
labels = []
for g in sorted(list(genotypes)):
    c = colors[g]
    labels.append(mpatches.Patch(color=c, label=g))

In [None]:
sns.set(style='whitegrid', font_scale=1.2, palette=colors['cmap'])
cmap = plt.get_cmap(colors['cmap'])
fig, ax = plt.subplots(figsize=(7.5/2, 5.83))
plot_tree(tree)
plt.legend((1,0.8), handles=labels)
plt.tight_layout()
plt.savefig('./png/genotype_tree.png', dpi=300, bbox_inches='tight')