In [1]:
import toytree
import toyplot
import ipcoal
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
def get_tree_total_length(ttree):
    tot_len = 0
    for node_ in ttree.treenode.traverse():
        if not node_.is_root():
            tot_len += node_.dist
    return(tot_len)
def get_num_edges_at_time(tree, time):
    nodes_above = ([idx for idx, node in tree.idx_dict.items() if node.height > time])
    edges_above = len(nodes_above) + 1
    return edges_above
def get_tree_clade_times(tree):
    nodes_ = []
    heights_ = []
    for curr_node in tree.treenode.traverse():
        if not curr_node.is_leaf():
            nodes_.append(curr_node.get_leaf_names())
            heights_.append(curr_node.height)
    pddf = pd.DataFrame([nodes_,heights_],index=['clades','heights']).T
    return(pddf)
def get_branch_intervals(tr, gt, br):
    '''
    tr = species tree with Ne attribute
    gt = gene tree simulated on that species tree
    br = treenode representing a branch on the tree
    '''
    st_times = get_tree_clade_times(tr)
    gt_times = get_tree_clade_times(gt)
    coalclade = br.get_leaf_names()
    
    ###temp
    st_coal_node = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(br.get_leaf_names()))[0]
    nearest_st_node = st_coal_node
    while ((nearest_st_node.height + nearest_st_node.dist) < br.height):
        if nearest_st_node.is_root():
            break
        nearest_st_node = nearest_st_node.up

    coalclade = nearest_st_node.get_leaf_names()
    ###
    
    
    br_lower = br.height
    br_upper = br_lower + br.dist
    gt_clade_changes = (gt_times.heights < br_upper) & (gt_times.heights > br_lower)
    st_clade_changes = (st_times.heights < br_upper) & (st_times.heights > br_lower)
    st_time_diffed = st_times[st_clade_changes]
    #return(np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades]))

    contains_clade = st_time_diffed[np.array([all(elem in clade for elem in coalclade) for clade in st_time_diffed.clades])]

    if not len(contains_clade.columns):
        contains_clade = pd.DataFrame(columns=['clades','heights'])
    contains_clade = pd.DataFrame([list(contains_clade.clades.append(pd.Series([coalclade]),ignore_index=True)),list(contains_clade.heights.append(pd.Series(br_lower)))],index=['clades','heights']).T
    contains_clade = contains_clade.sort_values('heights')

    all_members = []
    for i in contains_clade.clades:
        all_members.extend(i)
    all_members = np.unique(all_members)
    
    relevant_coals = pd.DataFrame(columns=["heights"])

    if np.sum(gt_clade_changes):
        potential_coals = gt_times[gt_clade_changes]
        relevant_coals = potential_coals[[set(i).issubset(all_members) for i in potential_coals.clades]]
        relevant_coals = relevant_coals.sort_values('heights')

    time_points = np.sort(list(contains_clade.heights) + list(relevant_coals.heights) + [br_upper])
    if int(time_points[-1]) == int(time_points[-2]):
        time_points = time_points[:-1]
    starts = time_points[:-1]
    stops = time_points[1:]
    lengths = stops-starts
    num_to_coal = np.repeat(1,len(starts))
    ne = np.repeat(1,len(starts))
    a_df = pd.DataFrame([starts,stops,lengths,num_to_coal,ne],index=['starts','stops','lengths','num_to_coal','ne']).T
    mids = (a_df.stops + a_df.starts)/2
    interval_reduced_trees=[]
    
    nes = []
    for mid in mids:
        clade = contains_clade.clades.iloc[np.sum(contains_clade.heights<mid)-1]

        cladeNe = tr.treenode.search_nodes(idx=tr.get_mrca_idx_from_tip_labels(clade))[0].Ne
        nes.append(cladeNe)
        reduced_tree = gt.prune(clade)
        interval_reduced_trees.append(reduced_tree.newick)
        
    a_df['reduced_trees'] = interval_reduced_trees
    a_df['mids'] = mids
    a_df['ne'] = nes
    a_df['num_to_coal'] = a_df.apply(lambda x: get_num_edges_at_time(toytree.tree(x['reduced_trees']), x['mids']), axis=1)
    
    return a_df

