## Setup

To access material for this workbook please execute the two notebook cells immediately below (e.g. use the shortcut <b>&lt;shift&gt;+&lt;return&gt;</b>). The first cell can be skipped if you are running this notebook locally and have already installed all the necessary packages. The second cell should print out "Your notebook is ready to go!"

In [None]:
if 'pyodide_kernel' in str(get_ipython()):  # specify packages to install under JupyterLite
    raise RuntimeError("This workbook is not designed to run in JupyterLite. Please use a Colab or local install")
elif 'google.colab' in str(get_ipython()):  # specify package location for loading in Colab
    from google.colab import drive
    drive.mount('/content/drive')
    %run /content/drive/MyDrive/GARG_workshop/Notebooks/add_module_path.py
else:  # install packages on your local machine (-q = "quiet": don't print out installation steps)
    !python -m pip install -q -r https://github.com/ebp-nor/GARG/raw/main/jlite/requirements.txt

In [None]:
# Load questions etc for this workbook
from IPython.display import SVG
import tskit
import ARG_workshop
workbook = ARG_workshop.Workbook1D()
display(workbook.setup)

### Using this workbook

This workbook is intended to be used by executing each cell as you go along. Code cells (like those above) can be modified and re-executed to perform different behaviour or additional analysis. You can use this to complete various programming exercises, some of which have associated questions to test your understanding. Exercises are marked like this:
<dl class="exercise"><dt>Exercise XXX</dt>
<dd>Here is an exercise: normally there will be a code cell below this box for you to work in</dd>
</dl>

# Workbook 2-C: ARG inference

