## Setup

To access material for this workbook please execute the two notebook cells immediately below (e.g. use the shortcut <b>&lt;shift&gt;+&lt;return&gt;</b>). The first cell can be skipped if you are running this notebook locally and have already installed all the necessary packages. The second cell should print out "Your notebook is ready to go!"

In [None]:
if 'pyodide_kernel' in str(get_ipython()):  # specify packages to install under JupyterLite
    %pip install -q -r jlite-requirements.txt
elif 'google.colab' in str(get_ipython()):  # specify package location for loading in Colab
    from google.colab import drive
    drive.mount('/content/drive')
    %run /content/drive/MyDrive/GARG_workshop/Notebooks/add_module_path.py
else:  # install packages on your local machine (-q = "quiet": don't print out installation steps)
    # (NB: you can probably ignore any message about restarting the kernel)
    !pip install -q -r https://github.com/ebp-nor/GARG/raw/main/jlite/requirements.txt

In [None]:
# Load questions etc for this workbook
import ARG_workshop
workbook = ARG_workshop.Workbook1A()
display(workbook.setup)

### Using this workbook

This workbook is intended to be used by executing each cell as you go along. Code cells (like those above) can be modified and re-executed to perform different behaviour or additional analysis. You can use this to complete various programming exercises, some of which have associated questions to test your understanding. Exercises are marked like this:
<dl class="exercise"><dt>Exercise XXX</dt>
<dd>Here is an exercise: normally there will be a code cell below this box for you to work in</dd>
</dl>

# Workbook 1-A: _Tskit_ and genetic genealogies

A genealogy is a network of relationships. Where the genealogy describes direct links between a child individual and its parents, we call it a **pedigree**. Where the genealogy describes sections of DNA transmitted between an parent (or ancestral) genome and a child (or descendant) genome, we call it a **genetic genealogy** or (loosely) an ancestral recombination graph or **ARG**.