In [2]:
# make a random tree
tre = toytree.rtree.bdtree(6,time=8e3,seed=12345)

In [3]:
# scale it so that branch lengths that make sense
tre = tre.mod.node_scale_root_height(treeheight=8e3)

In [4]:
# set a random Ne to each node
node_ne_dict = {i:np.random.randint(1,20000) for i in range(tre.nnodes)} # Ne drawn randomly between 1 and 20000
tre = tre.set_node_data('Ne',node_ne_dict)

In [5]:
tre.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

In [6]:
# define the model
mod = ipcoal.Model(tre,Ne=None,seed_trees=123)
# simulate a gene tree
mod.sim_trees(1)

In [7]:
# extract the gene tree individually
gtr = toytree.tree(mod.df.genealogy[0])
# draw it
gtr.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);

In [8]:
# sptree 1 has high ILS only from high Ne w/ gentime=1
SPTREE1 = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123)
SPTREE1 = SPTREE1.set_node_data("Ne", default=2e5)
MODEL1 = ipcoal.Model(SPTREE1, seed_trees=123)
MODEL1.sim_trees(1, 1)
GTREE1 = toytree.tree(MODEL1.df.genealogy[0])

In [9]:
df = get_branch_intervals(SPTREE1, GTREE1, GTREE1.idx_dict[9])
df

Unnamed: 0,starts,stops,lengths,num_to_coal,ne,reduced_trees,mids
0,759841.8,1000000.0,240158.240457,1,200000.0,"(r0:759842,r1:759842);",879920.9
1,1000000.0,1182958.0,182957.564941,3,200000.0,"((r0:759842,r1:759842)0:...",1091479.0
2,1182958.0,2029994.0,847036.077558,2,200000.0,"((r0:759842,r1:759842)0:...",1606476.0


In [10]:
def calc_P_bT(df):
    last_index = len(df.starts)-1

    full_branch_summation = 0
    full_branch_start = df['starts'][0]
    full_branch_stop = df['stops'][last_index]

    for interval_index in range(len(df)):
        ai = df['num_to_coal'][interval_index]
        ni = df['ne'][interval_index]*2######################
        sigi = df['stops'][interval_index]
        sigb = df['starts'][interval_index]
        Ti = df['lengths'][interval_index]

        first_term = (1/ai)*Ti

        second_expr_second_term = 0
        for int_idx in range(interval_index+1,last_index+1): # for the *full* intervals above t
            # start with the summation
            internal_summation = 0
            if int_idx - interval_index > 1:
                for q_idx in range(interval_index+1,int_idx):
                    aq = df['num_to_coal'][q_idx]
                    nq = df['ne'][q_idx]*2############################
                    Tq = df['lengths'][q_idx]
                    internal_summation += ((aq/nq)*Tq)

            # define the properties of the current interval
            aint = df['num_to_coal'][int_idx]
            nint = df['ne'][int_idx]*2####################################
            Tint = df['lengths'][int_idx]

            # calculate the expressions that are multiplied together
            #first_mult = np.exp((ai/ni)*t)
            second_mult = np.exp(-1*(ai/ni)*sigi - internal_summation)
            third_mult = (1/aint)*(1-np.exp(-1*(aint/nint)*Tint))

            #print(second_mult*third_mult)
            second_expr_second_term += (second_mult*third_mult)*(ni/ai)

        # preventing overflow
        if ((ai/ni)*sigi < 709) and ((ai/ni)*sigb < 709): # prevent overflow...
            second_expr_second_term += -np.exp(-1*(ai/ni)*sigi) * (ni/(ai*ai))
            first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))
        # if there is no internal summation, then the problem simplifies to (e^x-e^y)/e^x , which is 1-e^(y-x)
        elif second_expr_second_term == 0:
            second_expr_second_term +=1
            first_expr_second_term = (1-np.exp((ai/ni)*sigb-(ai/ni)*sigi))* (ni/(ai*ai))
        full_branch_summation += first_term + first_expr_second_term*second_expr_second_term
    return(full_branch_summation * (1/(full_branch_stop-full_branch_start)))

