# Compute ARG likelihood similar to SCAR

exploring ideas from the paper "Recombination-aware phylogeographic inference using the structured coalescent with ancestral recombination"

In [1]:
import ipcoal
import toytree
import numpy as np
import pandas as pd

### Get species tree and embedded genealogy

In [2]:
# 1. simulate a true ARG for 1Mb with N (interval, tree) pairs
# 2. create many modifications of this where intervals boundaries are shifted/merged/split
# 3. create many modifications where node heights are shifted
# 4. create many modifications with both shifted
# 5. calculate logliks of each. The true should always be best. 

In [14]:
from scipy import stats
-stats.expon.logpdf(scale=[1, 10, 100, 1000, 10000, 100000], x=1000)

array([1000.        ,  102.30258509,   14.60517019,    7.90775528,
          9.31034037,   11.52292546])

In [4]:
sptree = toytree.rtree.baltree(2, treeheight=5e5)
sptree.set_node_data("Ne", default=2e5, inplace=True)
model = ipcoal.Model(sptree, nsamples=8, discrete_genome=False, ancestry_model="smc_prime", seed_trees=123, recomb=2e-9)
model.sim_trees(1, 1e6)

from ipcoal.smc.src.utils import get_ms_smc_data_from_model
tree_spans, topo_spans, topo_idxs, genealogies = get_ms_smc_data_from_model(model)

spans = list(tree_spans)
G = ipcoal.smc.TreeEmbedding(model.tree, genealogies, model.get_imap_dict(), nproc=4)

# 
print(len(genealogies), 'trees')
sloglik = ipcoal.smc.get_ms_smc_loglik_from_embedding(G, 2e-9, np.array(spans), normalize=True)
print(sloglik)


# comparing different models to the same data, weight each interval equally
for ne in [8e3, 1e4, 5e4, 1e5, 2e5, 5e5]:
    G._update_neffs(np.array([ne, ne, ne]))
    #sloglik2 = ipcoal.smc.get_ms_smc_loglik_from_embedding(
    #    G, 2e-9, event_type=2, lengths=np.array(topo_spans), idxs=np.array(topo_idxs), normalize=0)
    sloglik2 = ipcoal.smc.get_ms_smc_loglik_from_embedding(
        G, 2e-9, event_type=1, lengths=np.array(tree_spans), normalize=1)
    print(ne, sloglik2)

# comparing different data to the same model, weight each interval by its proportional length
sptree.set_node_data("Ne", default=2e5, inplace=True)
for idx in [99, 150, 300, 500]:
    genealogies[idx].set_node_data("height", {genealogies[idx][-1]: genealogies[idx][-1].height + 5000}, inplace=True)
print(len(genealogies), 'trees with mod heights')
sloglik = ipcoal.smc.get_ms_smc_loglik(sptree, genealogies, model.get_imap_dict(), 2e-9, spans, normalize=True)
print(sloglik)

# ...
for idx in [40, 60, 80, 100]:
    genealogies.pop(idx)
    spans[idx - 1] += spans.pop(idx)
print(len(genealogies), 'trees with mod spans')
sloglik = ipcoal.smc.get_ms_smc_loglik(sptree, genealogies, model.get_imap_dict(), 2e-9, spans, normalize=True)
print(sloglik)


7083 trees
6.930987710077821
8000.0 6.790637677921562
10000.0 6.792751956508505
50000.0 6.831524674388762
100000.0 6.871400242694788
200000.0 6.930987710077821
500000.0 7.0292066287870645
7083 trees with mod heights
6.9309885756157215
7079 trees with mod spans
6.931300100262252


In [4]:
[sum(spans) / 3, sum(spans) / 3, sum(spans) / 3]

[33323.073187617374, 33323.073187617374, 33323.073187617374]

In [5]:
genealogies[:3]

[<toytree.ToyTree at 0x7fb12f40a110>,
 <toytree.ToyTree at 0x7fb12f4091e0>,
 <toytree.ToyTree at 0x7fb12f4125f0>]

In [9]:
ipcoal.smc.get_ms_smc_loglik(sptree, genealogies[:3], model.get_imap_dict(), 2e-9, lengths=[33333, 33333, 33333], normalize=1)

278.17359709429275

In [10]:
ipcoal.smc.get_ms_smc_loglik(sptree, genealogies[:2], model.get_imap_dict(), 2e-9, lengths=[50000, 50000], normalize=1)

416.9319098927417

In [8]:
ipcoal.smc.get_ms_smc_loglik(sptree, genealogies[:2], model.get_imap_dict(), 2e-9, lengths=[500, 500], normalize=1)

