In [1]:
import itertools
import pandas as pd
import numpy as np
import toytree
import toyplot
import arviz as az
import pymc3 as pm

In [78]:
# Load tree.
TREE = toytree.tree("/home/henry/oaks-thesis/full_crown.tre").drop_tips(names = ['Quercus|Quercus|Leucomexicana|Q.species', 
                                                                                 'Quercus|Quercus|Roburoids|Q.vulcanica',
                                'Quercus|Quercus|Roburoids|Q.imeretina', 'Quercus|Virentes|nan|Q.sagraeana',
                                'Quercus|Lobatae|Erythromexicana|Q.lowilliamsii', 
                                 'Quercus|Lobatae|Agrifoliae|Q.oxyadenia', 'Quercus|Quercus|Roburoids|Q.kotschyana',
                                'Quercus|Quercus|Roburoids|Q.cedrorum', 'Quercus|Quercus|Dumosae|Q.pacifica',
                                 'Quercus|Lobatae|Erythromexicana|Q.sartorii',
                                'Quercus|Lobatae|Agrifoliae|Q.tamalpaiensis','Quercus|Lobatae|Agrifoliae|Q.shrevei',
                                 'Cerris|Cyclobalanopsis|Semiserrata|Q.litoralis', 
                                 'Cerris|Cyclobalanopsis|Acuta|Q.ciliaris', 'Cerris|Cyclobalanopsis|Acuta|Q.stewardiana',
                                'Cerris|Cyclobalanopsis|Semiserrata|Q.patelliformis',
                                'Cerris|Cyclobalanopsis|Glauca|Q.multinervis', 'Cerris|Ilex|Himalayansubalpine|Q.sp.nov.'])

# Drop species with more than one tip because I have no means to prefer one.
TREE = TREE.drop_tips(wildcard = "1")
TREE = TREE.drop_tips(wildcard = "2")

# Scale tree.
TREE = TREE.mod.node_scale_root_height(1.0)

In [79]:
TREE.draw(height = 3000, width = 1500, scalebar = True);

In [4]:
# True param values
𝛼_mean = 0.05
𝛼_std = 0.02
𝛽_mean = 3.0
𝛽_std = 0.2
𝜓_mean = 0.0
𝜓_std = 0.2

In [5]:
# 8 different clade effects on rate of RI (used for partial-pooling data)
𝜓_Quercus_mean = 0.8
𝜓_Quercus_std = 0.2
𝜓_Virentes_mean = -1.0
𝜓_Virentes_std = 0.1
𝜓_Ponticae_mean = -1.0
𝜓_Ponticae_std = 0.1
𝜓_Protobalanus_mean = -0.6
𝜓_Protobalanus_std = 0.15
𝜓_Lobatae_mean = 1.2
𝜓_Lobatae_std = 0.2
𝜓_Cyclobalanopsis_mean = 1.0
𝜓_Cyclobalanopsis_std = 0.2
𝜓_Ilex_mean = -0.4
𝜓_Ilex_std = 0.1
𝜓_Cerris_mean = -0.4
𝜓_Cerris_std = 0.1

In [6]:
# Get crown nodes for eight clades.
crowns = [
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Quercus|Quercus"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Quercus|Virentes"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Quercus|Ponticae"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Quercus|Protobalanus"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Quercus|Lobatae"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Cerris|Cyclobalanopsis"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Cerris|Ilex"),
    TREE.get_mrca_idx_from_tip_labels(wildcard = "Cerris|Cerris")
]
crowns

[396, 397, 410, 423, 431, 433, 426, 427]

In [7]:
tips = len(TREE.get_tip_labels())
Cerris_tips = len(TREE.get_tip_labels(427))
Ilex_tips = len(TREE.get_tip_labels(426))
Cyclobalanopsis_tips = len(TREE.get_tip_labels(433))
Lobatae_tips = len(TREE.get_tip_labels(431))
Protobalanus_tips = len(TREE.get_tip_labels(423))
Ponticae_tips = len(TREE.get_tip_labels(410))
Virentes_tips = len(TREE.get_tip_labels(397))
Quercus_tips = len(TREE.get_tip_labels(396))