In [11]:
def calc_P_bT(df):
    last_index = len(df.starts)-1

    full_branch_summation = 0
    full_branch_start = df['starts'][0]
    full_branch_stop = df['stops'][last_index]

    for interval_index in range(len(df)):
        ai = df['num_to_coal'][interval_index]
        ni = df['ne'][interval_index]*2######################
        sigi = df['stops'][interval_index]
        sigb = df['starts'][interval_index]
        Ti = df['lengths'][interval_index]

        first_term = (1/ai)*Ti

        second_expr_second_term = 0
        for int_idx in range(interval_index+1,last_index+1): # for the *full* intervals above t
            # start with the summation
            internal_summation = 0
            if int_idx - interval_index > 1:
                for q_idx in range(interval_index+1,int_idx):
                    aq = df['num_to_coal'][q_idx]
                    nq = df['ne'][q_idx]*2############################
                    Tq = df['lengths'][q_idx]
                    internal_summation += ((aq/nq)*Tq)
                    print(interval_index, int_idx, q_idx, internal_summation)
                    
            # define the properties of the current interval
            aint = df['num_to_coal'][int_idx]
            nint = df['ne'][int_idx]*2####################################
            Tint = df['lengths'][int_idx]

            # calculate the expressions that are multiplied together
            #first_mult = np.exp((ai/ni)*t)
            second_mult = np.exp(-1*(ai/ni)*sigi - internal_summation)
            third_mult = (1/aint)*(1-np.exp(-1*(aint/nint)*Tint))

            #print(second_mult*third_mult)
            second_expr_second_term += (second_mult*third_mult)*(ni/ai)
            print("---", interval_index, int_idx, second_expr_second_term)
        
        # preventing overflow
        if ((ai/ni)*sigi < 709) and ((ai/ni)*sigb < 709): # prevent overflow...
            second_expr_second_term += -np.exp(-1*(ai/ni)*sigi) * (ni/(ai*ai))
            first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))
        
        
        # if there is no internal summation, then the problem simplifies to (e^x-e^y)/e^x , which is 1-e^(y-x)
        elif second_expr_second_term == 0:
            second_expr_second_term +=1
            first_expr_second_term = (1-np.exp((ai/ni)*sigb-(ai/ni)*sigi))* (ni/(ai*ai))
        
        full_branch_summation += first_term + first_expr_second_term*second_expr_second_term
        print(f"branch-sum: {first_term}, {first_expr_second_term} {second_expr_second_term} {ai / ni * sigb}")

        
    return(full_branch_summation * (1/(full_branch_stop-full_branch_start)))

In [12]:
np.exp(570)

3.5306501429882274e+247

In [13]:
np.exp(np.log(580) * np.log(10))

2306896.4649905204

In [14]:
round(1e-10, 9)

0.0

In [18]:
ipcoal.set_log_level("INFO")
pd.set_option("precision", 1)


SPTREE = toytree.rtree.unittree(ntips=6, treeheight=1e6, seed=123)
SPTREE = SPTREE.set_node_data("Ne", default=2e5)
MODEL = ipcoal.Model(SPTREE, seed_trees=123, nsamples=2)
MODEL.sim_trees(1, 1)
GTREE = toytree.tree(MODEL.df.genealogy[0])
GIDX = 0
GNODE = GTREE.idx_dict[GIDX]

IMAP = MODEL.get_imap_dict()

TABLE = get_embedded_gene_tree_table(SPTREE, GTREE, IMAP)

In [21]:
from ipcoal.smc.smc4 import get_embedded_path_of_gene_tree_edge, get_embedded_gene_tree_table, get_species_tree_intervals_on_gene_tree_path

In [25]:
#STABLE = get_embedded_path_of_gene_tree_edge(TABLE, SPTREE, GTREE, IMAP, 8)