8.919984680646579

In [39]:
ipcoal.smc.get_ms_smc_loglik(sptree, genealogies[:3], model.get_imap_dict(), 
                             recombination_rate=2e-9, 
                             lengths=[sum(spans) / 3, sum(spans) / 3, sum(spans) / 3], 
                             normalize=0
                            )

834.2765565446084

In [528]:
genealogies[0].draw('p');

In [47]:
from scipy import stats

In [320]:
# sample x exponential waiting distances
rates = [25, 50, 100, 200]
rvs = stats.expon.rvs(scale=rates, size=(100, 4))
len = 100

for rate, r in zip(rates, rvs):
    sumlen = 0
    loglik = []
    for span in r:
        sumlen += span
        if sumlen < 100:
            loglik.append(stats.expon.logpdf(scale=rate, x=span))
        else:
            print(span, sumlen, "#", logliks)

144.24856440471896 188.2221539583129 # [-24.87124497 -24.87124497 -49.92175954]
74.26211604248537 262.4842700007983 # [-24.87124497 -24.87124497 -49.92175954]
326.3132509468216 414.92755493376796 # [-24.87124497 -24.87124497 -49.92175954]
19.4565291330004 434.38408406676837 # [-24.87124497 -24.87124497 -49.92175954]
75.54958147238474 113.10230659136987 # [-24.87124497 -24.87124497 -49.92175954]
308.7016643481691 421.803970939539 # [-24.87124497 -24.87124497 -49.92175954]
165.4473140514186 587.2512849909576 # [-24.87124497 -24.87124497 -49.92175954]
52.28022763768242 143.710852949558 # [-24.87124497 -24.87124497 -49.92175954]
354.9742444024598 498.6850973520178 # [-24.87124497 -24.87124497 -49.92175954]


In [208]:
rates1 = np.array([25, 25, 50, 100])
rates2 = np.array([50, 50, 100])
rates3 = np.array([100, 100])
rates4 = np.array([200])

In [502]:
l = stats.expon.logpdf(scale=np.array([10,10,10,10]), x=rates1) * (rates1 / rates1.sum())
print(-sum(l), l)

9.177585092994047 [-0.60032314 -0.60032314 -1.82564627 -6.15129255]


In [514]:
l = stats.expon.logpdf(scale=np.array([10] * 4), x=rates1) * (rates1 / rates1.sum())
print(-sum(l), l)

9.177585092994047 [-0.60032314 -0.60032314 -1.82564627 -6.15129255]


In [511]:
l = stats.expon.logpdf(scale=rates1, x=rates1) * (rates1 / rates1.sum())
print(-sum(l), l)

5.085309800568133 [-0.52735948 -0.52735948 -1.22800575 -2.80258509]


In [417]:
l = stats.expon.logpdf(scale=rates2, x=rates2) * (rates2 / rates2.sum())
print(-sum(l), l)

5.258596595708119 [-1.22800575 -1.22800575 -2.80258509]


In [414]:
l = stats.expon.logpdf(scale=rates3, x=rates3) * (rates3 / rates3.sum())
print(-sum(l), l)

5.605170185988092 [-2.80258509 -2.80258509]


In [409]:
l = stats.expon.logpdf(scale=rates4, x=rates4) * rates4 #* (rates4 / rates4.sum())
print(-sum(l), l)

1259.6634733096073 [-1259.66347331]


In [472]:
l = stats.expon.logpdf(scale=[1000], x=rates4) * rates4 #* (rates4 / rates4.sum())
print(-sum(l), l)

1421.5510557964274 [-1421.5510558]


In [448]:
import numpy as np
from scipy.special import kl_div, rel_entr

# Define two probability distributions.
p = np.array([0.25, 0.25, 0.5])
q = np.array([0.1, 0.3, 0.6])

# Calculate the KL divergence.
kl_div(p, q)

sum(rel_entr(p, q))

0.09233151537307284

In [449]:
for scale in range(10, 110, 10):
    loglik = 0.
    for span in [50, 50]:
        rate = 1 / scale
        prob = rate * np.exp(-rate * span)
        loglik += np.log(prob) * 0.5#* span
    print(scale, loglik)   
    #print(scale, sum(stats.expon.logpdf(scale=1 / scale, x=[50, 50])) )

10 -7.302585092994046
20 -5.49573227355399
30 -5.067864048328822
40 -4.938879454113936
50 -4.912023005428146
60 -4.927677895555434
70 -4.962780956335073
80 -5.007026634673881
90 -5.05536522588582
100 -5.105170185988092


