## Setup

To access material for this workbook please execute the two notebook cells immediately below (e.g. use the shortcut <b>&lt;shift&gt;+&lt;return&gt;</b>). The first cell can be skipped if you are running this notebook locally and have already installed all the necessary packages. The second cell should print out "Your notebook is ready to go!"

In [None]:
if 'pyodide_kernel' in str(get_ipython()):  # specify packages to install under JupyterLite
    raise RuntimeError("This workbook is not designed to run in JupyterLite. Please use a Colab or local install")
elif 'google.colab' in str(get_ipython()):  # specify package location for loading in Colab
    from google.colab import drive
    drive.mount('/content/drive')
    %run /content/drive/MyDrive/GARG_workshop/Notebooks/add_module_path.py
else:  # install packages on your local machine (-q = "quiet": don't print out installation steps)
    !python -m pip install -q -r https://github.com/ebp-nor/GARG/raw/main/jlite/requirements.txt

In [None]:
# Load questions etc for this workbook
from IPython.display import SVG
import tskit
import ARG_workshop
workbook = ARG_workshop.Workbook2C()
display(workbook.setup)

### Using this workbook

This workbook is intended to be used by executing each cell as you go along. Code cells (like those above) can be modified and re-executed to perform different behaviour or additional analysis. You can use this to complete various programming exercises, some of which have associated questions to test your understanding. Exercises are marked like this:
<dl class="exercise"><dt>Exercise XXX</dt>
<dd>Here is an exercise: normally there will be a code cell below this box for you to work in</dd>
</dl>

# Workbook 2-C: ARG inference