sidxs = get_species_tree_intervals_on_gene_tree_path(SPTREE, GTREE, IMAP, 8)
gt_node = GTREE.idx_dict[0]
mask0 = TABLE.st_node.isin(sidxs)
mask1 = TABLE.start >= gt_node.height
mask2 = TABLE.stop <= gt_node.up.height
mask0 & mask1 & mask2

0     False
1     False
2      True
3     False
4     False
5     False
6     False
7     False
8     False
9     False
10    False
11    False
12    False
13    False
14    False
15    False
16    False
17    False
18    False
19    False
20    False
21    False
dtype: bool

In [77]:
round(GTREE.idx_dict[8].height, 10)

TypeError: too many arguments: expected 1, got 2

In [56]:
TABLE.start.astype(int), gt_node.height

(0           0
 1       77159
 2           0
 3       34277
 4           0
 5      245858
 6           0
 7           0
 8       16638
 9           0
 10      30175
 11     250000
 12     500000
 13     573191
 14     750000
 15     750000
 16     833780
 17     881690
 18    1000000
 19    1012778
 20    1152952
 21    1619177
 Name: start, dtype: int64,
 2.3283064365386963e-10)

In [479]:
GTREE.idx_dict[0].up.height

245858.80970911912

In [467]:
MODEL.df.genealogy[0]

'(((r5_0:30175.40002395390547,r5_1:30175.40002395390547):982603.50365525123198,(r1_0:34277.17030055852956,r1_1:34277.17030055852956):978501.73337864654604):606398.72462800797075,((r0_0:77159.31915462232428,r0_1:77159.31915462232428):1075793.39868642808869,((r3_0:573191.16259514470585,(r4_0:16638.08667917054481,r4_1:16638.08667917054481):556553.07591597421560):308499.06611437792890,(r3_1:833780.15952253190335,(r2_0:245858.80970911911572,r2_1:245858.80970911911572):587921.34981341275852):47910.06918699073140):271262.48913152772002):466224.91046616272070);'

In [26]:
get_species_tree_intervals_on_gene_tree_path

<function ipcoal.smc.smc4.get_species_tree_intervals_on_gene_tree_path(species_tree: toytree.core.tree.ToyTree, gene_tree: toytree.core.tree.ToyTree, imap: Dict, idx: int)>

In [24]:
from ipcoal.smc.msc import *
import ipcoal
import toyplot

NEFF = 1e3
MODEL = ipcoal.Model(
    None, 
    Ne=NEFF,
    seed_trees=123, 
    nsamples=20,
)
MODEL.sim_trees(100, 1)
IMAP = MODEL.get_imap_dict()
GTREES = toytree.mtree(MODEL.df.genealogy)
COAL_TIMES = np.array([
    sorted(gtree.get_node_data("height")[gtree.ntips:])
    for gtree in GTREES.treelist
])

In [29]:

def get_gene_tree_log_prob_single_pop(neff: float, coal_times: np.ndarray):
    """Return log prob density of a gene tree in a single population.

    All labeled histories have equal probability in a single population
    model, and so the probability of a gene tree is calculated only 
    from the coalescent times.

    Modified from equation 5 of RannalaSpeciesTree.pdf to use edge 
    lens in units of gens, and neff, instead of thetas.
    >>> # 2 / theta = 2 / 4Neu = 1 / 2Neu
    >>> # t_muts = 2 * t_gen * 2 * mu = tgen / 4

    Example
    -------
    >>> model = ipcoal.Model(None, Ne=1e5, nsamples=20)
    >>> model.sim_trees(100)
    >>> coal_times = np.array([
    >>>     sorted(gtree.get_node_data("height")[gtree.ntips:])
    >>>     for gtree in toytree.mtree(model.df.genealogy)
    >>> ])
    >>> xs = np.logspace(2, 7)
    >>> toyplot.plot(
    >>>     xs, 
    >>>     [get_gene_tree_log_prob_single_pop(i, coal_times) for i in xs],
    >>>     xscale="log",
    >>> )
    """
    nlineages = len(coal_times) + 1
    first_term = (1 / (2 * neff)) ** (nlineages - 1)

    # calculate second term from sum over reducing nlineages
    inner_sum = 0
    for nlineages in range(2, nlineages + 1):
        time = coal_times[-nlineages + 1] / 4
        n_choose_2 = comb(nlineages, 2)
        inner_sum += n_choose_2 * time

    second_term = np.exp(-(1 / (2 * neff)) * inner_sum)
    prob = first_term * second_term
    if prob > 0:
        return np.log(prob)
    return np.inf

