# ARG likelihood
This notebook demonstrates the calculation of ARG likelihoods on both simulated (ipcoal generated) and empirical (ARGweaver inferred) ARGs. The likelihood are calculated as the summed loglikelihoods of coalescent times in the set of observed trees, and the loglikelihoods of the waiting distances between recombination events. Here we show that these likelihoods are more informative when waiting distances are represented by multiple categories of recombination events (i.e., any recomb event, tree-change events, and/or topo-change events). 

In [1]:
import toytree
import ipcoal
import numpy as np
import pandas as pd
from scipy import stats
import gzip
from pathlib import Path

### The species tree used for simulations
This is the species tree that was used for those simulations.

In [2]:
# TRUE species tree has a divergence time of 500K
sptree = toytree.rtree.imbtree(2, treeheight=5e5)
sptree = sptree.set_node_data("Ne", {0: 4e5, 1: 2e5, 2: 3e5})
sptree = sptree.set_node_data("name", [0, 1, 2])
tmp = ipcoal.Model(sptree, nsamples=4)
tmp.draw_demography();
imap = tmp.get_imap_dict()

# however, ARGweaver modified species_tree has divergence rounded down to ~450K
sptree2 = sptree.set_node_data("height", {2: 450_000})

### The ARG used to simulate sequences
This dataframe has the true genealogies and their intervals on top of which the sequence data was generated.

In [3]:
# the true simulated ARG
ARG_DF = pd.read_csv("./ARG-423-N1e5-True.csv")

# add a column with genealogies as simplified (remove unary nodes) ToyTree objects
ARG_DF['gtrees'] = [toytree.tree(i).mod.remove_unary_nodes() for i in ARG_DF["genealogy"]]

# show an example true genealogy from the ARG
ARG_DF.gtrees[3].draw();

### ARGweaver data
This is a directory of results from running ARGweaver on simulated sequences, from notebook XXXX.

In [4]:
# get a list of 10K simulated ARGs
ARGDIR = Path("../ARG-data/")
ARGLIST = sorted(ARGDIR.glob("ARG-423-N1e5.*.smc.gz"), key=lambda x: int(x.name.rsplit(".")[-3]))
ARGLIST[:5]

[PosixPath('../ARG-data/ARG-423-N1e5.0.smc.gz'),
 PosixPath('../ARG-data/ARG-423-N1e5.10.smc.gz'),
 PosixPath('../ARG-data/ARG-423-N1e5.20.smc.gz'),
 PosixPath('../ARG-data/ARG-423-N1e5.30.smc.gz'),
 PosixPath('../ARG-data/ARG-423-N1e5.40.smc.gz')]

In [5]:
def iter_argweaver_relabeled_trees(smc_file: str, topo: bool = False):
    """Return a generator of relabeled trees from a SMC.gz file.

    Parameters
    ----------
    smc_file: Path or str
        A file path to a smc.gz file produced by ARGweaver.
    topo: bool
        If True then intervals are only returned for topo-changes, where the
        first tree is returned to represent the interval.

    Example
    -------
    >>> igen = get_argweaver_relabeled_tree_generator(ARG.0.smc.gz, topo=False)
    >>> next(igen)
    >>> # (0, 100, <toytree.ToyTree at 0x7f50ff8955d0>)
    """
    # 
    last_topo = None
    last_topo_id = None
    last_topo_id_start = None
    last_topo_id_stop = None

    # 
    with gzip.open(smc_file, 'rb') as idata:
        
        # get int to tip label translation dict
        trans = idata.readline().decode().strip().split()[1:]
        trans = {str(i): j for (i, j) in enumerate(trans)}

        # skip the header line that starts with "REGION"
        idata.readline()
        
        # iterate over lines in the file
        for line in idata:
            line = line.decode()

            # parse line if starts with "TREE"
            if line.startswith("TREE"):
                _, start, stop, nhx = line.split()
                start = int(start) - 1
                stop = int(stop)
                tree = toytree.tree(nhx, feature_prefix="&&NHX:", feature_delim=":")
                tree.set_node_data("name", {i: trans[i.name] for i in tree[:tree.ntips]}, inplace=True)

            # yield a tree on the next breakpoint (starts with "SPR")
            else:
                # yield the interval that just finished
                if not topo:
                    yield start, stop, tree

                # maybe yield the interval that just finished
                else: 
                    # record the new topo ID
                    topo_id = tree.get_topology_id(include_root=True)                     
                        
                    # if this is the first interval, store it
                    if not last_topo_id:
                        last_topo = tree
                        last_topo_id = topo_id
                        last_topo_id_start = start

                    # for all later trees, check if topo changed
                    if topo_id != last_topo_id:
                        last_topo_id_stop = start
                        yield last_topo_id_start, last_topo_id_stop, last_topo
                        last_topo = tree
                        last_topo_id = topo_id
                        last_topo_id_start = start
    return