In [8]:
SPECIES_DATA = pd.DataFrame({
    "Species": ["Quercus " + "{}".format(
        TREE.idx_dict[idx].name.split("|")[-1].split(".")[-1]) for idx in range(len(TREE.get_tip_labels()))],
    "𝛽": np.random.normal(𝛽_mean, 𝛽_std, tips),
    "𝜓": np.random.normal(𝜓_mean, 𝜓_std, tips),
    "𝜓_x": np.concatenate([
        np.random.normal(𝜓_Quercus_mean, 𝜓_Quercus_std, Quercus_tips),
        np.random.normal(𝜓_Virentes_mean, 𝜓_Virentes_std, Virentes_tips),
        np.random.normal(𝜓_Ponticae_mean, 𝜓_Ponticae_std, Ponticae_tips),
        np.random.normal(𝜓_Protobalanus_mean, 𝜓_Protobalanus_std, Protobalanus_tips),
        np.random.normal(𝜓_Lobatae_mean, 𝜓_Lobatae_std, Lobatae_tips),
        np.random.normal(𝜓_Ilex_mean, 𝜓_Ilex_std, Ilex_tips),
        np.random.normal(𝜓_Cerris_mean, 𝜓_Cerris_std, Cerris_tips),
        np.random.normal(𝜓_Cyclobalanopsis_mean, 𝜓_Cyclobalanopsis_std, Cyclobalanopsis_tips),
    ]),
    "gidx": np.concatenate([
        np.repeat(0, Quercus_tips),
        np.repeat(1, Virentes_tips),
        np.repeat(2, Ponticae_tips),
        np.repeat(3, Protobalanus_tips),
        np.repeat(4, Lobatae_tips),
        np.repeat(5, Ilex_tips),
        np.repeat(6, Cerris_tips),
        np.repeat(7, Cyclobalanopsis_tips),
    ]),
})

In [9]:
SPECIES_DATA

Unnamed: 0,Species,𝛽,𝜓,𝜓_x,gidx
0,Quercus ajoensis,2.899170,-0.114570,0.609336,0
1,Quercus turbinella,2.976692,0.031704,0.887498,0
2,Quercus toumeyi,3.289941,0.174112,0.621847,0
3,Quercus grisea,3.109816,0.082084,0.710671,0
4,Quercus striatula,3.084657,-0.052822,0.740515,0
...,...,...,...,...,...
214,Quercus rex,2.858612,0.162435,0.839202,7
215,Quercus chungii,3.131602,-0.140917,0.905523,7
216,Quercus delavayi,3.382792,-0.310317,0.721112,7
217,Quercus championii,2.704728,-0.249341,1.032134,7


In [10]:
def get_dist(tree, idx0, idx1):
    "returns the genetic distance between two nodes on a tree"
    dist = tree.treenode.get_distance(
        tree.idx_dict[idx0], 
        tree.idx_dict[idx1],
    )
    return dist

In [16]:
# get all combinations of two sampled taxa
a, b = zip(*itertools.combinations(range(tips), 2))

# organize into DF and get genetic distance between pairs
DATA = pd.DataFrame({
    "sidx0": a,
    "sidx1": b,
    "dist": [get_dist(TREE, i, j) for (i, j) in zip(a, b)],
})

DATA['velo_x'] = (
    np.random.normal(𝛽_mean, 𝛽_std, DATA.shape[0])
    + SPECIES_DATA['𝜓_x'][DATA.sidx0].values
    + SPECIES_DATA['𝜓_x'][DATA.sidx1].values
)
DATA['intercept'] = np.random.normal(𝛼_mean, 𝛼_std, DATA.shape[0])
DATA['RI'] = 6

In [17]:
DATA

