In [1]:
import dendropy as dp
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib

In [6]:
# IO
fp = '/Users/mlandis/projects/vib_div/'
out_fp = fp + 'output/'
plot_fp = fp + 'code/plot/fig/age/'

fn_list = [ 'out.1.t163.f5',
            'out.1.t163.f5.mask_fossil_states',
            'out.2.t163.f5',
            'out.2.t163.f5.mask_fossil_states'
          ]
fn_list = [ x + '.tre' for x in fn_list ]

# M: molecules (0: only, 1: +biome/geog)
# F: fossils (0: yes, 1: no)
# B: use fossil biome states (0: no, 1:yes)
# G: use fossil area states (0: no, 1:yes)
# R: use compound rates (0: no, 1:yes)

model_names = [ 'D1F1',
                'D1F0',
                'D2F1',
                'D2F0' ]

               

f_burn = 0.5

In [7]:
# define clades of interest

fossil_taxa = ['Valvatotinus_IS',
               'Valvatotinus_PB',
               'Valvatotinus_NWT',
               'Viburnum_BC',
               'Porphyrotinus_CO']


clade_list = {
    'Viburnum'     : ['V_clemensiae', 'V_lentago', 'V_molle'],
    'Lentago'      : ['V_lentago', 'V_nudum'],
    'Euviburnum'   : ['V_lantana', 'V_cotinifolium'], 
    'Valvatotinus' : ['V_nudum', 'V_lantana'],
    'Solenotinus'  : ['V_grandiflorum', 'V_awabuki'],
    'Pseudotinus'  : ['V_lantanoides','V_urceolatum'],
    'Lutescentia'  : ['V_lutescens', 'V_plicatum'],
    'Tinus'        : ['V_tinus', 'V_cinnamomifolium'],
    'Succotinus'   : ['V_mullaha', 'V_erosum'],
    'Laminotinus'  : ['V_coriaceum', 'V_mullaha'],
    'Sambucina'    : ['V_sambucinum', 'V_beccarii'],
    'Opulus'       : ['V_opulus', 'V_edule'],
    'Mollotinus'   : ['V_molle', 'V_australe'],
    'Porphyrotinus': ['V_australe','V_jucundum'],
    'Oreinotinus'  : ['V_dentatum','V_jucundum']
}

clade_plot = {
    'Viburnum'     : {'c':'black', 'm':'s' },
    'Lentago'      : {'c':'red','m':'8'},
    'Euviburnum'   : {'c':'blue', 'm':'>'},
    'Valvatotinus' : {'c':'green', 'm':'<'},
    'Solenotinus'  : {'c':'orange', 'm':'^'},
    'Pseudotinus'  : {'c':'gold', 'm':'v'},
    'Lutescentia'  : {'c':'deeppink', 'm':'o'},
    'Tinus'        : {'c':'purple', 'm':'X'},
    'Succotinus'   : {'c':'firebrick', 'm':'P'},
    'Laminotinus'  : {'c':'steelblue', 'm':'d'},
    'Sambucina'    : {'c':'aqua', 'm':'D'},
    'Opulus'       : {'c':'peru', 'm':'H'},
    'Mollotinus'   : {'c':'olive', 'm':'h'},
    'Porphyrotinus': {'c':'yellowgreen', 'm':'*'},
    'Oreinotinus'  : {'c':'turquoise', 'm':'p'}
}


In [13]:
# pre-process trees
min_size = 1e6

phy = {}
for i,fn in enumerate(fn_list):
    # get clean model name
    mn = model_names[i]
    # read data
    dat = pd.read_csv( out_fp + fn, sep='\t' )
    print(dat)
    # determine burnin
    n_row = dat.shape[0]
    n_burn = int(n_row * f_burn)
    # get phy str
    dat_phy = dat.loc[n_burn:n_row,'tree']
    if len(dat_phy) < min_size:
        min_size = len(dat_phy)
    # parse & build tree objects
    phy[mn] = []
    for phy_str in dat_phy:
        phy_tmp = dp.Tree.get(data=phy_str, schema='newick', preserve_underscores=True )
        phy_tmp.prune_taxa_with_labels(fossil_taxa)
        phy_tmp.calc_node_ages()
        phy[mn].append( phy_tmp )
        
for mn in model_names:
    phy[mn] = phy[mn][1:min_size]
    

    Iteration  Posterior  Likelihood    Prior  \
0           0   -56441.4    -53721.8 -2719.63   
1          50   -32749.2    -30330.7 -2418.52   
2         100   -32673.5    -30265.1 -2408.43   
3         150   -32653.8    -30242.1 -2411.79   
4         200   -32661.1    -30243.3 -2417.79   
5         250   -32681.0    -30268.1 -2412.92   
6         300   -32652.2    -30251.6 -2400.69   
7         350   -32669.0    -30266.2 -2402.84   
8         400   -32651.5    -30242.0 -2409.43   
9         450   -32651.9    -30237.8 -2414.13   
10        500   -32699.7    -30275.5 -2424.21   
11        550   -32700.0    -30269.6 -2430.45   
12        600   -32713.8    -30270.0 -2443.81   
13        650   -32706.6    -30272.3 -2434.23   
14        700   -32677.2    -30247.1 -2430.06   
15        750   -32680.2    -30244.3 -2435.90   
16        800   -32698.1    -30248.7 -2449.40   
17        850   -32744.5    -30288.4 -2456.11   
18        900   -32719.3    -30256.5 -2462.88   
19        950   -327

KeyError: 'the label [phy] is not in the [columns]'

In [5]:
# compute the model-clade-sample ages

