# TSINFER tutorial
## Toy example
Supposing to have phased haplotype data for five samples at six sites like this:

```text
sample  haplotype
0       AGCGAT
1       TGACAG
2       AGACAC
3       ACCGCT
4       ACCGCT
```

Before to derive a `tstree` object that model these data, in need to import data
with `tsinfer`: this requires to know the ancestral alleles first:

In [None]:
import string
import numpy as np
import tsinfer
import cyvcf2
import json
import tsdate

from tqdm.notebook import tqdm
from tskit import MISSING_DATA

from tskitetude import get_data_dir

In [None]:
with tsinfer.SampleData(sequence_length=6) as sample_data:
    sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"], ancestral_allele=0)
    sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"], ancestral_allele=0)
    sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"], ancestral_allele=0)
    sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"], ancestral_allele=MISSING_DATA)
    sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"], ancestral_allele=0)
    sample_data.add_site(5, [0, 1, 2, 0, 0], ["T", "G", "C"], ancestral_allele=0)

`tsinfer.Sampledata` is the object required for inferring a `tstree` object. Using 
the `add_site()` method a can add information for each SNP respectively. The first
argument is the *SNP position*: here for simplicity we track SNP in positional order
but it can be any positive value (even float). The only requirement is that this 
position should be unique and added in increasing order. The 2nd argument is for 
the *genotypes* of each sample in this position: is and index of the allele I can 
find in the 3rd argument. If I have a missing data, I need to use the `tskit.MISSING_DATA`
The last argument is the index of the ancestral allele. Not all the sites are used
to infer the *tree* object: sites with missing data or ancestral alleles or sites with
more than 2 genotypes are not considered by will be modeled in the resulting tree.
Once we have the `SampleData` instance, we can infer a `tstree` object using
`tsinfer.infer`:

In [None]:
ts = tsinfer.infer(sample_data)

This `ts` object is a full *Tree Sequence* object:

In [None]:
ts

This *Tree sequence* object can be analyzed as usual:

In [None]:
print("==Haplotypes==")
for sample_id, h in enumerate(ts.haplotypes()):
    print(sample_id, h, sep="\t")
ts.draw_svg(y_axis=True)

If I understand correctly, `tsinfer` can impute missing data (check this). For the
data I put, there's a *root* note with three *childs*: this is also known as *polytomy*.
Every *internal* node represent an ancestral sequence, By default, the time of those
nodes is not measured in years or generations, but is the frequency of the shared
derived alleles on which the ancestral sequence is based. This is why the time is
*uncalibrated* in the graph above.

In [None]:
# Extra code to label and order the tips alphabetically rather than numerically
labels = {i: string.ascii_lowercase[i] for i in range(ts.num_nodes)}
genome_order = [n for n in ts.first().nodes(order="minlex_postorder") if ts.node(n).is_sample()]
labels.update({n: labels[i] for i, n in enumerate(genome_order)})
style1 = (
    ".node:not(.sample) > .sym, .node:not(.sample) > .lab {visibility: hidden;}"
    ".mut {font-size: 12px} .y-axis .tick .lab {font-size: 85%}")
sz = (800, 250)  # size of the plot, slightly larger than the default

# ticks = [0, 5000, 10000, 15000, 20000]
# get max generations time:
max_time = ts.node(ts.get_num_nodes() - 1).time
ticks = np.linspace(0, max_time, 5)
ts.draw_svg(
    size=sz, node_labels=labels, style=style1, y_label="Time ago (uncalibrated)",
    y_axis=True, y_ticks=ticks)

## Inferring dates

To infer *dates* we can use `tsdate.date` with default parameter and by 
specifying *Effective population sizes* and *mutation rate*:

In [None]:
dated_ts = tsdate.date(ts, mutation_rate=1e-8, Ne=1e4)
dated_ts

In [None]:
dated_ts.draw_svg(y_axis=True, size=(800, 250))

## Data example