In [37]:
    >>> model = ipcoal.Model(None, Ne=1e5, nsamples=20)
    >>> model.sim_trees(1)
    >>> gtree = toytree.tree(model.df.genealogy[0])
    >>> coals = np.array(sorted(gtree.get_node_data("height")[gtree.ntips:]))
    >>> xs = np.logspace(2, 7)
    >>> toyplot.plot(
    >>>     xs, 
    >>>     [get_gene_tree_log_prob_single_pop(i, coals) for i in xs],
    >>>     xscale="log", height=300, width=400,
    >>> );

In [55]:
np.exp(0)

1.0

In [38]:
    >>> model = ipcoal.Model(None, Ne=1e5, nsamples=20)
    >>> model.sim_trees(100)
    >>> coal_times = np.array([
    >>>     sorted(gtree.get_node_data("height")[gtree.ntips:])
    >>>     for gtree in toytree.mtree(model.df.genealogy)
    >>> ])
    >>> xs = np.logspace(2, 7)
    >>> toyplot.plot(
    >>>     xs, 
    >>>     [optim_func(i, coal_times) for i in xs],
    >>>     xscale="log", height=300, width=400,
    >>> );

In [33]:
xs = np.logspace(4.5, 6.2, 100)

c, a, m = toyplot.plot(
    xs,
    [get_gt_prob(i, COAL_TIMES[10]) for i in xs],
    width=400, height=300,
);
a.vlines([1e5])

<toyplot.mark.AxisLines at 0x7f2b4a641d30>

In [107]:
0.06 / (5 * (5 - 1))

0.003

In [428]:
msgen = MODEL1._get_tree_sequence_generator(100)
tree_seq = next(msgen)
breaks = [int(i) for i in tree_seq.breakpoints()]
starts = breaks[0:len(breaks) - 1]
ends = breaks[1:len(breaks)]
lengths = [i - j for (i, j) in zip(ends, starts)]

data = pd.DataFrame({
    "start": starts,
    "end": ends,
    "nbps": lengths,
    "nsnps": 0,
    "tidx": 0,
    "locus": 0,
    "genealogy": "",
    },
    columns=[
        'locus', 'start', 'end', 'nbps',
        'nsnps', 'tidx', 'genealogy'
    ],
)

for mstree in tree_seq.trees():
    nwk = mstree.newick(precision=6)
    print(nwk)

((1:893134.198335,2:893134.198335):633472.223359,((3:568628.632875,4:568628.632875):950311.775176,(5:773013.855498,6:773013.855498):745926.552552):7666.013643);


In [429]:
for mstree in tree_seq.trees():
    nwk = mstree.newick()

In [506]:
#toytree.tree(MODEL.df.genealogy[0])
GTREE.idx_dict[0].up.height

245858.80970911912

In [508]:
TABLE.stop

0       77159.3
1      750000.0
2       34277.2
3      750000.0
4      245858.8
5      250000.0
6      250000.0
7       16638.1
8      500000.0
9       30175.4
10     750000.0
11     500000.0
12     573191.2
13     750000.0
14    1000000.0
15     833780.2
16     881690.2
17    1000000.0
18    1012778.9
19    1152952.7
20    1619177.6
21         <NA>
Name: stop, dtype: object

In [430]:
toytree.tree(nwk).get_node_data().astype(int)