Although ARG inference is a difficult problem, a number of software programs have been developed in recent years that make a reasonable job at ARG inference. Currently, we recommend either [_SINGER_](https://github.com/popgenmethods/SINGER) (for smaller datasets) or [_tsinfer_](https://github.com/tskit-dev/tsinfer) (coupled with [_tsdate_](https://github.com/tskit-dev/tsdate)) for larger datasets. Both of these approaches output tree sequences. You might also like to consider [Relate](https://myersgroup.github.io/relate/), which has been used extensively in analysis of recent human history (and which can convert output, albeit inefficiently, to _tskit_ format). These three software programs have the same basic requirements: phased data and knowledge of the ancestral state, but are based on very different approaches.

* ___SINGER___ takes a statistical sample of likely ARGs using MCMC sampling. That means it is pretty computationally intensive, so will only work on sample sizes of tens or hundreds of genomes. As _SINGER_ takes a sampling appraoch, it outputs multiple possible ARGs, which are sampled using the coalescent with recombination as a probabilistic model. Testing suggests that this provides the most accurate results.
* ___Tsinfer___, in contrast, creates a single "best guess" ARG using a heuristic approach based on ancestor reconstruction. It does not attempt to estimate times of nodes in the ARG, and unlike _SINGER_, it is therefore not dependent on coalescent assumptions. For many purposes, node times are important, and a separate algorithm, _tsdate_, can be used to infer node times under the molecular clock. We will be use extremly recent versions of the two pieces of software, such that their offical documentation may not yet be up-to-date. The major drawback of the _tsinfer_+_tsdate_ approach is that rather than outputting multiple ARGs, it expresses uncertainty by using *polytomies* (nodes with more than one child). Many analysis approaches are not designed for use on trees with polytomies.
* ___Relate___ also uses a heuristic approach, which essentially involves building separate trees along the genome and then combining them together (but this results in very inefficient tree sequences, with low amounts of correlation between adjacent trees). It also outputs a single ARG, but polytomies are fully resolved, even in the absence of strong evidence for how to resolve them. The accuracy of results is on a par with the latest _tsinfer_+_tsdate_.

In this workbook, we will focus on using the fastest and most efficient method: _tsinfer_ + _tsdate_. Further, advanced workbooks will introduce _SINGER_. To illustrate, we will use the same simulation as we previously encountered, and inpect the quality of the inference by comparing the dates of the nodes below each mutation.

In [None]:
import tskit
simulated_ts = tskit.load("data/chimp_selection.trees")
print(f"Loaded data for {simulated_ts.num_samples} genomes over {simulated_ts.sequence_length} bp")

For efficiency reasons, _tsinfer_ reads data in the [VCF Zarr](https://github.com/sgkit-dev/vcf-zarr-spec/) format, which is designed to allow rapid processing of large amounts of genetic variant data. It should soon be possible to [convert tree sequences directly to this format](https://github.com/sgkit-dev/bio2zarr/issues/232), but for the moment, this needs to be done the slow way, by outputting to VCF and converting to a `.vcz` file using [vcf2zarr](https://sgkit-dev.github.io/bio2zarr/vcf2zarr/overview.html#sec-vcf2zarr) on the command-line. For ease of use in the workbook, the `ARG_workshop.ts2vcz` function will do the hard work for you. Remember that the data need to be *phased*.

In [None]:
# create the .vcz file
ARG_workshop.ts2vcz(simulated_ts, "PanTro-chr3-sim.vcz")


Details for how to use for the new (currently alpha) version of `tsinfer` are at [https://tskit.dev/tsinfer/docs/latest/usage.html](https://tskit.dev/tsinfer/docs/latest/usage.html). Basically, we need to wrap the `.vcz` file in a `tsinfer.VariantData` object. This allows us to specify the ancestral allele, and mask out problematic sites or samples. In this simulated data we don't need any masking. That is likely to change with real data!

In [None]:
import tsinfer
import numpy as np

# We must obtain the ancestral allele from somewhere. Here we use that given in the simulation
# See https://github.com/tskit-dev/tsinfer/discussions/523 for discussion for real data
ancestral_allele=np.array([s.ancestral_state for s in simulated_ts.sites()])

# The VariantData interface is now the preferred way to create tsinfer input files (the old one was called SampleData)
vdata = tsinfer.VariantData("PanTro-chr3-sim.vcz", ancestral_allele=ancestral_allele)


### The 3 steps of tsinfer

The principle behind _tsinfer_ should have been introduced in a talk. The inference algorithm is split into 3 steps:
1. Generating ancestors (ga): [`tsinfer.generate_ancestors()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.generate_ancestors)
2. Matching ancestors (ma): [`tsinfer.match_ancestors()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.match_ancestors) - this is usually the slowest step, and hardest to parallelize.
3. Matching samples (ms): [`tsinfer.match_samples()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.match_samples)

We can run all three in one go by calling `tsinfer.infer()`, but it's more flexible to run them separately, and allows you to save intermediate results, which is useful for large inferences. You can use the `progress_monitor` argument to see how long the steps will take, and the `num_threads` argument to use more CPU cores on your computer


In [None]:
general_params = {
    "num_threads": 6,  # Set to the number of cores on your computer.
    "progress_monitor": True,
}
# On my macbook air with num_threads=6, this takes about 2 and a half minutes for this dataset
# (34.3 thousand constructed ancestors)
ancestors = tsinfer.generate_ancestors(vdata, **general_params)
ancestors_ts = tsinfer.match_ancestors(vdata, ancestors, **general_params)
ts = tsinfer.match_samples(vdata, ancestors_ts, **general_params)

In [None]:
# Execute code block with <shift>+Return to display question; type and press return, or click on the buttons to answer
workbook.question("num_trees")

The inferred tree sequence should encode exactly the same genotype data as the original (apart from imputation, see below). However, the genealogy is very likely to be different (after all, we are vanishingly unlikely to be able to infer the true ARG). There are four main points to note:

* The inferred tree sequence can contain polytomies
* The inferred tree sequence will have the same number of sites as the (masked) input, but can have multiple mutations at a site (this can be adjusted using a "mismatch" parameter, described later)
* The inferred tree sequence has no genealogical information before the first site and after the last site: there are empty trees in those regions (these empty regions can be removed using `ts.trim()`, although that will shift the site positions leftwards.
* Missing data in the original dataset is automatically *imputed* in the inferred tree sequence


In [None]:
print(f"simulation diversity = {simulated_ts.diversity():.6f}, diversity in inferred ts = {ts.diversity():.6f}")
print(f"simulation num_sites = {simulated_ts.num_sites}, num_sites in inferred ts = {ts.num_sites}")
print(f"simulation num_mutations = {simulated_ts.num_mutations}, num_mutations in inferred ts = {ts.num_mutations}")

<dl class="exercise"><dt>Exercise 1</dt>
<dd>Try obtaining a statistic (e.g <code>diversity</code>) on the inferred <code>ts</code>. Then try it again using the *branch length* diversity.</dd>
</dl>

In [None]:
# Exercise: try both site and branch mode statistics

In [None]:
# Execute code block with <shift>+Return to display question; type and press return, or click on the buttons to answer
workbook.question("branchVsite")

However, we can run topological analysis. Here we repeat the GNN and topology-counting we did with the true genealogy, but using the inferred tree sequence.

### GNN

The true genealogy had a boring GNN plot (each population completely separate), and we might hope the signal was so strong that we would see exactly the same in the inferred topologies:

In [None]:
import pandas as pd
import seaborn as sns

sample_sets = {pop.metadata["name"]: ts.samples(population=pop.id) for pop in ts.populations()}
sample_sets = {k: sample_sets[k] for k in ["bonobo", "central", "western"]}  # make sure bonobo first, central next, western last
gnn = ts.genealogical_nearest_neighbours(ts.samples(), sample_sets=list(sample_sets.values()))
df = pd.DataFrame(gnn, columns=list(sample_sets.keys()))
df["focal_population"] = [ts.population(ts.node(u).population).metadata["name"] for u in ts.samples()]
mean_gnn = df.groupby("focal_population").mean()
sns.clustermap(mean_gnn, col_cluster=False, z_score=0, cmap="mako", cbar_pos=(1.0, 0.05, 0.05, 0.7));

### Topology counting

The results from counting embedded topologies were a little more subtle. Here we'll also repeat the topology counting code for the simulated data, so it's easier to compare with the inferred tree sequence:

In [None]:
from tqdm.auto import tqdm
true_topology_totals = {tree.rank(): {"counts": 0, "spans": 0} for tree in tskit.all_trees(3)}
inferred_topology_totals = {tree.rank(): {"counts": 0, "spans": 0} for tree in tskit.all_trees(3)}
for topology_totals, tree_seq in [(true_topology_totals, simulated_ts), (inferred_topology_totals, ts)]:
    tree_seq = tree_seq.trim()  # remove empty regions before and after
    sample_sets = {pop.metadata["name"]: tree_seq.samples(pop.id) for pop in tree_seq.populations()}
    for topology_counter, tree in tqdm(zip(
        tree_seq.count_topologies(sample_sets.values()),
        tree_seq.trees()
    ), total=tree_seq.num_trees):
        embedded_topologies = topology_counter[0, 1, 2]
        weight = tree.span / embedded_topologies.total()
        for rank, count in embedded_topologies.items():
            topology_totals[rank]["counts"] += count
            topology_totals[rank]["spans"] += count * weight

We'll plot it out using some HTML to make it look nice. If you aren't familiar with raw HTML, don't worry about understanding the cell below: it's simply for formatting purposes.

In [None]:
from IPython.display import HTML
td = '<td style="text-align: center">'
for name, tree_seq, topology_totals in [
    ("True", simulated_ts, true_topology_totals),
    ("Inferred", ts, inferred_topology_totals),
]:
    names = {pop.id: pop.metadata["name"] for pop in tree_seq.populations() if len(tree_seq.samples(pop.id)) > 0}
    L = tree_seq.sequence_length
    display(HTML(
        f"<h3>{name} data</h3>" +
        "<table><tr><td>" +
        "</td><td>".join([
            tskit.Tree.unrank(num_leaves=3, rank=rank).draw_svg(node_labels=names)
            for rank in topology_totals.keys()
        ]) +
        "</td></tr>" + 
        ('<tr>' + td) +
        ('</td>' + td).join([
            f"counts: {v['counts']}<br>spans: {v['spans']:.2f}<br>span %: {v['spans']/L * 100:.2f}"
            for v in topology_totals.values()
        ]) + 
        "</td></tr></table>"
    ))


The simulated data is slightly more noisy than the true data, but it still very strongly supports grouping Western and Central chimps together (98% vs 99% of the genome). Although we don't have a measure for how much difference is "significant", it imples that the inference is not doing a bad job as clustering as we expect.

### Other topological measures

Another approach to comparing original and inferred tree sequences is to use tree distance metrics to compare each local tree. The [ts.coiterate(...)](https://tskit.dev/tskit/docs/stable/python-api.html#tskit.TreeSequence.coiterate) method, which jointly moves left-to-right along two comparable tree sequences can help here. However, it is unclear what the best tree metrics to use are. For this reason, we will instead carry out further analysis using dates of nodes and mutations, for which we will need to date the inferred tree sequence:

## <em>Tsdate</em>: an HMM on a graph

You should have been introduced to the concept behind <em>tsdate</em> in a talk. As the [docs](https://tskit.dev/tsdate/docs/latest/methods.html#the-variational-gamma-method) say:
<blockquote>The directed graph that represents the genealogy can (in its undirected form) contain cycles, so a technique called “expectation propagation” (EP) is used, in which local estimates to each gamma distribution are iteratively refined until they converge to a stable solution. This comes under a class of approaches sometimes known as “loopy belief propagation”.</blockquote>
Running <em>tsdate</em> requires a mutation rate to be specified. Handily, the mutation rate used in the simulation is stored in the <a href="https://tskit.dev/tskit/docs/stable/provenance.html">provenance</a> table of the simulated tree sequence. Running the actual algorithm is fast (although importing tsdate for the first time can take a minute or two). There's no need to parallelise across threads.

Note that the output of _tsinfer_ is not fully simplified (it has some nodes with non-coalescent segments). We have found that these can cause problems with dating, so we `tsdate.preprocess_ts` the inferred tree sequence before dating, which simplifies it: see the <a href="https://tskit.dev/tsdate/docs/latest/python-api.html#preprocessing-tree-sequences">docs</a>.

In [None]:
import tsdate
import json
last_provenance = json.loads(simulated_ts.provenance(-1).record)
assert last_provenance["parameters"]["command"] == "sim_mutations"
mutation_rate = last_provenance["parameters"]["rate"]
undated_ts = tsdate.preprocess_ts(ts)
dated_ts = tsdate.date(undated_ts, mutation_rate=mutation_rate, progress=True)

We can now plot the dates of nodes, or mutations. Of course, the inferred tree sequence will not have eactly the same nodes as the simulated one, so to find equivalent nodes, we simply compare nodes under the mutation at equivalent sites (if there are multiple mutations ata single site, we just take the oldest).

In [None]:
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter

def node_time_below_oldest_muts(input_ts):
    # As there are slightly different numbers of mutations, we can't simply compare mutation times
    # Instead, we compare the oldest mutation at each site, and take the time of the node below that
    return np.array([
        input_ts.nodes_time[s.mutations[0].node] if len(s.mutations) else np.nan
        for s in input_ts.sites()
    ])

def plot_log_times(orig_ts, new_ts):
    orig_time_oldest_mut_node = node_time_below_oldest_muts(orig_ts)
    new_time_oldest_mut_node = node_time_below_oldest_muts(new_ts)
    use = np.logical_and(orig_time_oldest_mut_node > 0, new_time_oldest_mut_node > 0)
    x = np.log10(orig_time_oldest_mut_node[use])
    y = np.log10(new_time_oldest_mut_node[use])
    plt.hexbin(x, y, bins='log')
    plt.axline((0, 0), slope=1, color="red")
    plt.xlabel("True node time (generations)")
    plt.ylabel("Inferred node time (generations)")
    plt.text(0, 5, f"r2 = {np.corrcoef(x, y)[0,1] ** 2:.5f}")
    # set log ticks
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x,y: f'{10**x:.0f}'))
    plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x,y: f'{10**x:.0f}'))

plot_log_times(simulated_ts, dated_ts)
plt.title("First round of inference");

It appears as if there is banding in the true times, presumably caused by the bottlenecking, as we saw in the true (simulated) tree sequence. Let's see how well we can spot the demographic and selection patterns in plots:

In [None]:
ARG_workshop.edge_plot(dated_ts, width=15, plot_hist=True, alpha=0.1)

And separated by population:

In [None]:
fig, axes = plt.subplots(3, 2, gridspec_kw={"width_ratios": [8, 1], "hspace": 0.3}, figsize=(15, 10), sharey=True)
for ax_row, pop in zip(axes, dated_ts.populations()):
    xaxis = (pop.id==dated_ts.population(-1).id)
    ARG_workshop.edge_plot(dated_ts.simplify(dated_ts.samples(population=pop.id)), ax=ax_row, xaxis=xaxis, title=pop.metadata["name"], alpha=0.5)

How about pairwise coalescence plots? First, averaged over the whole geneme:

In [None]:
def pair_coalescence_rates(input_ts, sample_sets=None, time_breaks=None, window_breaks=None):
    # NB: in the next tskit release (0.5.9), there will be an API change such that
    # this function will be directly available as `ts.pair_coalescence_rates(time_breaks)`
    if sample_sets is not None:
        sample_sets = [list(s) for s in sample_sets]  # work around small bug in implementation of coalescence_time_distribution
    d = input_ts.coalescence_time_distribution(
        sample_sets=sample_sets,
        window_breaks=window_breaks,
        weight_func="pair_coalescence_events",
    )
    return d.coalescence_rate_in_intervals(np.array(time_breaks))

time_windows = np.logspace(0, np.log10(dated_ts.max_time), 30)
rates = pair_coalescence_rates(dated_ts, time_breaks=time_windows)
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
# This might complain if any rate is 0: that can be ignored
for ax, ylabel, y in zip(axes, ("Instantaneous Coalescence Rate (ICR)", "Inverse ICR (IICR)"), (rates, 1/rates)):
    ax.stairs(y.flatten(), time_windows, baseline=None)
    ax.set_xscale("log")
    ax.set_xlabel("Time ago {dated_ts.time_units}")
    ax.set_ylabel(ylabel)

And now with a "local" (along the genome) plot.

In [None]:
def plot_pair_rates(input_ts, genomic_windows, num_log_timebins, sample_sets=None, indexes=None, axes=None):
    # indexes is a list of tuple pairs, e.g. [(0, 1), (1, 2)]
    time_breaks = np.logspace(0, np.log10(input_ts.max_time), num_log_timebins)
    rates = pair_coalescence_rates(input_ts, sample_sets, window_breaks=genomic_windows, time_breaks=time_breaks)
    if sample_sets is None:
        sample_sets = [input_ts.samples()]
    order = [(a, b) for a in range(len(sample_sets)) for b in range(a, len(sample_sets))]
    if indexes is None:
        indexes = np.arange(len(order))
    else:
        indexes = [order.index(i) for i in indexes]
    if axes is None:
        fig, axes = plt.subplots(len(indexes), 1, figsize=(12.5, 3 * len(indexes)))
    num_axes = 1
    try:
        num_axes = len(axes)
    except TypeError:
        axes = [axes]
    if num_axes != len(indexes):
        raise ValueError("Must have same number of axes as indexes")
    for ax, rate in zip(axes, (rates[i] for i in indexes)):
        im = ax.pcolormesh(genomic_windows, time_breaks, rate)
        ax.set_yscale("log")
        bar = plt.colorbar(im, ax=ax)
        bar.ax.set_ylabel('pairwise coalescent density', labelpad=10, rotation=270)
        ax.set_ylabel(f"Time ({input_ts.time_units})");

genomic_windows = np.linspace(0, dated_ts.sequence_length, 30)
plot_pair_rates(dated_ts.simplify(), genomic_windows, num_log_timebins=20)
plt.xlabel("Genome position")
plt.ylabel(f"Time ({dated_ts.time_units})");

<dl class="exercise"><dt>Exercise 2</dt>
<dd>Plot the cross-coalescence plots between the two populations, using the code in workbook 2A. Remember to use <code>dated_ts</code> not <code>ts</code></dd>
</dl>

In [None]:
# Exercise: have a look at the cross coalescence plots ()
ts = None  # remove references to the old "ts" variable: you should be using dated_ts

## Extension

In our experience, sometimes we can improve the _tsinfer_ step of inference by using the dates inferred from _tsdate_, rather than using frequency as a proxy for local time order. To do this, we can use the `sites_time` parameter when wrapping the `.vcz` file in a `VariantData` object:

In [None]:
# If we want to change anything (e.g. use different times), we can simply make a new VariantData object (this is efficient and instantaneous)
vdata = tsinfer.VariantData("PanTro-chr3-sim.vcz", ancestral_allele=ancestral_allele, sites_time=node_time_below_oldest_muts(dated_ts))

Note that this could take a long time (e.g. an hour)

In [None]:
# We'll demo the use of the `tsinfer.infer()` command, which rolls all 3 steps of
# inference into one.
reinferred_ts = tsinfer.infer(vdata, num_threads=4, progress_monitor=True)

In [None]:
undated_ts = tsdate.preprocess_ts(reinferred_ts)
redated_ts = tsdate.date(undated_ts, mutation_rate=mutation_rate, progress=True)

In [None]:
plot_log_times(simulated_ts, redated_ts)
plt.title("After one round of redating")

Plotting this new redated tree sequence is left as an exercise to the user!