Although ARG inference is a difficult problem, a number of software programs have been developed in recent years that make a reasonable job at ARG inference. Currently, we recommend either [_SINGER_](https://github.com/popgenmethods/SINGER) (for smaller datasets) or [_tsinfer_](https://github.com/tskit-dev/tsinfer) (coupled with [_tsdate_](https://github.com/tskit-dev/tsdate)) for larger datasets. Both of these approaches output tree sequences. You might also like to consider [Relate](https://myersgroup.github.io/relate/), which has been used extensively in analysis of recent human history (and which can convert output, albeit inefficiently, to _tskit_ format). These three software programs have the same basic requirements: phased data and knowledge of the ancestral state, but are based on very different approaches.

* ___SINGER___ takes a statistical sample of likely ARGs using MCMC sampling. That means it is pretty computationally intensive, so will only work on sample sizes of tens or hundreds of genomes. As _SINGER_ takes a sampling appraoch, it outputs multiple possible ARGs, which are sampled using the coalescent with recombination as a probabilistic model. Testing suggests that this provides the most accurate results.
* ___Tsinfer___, in contrast, creates a single "best guess" ARG using a heuristic approach based on ancestor reconstruction. It does not attempt to estimate times of nodes in the ARG, and unlike _SINGER_, it is therefore not dependent on coalescent assumptions. For many purposes, node times are important, and a separate algorithm, _tsdate_, can be used to infer node times under the molecular clock. We will be use extremly recent versions of the two pieces of software, such that their offical documentation may not yet be up-to-date. The major drawback of the _tsinfer_+_tsdate_ approach is that rather than outputting multiple ARGs, it expresses uncertainty by using *polytomies* (nodes with more than one child). Many analysis approaches are not designed for use on trees with polytomies.
* ___Relate___ also uses a heuristic approach, which essentially involves building separate trees along the genome and then combining them together (but this results in very inefficient tree sequences, with low amounts of correlation between adjacent trees). It also outputs a single ARG, but polytomies are fully resolved, even in the absence of strong evidence for how to resolve them. The accuracy of results is on a par with the latest _tsinfer_+_tsdate_.

In this workbook, we will focus on using the fastest and most efficient method: _tsinfer_ + _tsdate_. Further, advanced workbooks will introduce _SINGER_. To illustrate, we will use the same simulation as we previously encountered, and inpect the quality of the inference by comparing the dates of the nodes below each mutation.

In [None]:
import tskit
simulated_ts = tskit.load("data/chimp_selection.trees")

For efficiency reasons, _tsinfer_ reads data in the [VCF Zarr](https://github.com/sgkit-dev/vcf-zarr-spec/) format, which is designed to allow rapid processing of large amounts of genetic variant data. It should soon be possible to [convert tree sequences directly to this format](https://github.com/sgkit-dev/bio2zarr/issues/232), but for the moment, this needs to be done the slow way, by outputting to VCF and converting to a `.vcz` file using [vcf2zarr](https://sgkit-dev.github.io/bio2zarr/vcf2zarr/overview.html#sec-vcf2zarr) on the command-line. For ease of use in the workbook, the `ARG_workshop.ts2vcz` function will do the hard work for you. Remember that the data need to be *phased*.

In [None]:
# create the .vcz file
ARG_workshop.ts2vcz(simulated_ts, "PanTro-chr3-sim.vcz")

Details for how to use for the new (currently alpha) version of `tsinfer` are at [https://tskit.dev/tsinfer/docs/latest/usage.html](https://tskit.dev/tsinfer/docs/latest/usage.html). Basically, we need to wrap the `.vcz` file in a `tsinfer.VariantData` object. This allows us to specify the ancestral allele, and mask out problematic sites or samples.

In [None]:
import tsinfer
import numpy as np

# We must obtain the ancestral allele from somewhere. Here we use that given in the simulation
# See https://github.com/tskit-dev/tsinfer/discussions/523 for more details
ancestral_allele=np.array([s.ancestral_state for s in simulated_ts.sites()])

# The VariantData interface is now the preferred way to create tsinfer input files (the old one was called SampleData)
vdata = tsinfer.VariantData("PanTro-chr3-sim.vcz", ancestral_allele=ancestral_allele)


### The 3 steps of tsinfer

_Tsinfer_ inference is split into 3 steps:
1. Generating ancestors (ga): [`tsinfer.generate_ancestors()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.generate_ancestors)
2. Matching ancestors (ma): [`tsinfer.match_ancestors()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.match_ancestors)
3. Matching samples (ms): [`tsinfer.match_samples()`](https://tskit.dev/tsinfer/docs/latest/api.html#tsinfer.match_samples)

We can run all three in one go by calling `tsinfer.infer()`, but it's more flexible to run them separately, and allows you to save intermediate results, which is useful for large inferences. You can use the `progress_monitor` argument to see how long the steps will take, and the `num_threads` argument to use more CPU cores on your computer


In [None]:
general_params = {
    "num_threads": 6,  # Set to the number of cores on your computer. E.g. a macbook air M2 has 8
    "progress_monitor": True,
}
ancestors = tsinfer.generate_ancestors(vdata, **general_params)
ancestors_ts = tsinfer.match_ancestors(vdata, ancestors, **general_params)
ts = tsinfer.match_samples(vdata, ancestors_ts, **general_params)

The inferred tree sequence should encode exactly the same genotype data as the original (although the genealogy may be different). However, since the topology is different, it may require different numbers of mutations at any one site to encode the data

Note: this isn't strictly true, as the original dataset may contain missing data, which is imputed by tsinfer

In [None]:
print(f"simulation diversity = {simulated_ts.diversity():.6f}, diversity in inferred ts = {ts.diversity():.6f}")
print(f"simulation num_sites = {simulated_ts.num_sites}, num_sites in inferred ts = {ts.num_sites}")
print(f"simulation num_mutations = {simulated_ts.num_mutations}, num_mutations in inferred ts = {ts.num_mutations}")


In [None]:
# Exercise: Since the tsinferred tree sequence does not have meaningful node times, we cannot run branch-length statistics

## Tsdate: an HMM on a graph

Dating is fast (although importing tsdate for the first time can take a minute or two). There's no need to parallelise across threads.

TODO: briefly explain how tsdate works

In [None]:
import tsdate
undated_ts = tsdate.preprocess_ts(ts)
dated_ts = tsdate.date(undated_ts, mutation_rate=model.mutation_rate, progress=True)

We can now plot ...

In [None]:
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter

def node_time_below_oldest_muts(ts):
    # As there are slightly different numbers of mutations, we can't simply compare mutation times
    # Instead, we compare the oldest mutation at each site, and take the time of the node below that
    return np.array(
        [ts.nodes_time[s.mutations[0].node] if len(s.mutations) else np.nan for s in ts.sites()]
    )

def plot_log_times(orig_ts, new_ts):
    orig_time_oldest_mut_node = node_time_below_oldest_muts(orig_ts)
    new_time_oldest_mut_node = node_time_below_oldest_muts(new_ts)
    use = np.logical_and(orig_time_oldest_mut_node > 0, new_time_oldest_mut_node > 0)
    x = np.log10(orig_time_oldest_mut_node[use])
    y = np.log10(new_time_oldest_mut_node[use])
    plt.hexbin(x, y, bins='log')
    plt.axline((0, 0), slope=1, color="red")
    plt.xlabel("True node time (generations)")
    plt.ylabel("Inferred node time (generations)")
    plt.text(0, 5, f"r2 = {np.corrcoef(x, y)[0,1] ** 2:.5f}")
    # set log ticks
    plt.gca().yaxis.set_major_formatter(FuncFormatter(lambda x,y: f'{10**x:.0f}'))
    plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x,y: f'{10**x:.0f}'))

plot_log_times(ts, dated_ts)
plt.title("First round of inference");

In [None]:
# If we want to change anything (e.g. use different times), we can simply make a new VariantData object (this is efficient and instantaneous)
vdata = tsinfer.VariantData("PanTro-chr3-sim.vcz", ancestral_allele=ancestral_allele, sites_time=node_time_below_oldest_muts(dated_ts))

In [None]:
inferred_ts = tsinfer.infer(vdata, num_threads=4, progress_monitor=True)

In [None]:
undated_ts = tsdate.preprocess_ts(inferred_ts)
redated_ts = tsdate.date(undated_ts, mutation_rate=model.mutation_rate, progress=True)

In [None]:
plot_log_times(ts, redated_ts)
plt.title("After one round of redating")

In [None]:
# With just the topology, you can look at the GNN, or count the types of topology:
ts.genealogical_nearest_neighbours

In [None]:
# Edge plots


In [None]:
dts.node(100)

## Testing inference accuracy

There isn't an obvious way to 


In [None]:
windows = np.linspace(0, ts.sequence_length, 20)
plt.stairs(ts.diversity(mode="branch", windows=windows), windows, label="true branch AFS")
plt.stairs(redated_ts.diversity(mode="site", windows=windows) / model.mutation_rate, windows, label="tsinfer + tsdate site AFS")
plt.stairs(redated_ts.diversity(mode="branch", windows=windows), windows, label="tsinfer + tsdate branch AFS")
plt.yscale("log")
plt.legend()
plt.show();