Unnamed: 0,sidx0,sidx1,dist,velo_x,intercept,RI
0,0,1,0.002093,4.586177,0.073722,6
1,0,2,0.004185,4.153616,0.068527,6
2,0,3,0.008371,4.106666,0.027490,6
3,0,4,0.016741,3.692733,0.046834,6
4,0,5,0.033482,4.231343,0.049405,6
...,...,...,...,...,...,...
23866,215,217,1.150476,4.964473,0.047460,6
23867,215,218,1.150476,5.052815,0.034662,6
23868,216,217,1.150476,4.928213,0.041324,6
23869,216,218,1.150476,4.766011,0.043744,6


In [18]:
# Load real oak hybrid data.
hybrid = pd.read_csv("/home/henry/oaks-thesis/csv-files/oak-hybrid-table-3.csv")
match = [hybrid['speciesA'][idx] + hybrid['speciesB'][idx] for idx in hybrid.index]

In [60]:
for idx in DATA.index:
    cross1 = SPECIES_DATA['Species'][DATA.sidx0[idx]] + SPECIES_DATA['Species'][DATA.sidx1[idx]]
    cross2 = SPECIES_DATA['Species'][DATA.sidx1[idx]] + SPECIES_DATA['Species'][DATA.sidx0[idx]]
    if cross1 in match[:234]:
        DATA.loc[idx, 'RI'] = 0
    elif cross2 in match[:234]:
        DATA.loc[idx, 'RI'] = 0
    elif cross1 in match[234:]:
        DATA.loc[idx, 'RI'] = 1
    elif cross2 in match[234:]:
        DATA.loc[idx, 'RI'] = 1
    else:
        DATA.loc[idx, 'RI'] = np.nan

In [61]:
DATA

Unnamed: 0,sidx0,sidx1,dist,velo_x,intercept,RI
0,0,1,0.002093,4.586177,0.073722,
1,0,2,0.004185,4.153616,0.068527,
2,0,3,0.008371,4.106666,0.027490,
3,0,4,0.016741,3.692733,0.046834,
4,0,5,0.033482,4.231343,0.049405,
...,...,...,...,...,...,...
23866,215,217,1.150476,4.964473,0.047460,
23867,215,218,1.150476,5.052815,0.034662,
23868,216,217,1.150476,4.928213,0.041324,
23869,216,218,1.150476,4.766011,0.043744,


In [88]:
NSAMPLES = 1000
SAMPLE = DATA.sample(NSAMPLES).copy().reset_index(drop=True)
SAMPLE.head()

Unnamed: 0,sidx0,sidx1,dist,velo_x,intercept,RI
0,59,91,1.854821,5.154967,0.071502,
1,33,175,2.0,3.552985,0.045568,
2,128,162,2.0,3.688345,0.017682,
3,23,126,1.854821,5.127266,0.058095,
4,108,130,0.91181,5.406465,0.002735,


In [87]:
def logit_plot(data):
    canvas = toyplot.Canvas(width=500, height=250)
    ax0 = canvas.cartesian(
        label="pooled data (function)",
        xlabel="Genetic dist.",
        ylabel="Logit function",
        grid=(1, 2, 0),
    )
    ax1 = canvas.cartesian(
        label="pooled data (observation)",
        xlabel="Genetic dist.",
        ylabel="RI",
        grid=(1, 2, 1),
    )

    # points are jittered on x-axis for visibility
    ax0.scatterplot(
        data.dist,
        data.logit,
        size=5,
        opacity=0.33,
        color=toyplot.color.Palette()[0],
    );
    ax1.scatterplot(
        data.dist,
        data.RI,
        size=10,
        opacity=0.2,
        marker="|",
        mstyle={
            "stroke": toyplot.color.Palette()[1],
            "stroke-width": 3,
        },
    );
    return canvas, (ax0, ax1)

In [89]:
logit_plot(SAMPLE);

AttributeError: 'DataFrame' object has no attribute 'logit'

