# 4. Tree sequence inference and statistics

So far, we've focused on *simulating* data with a tree sequence structure, and we've talked about the benefits of doing so. We've looked at some basic summary statistics of the tree sequences to try and interpret them, but you might be wondering,
 - So far we've been doing all this on simulated datasets. Can I use any of these techniques on my real dataset?
  - what if I want to perform 'conventional' analyses on these? Do I need to convert my tree sequences out to VCFs and then apply all my regular tools, or can I do any of these analyses directly on the tree sequence files?

In both cases, the answer is Yes! 

1. The `tsinfer` and `tsdate` packages let you estimate a tree sequence for your dataset. (Note: these are *estimates* -- they won't be exactly correct. See Brandt et al for some discussion of these, and what types of information may or may not be legit for these)
2. The `tskit` package (which we've already seen and used) has a number of utility functions that let you explore, manipulate and analyse data stored in a tree sequence format. 

In this notebook, we'll touch on all of these to give you a flavour of what is possible with these packages.

 - [4.1 An overview of tsinfer](#5.1Overview)
 - [4.2 Hands on with tsinfer](#5.2HandsOn)



In [48]:
import cyvcf2 # For reading VCF files into Python

import tskit
import tsinfer, tsdate
from IPython.display import SVG
import json
import matplotlib.pyplot as plt

<a id='5.1Overview'></a>
## 4.1 An overview of `tsinfer`

Simulating a tree sequence is relatively simple compared to *inferring* a tree sequence from existing data.
The [tsinfer software](https://tsinfer.readthedocs.io/en/stable/) implements a heuristic algorithm which does this in a scalable manner.

`Tsinfer` (pronounced t-s-infer) is comparable in some ways to other ancestral inference software such as [ARGweaver](https://doi.org/10.1371/journal.pgen.1004342), [Relate](https://myersgroup.github.io/relate/), and [Rent+](https://doi.org/10.1093/bioinformatics/btw735). However, it differs considerably in approach and scalability.
Note that none of these other software packages produce tree sequences as output, although is possible to convert their output to tree sequences.
Also note that `tsinfer` produces trees with a relatively accurate topology, but unlike other ancestral inference tools, it makes no attempt at the moment to produce precise branch length estimates -- for this we need another tool like `tsdate`.

An important restriction is that `tsinfer` requires phased sample sequences with known ancestral states for each variant. It also works better with full sequence data than with data from scattered target SNPs (e.g. as obtained from SNP chips).



We'll be looking at a dataset of variants from sparrow chromosome 24 kindly shared by Mark Ravinet. There are 10 individuals represented in the samples: 5 from Norway and 5 from France.

### 5.1.1 Preparing the data for `tsinfer`

To prepare our data for `tsinfer`,
we need to make a `SampleData` object with information about the individuals and phased, bi-allelic sites we wish to use for inference.

These things are typically done with the `add_individual()` and `add_sites()` methods.
To save us some time, let's define some helper functions (taken from the tsinfer website) that will make this object for us from a `cyvcf2` object.

In [22]:
# Some functions for importing data from VCF --> tsinfer format.
# Properties these functions preserve:
# - Our data is diploid
# - Our individuals belong to two distinct populations, and we want our sample data to include that information.
# - we only want the inference to use alleles at SNVs and deletions for now (others are possible!)

def add_diploid_sites(vcf, samples):
    """
    Read the sites in the vcf and add them to the samples object.
    """
    # You may want to change the following line, e.g. here we allow
    # "*" (a spanning deletion) to be a valid allele state
    allele_chars = set("ATGCatgc*")
    pos = 0
    # progressbar = tqdm(total=samples.sequence_length, desc="Read VCF", unit='bp')
    for variant in vcf:  # Loop over variants, each assumed at a unique site
        # progressbar.update(variant.POS - pos)
        if pos == variant.POS:
            print(f"Duplicate entries at position {pos}, ignoring all but the first")
            continue
        else:
            pos = variant.POS
        if any([not phased for _, _, phased in variant.genotypes]):
            raise ValueError("Unphased genotypes for variant at position", pos)
        alleles = [variant.REF.upper()] + [v.upper() for v in variant.ALT]
        ancestral = variant.INFO.get("AA", ".")  # "." means unknown
        # some VCFs (e.g. from 1000G) have many values in the AA field: take the 1st
        ancestral = ancestral.split("|")[0].upper()
        if ancestral == "." or ancestral == "":
            ancestral_allele = tskit.MISSING_DATA
            # alternatively, you could specify `ancestral = variant.REF.upper()`
        else:
            ancestral_allele = alleles.index(ancestral)
        # Check we have ATCG alleles
        for a in alleles:
            if len(set(a) - allele_chars) > 0:
                print(f"Ignoring site at pos {pos}: allele {a} not in {allele_chars}")
                continue
        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [g for row in variant.genotypes for g in row[0:2]]
        samples.add_site(pos, genotypes, alleles, ancestral_allele=ancestral_allele)


def chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]

def add_populations(vcf, samples):
    """
    Add tsinfer Population objects and returns a list of IDs corresponding to the VCF samples.
    """
    # In this VCF, the first letter of the sample name refers to the population
    samples_first_letter = [sample_name[0] for sample_name in vcf.samples]
    pop_lookup = {}
    pop_lookup["8"] = samples.add_population(metadata={"country": "Norway"})
    pop_lookup["F"] = samples.add_population(metadata={"country": "France"})
    return [pop_lookup[first_letter] for first_letter in samples_first_letter]


def add_diploid_individuals(vcf, samples, populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy=2, metadata={"name": name}, population=population)

We'll use these helper functions to read in information about our sparrow sample from the VCF file and save it into a `tsinfer.SampleData` object.
First, we add in information about our populations.
Then we add information about the sampled individuals in those populations,
and finally then add the genotypes held by those individuals at each site.

In [None]:
vcf_location = "P_dom_chr24_phased.vcf.gz"
vcf = cyvcf2.VCF(vcf_location)
with tsinfer.SampleData(
    path="P_dom_chr24_phased.samples", sequence_length=chromosome_length(vcf)
) as samples:
    populations = add_populations(vcf, samples)
    add_diploid_individuals(vcf, samples, populations)
    add_diploid_sites(vcf, samples)


In [None]:
print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

### 5.1.2 Inferring a tree sequence with `tsinfer` and `tsdate`

We're now ready to run `tsinfer`:

In [None]:
ts_undated = tsinfer.infer(samples)
ts_undated

As usual, there's some useful information about the fact that we generated this with tsinfer in the provenance section of the table.

Do you notice anything different about this tree sequence compared with the (simulated) ones that we've been working with up to this point? 
1. "Time Units: uncalibrated". That's because tsinfer infers the topologies for these trees, but doesn't infer times for them. For this, we'll apply `tsdate`.
2. More sites than mutations. In the process of generating the trees, `tsinfer` also generated a bunch of ancestral haplotypes that *didn't* contribute any info to the final genotypes we see.

We'll fix (2) by applying `simplify()` to remove some 'stray' unary nodes and sites outputted by tsinfer:


In [None]:
# mutation rate is given as the probability of mutations per base per generation
tss = ts_undated.simplify()
tss

 Then we'll run `tsdate` to obtain a time-calibrated version of this inferred tree sequence using an estimated mutation rate of `1e-8` per base per generation.


In [None]:
ts_dated = tsdate.date(tss, mutation_rate=1e-8)
ts_dated

Note that each of these steps has added an extra row into the provenance table too.

Let's print some trees to see what this output looks like:

In [None]:
colours = {"Norway": "red", "France": "blue"}
colours_for_node = {}
for n in ts_dated.samples():
    population_data = ts_dated.population(ts_dated.node(n).population)
    colours_for_node[n] = colours[json.loads(population_data.metadata)["country"]]

individual_for_node = {}
for n in ts_dated.samples():
    individual_data = ts_dated.individual(ts_dated.node(n).individual)
    individual_for_node[n] = json.loads(individual_data.metadata)["name"]

tree = ts_dated.at(4e6)
tree.draw(
    height=700,
    width=1500,
    node_labels=individual_for_node,
    node_colours=colours_for_node
)

Many of these samples are clustering by population label, but not all of them are. Let's look at another tree at a different position where there even fewer sample nodes are clustering by population:

In [None]:
tree = ts_dated.at(2e6)
tree.draw(
    path="tree_at_1Mb.svg",
    height=700,
    width=1500,
    node_labels=individual_for_node,
    node_colours=colours_for_node,
)

The amount we can learn from any individual tree is perhaps limited, because of the stochasticity of the trees across the genome.

Can we summarise whether the Norwegian and French sparrows exhibit population structure, and learn more about the nature of that structure?


Note: I know that we're getting to a point where printing individual trees might be a little overwhelming and hard to interpret.
The `tsbrowse` application provides a number of useful functions that might help you visualise and summarise these larger tree sequences.
See [here](https://tskit.dev/software/tsbrowse.html) for more information.

## 4.2 Summary statistics with tskit

Suppose you have $n$ sequences typed at $m$ different sites...

```
   ...GTAACGCGATAAGAGATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGAGATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGAAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGTAATAGCGTA...
   ...GTAACGCGATAAGATATTAGCCCAAAAACACAGACATGGTAATAGCGTA...
```

...and you want to calculate mean pairwise diversity on these samples, i.e.

$$ \pi = \dfrac{1}{n(n-1)/2}\sum_{i=1}^{n-1} \sum_{j=i+1}^n k_{ij}, $$

where $k_{ij}$ is the number of sites at which sequences $i$ and $j$ carry a different allele.
The scaling of this procedure is

$$ O\left( n^2 m \right) $$

ie. quadratic in the number of samples $n$, and linear in the number of sites $m$.
However, there is an equivalent way of performing this calculation by assigning weights to the sample nodes, and propagating these values further up the tree using a 'summary function' at each mutation.

 <img src="pics/worksheet4-node-weights.jpeg" width="350" height="350">
 
 This is what `tskit` calls a *site statistic* calculation, and because the operation is of order
 
 $$ O\left( n + \rho m (\log(n))^2 \right)  << O\left( n^2 m \right), $$
 
the calculation is quick to run, especially on large datasets:

 <img src="pics/worksheet4-stat-speed.jpeg" width="500" height="500">

See the following paper for more details.

Peter Ralph, Kevin Thornton, Jerome Kelleher, Efficiently Summarizing Relationships in Large Samples: A General Duality Between Statistics of Genealogies and Genomes, Genetics, Volume 215, Issue 3, 1 July 2020, Pages 779–797, https://doi.org/10.1534/genetics.120.303253

### The basic syntax: (nucleotide diversity)

`tskit` uses very similar syntax for all of its inbuilt statistics, so we'll explore the options using `diversity()` as an example.

In [None]:
ts_dated.diversity()

By default, `tskit` presents a normalised version of the statistic  scaled by the length of the region represented in `ts`. This allows you to make comparisons between different tree sequences that may be of different lengths. However, this isn't how all other genetic software computes diversity -- if you wish to disable `tskit`'s default behaviour, use the `span_normalise` argument.

In [None]:
div = ts_dated.diversity(span_normalise=False)
div[()]

### Calculating statistics on subsets of the samples

Remember that this dataset consists of samples from two contemporary populations here, of different sizes. We’d expect them to each have different diversity levels, and for these to differ from the overall (sample-wide) diversity rate. We can get this information out by specifying each of these with the `sample_nodes` argument.

A quick and simple way to get all of the sample node IDs from a particular population is to use the `samples()` method. For instance, the following code returns a numpy array holding all of the sample node IDs from 'population 0':

In [None]:
samples_of_interest=[ts_dated.samples(population=0),
                          ts_dated.samples(population=1),
                          ts_dated.samples()]

ts_dated.diversity(sample_sets=samples_of_interest)

The output is a 1-dimensional numpy array, where each element is a diversity statistic value for one of the sample sets we specified.
Nucleotide diversity is lowest in the set of samples from the small population, and largest in the pooled set of samples, as you'd expect.

Note that you can use any list of node IDs as inputs to `sample_sets`. This may be useful if your samples of interest correspond to something other than populations (for instance, samples that hold some phenotype of interest).

In [None]:
ts_dated.diversity(sample_sets=[[12, 14, 18, 7, 8],
                          [3, 19, 10]])

### Genome scans

So far, we’ve just been calculating statistics summarising diversity values along the entire simulated genome. However, in many cases, we might be more interested in how diversity varies along the genome. We can do this using the `windows` argument.

We specify the start and end points of the sequence, and the locations of the breakpoints between each window. For instance, suppose we wanted to specify some windows of length 1Mb covering our 50Mb chromosome:

In [None]:
breakpoints = [i*1e5 for i in range(0, int(ts_dated.sequence_length/1e5))] + [ts_dated.sequence_length] 
print(breakpoints)

In [None]:
div = ts_dated.diversity(sample_sets=samples_of_interest, windows=breakpoints)
print("Dimensions of the output:", div.shape, "\n")
print("Diversity values over the first 10 windows in each sample set:")
print(div[:10,:])

The output holds one row for each of our specified windows,
and each element of the row holds diversity value in some particular window amongst one of our sample sets.
Let’s plot these:


In [None]:
names_to_plot = ['Norway', 'France', 'ALL']
lines = plt.plot(breakpoints[:-1], div)
plt.grid(alpha=0.5)
plt.legend(lines, names_to_plot);
plt.xticks()
plt.xlabel("Position on chromosome")
plt.title("Windowed diversity values")

<a id='4.2BranchStatistics'></a>
### 4.2.2 Branch statistics and the 'duality' of  tree-based statistics

There are several different types of randomness in genetic models that interact with each other in complex ways.
In addition to randomness in the genealogical trees that are produced in a given demographic scenario,
there is also randomness caused by the mutational process.

When you calculate site statistics, or anything based on allele frequencies, *both* of these processes contribute to the statistical noisiness you see.

However in tree sequences, you have information about branches which allows you to bypass this latter type of mutation.
Instead of moving upwards along the trees and updating the statistic every time you come across a mutation, you can update the statistic based on the lengths of the branches.
This should have some correspondence with the *number of mutations we may expect*. (This should certainly be true in simulated datasets, where we are certain of the correctness of the underlying trees).

This is the basic idea behind the *branch statistics* in `tskit`.

Here are the diversity stats we looked at before, this time with the branch versions included. 

In [None]:
div_branch = ts_dated.diversity(
    sample_sets=samples_of_interest,
    windows=breakpoints,
    mode='branch')
names_to_plot = ['Norway', 'France', 'ALL']
lines = plt.plot(breakpoints[:-1], div_branch)
plt.grid(alpha=0.5)
plt.legend(lines, names_to_plot);
plt.title("Windowed diversity (branch)")

## Extra material -- GNNs

Consider removing, or turning into a supplementary notebook -- I think we just won't have time for this!

In [None]:

def find_nearest_neighbour_populations(t, focal_sample):
    """
    Find the nearest neighbour populations for a given focal sample.
    """
    p = t.parent(focal_sample)
    if p == -1:
        return (0, 0)
    children = [c for c in t.leaves(u = p) if c != focal_sample]
    relative_populations = [ts_dated.node(c).population for c in children]  
    # Calculate the proportion of children in each population
    pop_proportions = {}
    for pop in relative_populations:
        if pop not in pop_proportions:
            pop_proportions[pop] = 0
        pop_proportions[pop] += 1
    # Count how many of the items in relative_populations are 0
    count_pop_0 = sum(1 for pop in relative_populations if pop == 0)
    count_pop_1 = len(relative_populations) - count_pop_0
    # Calculate proportion of children in each population
    prop_pop_0 = count_pop_0 / len(relative_populations)
    prop_pop_1 = count_pop_1 / len(relative_populations)

    return prop_pop_0, prop_pop_1


# test
t = ts_dated.at(3.49e6)
focal_sample = 0

print(find_nearest_neighbour_populations(t, focal_sample))

In [60]:
def plot_interval(t, focal_sample, seq_length, ax):
    nns = find_nearest_neighbour_populations(t, focal_sample)
    i = t.interval
    ax.add_patch(plt.Rectangle((i[0]/seq_length, 0), (i[1] - i[0])/seq_length, nns[0], color="red"))
    ax.add_patch(plt.Rectangle((i[0]/seq_length, nns[0]), (i[1] - i[0])/seq_length, nns[0] + nns[1], color="blue"))


In [None]:
for i in ts_dated.individuals():
    print(i)

In [None]:

def plot_proportions_for_individual(tree_seq, individual_id):

    for i in tree_seq.individuals():
        if i.id == individual_id:
            node1, node2 = i.nodes
            break

    seq_length = tree_seq.sequence_length

    # Plot it all
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 2.5))
    # for ax in [ax1, ax2]:
        # ax.set_xticks([])
        # ax.set_xticklabels([])
    for t in tree_seq.trees(sample_lists=True):
        plot_interval(t, node1, seq_length, ax1) # CHANGE focal sample
    ax1.set_yticks([])
    ax1.set_yticklabels([])
    ax1.set_ylabel("Hap 1")
    # Repeat the process for the second focal node
    for t in tree_seq.trees(sample_lists=True):
        plot_interval(t, node2, seq_length, ax2) # CHANGE focal sample
    ax2.set_yticks([])
    ax2.set_yticklabels([])
    ax2.set_ylabel("Hap 2")

    # Add legend
    colors = ['red', 'blue']
    labels = ['Norway', 'France']
    handles = [plt.Rectangle((0, 0), 1, 1, color=color) for color in colors]
    ax2.legend(handles, labels, loc='upper right', bbox_to_anchor=(1.2, 1.5))

    # Add overall title
    fig.suptitle(f"Nearest neighbour populations, sample {individual_id}", fontsize=16)

    plt.show()


# test
plot_proportions_for_individual(ts_dated, 0)

In [None]:
plot_proportions_for_individual(ts_dated, 1)

In [None]:
plot_proportions_for_individual(ts_dated, 2)

In [None]:
plot_proportions_for_individual(ts_dated, 3)

In [None]:
plot_proportions_for_individual(ts_dated, 4)


In [None]:
plot_proportions_for_individual(ts_dated, 5)


In [None]:
plot_proportions_for_individual(ts_dated, 6)


In [None]:
plot_proportions_for_individual(ts_dated, 7)


In [None]:
plot_proportions_for_individual(ts_dated, 8)


In [None]:
plot_proportions_for_individual(ts_dated, 9)