In [6]:
igen = iter_argweaver_relabeled_trees(ARGLIST[0], topo=0)
for i in range(10):
    start, stop, tree = next(igen)
    print(start, stop, tree.get_topology_id(include_root=True))

0 106 45b007d5a28f67f66de8607ebf31c1ae
106 174 9e53caea9c486c94b2d0d08cf8adb418
174 478 fd47cfe2345fc56c3817a7738cc1984b
478 542 cb824fbf14ee4fb09cd3f92c016f00ff
542 845 0d1a3570d6a70ea7142c9bd5c871c31b
845 877 fd7403f1f1c6b9362991e619b347b8c9
877 924 430764a4bcfe664cd155b1dad3e9ffd9
924 1029 430764a4bcfe664cd155b1dad3e9ffd9
1029 1050 430764a4bcfe664cd155b1dad3e9ffd9
1050 1415 932a17faedb7be3ee1d8cfb6f94eeefd


In [7]:
igen = iter_argweaver_relabeled_trees(ARGLIST[0], topo=1)
for i in range(10):
    start, stop, tree = next(igen)
    print(start, stop, tree.get_topology_id(include_root=True))

0 106 45b007d5a28f67f66de8607ebf31c1ae
106 174 9e53caea9c486c94b2d0d08cf8adb418
174 478 fd47cfe2345fc56c3817a7738cc1984b
478 542 cb824fbf14ee4fb09cd3f92c016f00ff
542 845 0d1a3570d6a70ea7142c9bd5c871c31b
845 877 fd7403f1f1c6b9362991e619b347b8c9
877 1050 430764a4bcfe664cd155b1dad3e9ffd9
1050 1415 932a17faedb7be3ee1d8cfb6f94eeefd
1415 1920 4ecaac0935491c458df57f3d653e68c1
1920 2012 df831ca05898b07ea69a45d603995176


In [8]:
def get_trees_and_intervals_from_argweaver_file(smc_file: Path, topo: bool = False):
    """Return the trees and intervals for all recomb events in an SMC.gz file.
    
    Parameters
    ----------
    smc_file: Path or str
        A file path to a smc.gz file produced by ARGweaver.
    topo: bool
        If True then intervals are only returned for topo-changes, where the
        first tree is returned to represent the interval.

    Example
    -------
    >>> trees, intervals = get_trees_and_intervals_from_argweaver_file(ARG.0.smc.gz, topo=False)
    >>> print(trees)
    >>> # (<toytree.ToyTree at 0x7f50ff6f5090>, <toytree.ToyTree at 0x7f50fe8eee30>, ...
    >>> print(intervals)
    >>> # array([157., 285., 113., 285.])
    """
    igen = iter_argweaver_relabeled_trees(smc_file, topo=topo)
    starts, stops, trees = zip(*[i for i in igen])
    intervals = (np.array(stops) - np.array(starts)).astype(np.float64)
    return trees, intervals

In [9]:
# example: parse all intervals from one ARG
trees, intervals = get_trees_and_intervals_from_argweaver_file(ARGLIST[0], topo=False)