In [191]:
def heatmap_plot(TREE, CLADES, data):
    
    # get canvas size
    canvas = toyplot.Canvas(width=500, height=500);

    # colormap for values between 0-1
    cmap_lower = toyplot.color.LinearMap(
        domain_min=data['logit'].min(), 
        domain_max=data['logit'].max()
    )

    # add tree to canvas
    ax0 = canvas.cartesian(
        bounds=("5%", "25%", "5%", "95%"),
        show=False
    )
    TREE.draw(
        axes=ax0, 
        layout='r', 
        tip_labels=False,
        edge_colors=TREE.get_edge_values_mapped({
            j: toytree.colors[i] for i,j in enumerate(CLADES)
        })
    )
    NSPECIES = TREE.ntips

    # add heatmap to canvas
    ax1 = canvas.table(
        rows=NSPECIES, 
        columns=NSPECIES, 
        bounds=("27%", "88%", "5%", "95%"),
        margin=20
    )
    ax1.cells.cell[:].style = {
        "fill": 'lightgrey',
        "stroke": "none"
    }

    # set values for data
    for idx in data.index:
        ridx, cidx = data.loc[idx, ['sidx0', 'sidx1']]
        rridx = NSPECIES - ridx - 1
        ccidx = NSPECIES - cidx - 1
        ridx, cidx, rridx, ccidx = [
            int(i) for i in (ridx, cidx, rridx, ccidx)
        ]

        # get RI for this cell
        value = data.loc[idx, 'logit']
        ax1.cells.cell[rridx, cidx].style = {
            "fill": cmap_lower.color(value),
            "stroke": "none"
        }
        fill = ("white" if data.at[idx, 'RI'] == 1 else "#262626")
        ax1.cells.cell[ccidx, ridx].style = {
            "fill": fill,
            "stroke": "none"
        }
            
    for idx in range(NSPECIES):
        ax1.cells.cell[NSPECIES - idx - 1, idx].style = {
            "fill": cmap_lower.color(0),
            "stroke": "none"
        }

    # dividers
    ax1.body.gaps.columns[...] = 0.5
    ax1.body.gaps.rows[...] = 0.5
    
    # add a colorbar
    numberline = canvas.numberline("92%", "95%", "92%", "5%")
    numberline.colormap(cmap_lower, style={"stroke-width": 5})
    
    return canvas

In [62]:
def toytrace(trace, var_names, titles):
    """
    Plot posterior trace with toyplot
    """
    nvars = len(var_names)
    
    # setup canvase
    canvas = toyplot.Canvas(width=500, height=200 * nvars)
    
    # store axes
    axes = []
    
    # iter over params
    for pidx, param in enumerate(var_names):
        
        # get param posterior
        posterior = trace.get_values(param)
        
        # setup axes 
        ax = canvas.cartesian(grid=(nvars, 1, pidx))
        ax.y.show = False
        ax.x.spine.style = {"stroke-width": 1.5}
        ax.x.ticks.labels.style = {"font-size": "12px"}
        ax.x.ticks.show = True
        ax.x.label.text = f"param='{titles[pidx]}'"        
        
        # iterate over shape of param
        for idx in range(posterior.shape[1]):
            mags, bins = np.histogram(posterior[:, idx], bins=100)
            ax.plot(bins[1:], mags, stroke_width=2, opacity=0.6)
        axes.append(ax)
    return canvas, axes

