In [None]:
import dendropy
from dendropy import Tree
import math
from dendropy import Node
import operator
import numpy as np
import random as rd
from scipy.stats import zscore

In [None]:
#function that takes the product of multiple numbers, or returns 1 if the list is empty
def prod(factors):
    return reduce(operator.mul, factors, 1)

In [None]:
#names nodes in a tree in tips then tree traversal order for clarity during function testing
def name_nodes(tree):
    current_node_name = len(tree.leaf_nodes())+1
    for (index, node) in enumerate(tree.nodes()):
        if node.taxon:
            node.label = str(node.taxon.label)
        else:
            node.label = str(current_node_name)
            current_node_name += 1
            
def name_edges(tree):
    #give each edge a label
    for (index, edge) in enumerate(tree.preorder_edge_iter()):
        edge.label = str(index)

### Expected Sampling Under Internal Nodes 
- With Simultaneous Sampling
- Can evaluate a portion of the tree starting from a given distance from root
- Can evaluate from a distance greater than the furthest sampled tip

In [None]:
def zipped_sorted_intervals(tree):
    """
        Takes in a tree object and finds information on each interval (marked by time points from root). 
        Returns a list of lists. Each element in the final list represents an interval. Within each interval is:
            1. [start time of interval, end time of interval]
            2. list of active lineages within each interval (the lineage is denoted by the node at its tail)
    """
    #each node and its distance from the most recently sampled tip, sorted in order from most recent to oldest
    zipped_dists = zip(tree.nodes(),tree.calc_node_root_distances(return_leaf_distances_only =False))
    sorted_zipped_dists = reversed(sorted(zipped_dists, key=lambda branches: branches[1])) #sort by farthest to nearest from root
    
    intervals = []
    living_lineages = []
    
    current_start = tree.max_distance_from_root()

    for (node, distance) in sorted_zipped_dists:
        if current_start == distance:
            #if there are multiple lineages sampled at the same time, add them to the current interval and move to the next node
            living_lineages.append(node)

        else:
            #add the interval to the set of intervals 
            intervals.append([[current_start, distance], living_lineages])

            #update the starting distance of the new interval
            current_start = distance

            #add the current node and remove children of the node if there are any
            children = set(node.child_nodes())
            living_lineages = list(children.symmetric_difference(living_lineages))
            living_lineages.append(node)  

    return intervals

#### Allows to start algorithm at cut_dist_from_root (>0) and evaluate tree from that point to the root

In [None]:
def zipped_partial_intervals(tree, cut_dist_from_root): 
    """
    Like zipped_sorted_intervals but takes a cut_distance_from_root, the time point in the tree
    to treat as time 0 (ignores all of the branches and samples that occur more recently from this time)
    """
    zipped_lineages = zipped_sorted_intervals(tree)
    
    if cut_dist_from_root > tree.max_distance_from_root():
        included_intervals =[([float(cut_dist_from_root), tree.max_distance_from_root()],[])]
        included_intervals.extend(zipped_lineages)
        
    else:
        included_intervals = [(interval_endpoints, interval_nodes) for (interval_endpoints, interval_nodes) in zipped_lineages if cut_dist_from_root > interval_endpoints[1]]
        included_intervals[0][0][0] = float(cut_dist_from_root) #if the cut fell in the middle of an interval, change the first included interval to reflect this
    
    return included_intervals

