# ARG inference with tsinfer (final version)

## Run this first
This will take a few minutes:

In [None]:
import sgkit
import tsinfer
import msprime
import tsdate
import numpy as np
import matplotlib.pyplot as plt

!pip install -q jupyterquiz
from jupyterquiz import display_quiz
from workshop_module import WB_base




## Overview

Many ARG inference methods are available that lie somewhere on a spectrum between inference accuracy and scalability. `tsinfer` is scalable to over 100 000 whole human genomes, does not require any demographic modeling assumptions and supports a variety of data formats. The biggest drawback is that it only produces a single, deterministic estimate of the ARG's topology, without a measure of uncertainty. However, branch lengths are probabilistically estimated by another package, `tsdate`, which has recently been given a major upgrade.

This workshop aims to introduce you to ARG inference using the Python packages `tsinfer` and `tsdate`. Working with a subset of the 1000 Genomes Project data in VCF format, we will show how to convert the data into a more useful format and run through the inference process. Finally, we will demonstrate the use of the `tsbrowse` tool for quality controlling an inferred ARG.

<div class="alert alert-block alert-info">
  <b>Note:</b> For practical reasons, the dataset we are working with is small enough to quickly analyse on an ordinary laptop. For large datasets, where parallel computation on a High Performance Computing Cluster is essential, we have developed a 
  <a href="https://github.com/benjeffery/tsinfer-snakemake/">Snakemake pipeline</a>
  to manage the entire inference process. Feel free to get in touch if you think this will be useful for you!
</div>

## Outline of the workshop

Since other workshops are covering applications, we will mostly focus on the practicalities of working with VCF data, inferring ARGs with `tsinfer` and QCing them:

1. Convert the VCF to Zarr format  
2. Create the ancestral allele array  
3. Inference with `tsinfer`  
   3.1. Loading the data  
   3.2. Generate ancestors  
   3.3. Match ancestors
   3.4. Match samples  
4. Dating the ARG with `tsdate`  
5. Quality control with `tsbrowse`


## 1. Converting the VCF to Zarr format

