In [None]:
import msprime, tskit
import tsinfer, tsdate
from IPython.display import SVG
import cyvcf2
import json

# 5. Tree sequence inference with tsinfer and tsdate

 - [5.1 An overview of tsinfer](#5.1Overview)
 - [5.2 Hands on with tsinfer](#5.2HandsOn)
 - [5.3 Inference accuracy](#5.3InferenceAccuracy)

Simulating a tree sequence is relatively simple compared to *inferring* a tree sequence from existing data.
The [tsinfer software](https://tsinfer.readthedocs.io/en/stable/) implements a heuristic algorithm which does this in a scalable manner.

<a id='5.1Overview'></a>
## 5.1 An overview of `tsinfer`

`Tsinfer` (pronounced t-s-infer) is comparable in some ways to other ancestral inference software such as [ARGweaver](https://doi.org/10.1371/journal.pgen.1004342), [Relate](https://myersgroup.github.io/relate/), and [Rent+](https://doi.org/10.1093/bioinformatics/btw735). However, it differs considerably in approach and scalability.
Note that none of these other software packages produce tree sequences as output, although is possible to convert their output to tree sequences.
Also note that `tsinfer` produces trees with a relatively accurate topology, but unlike other ancestral inference tools, it makes no attempt at the moment to produce precise branch length estimates -- for this we need another tool like `tsdate`.

An important restriction is that `tsinfer` requires phased sample sequences with known ancestral states for each variant. It also works better with full sequence data than with data from scattered target SNPs (e.g. as obtained from SNP chips).

###  Algorithm overview
The `tsinfer` method is split into two main parts: 
1. the reconstruction and time ordering of ancestral haplotypes and 
2. the inference of the copying process. 

The paper contains the following schematic overview of the method, with part 1 on the left and part 2 on the right. Note the reduced length of the blue inferred ancestor chunks back in time. 

<img style="height: 600px" src="pics/tsinfer-schematic.png">

<a id='5.2HandsOn'></a>
## 5.2 Hands-on with `tsinfer` and `tsdate`

Let's try out `tsinfer` using some sequence data generated with `msprime`.
Here's a small simulated sample drawn from an admixture scenario:

In [None]:
# Specify demographic history.
demography = msprime.Demography()
demography.add_population(name="SMALL", initial_size=2000)
demography.add_population(name="BIG", initial_size=5000)
demography.add_population(name="ADMIX", initial_size=2000)
demography.add_population(name="ANC", initial_size=5000)
demography.add_admixture(
    time=100, derived="ADMIX", ancestral=["SMALL", "BIG"], proportions=[0.5, 0.5])
demography.add_population_split(time=1000, derived=["SMALL", "BIG"], ancestral="ANC")
# demography.debug()

In [None]:
# Simulate.
seq_length = 1e6
ts = msprime.sim_ancestry(samples={"SMALL": 1, "BIG": 1, "ADMIX" : 1},
                          demography=demography,
                          random_seed=83,
                         sequence_length=seq_length,
                        recombination_rate=1e-8)
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=318)
ts

In [5]:
# Write to VCF.
with open("worksheet5-input.vcf", "w") as vcf_file:
    ts.write_vcf(vcf_file)

### Step 1: pre-process data

`tsinfer` requires a `SampleData` object as input.
To create this, we'll need:

 - phased genotype data at sites with known positions
 - information about ancestral and derived alleles at each site
 
First, let's see how we might get this information from a VCF file:

In [None]:
a_file = open("worksheet5-input.vcf")
number_of_lines = 10

for i in range(number_of_lines):
    line = a_file.readline()
    print(line)

We'll need the information in the `POS`, `REF` and `ALT` fields, as well as the genotypes.

In [7]:
a_file = open("worksheet5-input.vcf")

lines = a_file.readlines()
with tsinfer.SampleData(sequence_length=seq_length) as sample_data:
    for line in lines:
        if line[0] != "#":
            l = line.split("\t")
            pos = int(l[1])
            ref = l[3]
            alt = l[4]
            gens = "".join(l[9:]).replace("|", "").replace("\n", "")
            gens = [int(g) for g in gens]
            sample_data.add_site(pos, gens, [ref, alt])
a_file.close()

For larger VCFs, you may wish to use the [cyvcf2](https://github.com/brentp/cyvcf2) package.
See [this](https://tsinfer.readthedocs.io/en/latest/tutorial.html#reading-a-vcf) for some example usage.

You can use the `from_tree_sequence` method to create a `SampleData` object from the larger Pongo dataset:

### Step 2: Apply tsinfer!

All we need is our `SampleData` object and an estimated recombination rate:

In [None]:
tsi = tsinfer.infer(sample_data, recombination_rate=1e-7)
tsi

Let's have a look at some of the inferred trees. How do they compare with the real ones?

In [None]:
location=50000
SVG(tsi.at(location).draw_svg(size=(500,500)))

In [None]:
SVG(ts.at(location).draw_svg(size=(600,350)))

Some quick observations:
 - Various inaccuracies in topologies
 - Some *polytomies:* nodes with more than two children 
 
`tsinfer` also works on larger datasets:

To obtain estimates of node times, we will need to use `tsdate`,
a method for efficiently inferring the ages of ancestors in a tree sequence.
See the documentation page [here](https://tsdate.readthedocs.io/en/latest/).

### Step 3: simplify the tree sequence

First, we'll apply `simplify()`:

In [None]:
tsi = tsi.simplify()
SVG(tsi.at(location).draw_svg())

### Step 4: Apply `tsdate`!

We supply an estimated (haploid) effective population size, and a mutation rate.

In [None]:
tsid = tsdate.date(tsi, Ne=7000, mutation_rate=1e-8)
SVG(tsid.at(location).draw_svg())

In [None]:
SVG(ts.at(location).draw_svg())

With inferred node times and branch lengths in our tree sequence,
we can now apply any of the branch or time-related methods in the previous worksheet to obtain inferred branch statistics, IBD segments and so on.

<a id='5.3InferenceAccuracy'></a>
## 5.3 Inference accuracy

Inferring genome-wide genealogies is a challenging task, and (as we have seen) the output from `tsinfer` should be treated with some caution.

There are not many established ways to compare one tree sequence (or ARG) with another. However, thanks to phylogenetics, there *are* many ways to compare individual trees (i.e. tree distance metrics). The most discriminating that we have found is the Kendall-Colijn metric, which also has the benefit of dealing in a principled way with the *polytomies* found in `tsinfer` trees.

<img style="height: 600px" src="pics/worksheet5-distances.png">

Consider what parts of the inferred tree sequence are likely to be important in your downstream analyses.
For instance, do branch lengths and ancestor times matter for you, or will tree topologies suffice?
Do you need your ancestral segments to be contiguous,
or is it okay if they are split over multiple ancestors in multiple edges?
Questions like these should inform the types of benchmarking that matter to you.

For instance, I thought `tsinfer` did a pretty good job of inferring recent IBD segment lengths:

<img style="height: 300px" src="pics/worksheet5-ibd-length.png">

But `tsdate` seemed to systematically overestimate their ages:

<img style="height: 300px" src="pics/worksheet5-ibd-time.png">

Given the variety of tools that are now available for these purposes (including the many that we have covered today), benchmarking the accuracy of these inferences is a task of high community importance if we are to rely on inferred genome-wide genealogies for future work.

# REVAMPED MATERIAL:

So far, we've focused on *simulating* data with a tree sequence structure, and we've talked about the benefits of doing so. We've looked at some basic summary statistics of the tree sequences to try and interpret them, but you might be wondering,
 - what if I want to perform 'conventional' analyses on these? Like, calculating diversity, AFS-based analyses and so on? Do I simply convert my tree sequences out to VCFs and then apply all my regular tools, or can I do any of these analyses directly on the tree sequence files?
 - So far we've been doing all this on simulated datasets -- do all of these things only work on simulated data? Can I use any of these analysis techniques on my actual observed dataset?

In both cases, the answer is Yes! 

1. The `tsinfer` and `tsdate` packages let you estimate a tree sequence for your dataset. (Note: these are *estimates* -- they won't be exactly correct. See Brandt et al for some discussion of these, and what types of information may or may not be legit for these)
2. The `tskit` package (which we've already seen and used) has a number of utility functions that let you explore, manipulate and analyse data stored in a tree sequence format. 

In this notebook, we'll touch on all of these to give you a flavour of what is possible with these packages.

In [None]:
import cyvcf2 # For reading VCF files into Python

import tskit
import tsinfer, tsdate
from IPython.display import SVG
import json

## 5.1 Inferring a tree sequence using `tsinfer` and `tsdate`

We'll be looking at a dataset of sparrow chromosome 24 (find reference) kindly shared by Mark Ravinet. There are 10 individuals represented in the samples: 5 from Norway and 5 from France.

### 5.1.1 Preparing the data for tsinfer

The heart is a `SampleData` object that contains information about individuals, and the genotypes held by each individual at each site. A few things
 - needs to be phased
 - needs to be bi-allelic (satisfy infinite sites)
 - Missingness is fine!

These things are typically done with the `add_individual` and `add_sites` methods. To save us some boilerplate, let's import some helper functions (taken from the tsinfer website) that will make this object for us from a `cyvcf2` object. The core of these functions are calls to these above two methods, but it does a few other things as well (preserves the ploidy of the individuals, and their populations).

- our data is diploid
- Our individuals belong to two distinct populations, and we want our sample data to include that information.
- we only want the inference to use alleles at SNVs and deletions for now (others are possible!)
- we need phased variants

In [14]:
def add_diploid_sites(vcf, samples):
    """
    Read the sites in the vcf and add them to the samples object.
    """
    # You may want to change the following line, e.g. here we allow
    # "*" (a spanning deletion) to be a valid allele state
    allele_chars = set("ATGCatgc*")
    pos = 0
    # progressbar = tqdm(total=samples.sequence_length, desc="Read VCF", unit='bp')
    for variant in vcf:  # Loop over variants, each assumed at a unique site
        # progressbar.update(variant.POS - pos)
        if pos == variant.POS:
            print(f"Duplicate entries at position {pos}, ignoring all but the first")
            continue
        else:
            pos = variant.POS
        if any([not phased for _, _, phased in variant.genotypes]):
            raise ValueError("Unphased genotypes for variant at position", pos)
        alleles = [variant.REF.upper()] + [v.upper() for v in variant.ALT]
        ancestral = variant.INFO.get("AA", ".")  # "." means unknown
        # some VCFs (e.g. from 1000G) have many values in the AA field: take the 1st
        ancestral = ancestral.split("|")[0].upper()
        if ancestral == "." or ancestral == "":
            ancestral_allele = tskit.MISSING_DATA
            # alternatively, you could specify `ancestral = variant.REF.upper()`
        else:
            ancestral_allele = alleles.index(ancestral)
        # Check we have ATCG alleles
        for a in alleles:
            if len(set(a) - allele_chars) > 0:
                print(f"Ignoring site at pos {pos}: allele {a} not in {allele_chars}")
                continue
        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [g for row in variant.genotypes for g in row[0:2]]
        samples.add_site(pos, genotypes, alleles, ancestral_allele=ancestral_allele)


def chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]

def add_populations(vcf, samples):
    """
    Add tsinfer Population objects and returns a list of IDs corresponding to the VCF samples.
    """
    # In this VCF, the first letter of the sample name refers to the population
    samples_first_letter = [sample_name[0] for sample_name in vcf.samples]
    pop_lookup = {}
    pop_lookup["8"] = samples.add_population(metadata={"country": "Norway"})
    pop_lookup["F"] = samples.add_population(metadata={"country": "France"})
    return [pop_lookup[first_letter] for first_letter in samples_first_letter]


def add_diploid_individuals(vcf, samples, populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy=2, metadata={"name": name}, population=population)

We'll use these helper functions to read in information about our sparrow sample from the VCF file. Note the order: first, we'll add in population information, then we add in information about the individuals in those populations that we've sampled, then we add in information about the genotypes held by those individuals at each site.

In [None]:
vcf_location = "P_dom_chr24_phased.vcf.gz"
vcf = cyvcf2.VCF(vcf_location)
with tsinfer.SampleData(
    path="P_dom_chr24_phased.samples", sequence_length=chromosome_length(vcf)
) as samples:
    populations = add_populations(vcf, samples)
    add_diploid_individuals(vcf, samples, populations)
    add_diploid_sites(vcf, samples)

print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

# print(samples)

### 5.1.2 Inferring a tree sequence with `tsinfer` and `tsdate`

Importing the data into a tsinfer-ready format was the hardest part -- the next thing is to run tsinfer. 

(note to self: should mention mismatch ratio here somewhere)

In [None]:
ts_undated = tsinfer.infer(samples)

ts_undated

Yay! We now have a bona-fide tree sequence holding our data that we can inspect, analyse and manipulate in all the ways we have done previously. 

Do you notice anything different about this tree sequence compared with the simulated ones that we've been working with previously? 
1. As usual, there's some useful information about the fact that we generated this with tsinfer in the provenance section of the table. Note that you can click through and see all the parameters etc used.
2. Secondly, "Time Units: uncalibrated". That's because tsinfer infers the topologies for these trees, but doesn't infer times for them (and in fact, takes no informarion about mutation rates or coalescence times or anything else that could help to estimate this.). For this, we'll apply `tsdate`.
3. A more subtle one: more sites than mutations. This happens because, in the process of generating the trees, tsinfer also generates a bunch of ancestral haplotypes and that *don't* end up contributing info to the final genotypes we see.

We'll deal with these second two problems one a time.

First, we'll apply `simplify()` to remove some 'stray' unary nodes and sites outputted by tsinfer -- this 'pruning' step removes all elements of the inferred tree sequence other than those needed to represent the genotypes of the samples.
After this, we'll run `tsdate` to obtain a time-calibrated version of this inferred tree sequence using an estimated mutation rate of `1e-8` per base per generation.

(Can this be altered?)

 (Also, see [this function](https://tskit.dev/tsdate/docs/latest/python-api.html#preprocessing-tree-sequences) for some other useful pre-processing functions one can use on the tree sequence -- excluding samples from data poor regions e.g. near the telomeres, or excluding particular samples.)

In [None]:
# mutation rate is given as the probability of mutations per base per generation
tss = ts_undated.simplify()
ts_dated = tsdate.date(tss, mutation_rate=1e-8)

ts_dated

Note that each of these steps has added an extra row into the provenance table too.

We're getting close to the point where printing individual trees might be a little overwhelming and hard to interpret. (tsbrowse), on the web, provides a number of useful functions that shows you some other visualisations of these trees that might be especially big for these larger datasets. 

Let's print some trees:

In [None]:
colours = {"Norway": "red", "France": "blue"}
colours_for_node = {}
for n in ts_dated.samples():
    population_data = ts_dated.population(ts_dated.node(n).population)
    colours_for_node[n] = colours[json.loads(population_data.metadata)["country"]]

individual_for_node = {}
for n in ts_dated.samples():
    individual_data = ts_dated.individual(ts_dated.node(n).individual)
    individual_for_node[n] = json.loads(individual_data.metadata)["name"]

tree = ts_dated.at(4e6)
tree.draw(
    height=700,
    width=1200,
    node_labels=individual_for_node,
    node_colours=colours_for_node
)

Many of these samples are clustering by population label, but not all of them are. Let's look at another tree at a different position where there even fewer sample nodes are clustering by population:

In [None]:
tree = ts_dated.at(2e6)
tree.draw(
    path="tree_at_1Mb.svg",
    height=700,
    width=1200,
    node_labels=individual_for_node,
    node_colours=colours_for_node,
)

Hmm. This one is less clear.

The amount we can learn from any individual tree is perhaps limited, because of the stochasticity of the trees across the genome.

Can we think of some way to summarise whether the Norwegian and French sparrows exhibit population structure, and learn more about the nature of that structure?


Suppose you have $n$ sequences typed at $m$ different sites...

```
   ...GTAACGCGATAAGAGATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGAGATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGTAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGTAATAGCGTA...
```

...and you want to calculate mean pairwise diversity on these samples, i.e.

$$ \pi = \dfrac{1}{n(n-1)/2}\sum_{i=1}^{n-1} \sum_{j=i+1}^n k_{ij}, $$

where $k_{ij}$ is the number of sites at which sequences $i$ and $j$ carry a different allele.
The scaling of this procedure is

$$ O\left( n^2 m \right) $$

ie. quadratic in the number of samples $n$, and linear in the number of sites $m$.
However, there is an equivalent way of performing this calculation by assigning weights to the sample nodes, and propagating these values further up the tree using a 'summary function' at each mutation.

 <img src="pics/worksheet4-node-weights.jpeg" width="350" height="350">
 
 This is what `tskit` calls a *site statistic* calculation, and because the operation is of order
 
 $$ O\left( n + \rho m (\log(n))^2 \right)  << O\left( n^2 m \right), $$
 
the calculation is quick to run, especially on large datasets:

 <img src="pics/worksheet4-stat-speed.jpeg" width="500" height="500">

This is what `tskit` calls a *site statistic* calculation, and because the operation is of order
 
 $$ O\left( n + \rho m (\log(n))^2 \right)  << O\left( n^2 m \right), $$
 
the calculation is quick to run, especially on large datasets:

 <img src="pics/worksheet4-stat-speed.jpeg" width="500" height="500">

See the following paper for more details.

Peter Ralph, Kevin Thornton, Jerome Kelleher, Efficiently Summarizing Relationships in Large Samples: A General Duality Between Statistics of Genealogies and Genomes, Genetics, Volume 215, Issue 3, 1 July 2020, Pages 779â€“797, https://doi.org/10.1534/genetics.120.303253

### The basic syntax: (nucleotide diversity)

`tskit` uses very similar syntax for all of its inbuilt statistics, so we'll explore the options using `diversity()` as an example.

In [None]:
print(ts_dated.diversity())
print(ts_dated.diversity(sample_sets = [ts_dated.samples(0), ts_dated.samples(1)]))

In [None]:
print(ts_dated.divergence(sample_sets=[ts_dated.samples(0), ts_dated.samples(1)]))
print(ts_dated.divergence(sample_sets=[ts_dated.samples(0), ts_dated.samples(0)]))
print(ts_dated.divergence(sample_sets=[ts_dated.samples(1), ts_dated.samples(1)]))

In [None]:
Fst = ts_dated.Fst(sample_sets=[ts_dated.samples(0), ts_dated.samples(1)])

print(Fst)

Remember, $F_{st}$ the the proportion of the total genetic variance that can be attributed to genetic variance within populations. (Closer to 0 --> less differentiation, closer to 1 -- more)

This is pretty small, but we also have high dispersal rates (and potentially large effective population sizes.)

So, yes, more difference between average pairs of individuals of different populations than between the sample populations.
Would also guess from this that the second population is a bit bigger (has larger Ne) than the first.

(Not actually doing inference with this -- but maybe get some suggested ideas.)

A second question -- when, in time, did this split happen approximately?
Would like to know something like -- for each ancestral node, how many present-day descendants are Norwegian versus French?

In [None]:
ts_dated = ts_dated.simplify()
mean_desc = ts_dated.mean_descendants([ts_dated.samples(0), ts_dated.samples(1)])

print(mean_desc)

Each row is a node (haplotype) that is ancestral to the sample, and the columns list the number of samples from each population.

Let's calculate the different between this value and 50%;

In [24]:
# for each row of mean_desc, calculate the absolute value of the difference between the two columns divided by the sum of the two columns
diff_in_anc = []

for row in mean_desc:
    diff_in_anc.append(abs(row[0] - row[1]) / (row[0] + row[1]))

In [None]:
diff_in_anc

In [None]:
# plot diff_in_anc
import matplotlib.pyplot as plt
import numpy as np

# make a scatterplot of diff_in_anc (on y axis) and nodes.tables.time (on x axis)
plt.scatter(np.log10(ts_dated.tables.nodes.time), diff_in_anc, )

Let's do a smaller, neater  version of this first.
If we pull out all of the nodes that have 10 descendants, when does the standard deviation of the number of blue descedants (approximately) match what you'd expect from a binomial distribution?

First, let's extract all of the nodes with 10 leaves in the tree. We can do this with the `mean_desc` object we created before

In [None]:
num_descendants = mean_desc.sum(axis=1).astype(int)
print(num_descendants.shape)

print(num_descendants)
ids_10 = np.where(num_descendants == 10)

print(ids_10)

total_ancestral_nodes = 0
for n in range(0, 21):

    ids = np.where(num_descendants == n)
    total_ancestral_nodes += len(ids[0])

print(total_ancestral_nodes)
print()


Next, we want to extract the values of `mean_desc` and `time` for each of these

wow, this is confusing because everything is not an integer. 

In [None]:
focal_nodes = np.where(num_descendants == 4)[0]
focal_mean_desc = mean_desc[focal_nodes]
focal_times = ts_dated.tables.nodes.time[focal_nodes]

print(focal_mean_desc)

In [None]:
# plot focal_mean_desc (on y axis) and focal_times (on x axis)
plt.scatter(focal_times[:-1], focal_mean_desc[:-1, 0], label="Population 0")