In [173]:
def partpooled_logistic(x, y, idx0, idx1, gidx, **kwargs):
    
    # define model
    with pm.Model() as model:
        
        # set up mask
        # missing = np.isnan(y)

        # randomly impute 0 or 1 for missing data
        # π = pm.Uniform('π', 0, 1)
        # y_imputed = pm.Bernoulli('y_imputed', π, observed=y)
        
        # indexers
        sidx0 = pm.Data("spp_idx0", idx0)
        sidx1 = pm.Data("spp_idx1", idx1)
        # sidx0_m = pm.Data('spp_idx0_m', idx0[missing])
        # sidx1_m = pm.Data('spp_idx1_m', idx1[missing])
        gidx = pm.Data("gidx", gidx)

        # parameters and error
        𝜓_mean = pm.Normal('𝜓_mean', mu=0., sigma=5., shape=8)
        𝜓_std = pm.HalfNormal('𝜓_std', 5., shape=8)
        𝜓_offset = pm.Normal('𝜓_offset', mu=0, sigma=1., shape=tips)
        𝜓 = pm.Deterministic('𝜓', 𝜓_mean[gidx] + 𝜓_std[gidx] * 𝜓_offset)
        𝛽 = pm.Normal('𝛽', mu=0., sigma=10., shape=1)
        𝛼 = pm.Normal('𝛼', mu=0., sigma=10., shape=1)
        
        # linear model prediction
        effect = 𝛼 + (𝛽 + 𝜓[sidx0] + 𝜓[sidx1]) * x
        logit = pm.Deterministic("logit", pm.invlogit(effect))
        
        # set up mask
        # y_masked = np.ma.masked_equal(y, value = -1)
        
        # data likelihood (normal distributed errors)
        y_like = pm.Bernoulli("y", p=logit, observed=y)
        
        # predict unobserved values
        # p_pred = pm.Deterministic('p_pred', pm.invlogit(𝛼 + (𝛽 + 𝜓[sidx0_m] + 𝜓[sidx1_m]) * x[missing]))

        # sample posterior, skip burnin
        trace = pm.sample(**kwargs)[1000:]

        # show summary table
        stats = pm.summary(trace)
        
    # organize results
    result_dict = {
        'model': model, 
        'trace': trace,
        'stats': stats,
    }
    return result_dict

In [34]:
# MCMC sampler kwargs
sample_kwargs = dict(
    tune=4000,
    draws=4000,
    target_accept=0.95,
    return_inferencedata=False,
    progressbar=True,
)

In [180]:
SAMPLE = DATA.sample(5000)

In [None]:
# model input
model_args = [
    SAMPLE.dist,
    SAMPLE.RI,
    SAMPLE.sidx0,
    SAMPLE.sidx1,
    SPECIES_DATA.gidx,
]

partpooled = partpooled_logistic(*model_args, **sample_kwargs)

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [𝛼, 𝛽, 𝜓_offset, 𝜓_std, 𝜓_mean]
>BinaryGibbsMetropolis: [y_missing]


In [177]:
partpooled['stats']

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_mean,ess_sd,ess_bulk,ess_tail,r_hat
𝜓_mean[0],0.122,2.833,-5.087,5.401,0.217,0.185,171.0,118.0,172.0,209.0,1.05
𝜓_mean[1],1.045,4.476,-7.658,9.217,0.404,0.287,123.0,123.0,122.0,473.0,1.04
𝜓_mean[2],0.251,4.972,-9.292,8.516,0.814,0.580,37.0,37.0,38.0,150.0,1.08
𝜓_mean[3],1.954,4.011,-5.058,9.797,0.220,0.156,331.0,331.0,330.0,684.0,1.02
𝜓_mean[4],-1.512,2.635,-6.391,3.471,0.306,0.217,74.0,74.0,73.0,96.0,1.05
...,...,...,...,...,...,...,...,...,...,...,...
logit[995],0.000,0.003,0.000,0.001,0.000,0.000,502.0,502.0,77.0,166.0,1.05
logit[996],0.996,0.033,0.990,1.000,0.001,0.000,3131.0,2434.0,89.0,718.0,1.05
logit[997],0.958,0.182,0.909,1.000,0.015,0.011,149.0,149.0,137.0,165.0,1.03
logit[998],0.873,0.318,0.000,1.000,0.030,0.021,114.0,114.0,131.0,219.0,1.04


In [178]:
toytrace(partpooled['trace'], ['𝜓_mean', '𝜓_offset', '𝜓'], ['psi-mean', 'psi-offset', 'psi-spp']);

In [179]:
# show plot of TRUE vs. ESTIMATED rates
c, a, m = toyplot.scatterplot(
    partpooled['trace']['𝜓'].mean(axis=0),         # estimated
    SPECIES_DATA['𝜓_x'],                             # true
    width=400,
    height=250,
    xlabel="ESTIMATED species velocity",
    ylabel="TRUE species velocity",
    color=[toyplot.color.Palette()[i] for i in SPECIES_DATA.gidx],
);