In [10]:
# example: parse all intervals from one ARG
trees, intervals = get_trees_and_intervals_from_argweaver_file(ARGLIST[0])
trees[:4], intervals[:4]

((<toytree.ToyTree at 0x7f7de0398610>,
  <toytree.ToyTree at 0x7f7de0398ca0>,
  <toytree.ToyTree at 0x7f7de0399330>,
  <toytree.ToyTree at 0x7f7de03999c0>),
 array([106.,  68., 304.,  64.]))

In [11]:
# show the first four trees in this ARG
toytree.mtree(trees[:10]).draw(shape=(2, 5), width=1000, ts='c', shared_axes=True, scale_bar=True, fixed_order=trees[0].get_tip_labels());

### Likelihood calculation of coalescent times
Calculate the likelihood of coalescent times among gene trees in an ARG, given the species tree model.

In [12]:
# this is the classic MSC calculation where each tree is weighted equally. Not what we want.
ipcoal.msc.get_msc_loglik(sptree2, trees, imap)

75316.06039218868

In [13]:
# this is a weighted MSC calculation where each tree is weighted by its proportion of the chromosome length.
ipcoal.msc.get_msc_loglik(sptree2, trees, imap, intervals)

74734.30905617788

### Likelihood calculation of waiting distances
Calculate the likelihood of the distances between recombination events given the species tree, genealogies, and recombination rate. This can be calculated in three ways. The simplest does not use the species tree, and simply calculates the likelihood of any recombination event occurring given the sum gene tree branch lengths and recombiation rate (event_type=0). The next type is to calculate the likelihood of a recombination event that causes a tree-change (topology or coal time change), given the species tree, gene tree, and recombination rate (event_type=1). The third is the likelihood of a topology-change, which is a subset of changes that cause a tree-change (event_type=2). We can calculate each of these separately, and we can examine them individually, or together. 


In [18]:
# likelihood of distances between ANY recomb event types
ipcoal.smc.get_ms_smc_loglik(sptree, trees, imap, 2e-8, intervals, event_type=0)

10586.263227383479

In [19]:
# likelihood of distances between tree-change event types
ipcoal.smc.get_ms_smc_loglik(sptree, trees, imap, 2e-8, intervals, event_type=1)

8694.494996632146

In [20]:
# likelihood of distances between tree-change event types
ipcoal.smc.get_ms_smc_loglik(sptree, trees, imap, 2e-8, intervals, event_type=2)

6219.896001700405

### Testing
Get a distribution of loglikelihoods for ...

In [36]:
data = np.zeros(shape=(3, len(ARGLIST)))

for aidx, argfile in enumerate(ARGLIST):
    # get data and embedding for trees
    trees, tree_dists = get_trees_and_intervals_from_argweaver_file(argfile)
    G = ipcoal.smc.TreeEmbedding(sptree, trees, imap, nproc=4)

    # get data and embedding for topos
    topos, topo_dists = get_trees_and_intervals_from_argweaver_file(argfile, topo=True)
    T = ipcoal.smc.TreeEmbedding(sptree, topos, imap, nproc=4)

    # get the gene tree likelihoods over all intervals weighted
    msc_logliks = ipcoal.msc.get_msc_loglik_from_embedding(G.emb, tree_dists)
    smc_logliks_tree = ipcoal.smc.get_ms_smc_loglik_from_embedding(G, 2e-8, tree_dists, event_type=1)
    smc_logliks_topo = ipcoal.smc.get_ms_smc_loglik_from_embedding(T, 2e-8, topo_dists, event_type=2)

    data[:, aidx] = msc_logliks, smc_logliks_tree, smc_logliks_topo
    if not aidx % 10:
        print(aidx)

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590
600
610
620
630
640
650
660
670
680
690
700
710
720
730
740
750
760
770
780
790
800
810
820
830
840
850
860
870
880
890
900
910
920
930
940
950
960
970
980
990
1000
1010
1020
1030
1040
1050
1060
1070
1080
1090
1100
1110
1120
1130
1140
1150
1160
1170
1180
1190
1200
1210
1220
1230
1240
1250
1260
1270
1280
1290
1300
1310
1320
1330
1340
1350
1360
1370
1380
1390
1400
1410
1420
1430
1440
1450
1460
1470
1480
1490
1500
1510
1520
1530
1540
1550
1560
1570
1580
1590
1600
1610
1620
1630
1640
1650
1660
1670
1680
1690
1700
1710
1720
1730
1740
1750
1760
1770
1780
1790
1800
1810
1820
1830
1840
1850
1860
1870
1880
1890
1900
1910
1920
1930
1940
1950
1960
1970
1980
1990
2000
2010
2020
2030
2040
2050
2060
2070
2080
2090
2100
2110
2120
2130
2140
2150
2160
2170
2180
2190
2200
2210
2