In [None]:
def conditioned_prob_lineage_coal(tree, popsize, cut_dist_from_root = None):
    """
        Takes in a tree object and a constant population size and returns a list containing 
        the probability of coalescence to one lineage for each interval. (can start midway through tree with cut_dist_from_root)
        Those probabilities are conditioned by the probability of not coalescing anywhere else downstream in the tree.
    """
    if cut_dist_from_root:
        intervals = zipped_partial_intervals(tree, cut_dist_from_root)
    else:
        intervals = zipped_sorted_intervals(tree)
    
    interval_length = []
    num_lineages = []

    #for each interval find the length of the interval and the number of lineages present
    for (interval_endpoints, nodes) in intervals:
        interval_length.append(interval_endpoints[0] - interval_endpoints[1])
        num_lineages.append(len(nodes))

    #continuous time probability of coalescence for each interval (not conditioned on time)
    prob_coalescence_in_interval = [1-math.exp(-(float(lineages)/popsize)*length) if num_lineages is not 0 else 0 for (length, lineages) in zip(interval_length, num_lineages)]
    prob_no_coalescence_in_interval = [1-pcoal for pcoal in prob_coalescence_in_interval]

    coalintervals = []
    #find the probability of coalescing with a single lineage conditioned by position in tree (p no coal earlier * pcoal current interval)
    for (index, interval) in enumerate(prob_no_coalescence_in_interval):
        if num_lineages[index] is not 0:
            coalintervals.append(prod(prob_no_coalescence_in_interval[:index])*prob_coalescence_in_interval[index]/num_lineages[index])#the number of lineages 
        else:
            coalintervals.append(0)
    return coalintervals

In [None]:
def pcoal_along_edge(tree, popsize, cut_dist_from_root = None):
    """
        Takes in a tree object and constant population size and returns a dictionary containing each edge's conditioned probability
        of coalescence.(can start midway through tree with cut_dist_from_root)
        Each probability is a value keyed to the corresponding edge and these probabilities are not cumulative
    """
    edge_prob = {}
    for node in tree.nodes():
        edge_prob[node.edge] = 0.0
        
    pcoal = conditioned_prob_lineage_coal(tree, popsize, cut_dist_from_root)
    
    if cut_dist_from_root:
        interval_lineages = zip(*zipped_partial_intervals(tree, cut_dist_from_root))[1]
    else:
        interval_lineages = zip(*zipped_sorted_intervals(tree))[1]

    for (index, interval) in enumerate(interval_lineages): #for the list of nodes in each interval, if it doesn't belong to the dictionary, set the edge connection to it to pcoal (conditioned)
        for node in interval:
            edge_prob[node.edge] += pcoal[index]
            
    edge_prob[tree.seed_node.edge] = 1-sum(edge_prob.values())

    return edge_prob

### Measures of Sampledness for the Entire Tree

In [None]:
def calculate_cumulative_node_prob(tree, popsize):
    """
        Takes in a tree object and constant population size and finds the cumulative expectation of proportion of the 
        tips theoretically sampled under each internal node using the coalescent model.
    """
    prob_lineage = pcoal_along_edge(tree, popsize)

    #dictionary to store the cumulative probability of a new sample coalescing (value) under each node (key)
    cumulative_node_prob ={}

    #look at each internal (non-tip) node
    for node in tree.internal_nodes():
        node_prob = 0

        #iterate through the nodes belonging to the subtree rooted at node
        for subtree_node in node.preorder_iter():

            #look at each edge of the node and add its probability to the cumulative node prob
            for edge in subtree_node.child_edge_iter():
                node_prob += prob_lineage[edge]

        cumulative_node_prob[node] = node_prob
    return cumulative_node_prob

#### Z-Scores

In [None]:
def node_zscores(node_prob_dict, tree):
    """
        Takes in the resulting dictionart from calculate_cumulative_node_prob() as well as the tree and returns
        a dictionary of each internal node's z-score
    """
    keys, vals = zip(*node_prob_dict.items())
    
    #number of tips under each node as well as the total number of tips in the tree
    num_tips = [len(node.leaf_nodes()) for node in keys]
    total_tips = len(tree.leaf_nodes())
    
    #take the difference between the actual number of tips under each interal node and the theoretical number 
    #of internal nodes from the calculated cumulative node prob
    vals = [prob*total_tips-tips for (prob, tips)  in zip(vals, num_tips)]
    
    #scale result and zip back into dictionary ddof = 1 for n-1 df (divides by n-1 instead of n)
    return dict(zip(keys, zscore(vals, ddof=1)))

## Example trees without simultaneous sampling