ages = {}

hpd = 0.95

# for each model
for mn in model_names:
    print("Processing " + mn)
    # initialize model-clade
    ages[mn] = {}
    for k,v in clade_list.items():
        ages[mn][k] = []        
    # for each tree
    for i,phy_tmp in enumerate(phy[mn]):
        #print( "\titeration" + str(i) )
        # for each clade
        for k,v in clade_list.items():
            #print(v)
            clade_mrca = phy_tmp.mrca( taxon_labels=v )
            clade_age = clade_mrca.age
            ages[mn][k].append( clade_age )
        
        for k,v in clade_list.items():
            a = np.sort( ages[mn][k] )
            ntail = int( (1.0-(hpd/2))*len(a) )
            a = a[ntail:(len(a)-ntail)]

Processing D1F1


KeyError: 'D1F1'

In [6]:
def plot_scatter(ages, clade_list, clade_plot, mn1, mn2):
    fig = plt.figure(figsize=(8, 8))
    ax1 = fig.add_subplot(111)

    plt.plot( [0,80], [0,80], color='k', linestyle='--', linewidth=2, zorder=1, alpha=0.5 )

    a1 = {}
    a2 = {}
    for k,v in clade_list.items():
        a1[k] = np.sort( ages[ mn1 ][ k ] )
        a2[k] = np.sort( ages[ mn2 ][ k ] )
    
    for k,v in clade_list.items():
        col = clade_plot[k]['c']
        mrk = clade_plot[k]['m']
        plt.scatter(x=np.mean( a1[k] ),
                    y=np.mean( a2[k] ),
                    c=col,
                    edgecolor='black',
                    marker=mrk,
                    alpha=1.0,
                    label=k,
                    s=40,
                    zorder=2)

    plt.legend(loc='upper left')

    for k,v in clade_list.items():
        col = clade_plot[k]['c']
        mrk = clade_plot[k]['m']
        plt.scatter(x=a1[k],
                    y=a2[k],
                    c=col,
                    linewidths=0,
                    marker=mrk,
                    s=20,
                    alpha=0.2,
                    label=k,
                    zorder=1)


        A = np.vstack([a1[k], np.ones(len(a1[k]))]).T
        m, c = np.linalg.lstsq(A, a2[k])[0]
        plt.plot(a1[k], m*a1[k] + c, c=col, zorder=1, alpha=0.5)
        
    plt.xlabel( mn1 )
    plt.ylabel( mn2 )
    
    #ax1.set_xscale("log", nonposx='clip')
    #ax1.set_yscale("log", nonposx='clip')
    plt.xlim(0,80)
    plt.ylim(0,80)
    
    
    # How to interpret this plot:
    # - Clade ages for M1 on the x-axis, M2 on the y-axis.
    # - Markers correspond to sorted clade ages (color gives clade)
    # - M1 gives older clade ages if below the 1:1 line (intercept < 1), and
    #     younger clade ages if above the 1:1 line (intercept > 1).
    # - M1 and M2 are equally precise if they are parallel to dotted line
    #     (slope = 1). M1 is less precise if the scatter is horizontal
    #     (slope < 1). M1 is more precise if the scatter is vertical
    #     (slope >1).

    plt.savefig(plot_fp + 'age_scatter.' + mn1 + '___vs___' + mn2 + '.pdf')
    plt.close()
    #plt.show()
    
    return

In [7]:
def plot_ratio(ages, clade_list, clade_plot, mn1, mn2):

    # format clade age ratios
    z_list = []
    for k,v in clade_list.items():
        z = np.sort( ages[ mn1 ][ k ] ) / np.sort( ages[ mn2 ][ k ] )
        z_list.append(z)

    # initialize plot
    fig = plt.figure(figsize=(8, 8))
    ax1 = fig.add_subplot(111)

    # make violin plot
    violin_parts = plt.violinplot(z_list, showmeans = True, showextrema = True, vert=False )

    # format colors
    my_colors = [ v['c'] for v in clade_plot.values() ]
    for i,pc in enumerate(violin_parts['bodies']):
        pc.set_facecolor( my_colors[i] )
        pc.set_edgecolor('black')

    # format axes/ticks/labels
    labels = clade_plot.keys()
    ax1.get_yaxis().set_tick_params(direction='out')
    ax1.set_yticks(np.arange(1, len(labels) + 1))
    ax1.set_yticklabels(labels)
    ax1.set_ylim(0.25, len(labels) + 0.75)
    ax1.set_ylabel('Clade')

    ax1.get_xaxis().set_tick_params(direction='out')
    ax1.set_xscale("log", nonposx='clip')
    ax1.set_xticks( [1/4, 1/2, 1, 2, 4])
    ax1.set_xticklabels( [1/4, 1/2, 1, 2, 4] )
    ax1.set_xlim(1/4,4)
    ax1.set_xlabel('Ratio of posterior age,\n( '+ mn1 +" / "+ mn2 + ' )')

    plt.savefig(plot_fp + 'age_ratio.' + mn1 + '___vs___' + mn2 + '.pdf')
    plt.close()
    #plt.show()

In [8]:


compare_model_list = [
    ['D1F1', 'D1F0'],
    ['D2F1', 'D2F0']
]
    
    
    
def plot_all(ages, clade_list, clade_plot, models):
    for mn1,mn2 in models:
        plot_scatter( ages, clade_list, clade_plot, mn1, mn2 )
        plot_ratio( ages, clade_list, clade_plot, mn1, mn2 )
            
            
plot_all(ages,clade_list,clade_plot,compare_model_list)