Unnamed: 0,height,dist,support,name,idx
0,0,773013,0,6,0
1,0,773013,0,5,1
2,0,568628,0,4,2
3,0,568628,0,3,3
4,0,893134,0,2,4
5,0,893134,0,1,5
6,773013,745926,0,6,6
7,568628,950311,0,7,7
8,1518940,7666,0,8,8
9,893134,633472,0,9,9


In [433]:
GTREE.get_node_data("height").astype(int)

0           0
1           0
2           0
3           0
4           0
5           0
6           0
7           0
8           0
9           0
10          0
11          0
12     245858
13      16638
14     833780
15     573191
16     881690
17      77159
18      34277
19      30175
20    1152952
21    1012778
22    1619177
Name: height, dtype: int64

In [213]:
def draw_trees(species_tree, gene_tree):
    """Return a drawing of species tree, gene tree, and intervals of interest."""
    mtre = toytree.mtree([species_tree, gene_tree])
    canvas, axes, _ = mtre.draw(
        ts='p',
        shared_axes=True,
        scale_bar=True,
        #fixed_order=species_tree.get_tip_labels(),
        node_labels="idx",
        node_labels_style={"baseline-shift": "10px", "font-size": "11px"},
        node_sizes=6,
        height=325, width=500,
    );

    axes[0].label.text = "SPTREE"
    axes[1].label.text = "GTREE"
    for ax in axes:
        ax.hlines(
            species_tree.get_node_data("height").iloc[species_tree.ntips:].unique(),
            style={
                "stroke": toytree.COLORS1[1], 
                "stroke-width": 2,
                "stroke-dasharray": "2,4"},
        )
    return canvas, axes, _

In [180]:
idx = 9
time = 1_500_000

df = get_branch_intervals(SPTREE1, GTREE1, GTREE1.idx_dict[idx])
calc_P_btT(time, df)

0.49750404437120693

In [181]:
calc_P_bT(df)

--- 0 1 8169.611609291846
0 2 1 1.3721817370605311
--- 0 2 12271.9313899456
branch-sum: 240158.24045731744, 5.499243942806366 -20562.068059613917 1.8996043988567066
--- 1 2 9.213737751781636
branch-sum: 60985.85498046805, 5322.779177516553 2.981013531829631 7.5
branch-sum: 423518.0387792416, 25219.813053866103 -3.9077323826418646 5.914787824707021


0.416407919151819

In [18]:
from ipcoal.smc.smc4 import (
    get_embedded_gene_tree_table, 
    get_prob_gene_tree_is_unchanged_by_recomb_event,
    get_species_tree_intervals_on_gene_tree_path,
    get_embedded_path_of_gene_tree_edge,
    get_prob_gene_tree_is_unchanged_by_recomb_on_edge,
)

In [79]:
table = get_embedded_gene_tree_table(SPTREE1, GTREE1, MODEL1.get_imap_dict())

In [82]:
get_prob_gene_tree_is_unchanged_by_recomb_event(table, SPTREE1, GTREE1, MODEL1.get_imap_dict(), 9, 1_500_000)

0.49750404437120693

In [93]:
get_prob_gene_tree_is_unchanged_by_recomb_on_edge(table, SPTREE1, GTREE1, MODEL1.get_imap_dict(), 9)

1.8498270241983283

In [24]:
SPTREE = toytree.rtree.unittree(ntips=6, treeheight=8e8, seed=123)
# SPTREE = SPTREE.set_node_data("Ne", {i: 5e4 for i in (0, 1, 8)}, default=1e5)
SPTREE = SPTREE.set_node_data("Ne", default=20)
MODEL = ipcoal.Model(SPTREE, seed_trees=123, nsamples=1)
MODEL.sim_trees(1, 1)
GTREE = toytree.tree(MODEL.df.genealogy[0])

In [25]:
height = 8e7
# make a random tree
tre = toytree.rtree.bdtree(6,time=height,seed=12345)

In [27]:
# scale it so that branch lengths that make sense
tre = tre.mod.node_scale_root_height(treeheight=height)
# set a random Ne to each node
np.random.seed(22)
node_ne_dict = {i:np.random.randint(1,2000) for i in range(tre.nnodes)} # Ne drawn randomly between 1 and 20000
tre = tre.set_node_data('Ne',node_ne_dict)