In [None]:
ebov_j = Tree.get(path="jittered-ebola.nex", schema="nexus")
name_nodes(ebov_j)
popsize = 1
#zipped_sorted_intervals(tr)
#conditioned_prob_lineage_coal(tr, popsize)
#lineages_in_each_interval(tr)
#pcoals = pcoal_along_edge(tr, popsize)
calculate_cumulative_node_prob(ebov_j, popsize)


In [None]:
tr = Tree.get(path="toytree.nex", schema="nexus")
name_nodes(tr)
popsize = 50
#print zipped_sorted_intervals(tr), "\n"
#print conditioned_prob_lineage_coalescence(tr, popsize), "\n"
#print lineages_in_each_interval(tr), "\n"
#print pcoal_along_edge(tr, popsize), "\n"
calculate_cumulative_node_prob(tr, popsize)

### Example Trees with Simultaneous Sampling

In [None]:
ebov = Tree.get(path="ebola.tree", schema="nexus")
name_nodes(ebov)
name_edges(ebov)

tr2 = Tree.get(path="toy-2-multiple-samples-at-t.nex", schema="nexus")
name_nodes(tr2)
name_edges(tr2)

In [None]:
calculate_cumulative_node_prob(ebov, 1)

In [None]:
#calculate_cumulative_node_prob(tr2, 50)
#pcoal_along_edge(tr2, 50) #full tree
#pcoal_along_edge(tr2, 50, 39) #sliced partway through the tree
pcoal_along_edge(tr2, 50, 70) #a slice further than the most recent sampled tip, gives the same probs as the full tree but is given a space holder in the other dependent functions

### Larger Tree ~1600 samples

In [None]:
lg_tr = Tree.get(path="full-ebola.nex", schema="nexus")
name_nodes(lg_tr)
name_edges(lg_tr)
popsize = 1

### Time Slicing (From most recent tip to sliced to the time time_slice)

In [None]:
def time_sliced_lineages_in_interval(tree, slice_time):
    zipped_lineages = zipped_sorted_intervals(tree)

    included_intervals = []
    reached_time_slice = False
    for (interval_endpoints, node_list) in zipped_lineages:    
        if reached_time_slice:
            break
        if interval_endpoints[1] >= float(slice_time):
            reached_time_slice = True
        included_intervals.append([interval_endpoints, node_list])

    return included_intervals

In [None]:
def time_sliced_pcoal_along_edge(tree, popsize, slice_time):
    edge_prob = {}
    pcoal = conditioned_prob_lineage_coal(tree, popsize)
    last_interval, lin_set = zip(*time_sliced_lineages_in_interval(tree, slice_time))
    last_interval = last_interval[-1]

    for (index, interval) in enumerate(lin_set):
        if index == len(lin_set)-1:
            frac_of_interval = (slice_time - last_interval[0])/(last_interval[1]-last_interval[0])
        else:
            frac_of_interval = 1.0
            
        for node in interval:
            if node.edge not in edge_prob:
                edge_prob[node.edge] = pcoal[index]*frac_of_interval

            else:
                edge_prob[node.edge] += pcoal[index]*frac_of_interval

    return edge_prob

In [None]:
def calculate_cumulative_time_sliced_edge_prob(tree, popsize, slice_time):
    if slice_time > tree.max_distance_from_root():
        slice_time = tree.max_distance_from_root()
        
    prob_lineage = time_sliced_pcoal_along_edge(tree, popsize, slice_time)
    cumulative_prob ={}

    #find which lineages to calculate cumulative probabilities for (at the time slice) 
    last_lin_set = zip(*time_sliced_lineages_in_interval(tree, slice_time))[1][-1]

    for lineage_node in last_lin_set:
        #add the initial sliced edge
        prob = prob_lineage[lineage_node.edge]

        #iterate through the nodes belonging to the subtree rooted at node
        for subtree_node in lineage_node.preorder_iter():

            #look at each edge of the node and add its edge probability to the cumulative node prob
            for edge in subtree_node.child_edge_iter():
                prob += prob_lineage[edge]
                
        #add the cumulative probability of everything below the sliced edge to the set of final cumulative probabilities
        cumulative_prob[lineage_node.edge] = prob
    return cumulative_prob