The VCF format has severe limitations as a genetic data format, especially for large datasets. Extracting data from a particular field is inefficient and it does not lend itself to distributed computing. The **VCF Zarr format**, described in [this preprint](https://www.biorxiv.org/content/10.1101/2024.06.11.598241v3), addresses these limitations, making it easy to slice the data and perform parallel computations efficiently.  

**Zarr** is a general format for storing multi-dimensional data, so there are many libraries available to work with it. We will focus on tools specialised for working with VCF Zarr data specifically. The Python package <code>bio2zarr</code> can convert VCF, plink and tskit ARGs into the format, while `sgkit` offers tools for analysing and manipulating the data.

As discussed in the [<code>bio2zarr</code> documentation](https://sgkit-dev.github.io/bio2zarr/vcf2zarr/tutorial.html#sec-vcf2zarr-tutorial), converting the VCF to Zarr is a two-step process for moderately-sized datasets. We will do it with the Command Line Interface. First, we **explode** the vcf into the Intermediate Column Format. This separates out the fields (columns) of the VCF:

In [None]:
%%bash
mkdir data/zarr data/args; vcf2zarr explode -f data/vcf/tgp.vcf.gz data/zarr/tgp.icf

Now we can <code>inspect</code> the ICF we have created, which tells us about the fields in the input VCF:

In [None]:
%%bash
vcf2zarr inspect data/zarr/tgp.icf

Minimum and maximum values of numerical fields are shown, so we can see that this is a 1mbp region from 6-7 mbp. At this stage, we choose how we want to encode the ICF to a Zarr by adjusting the **schema**, which allows us to remove fields and make other changes. We don't need to here, but [see the `vcf2zarr` tutorial](https://sgkit-dev.github.io/bio2zarr/vcf2zarr/tutorial.html) if you're interested. 

The final step is to <code>encode</code> the ICF into a Zarr. Since we are not making any changes to the schema, this is a simple command:

In [None]:
%%bash
vcf2zarr encode -f data/zarr/tgp.icf data/zarr/tgp.zarr

We can again use <code>inspect</code> to check our results:

In [None]:
%%bash
vcf2zarr inspect data/zarr/tgp.zarr

In [None]:

display_quiz(WB_base + "Q1.json")

Zarr arrays are automatically divided into chunks to aid parallel computation. The genotype matrix is encoded in a 3D array (`call_genotype`). `mask`s are used extensively to exclude certain data (e.g. a `sample_mask` is True for every sample that is excluded).


<div class="alert alert-block alert-info"><b>Note</b>: for small datasets, a convenience function (<code>convert</code>) is available to do the conversion in one step in Python or with the CLI. For very large datasets, functions are also available to distribute the encoding/decoding across multiple nodes.</div>

### 1.1. Exploring the VCF Zarr

At the moment, the easiest way to do analyse the data in the VCF Zarr is with the `sgkit` package:

In [None]:
ds = sgkit.load_dataset("data/zarr/tgp.zarr")
ds

<div class="alert alert-block alert-success">
  <b>Exercise:</b> The above dataset view is interactive: click on "Data variables" to see all the stored arrays. To view one in detail, click the stack symbol on the far right.
</div>

The arrays have name dimensions like `variants` and `samples`, to make it easier to interpret. Large arrays are divided into chunks automatically, which makes it easier to do computations in parallel. The library `xarray` is used to store these arrays. 


You can fetch any of the arrays by name:

In [None]:
call_genotype = ds.call_genotype
call_genotype

`xarray` objects have names for their dimensions to improve readability. You convert any `xarray` into a simpler `numpy` array using the  `.values` method:

In [None]:
G = call_genotype.values
G[0,0] # diploid genotype at first site of first sample

We can obtain an array of allele counts for each site by summing over all the samples and ploidy: since we have 100 diploid individuals there are up 199 alternate alleles at each site (there are no fixed mutations):

In [None]:
# axis 0 = sites, axis 1 = samples, axis 2 = ploidy
ac = np.sum(ds.call_genotype.values, axis=(1, 2))
print(f'Maximum allele count is {ac.max()}')

Usually, the two columns of the genotype matrix correspond to `REF` and `ALT` alleles in the VCF. In our case, I've made sure that the `REF` alleles are always ancestral, so `ALT` alleles are always derived. This means that we can plot the derived allele frequency spectrum using the allele counts above

In [None]:
import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.hist(ac, bins=50, color='orange', edgecolor='orange')
plt.title('Derived allele frequency spectrum')
plt.xlabel('Allele count')
plt.ylabel('Number of variants')
plt.yscale('log')
plt.grid(axis='y', alpha=0.6)
plt.tight_layout()

plt.show()

<div class="alert alert-block alert-success">
  <b>Exercise:</b> What is the biggest difference between this AFS and one that you'd expect from a classic Hudson coalescen model? Hint: Rare are alleles are common.
</div>

In [None]:
display_quiz(WB_base + "Q2.json")

## 2. Create the ancestral allele array

`tsinfer` requires that variants are **bi-allelic**, **phased** and have a **known ancestral state**. Many methods are available to estimate ancestral alleles. In the case of humans, ancestral FASTA sequences, calculated with a tool called Orpheus, are available online. The dataset we are working with has already been polarised with this method, such that the first (reference) allele is always ancestral.

To provide `tsinfer` with the ancestral allele array, we start with the `variant_allele` array in the Zarr:


In [None]:
variant_allele = ds.variant_allele.values
variant_allele

In this 2D array, the left column is `REF` and the right column is `ALT`. As mentioned previously, the `REF` alleles are ancestral so we just need to slice them out:

In [None]:
ancestral_state = variant_allele[:,0]  # select all the rows in first column
ancestral_state

At this stage, we could store this array in the Zarr, but we generally recommend passing it directly to `tsinfer`.

## 3. Inference with `tsinfer`

### 3.1. Loading in the data

To start inference, we need to load the data into `tsinfer` via the `VariantData` class, which is just a wrapper of a zarr. To specify the ancestral states, we can either point it to a variable in the Zarr, or provide the array directly; we will do the latter.

In [None]:
variant_data = tsinfer.VariantData("data/zarr/tgp.zarr", ancestral_state=ancestral_state)
print(f'There are are {variant_data.num_sites} in the zarr')

We can still access the data in the zarr, which is accessible in `variant_data.data`. 

### 3.2. Generate ancestors

The first step of `tsinfer`, ancestor generation, uses the genotype data to estimate the ancestral haplotypes that the samples inherited from. For a given set of inference sites with the same allele frequency and genotypes, which we call **focal sites**, we create an ancestor that carries the derived mutation, extending to the left and right until a stopping criterion is reached.

The result of ancestor generation is an object called `AncestorData`, which contains all the inferred ancestral haplotypes

In [None]:
anc_data = tsinfer.generate_ancestors(variant_data, progress_monitor=True)
print(f'Generated {anc_data.num_ancestors} ancestors from {variant_data.num_sites} sites')

For what follows, ancestors need to have a measure of relative age, so that we can match their haplotypes against each other. The default is to use **allele frequency as a proxy of relative age**, since we expect rare mutations to have occurred more recently than common ones. 

We usually don't do any QC or analysis of the ancestor data at this stage except for development purposes.

<div class="alert alert-block alert-info"><b>Note</b>: Ancestor generation is usually much faster than the subsequent steps (ancestor matching and sample matching). The only potential issue is that it can be RAM intensive for large datasets, since the 1-bit encoded genotype matrix needs to fit in RAM. </div>

### 3.3. Match ancestors

We next need to match the ancestral haplotypes against each other to form an **ancestors tree sequence/ARG**. This ARG captures all the copying patterns between the ancestors, but does not include the sample haplotypes. `tsinfer` uses the [Li and Stephens (2003) Hidden Markov Model (HMM)](https://pubmed.ncbi.nlm.nih.gov/14704198/) for matching. It has two main inputs in our implementation: `recombination_rate` and `mismatch_ratio`. We will discuss these in detail in the next section, because we **strongly advise leaving them both at default for this step, as in the call below**. 

The `match_ancestors` function needs both the `VariantData` and `AncestorData` to run, so the call looks like


In [None]:
ancestors_arg = tsinfer.match_ancestors(variant_data, anc_data, progress_monitor=False)
ancestors_arg

Curiously, there are 3168 sample nodes, one for each ancestor. This is because we haven't added the actual samples yet. We usually don't QC the ancestor ARG directly, so we will proceed to the final stage.

### 3.4. Match samples


The final inference step is to match the sample haplotypes against the ancestors ARG, which will complete the ARG topology. To enable the use of the mismatch parameter discussed below, you must specify a **recombination rate**, which can be either be:

1. A floating point recombination rate $\rho$ per unit length of genome
2. A recombination map stored as an `msprime.RateMap` object.

If you have a recombination rate map available, as is the case for humans, it is worth using. We have provided the HapMap for chromosome 20 to use with our data, so we need to use `msprime` to load it in:

In [None]:
hapmap_path = "data/hapmap/genetic_map_Hg38_chr20.txt"
rate_map = msprime.RateMap.read_hapmap(hapmap_path, position_col=1, rate_col=2)

Whenever a recombination rate is provided, you can also specify a `mismatch_ratio`, which determines the balance between recombination events and recurrent mutations. If a mismatch is encountered at a site between otherwise closely matching haplotypes, the HMM will

- Almost always adds a recombination event if `mismatch_ratio ~ 0` (default, i.e. it is infinitesimal)
- Be equally likely to add a recombination event or recurrent mutation if `mismatch_ratio = 1`
- Be two times more likely to add a recurrent mutation if `mismatch_ratio = 2`

We advise setting mismatch to zero in the first instance, which makes inference much faster and results in little to no recurrent mutations being added. If you do set it to a non-zero value, it is best to make it very small ($\leq 10^{-3}$), otherwise an excessively many recurrent mutations will be added.

In [None]:
arg = tsinfer.match_samples(variant_data,
                                     ancestors_arg,
                                     recombination_rate=rate_map,
                                     mismatch_ratio=0, 
                                     progress_monitor=True,
                                     )
arg

#### One recurrent mutation?
Notice that one site has more than 1 mutations (since `num_mutations` - `num_sites` = 1). What is up with this site? To find out, we can use the `sites()` iterator (the slow way):

In [None]:
for site in arg.sites():
    mutations = site.mutations
    if len(mutations) > 1: #recurrent
        for mut in mutations:
            print(mut)

It's a recurrent mutation! This can rarely happen, since we provided a recombination map but the mismatch ratio is slightly above zero.

#### Lots more sites vs. the ancestor trees?

By default, `match_samples` post-processes the final ARG to remove some artefacts of the algorithm and simplify without removing unary nodes. This step can be disabled, but for most users it is best to stick to the default behaviour.

Notice that the **number of sites has increased** after sample matching. This is because there are often sites in the Zarr which are not suitable for inference; `match_samples` adds them back to the ARG by parsimony by default, which can also be disabled. 

In [None]:
print(f'There are {arg.num_sites - ancestors_arg.num_sites} new sites added by match samples')

<div class="alert alert-block alert-success">
  <b>Exercise:</b> In this case, all the sites added have the same property. What are they? Hint: the previous challenge question might help.
</div>

In [None]:
display_quiz(WB_base + "Q3.json")

In [None]:
crazy_arg = #insert here


### 4. Dating the ARG with `tsdate`

The ARG we have made still has uncalibrated branch lengths: all the nodes have an age between 0 and ~1

In [None]:
times = arg.tables.nodes.time
print(f'Node times are between {min(times)} and {max(times)} in {arg.time_units} time units')

The `tsdate` package uses an expectation-propagation (EP) algorithm to calibrate the ARG's branch lengths to units of generations. First, we need to **pre-process** the ARG, then provide a **mutation rate** to `tsdate` for your species. For humans, we usually use $1.29\times 10^{-8}$:

In [None]:
preproc_arg = tsdate.preprocess_ts(arg)
dated_arg = tsdate.date(preproc_arg, mutation_rate=1.29e-8, progress=True)

Checking the node ages again,

In [None]:
times = dated_arg.tables.nodes.time

print(f'Node times are between {min(times)} and {max(times)} in {dated_arg.time_units} time units')
dated_arg

As we can see from the above, the time units are now calibrated to generations. **All the singletons have been dated too**! For what follows, let's save it to a file:

In [None]:
dated_arg.dump('data/args/dated.trees')

## 5. Quality control with `tsbrowse`

Since large ARGs are too complex to fully visualise as a graph or tree sequence, quality control can be difficult. In our experience, it is dangerous to solely rely on summary statistics and summary plots (e.g. of node ages), since many issues can be missed this way.

To help QC inferred ARGs, we developed a web app called **`tsbrowse`** (led by Savita Karthikeyan) to interactively visualise nodes, edges and various other data. The app works for all major ARG inference methods, provided they are in `tskit` format. 

Using `tsbrowse` is a two step process. First, we `preprocess` the ARG into a compressed `tsbrowse` file:

In [None]:
%%bash

python -m tsbrowse preprocess data/args/dated.trees

Then we need to start the app server. The `--show` command should make it open in your browser when it's ready. If that doesn't work, just copy-paste `http://localhost:1111/` in your browser.

In [None]:
%%bash

python -m tsbrowse serve --port 1112 --show data/args/dated.tsbrowse 


In [None]:
display_quiz(WB_base + "Q4.json")

### Things to look out for in `tsbrowse`

1. **Mutations**: Are there any regions with large gaps / low site density? These can bias downstream statistics.
2. **Mutations**: In log scale, do you generally see fewer mutations as age increases? Sometimes, ARG inference errors can cause an excessive number of very old mutations (>100 000 generations)
3. **Nodes**: The plot should generally be L-shaped, because older ancestors should tend to be shorter than younger ones.
4. **Edges**: Look for artefacts (such as regions with exceptionally long edges).


## A few exploratory plots

Once we are satisfied with the quality control results, we can start analysing our results. For example, let's calculate the branch-mode diversity along the genome:

In [None]:
# Windows have to go from 0 to sequence_length
windows = np.arange(6e6, 7e6, 1e4)
all_windows = np.concatenate([[0], windows, [dated_arg.sequence_length]])
pi_window = dated_arg.diversity(mode="branch",windows=all_windows)
#plot diversity along genome in windows
plt.figure(figsize=(8, 5))
plt.plot((all_windows[:-1] + all_windows[1:]) / 2,
            pi_window,
            marker='o',
            linestyle='-',
            color='blue')
plt.title('Branch-mode diversity along genome')
plt.xlabel('Genomic Position (bp)')
plt.ylabel('Avg. branch length between samples')
plt.xlim(6e6, 7e6)
plt.grid()
plt.show()

`tsdate` now automatically estimates mutation ages and adds the results to the mutations table. Let's see what the range of mutation ages is

In [None]:
x = dated_arg.tables.mutations.time

plt.figure(figsize=(8, 5))
plt.hist(x, bins=50, color='green', edgecolor='black')
plt.title('Histogram of mutation ages')
plt.xlabel('Mutation age (generations)')
plt.ylabel('Number of mutations')
plt.yscale('log')
plt.show()

### A note about age estimates

`tsdate` now uses a Bayesian algorithm to estimate the age of nodes (and hence mutations). Conveniently, the posterior distribution of age estimates follows a Gamma distribution, and the mean and variance are stored in the metadata as `mn` and `vr`:

In [None]:
mutation = dated_arg.mutation(50)
print(f'Mutation metadata: {mutation.metadata}')
node = dated_arg.node(mutation.node)
print(f'Node metadata: {node.metadata}')