_Tskit_ is a library for storing genetic genealogies and genetic data in the form of
[succinct tree sequences](https://tskit.dev/tutorials/what_is.html). Tree sequences can be created by many programs, e.g.
simulated using _msprime_ or _SLiM_, or inferred from genetic data.

_Tskit_ is designed for use with very large genetic genealogies, potentially millions of genomes. Deducing an exact history in this case is probably impossible. Instead, analyses often focus on sampling likely genealogies rather than building a single correct one. This leads to a greater focus on methods such as simulation, compared to fields such as phylogenetics.

In this workbook you'll learn the data structures captured by _tskit_ through building your own genealogical simulator. We'll use the simulator as a base to investigate the principles behind ARGs. For simplicity, in this workbook and the next, we'll temporarily ignore recombination, and therefore all regions of the genome will be inherited from the same parent.

## A simple forward-time simulator

We will build a basic forward-time simulator from scratch, storing the resulting genealogy in _tskit_ format. We will use the Wright-Fisher approach, where one generation reproduces, and is then entirely replaced by the resulting children. For the moment we won't model recombination, so the resulting genealogy should take the form of a simple tree.

We will simulate diploids, so that pairs of genomes (nodes) are grouped into a single individual: for an even simpler example of a 20 line haploid tskit simulator, see [this tutorial](https://tskit.dev/tutorials/completing_forward_sims.html).

### Making a tree sequence

You can create an empty set of _tskit_ tables by calling `tskit.TableCollection()`.

In [None]:
import tskit
tables = tskit.TableCollection(sequence_length=100)

<dl class="exercise"><dt>Exercise 1</dt>
<dd>Modify the code below to also print out the number of rows in the <code>individuals</code> table, and check it is zero</dd></dl>

In [None]:
# Exercise 1: modify me
print(
    f"The `tables` object currently has {tables.nodes.num_rows} nodes",
    f"and {tables.edges.num_rows} edges",
)

In [None]:
# Execute code block with <shift>+Return to display question; press on one of the buttons to answer
workbook.question("tables")


#### Adding nodes and individuals

You can add to the tables using the `add_rows` methods, which return a numeric _tskit_ ID for future use (the ID is simply the row number in that table). Numerical IDs are of core importance in _tskit_: nearly all objects will be referred to using their ID, so it's worth getting used to.

For example if we add a new diploid individual, we can use the returned ID when we then create the maternal and paternal genomes (nodes) for that individual.

<div class="alert alert-block alert-info"><b>Note:</b> <em>Tskit</em> IDs start from 0 (not 1). Using 0-based indexing is the norm in Python, but can cause confusion for those who are more used to R.
</div>

In [None]:
for _ in range(8):
    individual_id = tables.individuals.add_row()
    maternal_genome_id = tables.nodes.add_row(individual=individual_id)
    paternal_genome_id = tables.nodes.add_row(individual=individual_id)
    print(
        f"Created a new individual (ID {individual_id}) "
        f"containing nodes {maternal_genome_id} and {paternal_genome_id}"
    )

It's easy to get confused between the IDs in the *individuals* table (which count diploid individuals) and the IDs in the *nodes* table (which count haploid genomes). Each node can (optionally) be associated with individual, by specifying the ID of that individual in a separate column. The individual ID differs from the ID of the node itself.

<div class="alert alert-block alert-info"><b>Note:</b> In the table below, an ID of `-1` is used to denote "NULL": the presence of `-1` in the "population" column means that none of these nodes have been assigned to a specific <em>tskit</em> <a href="https://tskit.dev/tskit/docs/stable/data-model.html#population-table">population</a>.
</div>

<dl class="exercise"><dt>Exercise 2</dt>
<dd>Use <code>display(table.nodes)</code> to show the nodes table you have just created</dd></dl>

In [None]:
# Exercise: print the nodes table
tables.nodes

In [None]:
workbook.question("nodes_table")

#### Creating a population

We'll wrap the code above in a function called `initialize` that places a new population of $N_e$ diploids into a set of tables. We also need two extra features:

1.   All nodes need a _time_. Above, nodes took the default time of 0. In the simulation, we must specify a time, starting at a fixed number of generations ago, counting down until the youngest nodes are created at time 0.
2.   We'll temporarily store the individual and node IDs in a python dictionary, mapping `individual_ID: (maternal_genome, paternal_genome)`. This will be used in the next generation, when choosing which genomes to inherit.

In [None]:
def initialize(tables, diploid_population_size, start_time):
    """
    Save a population to the tskit_tables and return a Python dictionary
    mapping the newly created individual ids to a pair of genomes (node ids)
    """
    temp_pop = {}  # make an empty dictionary
    for _ in range(diploid_population_size):
        # store in the TSKIT tables
        i = tables.individuals.add_row()
        maternal_node = tables.nodes.add_row(time=start_time, individual=i)
        paternal_node = tables.nodes.add_row(time=start_time, individual=i)

        # Add to the dictionary: map the individual ID to the two node IDs
        temp_pop[i] = (maternal_node, paternal_node)
    return temp_pop

In [None]:
# Test it out
tables = tskit.TableCollection(sequence_length=1000)  # Get empty tables
time = 2
current_pop = initialize(tables, diploid_population_size=8, start_time=time)
print(f"Population at time {time}:\n {current_pop}")

For this workbook, we've created a function `basic_genealogy_viz`, that plots out the nodes as a line of dots at a particular time, grouped by individual. Try it now: you should see 5 pairs of blue nodes, at the correct time, grouped into hexagons (representing individuals).

In [None]:
ARG_workshop.basic_genealogy_viz(tables, show_individuals=True)

#### Reproduction

To complete the simulation, we need a way to create children. We simply make an entirely new population of the same size, choosing parents at random from the previous population. This is done in the `reproduce` function below, which takes a previous population (such as that returned by `initialize`) and returns a new one.

The only new things here are the choice of random parents, and the `add_edges()` function which is given a randomly shuffled pair of genomes to pass on from each parent, and simply picks the first (we're ignoring recombination here).
The source of randomness will come from a random number generator ("rng") provided by the `numpy` library. We'll use this, for example, to generate a large 2D array of random parent individual IDs to use as the "mum" and "dad" for each new individual.

In [None]:
import numpy as np

def reproduce(tables, previous_pop, current_time, rng):
    temp_pop = {}
    prev_individual_ids = list(previous_pop.keys())
    for _ in range(len(previous_pop)):
        mum, dad = rng.choice(prev_individual_ids, size=2, replace=False)
        # Same code as before, to make a new population
        i = tables.individuals.add_row()
        maternal_node = tables.nodes.add_row(time=current_time, individual=i)
        paternal_node = tables.nodes.add_row(time=current_time, individual=i)
        temp_pop[i] = (maternal_node, paternal_node)

        # Now add inheritance paths to the edges table, ignoring recombination
        add_edges(tables, rng.permuted(previous_pop[mum]), maternal_node)  # Maternal genome
        add_edges(tables, rng.permuted(previous_pop[dad]), paternal_node)  # Paternal genome
    return temp_pop

def add_edges(tables, randomly_ordered_parent_nodes, child_node):
    # A trivial inheritance scheme: pass on a single genome from the parent (arbitrarily the first)
    parent_node = randomly_ordered_parent_nodes[0]
    tables.edges.add_row(parent=parent_node, child=child_node, left=0, right=tables.sequence_length)

# Try it out, by reproducing from the previously created population
time = time - 1
current_pop = reproduce(tables, current_pop, current_time=time, rng=np.random.default_rng(7))
ARG_workshop.basic_genealogy_viz(tables, show_individuals=True)

You can see some of the genomes in the initial generation 2, reproduced to make children in generation 1. The `basic_genealogy_viz` function highlights these parents in orange.

You could add more generations simply by rerunning the cell above. But it's more useful to place the `initialize` and `reproduce` functions into a single function that performs the entire simulation:

In [None]:
def simulate_WrightFisher(population_size, generations, sequence_length=1000, random_seed=8):
  rng = np.random.default_rng(seed=random_seed)
  tables = tskit.TableCollection(sequence_length=sequence_length)
  tables.time_units = "generations"

  current_population = initialize(tables, population_size, start_time=generations)
  while generations > 0:
      generations = generations - 1
      current_population = reproduce(tables, current_population, generations, rng)

  tables.sort()  # Sort edges into canonical order, required for converting to a tree seq
  return tables.tree_sequence()

And that's it. You have built a forward-time Wright Fisher genealogical simulator!

Notice that instead of returning the raw tables, we convert them into a (read-only) tree sequence. This verifies that the tables represent a sensible genealogy, e.g. that parents are always older than their children.

## The tree sequence object

Let's run our simulator for a few generations to make a tree sequence object. We'll adopt the convention of storing the tree sequence in a variable called `ts` where possible.

In [None]:
# save the tree sequence in a variable called `ts`
ts = simulate_WrightFisher(population_size=8, generations=2)

Below is what the result of this simple 2-generation simulation looks like. The hexagons have been omitted for simplicity. On the right is a plot where we are deliberately not grouping by individual, but reordering the horizontal position of the nodes to make the genealogy.

In [None]:
from matplotlib import pyplot as plt
fig, (ax_lft, ax_rgt) = plt.subplots(1, 2, figsize=(12, 4))  # set up side-by-side plots

ARG_workshop.basic_genealogy_viz(ts, ax_lft, title="Nodes grouped by individual")
ARG_workshop.basic_genealogy_viz(ts, ax_rgt, show_individuals=False, title="Nodes repositioned")

Notice that the blue individuals in older generations have not passed on their genomes. In fact, some of the orange genomes (like 7 and 11) are also "dead ends": they don't have lineages that make it to the current day. We can use tree sequence methods to remove these.

### Basic tree sequence methods

Now that we have a simple tree sequence, we can learn how to manipulate or extract information from it.

#### A summary of the tree sequence

The Jupyter notebook `display()` function shows a tabular summary of the entire tree sequence in a notebook cell:

In [None]:
display(ts)  # by default, display() is run on the last output of a cell, so you could just call `ts` here

#### Tree sequence objects

The _Tskit_ Python interface wraps each table row (i.e. a tree sequence node, edge, individual, or whatever) in a convenient object. You can loop through all the node objects using `ts.nodes()`, or get one of them using `ts.node()`; the same goes for edges, individuals, and so on

In [None]:
print(f"First four of {ts.num_nodes} nodes:")
for nd in ts.nodes():  # Iterate over nodes using ts.nodes()
  print(nd)
  if nd.id == 4:
    print("...")
    break

print(f"The first node is at time={ts.node(0).time}")

<dl class="exercise"><dt>Exercise 3</dt>
<dd>If you want <em>all</em> the node times, you could loop using the <code>nodes()</code> iterator, but the
convenience property <code>ts.nodes_time</code> gives you direct memory access to an array of times, which is much faster. Check the code below gives the expected answer, then also print the <code>edges_child</code> value instead:</dd></dl>

In [None]:
# Exercise: print the the child for all edges as well as the time for all nodes
print(ts.nodes_time)

In [None]:
workbook.question("array_access")

### Tree sequence simplification

We saw that the forward-simulator created "dead end" genomes. These can be removed by the _tskit_ `simplify()` method, which marks a set of nodes as "samples" and removes information that is irrelevant to those samples. In general, a node can be thought of as a "sample" if we wish to keep its full genomic ancestry.

#### Simplifying edges

First we'll focus only on the edges. We'll specify the nodes at time zero as samples, modifying and removing edges that are not ancestral to those sample nodes. We call this **sample resolving**.

The code below runs the simulation for a longer time, and then sample-resolves the resulting tree sequence. We tell the `.simplify()` method to do the minimum required, just removing removing redundant lineages, by giving the the `keep_unary=True` and `filter_nodes=False` options.  Following `tskit` convention, sample nodes are drawn as squares in the right hand plot.

In [None]:
fig, (ax_lft, ax_rgt) = plt.subplots(1, 2, figsize=(10, 8))  # set up side-by-side plots

base_ts = simulate_WrightFisher(10, 30, random_seed=8)

# Remove non-ancestral lineages ("sample resolve") using `simplify`
current_gen_IDs = np.flatnonzero(base_ts.nodes_time == 0)
ts = base_ts.simplify(samples=current_gen_IDs, keep_unary=True, filter_nodes=False)

ARG_workshop.basic_genealogy_viz(
    base_ts, ax_lft, show_node_ids=False, show_individuals=False, title="Base simulation")
ARG_workshop.basic_genealogy_viz(
    ts, ax_rgt, show_node_ids=False, show_individuals=False, title="Sample resolved")
plt.show()
print(f"The 'sample resolved' tree sequence has {ts.num_nodes} nodes")
print(f"The following nodes at time 0 are samples: {ts.samples()}")


#### Simplifying nodes

Sample-resolving reveals a large number of redundant nodes that are not ancestral to the samples (blue circles in the right hand plot). These were kept because we specified `filter_nodes=False` in the `simplify()` command. Normally, as well as sample-resolving, `simplify` removes these unreferenced nodes. This makes the tree sequence much smaller, although the node IDs will change (and `simplify` will reorder the nodes to put the sample nodes first).

In [None]:
ts = base_ts.simplify(samples=current_gen_IDs, keep_unary=True)
print(f"Simplify removed {base_ts.num_nodes - ts.num_nodes} of {base_ts.num_nodes} nodes")
print(f"Sample node IDs are now {ts.samples()}")

Let's look at the resulting genealogy using the [draw_svg()](https://tskit.dev/tutorials/viz.html#svg-format) method (to display this, Colab notebooks require it to be wrapped in an `SVG()` function)

In [None]:
from IPython.display import SVG
SVG(ts.draw_svg(size=(600, 500), y_axis=True))

#### Full simplification

The ancestry above forms a clear tree. However, some of the ancestral nodes have only one child in this tree: genetic information has passed through them unchanged (an example of such a "pass-through" node is that labelled 21). At the cost of losing the chain of direct parent-to-child links, it's possible to replace the edge from 0→21→36 with one directly from 0→36. This is what happens when we simplify without `keep_unary=True`. The result is a much more compact but essentially equivalent tree, that only shows nodes that represent branch points: i.e. those associated with coalescence. In future practicals, we'll see that the process of simplifying a recombinant genealogy (i.e. a graph rather than a tree) is rather more nuanced.

<div class="alert alert-block alert-info"><b>Note:</b> Replacing the edge from <b>node0</b> (time=0) → <b>node21</b> (time=1) → <b>node36</b> (time=2) will create a new edge whose <code>child</code> is node 0 and whose <code>parent</code> is at time 2. Therefore the "child" and "parent" of an edge is meant in a mathematical sense, rather than referring to individuals separated by one generation; the node identified as an edge's <code>parent</code> can be many generations older than the node identifed as the <code>child</code>.
</div>

<dl class="exercise"><dt>Exercise 4</dt>
<dd>Fully simplify the <code>base_ts</code> specifying the genomes at time 0 as samples, and plot it out using <code>draw_svg()</code>. You might also want to specify a <code>size</code> for the plot.</dd></dl>

In [None]:
# Exercise: fully simplify the original tree sequence (base_ts) using current_gen_IDs as samples


In [None]:
workbook.question("simplified_tree_MRCA")

#### Subsetting

You can also use `simplify()` to reduce the tree sequence to showing the genealogy of a much smaller subset of the nodes. For example, we could simplify down to the last three sample genomes. To keep track of the node ids we can ask for the mapping of old node ids to new ones to be returned using the `map_nodes` parameter, then use the old IDs as labels when plotting.

<div class="alert alert-block alert-info"><b>Note:</b> A better way to keep track of which node is which is to add *metadata* to nodes or individuals. We will see how to do this later.</div>

In [None]:
ids = ts.samples()[-3:]
print(f"Simplifying to the ancestry of sample nodes {ids}")
small_ts, node_map = ts.simplify(ids, map_nodes=True)
node_labels = {new_id: u for u, new_id in enumerate(node_map) if new_id != tskit.NULL}
SVG(small_ts.draw_svg(y_axis=True, node_labels=node_labels))

## An improved forward-time simulator

Now that we have introduced the concept of simplification and sample nodes, we can make a few minor improvements to our forward-time simulator. **The code below is mainly for reference: you don't need to look at it in detail**.

The particular, minor changes are:
* We wrap all the code in a Python class, and call `.run(num_gens)` to get a tree sequence after a given number of generations
* During the simulation *all* simulation nodes are flagged as samples (because sample nodes can be usefully thought of as "known genomes", and we do know each genome during simulation), but...
* ... `sim.run(..., simplify=True)` simplifies the tables (by default taking nodes at time=0 as samples).
* For ease of reference, the nodes are reordered using [table.subset](https://tskit.dev/tskit/docs/stable/python-api.html#tskit.TableCollection.subset) with the youngest put first.
* We also save the parent *individual IDs* in the individuals table. As long as we retain individuals when simplifying, this allows us to reconstruct the *pedigree* (see e.g. [here](https://tskit.dev/msprime/docs/stable/pedigrees.html#pedigree-encoding))

We'll be using this simulator in later workbooks, so it's also available as `ARG_workshop.WrightFisherSimulator`

In [None]:
class FwdWrightFisherSimulator:
    def __init__(self, population_size, seq_len=1000, random_seed=8):
        self.flags = tskit.NODE_IS_SAMPLE
        self.rng = np.random.default_rng(seed=random_seed)
        self.tables = tskit.TableCollection(sequence_length=seq_len)
        self.tables.time_units = "generations"
        self.current_population = self.initialize(population_size)

    def run(self, gens, simplify=False, samples=None, **kwargs):
        # NB: assume current_population is at time 0, and count downwards
        # so that generations are negative. On output, rebase the times
        # so the current generation is at time 0
        for neg_gens in -np.arange(gens):
            self.current_population = self.reproduce(self.current_population, neg_gens-1)

        # reorder the nodes so that youngest are IDs 0..n
        self.tables.nodes.time += gens
        self.tables.subset(np.arange(self.tables.nodes.num_rows)[::-1])
        self.tables.sort()  # Sort edges into canonical order, required for converting to a tree seq

        if simplify:
            if samples is None:
                samples = np.flatnonzero(self.tables.nodes.time == 0)
            self.tables.simplify(samples, **kwargs)
        return self.tables.tree_sequence()


    def initialize(self, diploid_population_size):
        """
        Save a population to the tskit_tables and return a Python dictionary
        mapping the newly created individual ids to a pair of genomes (node ids)
        """
        temp_pop = {}  # make an empty dictionary
        for _ in range(diploid_population_size):
            # store in the TSKIT tables
            i = self.tables.individuals.add_row(parents=(tskit.NULL, tskit.NULL))
            maternal_node = self.tables.nodes.add_row(self.flags, time=0, individual=i)
            paternal_node = self.tables.nodes.add_row(self.flags, time=0, individual=i)
            # Add to the dictionary: map the individual ID to the two node IDs
            temp_pop[i] = (maternal_node, paternal_node)
        return temp_pop

    def reproduce(self, previous_pop, current_time):
        temp_pop = {}
        prev_individual_ids = list(previous_pop.keys())
        for _ in range(len(previous_pop)):
            mum, dad = self.rng.choice(prev_individual_ids, size=2, replace=False)
            i = self.tables.individuals.add_row(parents=(mum, dad))
            maternal_node = self.tables.nodes.add_row(time=current_time, individual=i)
            paternal_node = self.tables.nodes.add_row(time=current_time, individual=i)
            temp_pop[i] = (maternal_node, paternal_node)
    
            # Now add inheritance paths to the edges table, ignoring recombination
            self.add_edges(self.rng.permuted(previous_pop[mum]), maternal_node)
            self.add_edges(self.rng.permuted(previous_pop[dad]), paternal_node)
        return temp_pop

    def add_edges(self, randomly_ordered_parent_nodes, child_node):
        parent_node = randomly_ordered_parent_nodes[0]
        L = self.tables.sequence_length
        self.tables.edges.add_row(parent=parent_node, child=child_node, left=0, right=L)


In [None]:
# Test it out
sim = FwdWrightFisherSimulator(4, random_seed=123)
ts = sim.run(20, simplify=True)
SVG(ts.draw_svg(y_axis=True))

In [None]:
# Check that ARG_workshop.WrightFisherSimulator is identical
second_ts = ARG_workshop.FwdWrightFisherSimulator(4, random_seed=123).run(20, simplify=True)

# When you make a tree sequence, by default the time of creation is embedded in its
# "provenence" table, so we often deliberately ignore provenance when testing equality 
assert ts.equals(second_ts, ignore_provenance=True)