In [147]:
for scale in range(10, 150, 10):
    loglik = 0.
    for span in [100]:
        rate = 1 / scale
        prob = rate * np.exp(-rate * span)
        loglik += np.log(prob) #* span
    print(scale, loglik)
    #print(scale, sum(stats.expon.logpdf(scale=1 / scale, x=[50, 50])) )

10 -12.302585092994045
20 -7.995732273553991
30 -6.734530714995489
40 -6.188879454113936
50 -5.9120230054281455
60 -5.761011228888767
70 -5.677066670620787
80 -5.632026634673881
90 -5.610920781441377
100 -5.605170185988091
110 -5.609571274883325
120 -5.62082507611538
130 -5.636765219686351
140 -5.655928136895018


In [148]:
for scale in range(10, 150, 10):
    loglik = 0.
    for span in [100]:
        rate = 1 / scale
        prob = rate * np.exp(-rate * span)
        loglik += stats.expon.logpdf(scale=scale, x=span)
    print(scale, loglik)
    #print(scale, sum(stats.expon.logpdf(scale=1 / scale, x=[50, 50])) )

10 -12.302585092994047
20 -7.99573227355399
30 -6.734530714995489
40 -6.188879454113936
50 -5.912023005428146
60 -5.761011228888767
70 -5.677066670620788
80 -5.632026634673881
90 -5.610920781441376
100 -5.605170185988092
110 -5.609571274883326
120 -5.620825076115379
130 -5.636765219686351
140 -5.655928136895018


In [117]:
stats.expon.logpdf(scale=100, x=[100])

array([-5.60517019])

In [118]:
stats.expon.logpdf(scale=50, x=[50, 50])

array([-4.91202301, -4.91202301])

In [95]:
rates = np.array([25, 25, 50])
logliks = stats.expon.logpdf(scale=1/rates, x=rates) / rates
logliks.sum()

-99.66424947390198

In [31]:
# get a species tree with Ne on branches
sptree = toytree.rtree.imbtree(4, treeheight=1e6)
sptree = sptree.set_node_data("Ne", default=100_000)

In [32]:
# simulate an ARG
model = ipcoal.Model(sptree, nsamples={"r0": 3, "r1": 2, "r2": 1, "r3": 1})
model.sim_loci(1, 100_000)
model.df.head()

