# Forward-time simulation with tskit

This tutorial shows the basics of creating a basic forward-time simulator that stores the evolving genealogy as a [tskit](https://tskit.dev/) tree sequence. We will focus on the case of diploids, in which each individual contains 2 genomes, but the concepts used here generalize to any ploidy, if you are willing to do the book-keeping! 

:::{note}
The distinction between an *individual* and the *genomes* it contains is an important one. In fact, individuals are not strictly necessary for representing genetic genealogies (it's the genomes which are important), but the simulator needs to account for individuals, at least temporarily.
:::

## Definitions

Before we can make any progress, we require a few definitions.

A *node* represents a genome at a point in time (often we imagine this as the "birth time" of the genome).  It can be described by a tuple, `(id, time)`, where `id` is a unique integer, and `time` reflects the birth time of that `id`. When generating a tree sequence, this will be stored in a row of the [node table](https://tskit.dev/tskit/docs/stable/data-model.html#node-table).

A *diploid individual* is a group of two nodes. During simulation, a simple and efficient grouping assigns sequential pairs of node IDs to an individual. It can be helpful (but not strictly necessary) to store individuals within the tree sequence as rows of the the [individual table](https://tskit.dev/tskit/docs/stable/data-model.html#individual-table) (a node can then be assigned to an individual by storing that individual's id in the appropriate row of the node table).

An *edge* reflects a transmission event between nodes.  An edge is a tuple `(Left, Right, Parent, Child)` whose meaning is "Parent genome $P$ passed on the genomic interval $[L, R)$ to child genome $C$". In a tree sequence this is stored in a row of the [edge table](https://tskit.dev/tskit/docs/stable/data-model.html#edge-table)

The *time*, in the discrete-time Wright-Fisher (WF) model which we will simulate, is measured in integer generations. To match the tskit notion of time, we record time in *generations ago*: i.e. for a simple simulation of G generations, we start the simulation at generation $G-1$ and count down until we reach generation 0 (the current-day).

The *population* consists of $N$ diploid individuals ($2N$ nodes) at a particular time $t$. At the start, the population will have no known ancestry, but subsequently, each individual will be formed by choosing (at random) two parent individuals from the population in the previous generation.


## Approach

The goal of this tutorial is to work through the book-keeping required to generate edges, nodes, and individuals forwards in time, adding them to the relevant `tskit` tables. To aid efficiency, we will also see how to "simplify" the tables into the minimal set of nodes and edges that describe the history of the sample. Finally, these tables can be exported into an immutable [tree sequence](https://tskit.dev/tutorials/what_is.html) for storing or analysis.

### Setup
First, we'll import the necessary libraries and define some general parameters. The [numpy](https://numpy.org/doc/stable/) library will be used to produce random numbers.

In [None]:
import tskit
import numpy as np

random_seed = 123
random = np.random.default_rng(random_seed)  # A random number generator for general use

diploid_population_size = 6  # We can make this bigger later
sequence_length = 50_000  # 50 Kb

## Simulation without recombination

We'll start by simulating a small region of a larger genome (e.g. a "gene", or a portion of a gene).
We will assume the region is small enough that there is no recombination. The first building block
is to define how one of the child's genomes in our simulation, say the
maternal one, is created from the two genomes that were present in the individual mother. With no recombination, we can simply pick one of the mother's genomes at random, and save the inheritance paths in the [edge table](https://tskit.dev/tskit/docs/stable/data-model.html#edge-table). We can then call the same function for the paternal genome.

In [None]:
focal_region = [20_000, 21_000]
def add_inheritance_paths(tables, parent_genomes, child_genome):
    "Add inheritance paths from a randomly chosen parent genome to the child genome"
    left, right = focal_region  # only define inheritance in this focal region
    inherit_from = random.integers(2)  # randomly chose 0 or 1
    tables.edges.add_row(left, right, parent_genomes[inherit_from], child_genome)

# Should we test this routine somehow?

The function that generates a new population will repeats the following steps:
1. create a new child individual and two genomes (the individual ID will be created by adding a row to the [individual table](https://tskit.dev/tskit/docs/stable/data-model.html#sec-individual-table-definition) and two nodes IDs will be created by adding two rows to the [node table](https://tskit.dev/tskit/docs/stable/data-model.html#node-table)).
2. Select parents for the two genomes
3. Add the inheritance paths for each genome using the `add_inheritance_paths()` function we just created

For convenience, we keep the population stored in a Python dictionary which maps the individual ID to the IDs of its two genomes.

In [None]:
def new_population(tables, time, prev_population=None) -> dict:
    """
    Returns a Python dict of length `diploid_population_size` representing a population,
    optionally derived from a prev_population of the same form. Populations look like
    {individual_ID: (maternal_genome_ID, paternal_genome_ID), ...}
    """
    new_population = {}
    if prev_population is not None:
        # For efficiency, cache an array of individual IDs from the prev population
        prev_individuals = np.array([i for i in prev_population], dtype=np.int32)

    for _ in range(diploid_population_size):
        # 1. Pick two individual parent IDs at random, `replace=True` allows selfing
        if prev_population:
            parent_individual_ids = random.choice(prev_individuals, 2, replace=True)
        else:
            parent_individual_ids = None

        # 2. Get a new individual ID and two new genome IDs
        i = tables.individuals.add_row(parents=parent_individual_ids)
        child_genomes = (
            tables.nodes.add_row(tskit.NODE_IS_SAMPLE, time, individual=i),
            tables.nodes.add_row(tskit.NODE_IS_SAMPLE, time, individual=i),
        )
        new_population[i] = child_genomes  # store the genome IDs
        
        if prev_population:
            # 2. Pick two individual parent IDs at random, `replace=True` allows selfing
            mother_and_father = random.choice(prev_individuals, 2, replace=True)
            for child_genome, parent_individual in zip(child_genomes, mother_and_father):
                parent_genomes = prev_population[parent_individual]
                # 3. Add inheritance paths
                add_inheritance_paths(tables, parent_genomes, child_genome)
    return new_population

<div class="alert alert-block alert-info">
    <b>Note:</b> 
For simplicity, the code above assumes any parent can be a mother or a father (i.e. this is a hermaphrodite species). It also allows the same parent to be chosed as a mother and as a father (i.e. "selfing" is allowed), which gives simpler theoretical results. This is easy to change if required.
</div>

Our forward-in-time simulator simply involves repeatedly running the `new_population()` routine, replacing the old population with the new one. For efficiency reasons, `tskit` has strict requirements for the order of edges in the edge table, so we need to [sort](https://tskit.dev/tskit/docs/stable/python-api.html?highlight=sort#tskit.TableCollection.sort) the tables before we output the final tree sequence.

In [None]:
def simple_diploid_sim(generations) -> tskit.TreeSequence:
    tables = tskit.TableCollection(sequence_length)
    tables.time_units = "generations"  # optional, but helpful when plotting

    population = new_population(tables, generations)  # initial population
    while generations > 0:
        generations = generations - 1
        population = new_population(tables, generations, population)

    tables.sort()
    return tables.tree_sequence()


### Test it for a single generation
ts = simple_diploid_sim(generations=1)
ts.draw_svg(y_axis=True, size=(500, 200))

It looks like it is working correctly: all 12 genomes in the current generation (time=0) trace back to a genome in the initial generation (time=1). Note that not all individuals in the initial generation have passed on genetic material at this genomic position (they appear as isolated nodes at the top of the plot).

Now let's simulate for a longer time period, and set a few helpful plotting parameters.

:::{note}
By convention we plot the most recent generation at the bottom of the plot (i.e. perversely, each "tree" has leaves towards the bottom, and roots at the top)
:::

In [None]:
ts = simple_diploid_sim(generations=15)


graphics_params = {
    "y_axis": True,
    "y_label": f"Time ({ts.time_units} ago)",
    "y_ticks": {i: 'Current' if i==0 else str(i) for i in range(16)},
}
ts.draw_svg(size=(1200, 350), **graphics_params)

This is starting to look like a real genealogy! But you can see that there are a lot of lineages that have not made it to the current day...

## Simplification

The key to efficent forward-time genealogical simulation is the process of [simplification]((https://tskit.dev/tutorials/simplification.html)), which can reduce much of the complexity shown in the tree above. Typically, we want to remove all the lineages that do not contribute to the current day genomes. We do this via the :meth:`~tskit.TreeSequence.simplify` method, specifying that only the nodes in the current generation are "samples".

In [None]:
current_day_genomes = ts.samples(time=0)
simplified_ts = ts.simplify(current_day_genomes, keep_unary=True, filter_nodes=False)
simplified_ts.draw_svg(size=(600, 300), **graphics_params)

### ID changes

We just simplified with `filter_nodes=False`, meaning that the tree sequence retained all nodes even after simplification. However, many nodes are not longer part of the genealogy; removing them means we can store fewer nodes (although it will change the node IDs).

In [None]:
simplified_ts = ts.simplify(current_day_genomes, keep_unary=True)
simplified_ts.draw_svg(size=(600, 300), **graphics_params)

Note that the list of nodes passed to `simplify` (i.e. the current-day genomes) have become the first nodes in the table, numbered from 0..11, and the remaining nodes have been renumbered from youngest to oldest.

### Extra node removal

The `keep_unary=True` parameter meant that we kept intermediate ("unary") nodes, even those that do not not represent branch-points in the tree. Often these are also unneeded, and by default we remove those too; this will mean that the node IDs of older nodes will change again

In [None]:
simplified_ts = ts.simplify(current_day_genomes)
simplified_ts.draw_svg(size=(400, 300), y_axis=True)

This is now looking much more like a "normal" genetic genealogy (a "gene tree"), in which all the sample genomes trace back to a single common ancestor.

## Multiple roots

If we run the simulation for fewer generations, we are not guaranteed to create genomes that share a common ancestor within the timeframe of our simulation.

In [None]:
ts = simple_diploid_sim(generations=5)
ts.draw_svg(size=(700, 200), y_axis=True)

Even the simplified version doesn't look quite like a normal "tree", as it contains several unlinked topologies. In `tskit` we call this a single tree with [multiple roots]():

In [None]:
simplified_ts = ts.simplify(ts.samples(time=0))
simplified_ts.draw_svg(size=(700, 200), y_axis=True)

When a forward-simulated tree has multiple roots, it can be useful to retain relevant lineages all the way back to the start of the simulation. This can be done using the `keep_input_roots` option:

In [None]:
simplified_ts = ts.simplify(ts.samples(time=0), keep_input_roots=True)
simplified_ts.draw_svg(size=(700, 200), y_axis=True)

## Recombination

It is relatively easy to modify the simulation code to allow recombination. All we need to do is to redefine the `add_inheritance_paths()` function, so that the child inherits a mosaic of the two genomes present in each parent.

Below is a redefined function which selects a set of "breakpoints" along the genome. It then allocates the first edge from zero to breakpoint 1 pointing it to one parent genome, and then allocates a second edge from breakpoint 1 onwards pointing to the other parent genome. If there is a second breakpoint, a third edge is created from breakpoint 2 to the next breakpoint that points back to the initial parent genome, and so forth, up to the end of the sequence. Biologically, recombination rates are such that they usually result in a relatively small number of breakpoints per chromosome (in humans, around 1 or 2).

:::{note}
Here we chose breakpoint positions in continuous space ("infinite breakpoint positions"), to match population genetic theory, although it is relatively easy to alter this to recombinations at integer positions
:::

In [None]:
recombination_rate = 5e-7

def add_inheritance_paths(tables, parent_genomes, child_genome):
    "Add paths from parent genomes to the child genome, with crossover recombination"
    L = tables.sequence_length
    num_recombinations = random.poisson(recombination_rate * L)
    breakpoints = random.uniform(0, L, size=num_recombinations)
    breakpoints = np.concatenate(([0], np.unique(breakpoints), [L]))
    inherit_from = random.integers(2)  # starting parental genome

    # iterate over pairs of ([0, b1], [b1, b2], [b2, b3], ... [bN, L])
    for left, right in zip(breakpoints[:-1], breakpoints[1:]):
        tables.edges.add_row(
            left, right, parent_genomes[inherit_from], child_genome)
        inherit_from = 1 - inherit_from  # switch to other parent genome


# Simulate a few generations, for testing
ts = simple_diploid_sim(generations=5)  # Now includes recombination
ts  # Show the tree sequence

You can see that recombination has lead to more than one tree. In fact, there are 2 "local" trees along the genome. Here's how the full (unsimplified) genealogy looks:

In [None]:
ts.draw_svg(size=(1000, 300), **graphics_params)

This is rather confusing to visualise, and will get even worse if we simulate more generations. However, even with more generations, the act of simplification allows us to to reduce the genealogy to something more managable, both for analysis and for visualization:

In [None]:
ts = simple_diploid_sim(generations=50)
simplified_ts = ts.simplify(ts.samples(time=0), keep_input_roots=True)
graphics_params["y_ticks"] = [0, 10, 20, 30, 40 ,50]
simplified_ts.draw_svg(size=(1000, 300), time_scale="log_time", **graphics_params)

### Subsampling the population

We have only simulated a relatively small population size (6 diploids). We can easily simulate a much larger population by setting the global `diploid_population_size` variable to (say) 500. The resulting simplified tree sequence will be reasonably small in terms of disk and memory storage, but will be problematic to visualise.

This is where `simplify` can come in handy again: in this case, we can use it to reduce the genealogy to a handful of (hopefully representative) current-day genomes, or possibly a handful of current-day individuals (i.e. retaining both genomes from a randomly selected set of individuals:

In [None]:
diploid_population_size = 250
gens = 500
large_ts = simple_diploid_sim(generations=gens) # May take a minute or two
print(
    f"Finished simulating {diploid_population_size} individuals",
    f"({diploid_population_size * 2} genomes)",
    f"for {gens} generations",
)

In [None]:
print(f"Full tree sequence including dead lineages: {large_ts.nbytes/1024/1024:.2f} MB")
current_day_genomes = large_ts.samples(time=0)
simplified_ts = large_ts.simplify(current_day_genomes, keep_input_roots=True)
print(
    f"Tree sequence of current-day individuals: {simplified_ts.nbytes/1024/1024:.2f} MB,",
    f"{simplified_ts.num_trees} trees."
)

Even the simplified genealogy will consist of hundreds of trees, each with a thousand tips. One way to reduce this for plotting is to select a set of genomes from randomly chosen current-day individuals, and plot only a small region. See the [visualization tutorial](https://tskit.dev/tutorials/viz.html) for other options.

In [None]:
# Select e.g. 6 randomly chosen individuals for display purposes
# NB: understanding the code below requires some knowledge of numpy
current_day_individuals = simplified_ts.nodes_individual[simplified_ts.samples()]
use = random.choice(np.unique(current_day_individuals), 6, replace=False)  # Chose 6
print("Plotting individuals with these IDs in the simplified ts:", use)

# Find the genomes corresponding to these individuals
genomes_to_use = np.isin(simplified_ts.nodes_individual, use)
selected_genomes = np.where(genomes_to_use)[0]
print("These correspond to the genomes with these IDs:", selected_genomes)

representative_ts = simplified_ts.simplify(
    selected_genomes,
    keep_input_roots=True,
    filter_nodes=False,  # Keep the node IDs of the simplified_ts, to compare
)

# plot a short region of genome, using some plot tweaks
representative_ts.draw_svg(
    size=(1000, 300),
    x_lim=[10_000, 11_000],
    y_axis=True,
    time_scale="log_time",
    y_ticks = [0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000],
    style=(
        ".node > .lab {font-size: 80%}"
        ".leaf > .lab {text-anchor: start; transform: rotate(90deg) translate(6px)}"
    )
)

## Ensuring coalescence

You can see that some of these strees still have multiple roots. In other words, 1000 generations is not long enough to capture the ancestry back to a single common ancestor (i.e. to ensure "full coalescence" of all local trees). If the local trees have not all coalesced, then the simulation will be failing to capture the entire genetic diversity within the sample. Moreover, the larger the populations, the longer the time needed to ensure that the full genealogy is captured. For large models, time period required for forward simulations to ensure full coalescence can be prohibitive.

A powerful way to get around this problem is *recapitation*, in which an alternative technique, such as backward-in-time coalescent simulation is used to to fill in the "head" of the tree sequence. In other words, we use a fast backward-time simulator such as `msprime` to simulate the genealogy of the oldest nodes in the simplified tree sequence. To see how this is done, consult the [recapitation tutorial].

## More complex forward-simulations

The next tutorial shows the principles behind more complex simulations, e.g. including regular simplification during the simulation, adding mutations, and adding metadata. It also details several extra tips and tricks we have learned when building forward simulators.

In [None]:
random.integers(100)

In [None]:
d = dict((i, 2) for i in range(10))

In [None]:
ts.simplify(list(d.keys()))

In [None]:
tables.nodes.individual

In [None]:
a

In [None]:
import tskit
import numpy as np

random_seed = 123
random = np.random.default_rng(random_seed)  # A random number generator for general use

diploid_pop_size = 6 
sequence_length = 50_000  # 50 Kb
recombination_rate = 5e-7  # Per base per generation

def add_inheritance_paths(tables, parent_genomes, child_genome):
    "Add paths from parent genomes to the child genome, with crossover recombination"
    L = tables.sequence_length
    num_recombinations = random.poisson(recombination_rate * L)
    breakpoints = random.integers(0, L - 1, size=num_recombinations)
    break_pos, counts = np.unique(breakpoints, return_counts=True)
    crossovers = break_pos[counts % 2 == 1]  # no crossover if e.g. 2 breaks at same pos
    left_positions = np.insert(crossovers, 0, 0)
    right_positions = np.append(crossovers, L)

    inherit_from = random.integers(2)
    for left, right in zip(left_positions, right_positions):
        tables.edges.add_row(
            left, right, parent_genomes[inherit_from], child_genome)
        inherit_from = 1 - inherit_from  # switch to other parent genome

def make_diploid(tables, time, parent_individuals=None) -> tuple[int, tuple[int, int]]:
    individual_id = tables.individuals.add_row(parents=parent_individuals)
    return individual_id, (
        tables.nodes.add_row(time=time, individual=individual_id),
        tables.nodes.add_row(time=time, individual=individual_id),
    )

def new_pop(tables, time, prev_pop) -> dict[int, tuple[int, int]]:
    pop = {}
    prev_individuals = np.array([i for i in prev_pop.keys()], dtype=np.int32)
    for _ in range(diploid_pop_size):
        mother_and_father = random.choice(prev_individuals, 2, replace=True)
        child_id, child_genomes = make_diploid(tables, time, mother_and_father)
        pop[child_id] = child_genomes  # store the genome IDs
        for child_genome, parent_individual in zip(child_genomes, mother_and_father):
            add_inheritance_paths(tables, prev_pop[parent_individual], child_genome)
    return pop

def initialise_pop(tables, time) -> dict:
    return dict(make_diploid(tables, time) for _ in range(diploid_pop_size))

def simple_diploid_sim(generations) -> tskit.TreeSequence:
    tables = tskit.TableCollection(sequence_length)
    tables.time_units = "generations"  # optional, but helpful when plotting

    pop = initialise_pop(tables, generations)
    while generations > 0:
        generations = generations - 1
        pop = new_pop(tables, generations, pop)

    tables.sort()
    return tables.tree_sequence()


In [None]:
def simplify_tables(tables, samples, pop):
    """
    Simplify the tables with respect to the given samples, and return a
    population dictionary in which individual and nodes are remapped.

    This is more involved than might be expected, because the mapping from old to new
    individuals is not currently returned by `simplify`, so we need to make it ourselves
    """
    old_nodes_individual = tables.nodes.individual

    tables.sort()
    node_map = tables.simplify(samples, keep_input_roots=True, record_provenance=False)
    
    # Make the map from old to new individuals
    individual_map = {}
    nodes_individual = tables.nodes.individual
    for ind_id, (node1_id, node2_id) in pop.items():
        old_ind_id = old_nodes_individual[node1_id]
        assert nodes_individual[node_map[node1_id]] == nodes_individual[node_map[node2_id]]
        individual_map[old_ind_id] = nodes_individual[node_map[node1_id]]
 
    return {
        individual_map[ind_id]: (node_map[node1_id], node_map[node2_id])
        for ind_id, (node1_id, node2_id)  in pop.items()
    }


def diploid_sim(diploid_pop_size, generations, simplification_interval=None, show=None):
    tables = tskit.TableCollection(sequence_length)
    simplify_modulo = generations % simplification_interval
    tables.time_units = "generations"  # optional, but helpful when plotting

    pop = initialise_pop(tables, generations, diploid_pop_size)
    while generations > 0:
        generations = generations - 1
        pop = new_pop(tables, generations, pop)
        if generations > 0 and generations % simplification_interval == simplify_modulo:
            current_nodes = [u for genomes in pop.values() for u in genomes]
            pop = simplify_tables(tables, current_nodes, pop)
            if show:
                print("Simplified", generations, "generations before end")

    pop = simplify_tables(tables, [u for genomes in pop.values() for u in genomes], pop)
    if show:
        print("Final simplification")
    return tables.tree_sequence()

ts = diploid_sim(6, 100, simplification_interval=25, show=True)
ts.draw_svg(size=(800, 200))

In [None]:
from tqdm.autonotebook import tqdm
random = np.random.default_rng(42)
ts = diploid_sim(50, 500, simplification_interval=1)
display(ts.draw_svg(size=(2000, 200), time_scale="log_time"))
# Iterate over a range of 
# odd and even simplification
# intervals.
for i in tqdm(np.arange(2, 500, 33)):
    # Make sure each new sim starts with same random seed!
    random = np.random.default_rng(42)
    ts_test = diploid_sim(50, 500, simplification_interval=i)
    if not ts.equals(ts_test, ignore_provenance=True):
        display(ts_test.draw_svg(size=(2000, 200), time_scale="log_time"))
        raise ValueError