In [31]:
np.exp(100) - np.exp(10)

2.6881171418161356e+43

In [47]:
-np.exp(-500)

-7.124576406741286e-218

In [28]:
tre.draw(ts='p')

(<toyplot.canvas.Canvas at 0x7fda14a56a90>,
 <toyplot.coordinates.Cartesian at 0x7fda14a56eb0>,
 <toytree.core.drawing.toytree_mark.ToytreeMark at 0x7fda14a60ac0>)

In [5]:
# define the model
mod = ipcoal.Model(tre,Ne=None,seed_trees=1235)
# simulate a gene tree
mod.sim_trees(1)

In [6]:
# extract the gene tree individually
gtr = toytree.tree(mod.df.genealogy[0])
# draw it
gtr.draw(ts='p',node_labels=True,node_sizes=15,width=500,height=500,node_mask=False);


In [7]:
def calc_P_bT(df):
    last_index = len(df.starts)-1

    full_branch_summation = 0
    full_branch_start = df['starts'][0]
    full_branch_stop = df['stops'][last_index]

    for interval_index in range(len(df)):
        ai = df['num_to_coal'][interval_index]
        ni = df['ne'][interval_index]*2######################
        sigi = df['stops'][interval_index]
        sigb = df['starts'][interval_index]
        Ti = df['lengths'][interval_index]

        first_term = (1/ai)*Ti

        second_expr_second_term = 0
        for int_idx in range(interval_index+1,last_index+1): # for the *full* intervals above t
            # start with the summation
            internal_summation = 0
            if int_idx - interval_index > 1:
                for q_idx in range(interval_index+1,int_idx):
                    aq = df['num_to_coal'][q_idx]
                    nq = df['ne'][q_idx]*2############################
                    Tq = df['lengths'][q_idx]
                    internal_summation += ((aq/nq)*Tq)

            # define the properties of the current interval
            aint = df['num_to_coal'][int_idx]
            nint = df['ne'][int_idx]*2####################################
            Tint = df['lengths'][int_idx]

            # calculate the expressions that are multiplied together
            #first_mult = np.exp((ai/ni)*t)
            second_mult = np.exp(-1*(ai/ni)*sigi - internal_summation)
            third_mult = (1/aint)*(1-np.exp(-1*(aint/nint)*Tint))

            #print(second_mult*third_mult)
            second_expr_second_term += (second_mult*third_mult)*(ni/ai)

        second_expr_second_term += -np.exp(-1*(ai/ni)*sigi) * (ni/(ai*ai))
        first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))
        
        full_branch_summation += first_term + float(first_expr_second_term*second_expr_second_term)
    return(full_branch_summation * (1/(full_branch_stop-full_branch_start)))

In [8]:
df = get_branch_intervals(tre, gtr, gtr.idx_dict[2])
df

Unnamed: 0,starts,stops,lengths,num_to_coal,ne,reduced_trees,mids
0,0.0,2006248.0,2006248.0,1,813,r2:8.00007e+07;,1003124.0
1,2006248.0,13583950.0,11577700.0,1,813,r2:8.00007e+07;,7795098.0
2,13583950.0,13586940.0,2993.551,2,492,"(r2:1.35869e+07,(r0:2.00...",13585440.0


In [9]:
calc_P_bT(df)

  first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))
  full_branch_summation += first_term + float(first_expr_second_term*second_expr_second_term)
  first_expr_second_term = (np.exp((ai/ni)*sigi) - np.exp((ai/ni)*sigb))


nan

In [15]:
from ipcoal.smc.smc4 import *

In [17]:
table = get_embedded_gene_tree_table(tre, gtr, mod.get_imap_dict())
get_prob_gene_tree_is_unchanged_by_recomb_on_edge(table, tre, gtr, mod.get_imap_dict(), 2)

nan

In [22]:
get_prob_gene_tree_is_unchanged_by_recomb_on_edge(table, tre, gtr, mod.get_imap_dict(), 0)

nan