Unnamed: 0,locus,start,end,nbps,nsnps,tidx,genealogy
0,0,0,9,9,1,0,(r3_0:1114082.4766803677...
1,0,9,59,50,4,1,(r3_0:1114082.4766803677...
2,0,59,358,299,7,2,(r3_0:1023375.4532953024...
3,0,358,4004,3646,145,3,(r3_0:1023375.4532953024...
4,0,4004,4191,187,7,4,(r3_0:1023375.4532953024...


### Get genealogy embedding table

In [13]:
T = ipcoal.smc.TreeEmbedding(sptree, model.df.genealogy, model.get_imap_dict(), 4)

In [29]:
ipcoal.smc.get_ms_smc_loglik_from_embedding(T, 2e-9, model.df.nbps.values, 1, model.df.index.values)

1235.8851339931275

In [12]:
for tidx, tdf in model.df.groupby("tidx"):
    print(tidx, tdf)
    

0    locus  start  end  nbps  nsnps  tidx                    genealogy
0      0      0  583   583     14     0  (r3_0:1405041.0880687532...
1    locus  start  end  nbps  nsnps  tidx                    genealogy
1      0    583  782   199      6     1  (r3_0:1405041.0880687532...
2    locus  start   end  nbps  nsnps  tidx                    genealogy
2      0    782  1157   375     16     2  (r3_0:1405041.0880687532...
3    locus  start   end  nbps  nsnps  tidx                    genealogy
3      0   1157  1173    16      1     3  (r3_0:1279681.8729624752...
4    locus  start   end  nbps  nsnps  tidx                    genealogy
4      0   1173  1210    37      2     4  (r3_0:1279681.8729624752...
5    locus  start   end  nbps  nsnps  tidx                    genealogy
5      0   1210  1559   349     18     5  (r3_0:1279681.8729624752...
6    locus  start   end  nbps  nsnps  tidx                    genealogy
6      0   1559  2820  1261     66     6  (r3_0:1279681.8729624752...
7    locus

In [10]:
table = ipcoal.smc.get_genealogy_embedding_table(sptree, model.df.genealogy, model.get_imap_dict())
table.head(10)

Unnamed: 0,start,stop,st_node,neff,nedges,dist,gidx,edges
0,0.0,24717.13269,0,100000.0,3,24717.13269,0,"[4, 5, 6]"
1,24717.13269,129490.860472,0,100000.0,2,104773.727782,0,"[4, 8]"
2,129490.860472,333333.333333,0,100000.0,1,203842.472861,0,[9]
3,0.0,2894.690133,1,100000.0,2,2894.690133,0,"[2, 3]"
4,2894.690133,333333.333333,1,100000.0,1,330438.6432,0,[7]
5,0.0,666666.666667,2,100000.0,1,666666.666667,0,[1]
6,0.0,1000000.0,3,100000.0,1,1000000.0,0,[0]
7,333333.333333,376389.041841,4,100000.0,2,43055.708507,0,"[7, 9]"
8,376389.041841,666666.666667,4,100000.0,1,290277.624826,0,[10]
9,666666.666667,689733.522144,5,100000.0,2,23066.855477,0,"[1, 10]"


### Calculate the likelihood of a genealogy in an ARG

In [174]:
tgen = model._get_tree_sequence_generator(nsites=1e5)
# sample the tree sequence for this chromosome
tseq = next(tgen)

# get a copy of the tree sequence that has been simplified.
# Because it was simulated with record_full_arg=True there are
# many records of recombination that cause no-change that add
# extra nodes to the trees. We need to simplify these to more
# easily find the intervals at which changes occur, and to have
# simpler trees that can be compared to detect event types.
stseq = tseq.simplify(filter_sites=False)

# get the starting tree in each tree sequence.
tree = tseq.first(sample_lists=True)
simple_tree0 = stseq.first(sample_lists=True)

# iterate over subsequent intervals until the first topology
# change event is observed. Start from that new fresh tree.
while 1:
    tree.next()
    next_simple_tree = stseq.at(tree.interval.left, sample_lists=True)

    # if the topology changed then this break and save tree as
    # this will be our new starting tree.
    dist = next_simple_tree.kc_distance(simple_tree0, lambda_=0)
    print(dist)
    if dist:
        tree1 = next_simple_tree
        break
        

# get sum edge lengths of tree at starting position
tsumlen1 = tree1.get_total_branch_length()

# compute analytical probabilities of change given tree1
toy1 = toytree.tree(tree1.as_newick(node_labels=model.tipdict))
prob_tree = ipcoal.smc.get_probability_tree_change(
    model.tree, toy1, imap)
prob_topo = ipcoal.smc.get_probability_topology_change(
    model.tree, toy1, imap)

# compute lambda_ (rate) of tree/topo change given sptree and tree1
tree_rate = tsumlen1 * prob_tree * recomb
topo_rate = tsumlen1 * prob_topo * recomb

# RECORD FIRST EVENT TYPE ------------------------------------
# iterate over subsequent intervals of non-simplified tree seq
# until each change event type is observed.
observed_topo_dist = 0.
observed_tree_dist = 0.
event_type = None

# advance to next tree in non-simple treeseq and get simplified tree
tree.next()
next_simple_tree = stseq.at(tree.interval.left, sample_lists=True)

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
2.8284271247461903


KeyError: 'r2_1'

In [115]:
ts = model.get_tree_sequence()

In [116]:
tree = ts.first()
nodes = tree.nodes()
node = next(nodes)

In [117]:
ts.edge(11)

Edge(left=0.0, right=1.0, parent=12, child=11, metadata=b'', id=11)

In [124]:
ts.tables_dict["edges"]

id,left,right,parent,child,metadata
0,0,1,7,0,
1,0,1,7,2,
2,0,1,8,1,
3,0,1,8,7,
4,0,1,9,3,
5,0,1,9,4,
6,0,1,10,8,
7,0,1,10,9,
8,0,1,11,5,
9,0,1,11,10,


In [129]:
ts.breakpoints()

<map at 0x7f507236d0f0>

In [133]:
def get_structured_arg_likelihood(
    table: pd.DataFrame,
) -> float:
    """Return likelihood of a genealogy ..."""
    
    likelihood_components = {}
    for idx in table.index[:-1]:
        k = table.nedges[idx]
        n = table.neff[idx]
        
        λc = (k * (k - 1)) / (2 * n)
        likelihood_components[idx] = np.exp(-λc * table.dist[idx])
    return np.prod(list(likelihood_components.values()))

In [134]:
get_structured_arg_likelihood(table)

1.4100341731076422e-61