In [37]:
np.savetxt("./ARG-loglik-results.csv", data, delimiter=",")

In [447]:
logliks0, logliks1, logliks2 = data

In [451]:
logliks0.mean()

58638.74500539014

In [441]:
stats.entropy(data[0])

9.209942831422431

In [442]:
stats.entropy(data[1])

9.210393151825498

In [443]:
stats.entropy(data[2])

9.210343480251167

In [444]:
stats.entropy(data[0] + data[1] + data[2])

9.210077650484655

In [445]:
stats.entropy(data[0], data[0] + data[1] + data[2])

1.1910675081544363e-05

## Question
- Is the top 10% of MSC scored ARGs associated with better per-site RF scores?
- Is the top 10% of MSC+SMC scores ARGS associated with better per-site RF scores than the top 10% of MSC ARGs?

In [513]:
avg_loglik_model1 = pk = data[0]#.mean()
avg_loglik_model2 = (data[0] + data[1] + data[2])#.mean()
qk = np.ones(data[0].size)

In [514]:
stats.entropy(avg_loglik_model1 * qk, avg_loglik_model2, )

1.1910675081544363e-05

In [515]:
stats.entropy(avg_loglik_model2 * qk, avg_loglik_model1, )

1.1911203517789198e-05

In [440]:
avg_loglik_model1, avg_loglik_model2

(58638.74500539014, 72079.8273484521)

In [424]:
stats.entropy(qk), stats.entropy(pk)

(9.210440366976519, 0.0)

In [60]:
stats.entropy(pk) + stats.entropy(qk, pk)

9.210089561159736

In [378]:
# H(T, q)
-np.log2(data[0]).mean()

-15.83884927777721

In [402]:
sum(data[0] * np.log(data[0] / (data[0] + data[1])))

-78644443.88530296

In [403]:
# H(T, q)
-np.log2(data[0] + data[1]).mean()

-16.032460514726775

In [380]:
# H(T, q)
-np.log2(data[0] + data[1] + data[2] + np.random.uniform(-1, 1, data[0].size)).mean()

-16.136785161828534

In [381]:
# H(T, q)
-np.log2(data[0] + data[1] + data[2]).mean()

-16.13678499428141

In [399]:
stats.entropy(data[0], data[0] + data[1] , base=2)

7.4536286355732786e-06

In [84]:
import toyplot
c, a, m = toyplot.bars(np.histogram(data[0], bins=50), width=500, height=250);
a.bars(np.histogram(data[0] + data[2], bins=50));
a.bars(np.histogram(data[0] + data[1] * 2 + data[2] * 2, bins=50));

In [21]:
ipcoal.smc.get_ms_smc_loglik_from_embedding(E, 2e-8, intervals, event_type=1)

nan

In [22]:
ipcoal.smc.get_ms_smc_loglik_from_embedding(E, 2e-8, intervals, event_type=1)

nan

In [23]:
# ...
ipcoal.smc.get_ms_smc_loglik(sptree, trees, imap, 2e-8, intervals, event_type=1)

nan

In [30]:
stats.entropy([3, 4, 5])

1.0775563270668007