This is the [Data example](https://tskit.dev/tsinfer/docs/stable/tutorial.html#data-example)
part of the tutorial:

In [None]:
def add_diploid_sites(vcf, samples):
    """
    Read the sites in the vcf and add them to the samples object.
    """
    # You may want to change the following line, e.g. here we allow
    # "*" (a spanning deletion) to be a valid allele state
    allele_chars = set("ATGCatgc*")
    pos = 0
    progressbar = tqdm(total=samples.sequence_length, desc="Read VCF", unit='bp')

    for variant in vcf:  # Loop over variants, each assumed at a unique site
        progressbar.update(variant.POS - pos)

        if pos == variant.POS:
            print(f"Duplicate entries at position {pos}, ignoring all but the first")
            continue

        else:
            pos = variant.POS

        if any([not phased for _, _, phased in variant.genotypes]):
            raise ValueError("Unphased genotypes for variant at position", pos)

        alleles = [variant.REF.upper()] + [v.upper() for v in variant.ALT]
        ancestral = variant.INFO.get("AA", ".")  # "." means unknown

        # some VCFs (e.g. from 1000G) have many values in the AA field: take the 1st
        ancestral = ancestral.split("|")[0].upper()

        if ancestral == "." or ancestral == "":
            ancestral_allele = MISSING_DATA
            # alternatively, you could specify `ancestral = variant.REF.upper()`

        else:
            ancestral_allele = alleles.index(ancestral)

        # Check we have ATCG alleles
        for a in alleles:
            if len(set(a) - allele_chars) > 0:
                print(f"Ignoring site at pos {pos}: allele {a} not in {allele_chars}")
                continue

        # Map original allele indexes to their indexes in the new alleles list.
        genotypes = [g for row in variant.genotypes for g in row[0:2]]
        samples.add_site(pos, genotypes, alleles, ancestral_allele=ancestral_allele)


def chromosome_length(vcf):
    assert len(vcf.seqlens) == 1
    return vcf.seqlens[0]


# NB: could also read from an online version by setting vcf_location to
# "https://github.com/tskit-dev/tsinfer/raw/main/docs/_static/P_dom_chr24_phased.vcf.gz"
vcf_location =  get_data_dir() / "P_dom_chr24_phased.vcf.gz"
samples_location = get_data_dir() / "P_dom_chr24_phased.samples"

vcf = cyvcf2.VCF(vcf_location)

with tsinfer.SampleData(
    path=str(samples_location), sequence_length=chromosome_length(vcf)
) as samples:
    add_diploid_sites(vcf, samples)

print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

# Do the inference
ts = tsinfer.infer(samples)
print(
    "Inferred tree sequence: {} trees over {} Mb ({} edges)".format(
        ts.num_trees, ts.sequence_length / 1e6, ts.num_edges
    )
)

There's also a parallel version of this `add_diploid_sites` [here](https://github.com/tskit-dev/tsinfer/issues/277#issuecomment-652024871).
Well, until now I added 20 different individuals (with a single chromosome) instead 
of adding 10 diploid individuals. I can change something to add more chromosomes 
to the same individual, and even add other meta information to the three:

In [None]:
def add_populations(vcf, samples):
    """
    Add tsinfer Population objects and returns a list of IDs corresponding to the VCF samples.
    """

    # In this VCF, the first letter of the sample name refers to the population
    samples_first_letter = [sample_name[0] for sample_name in vcf.samples]

    pop_lookup = {}
    pop_lookup["8"] = samples.add_population(metadata={"country": "Norway"})
    pop_lookup["F"] = samples.add_population(metadata={"country": "France"})

    return [pop_lookup[first_letter] for first_letter in samples_first_letter]


def add_diploid_individuals(vcf, samples, populations):
    for name, population in zip(vcf.samples, populations):
        samples.add_individual(ploidy=2, metadata={"name": name}, population=population)


# Repeat as previously but add both populations and individuals
vcf_location =  get_data_dir() / "P_dom_chr24_phased.vcf.gz"
samples_location = get_data_dir() / "P_dom_chr24_phased.samples"

vcf = cyvcf2.VCF(vcf_location)
with tsinfer.SampleData(
        path=str(samples_location), sequence_length=chromosome_length(vcf)
        ) as samples:
    populations = add_populations(vcf, samples)
    add_diploid_individuals(vcf, samples, populations)
    add_diploid_sites(vcf, samples)

print(
    "Sample file created for {} samples ".format(samples.num_samples)
    + "({} individuals) ".format(samples.num_individuals)
    + "with {} variable sites.".format(samples.num_sites),
    flush=True,
)

# Do the inference
sparrow_ts = tsinfer.infer(samples)

print(
    "Inferred tree sequence `{}`: {} trees over {} Mb".format(
        "sparrow_ts", sparrow_ts.num_trees, sparrow_ts.sequence_length / 1e6
    )
)
# Check the metadata
for sample_node_id in sparrow_ts.samples():
    individual_id = sparrow_ts.node(sample_node_id).individual
    population_id = sparrow_ts.node(sample_node_id).population
    print(
        "Node",
        sample_node_id,
        "labels a chr24 sampled from individual",
        json.loads(sparrow_ts.individual(individual_id).metadata),
        "in",
        json.loads(sparrow_ts.population(population_id).metadata)["country"],
    )

## Analysis

Now analyses can be done with `tskit` libraries. I can't show the full *tree sequences*
for this object, I can focus to a segment however:

In [None]:
colours = {"Norway": "red", "France": "blue"}
colours_for_node = {}

for n in sparrow_ts.samples():
    population_data = sparrow_ts.population(sparrow_ts.node(n).population)
    colours_for_node[n] = colours[json.loads(population_data.metadata)["country"]]

individual_for_node = {}
for n in sparrow_ts.samples():
    individual_data = sparrow_ts.individual(sparrow_ts.node(n).individual)
    individual_for_node[n] = json.loads(individual_data.metadata)["name"]

tree = sparrow_ts.at(1e6)
tree.draw(
    height=700,
    width=1200,
    node_labels=individual_for_node,
    node_colours=colours_for_node,
)