# Inference from a test VCF file

Try to infer a tstree object like I did with `nf-treeseq` pipeline but using the 
sample VCF provided in [sheepTSexample](https://github.com/HighlanderLab/sheepTSexample)

In [None]:
import cyvcf2
import tsinfer
import tsdate
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from tskitetude import get_project_dir
from tskitetude.helper import get_chromosome_lengths, add_diploid_sites, create_windows

Configure `tskitetude` module to use `tqdm.notebook`:

In [None]:
from tqdm.notebook import tqdm

import tskitetude.helper

tskitetude.helper.tqdm = tqdm

In [None]:
vcf_file = str(get_project_dir() / "experiments/test1M/test1M.out.inf.vcf.gz")
vcf = cyvcf2.VCF(vcf_file)
chromosome_lengths = get_chromosome_lengths(vcf)

# get first variant like I did in the helper script
variant = next(vcf)
sequence_length = chromosome_lengths[variant.CHROM]

print(f"Getting information for chromosome {variant.CHROM} with length {sequence_length} bp")

Create `SampleData` object. Add additional information (as we did in [tutorial](https://github.com/HighlanderLab/sheepTSexample/blob/main/Notebooks/10Inference.ipynb)):

In [None]:
# reset the VCF file
vcf = cyvcf2.VCF(vcf_file)

with tsinfer.SampleData(
        path = str(get_project_dir() / "experiments/test1M/test1M.out.inf.samples"),
        sequence_length = sequence_length) as samples:

    # add population information (optional)
    samples.add_population(metadata={"name": "Mouflon"})
    samples.add_population(metadata={"name": "Iranian"})
    samples.add_population(metadata={"name": "Border"})

    # add individuals (optional)

    # create a population lookup list to iterate over
    popID = np.repeat([0,1,2], [5, 50, 50]).tolist()

    for i in range(105):
        samples.add_individual(ploidy=2, population=popID[i], metadata={"name": f"tsk_{i}"})

    add_diploid_sites(vcf, samples, {}, allele_chars=set("01"))

print(
    f"Sample file created for {samples.num_samples} samples "
    f"({samples.num_individuals} individuals) "
    f"with {samples.num_sites} variable sites."
)

In [None]:
# Do the inference
sparrow_ts = tsinfer.infer(
    samples,
    num_threads=4
)

# Simplify the tree sequence
ts = sparrow_ts.simplify()

print(
    f"Inferred tree sequence `ts`: {ts.num_trees} "
    f"trees over {ts.sequence_length / 1e6} Mb"
)

In [None]:
ts

In [None]:
# Removes unary nodes (currently required in tsdate), keeps historical-only sites
ts = tsdate.preprocess_ts(ts, filter_sites=False)

ts = tsdate.date(
    ts,
    method="inside_outside",
    mutation_rate=1e-8,
    Ne=1e4
)

ts

dump the *treesequence* file:

In [None]:
# save generated tree
ts.dump(get_project_dir() / "tests/test1M.out.inf.trees")

I want to try to collect the first tree like I did in a previous example:

In [None]:
POS = 108
tree = ts.at(POS)
tree

Now get the intervals of this tree. Then try to filter out edges between those positions:

In [None]:
interval = tree.interval
left_bound = interval.left
right_bound = interval.right

filtered_edges = ts.tables.edges[
    np.logical_and(ts.tables.edges.left >= left_bound, ts.tables.edges.right <= right_bound)]
filtered_edges[:10]

In [None]:
len(filtered_edges)

why so few edges in this cases? How I can draw a tree with so few edges? Maybe the table
itself doesn't model every connection between nodes. Or the way I collect data on edges
is completely wrong.

Can I filter out the nodes in the same way? In this case I don't have a left and right 
position like in the edge table. However, from the edge table I can derive which nodes are
*child* of *parents*:

In [None]:
parents = set(filtered_edges.parent)
childs = set(filtered_edges.child)

node_ids = parents.union(childs)
print(f"Got {len(node_ids)} distinct nodes")


In [None]:
# heavy intensive operation
# tree.draw_svg(
#     size=(800, 400),
#     time_scale="log_time",
# )

## Exploring nucleotide diversity

Calculate diversity *per SNP positions*: use the `create_windows` function and select all the
odd positions:

In [None]:
# remove the 0 values will be enough!
ts_diversity = ts.diversity(windows=create_windows(ts))
ts_diversity = ts_diversity[ts_diversity > 0]
ts_diversity[:10]

Now let's compare the nucleotide diversity calculated using vcftools: here's the 
command line to calculate nucleotide diversity *per site*:

```bash
cd test
vcftools --gzvcf test1M.out.inf.vcf.gz --out allsamples_pi --site-pi
```

The `allsamples_pi.sites.pi` is a *TSV* file with the positions and the nucleotide diversity. Read it with pandas:

In [None]:
vcftools_diversity = pd.read_csv(get_project_dir() / "experiments/test1M/allsamples_pi.sites.pi", sep="\t")
vcftools_diversity.head()

In [None]:

print(f"ts_diversity is {len(ts_diversity)} in size")
print(f"vcftools_diversity is {len(vcftools_diversity)} in size")

this dataframe has more values than ts_diversity, since there are duplicated positions. Since I've
dropped duplicated positions in `create_windows`, I need to drop duplicates in this dataframe. However
the diveristy I can misure in those points will be different: 

In [None]:
vcftools_diversity.drop_duplicates(subset='POS', keep='first', inplace=True)

Are this values similar?

In [None]:
np.isclose(ts_diversity, vcftools_diversity["PI"], atol=1e-6).all()

Calculate diversity using *branch*:

In [None]:
# remove the 0 values will be enough!
ts_diversity_branch = ts.diversity(mode='branch', windows=create_windows(ts))
ts_diversity_branch = ts_diversity_branch[ts_diversity_branch > 0]
print(ts_diversity_branch[:10])
print(f"ts_diversity_branch is {len(ts_diversity_branch)} in size")

In [None]:
print(ts.diversity(mode='branch', windows=create_windows(ts))[:10])
print(create_windows(ts)[:10])

Try to plot the tow different diversities with vcftools output:

In [None]:
plt.scatter(ts_diversity, vcftools_diversity["PI"])

The *branch* nuclueotide diversity need to be fixed: it doesn't have `0` in positions between windows:

In [None]:
# plt.scatter(ts_diversity_branch, vcftools_diversity["PI"])
# plt.xlim(0, 300)