# Experimenting with parallel computation using tskit

* Kevin Thornton
* 7 December, 2022

# Outline

* Improving tree access
* One possible path to simplification using threads.
  * This is the simplest/most obvious method.

# Technicalities

* I am only talking about `tskit-c`
* I will show pseudocode in Python
* I did everything in rust, using `tskit-rust`.
* Yes, that's confusing...

# Background

* Multithreaded programming is hard.
* tskit **should probably never bother with it**.
  * There are lots of reasons for it.
* tskit should:
  * think about how to alleviate barriers to parallelism
  * Document what safe access to data structures looks like.

# Background

* Two threads can never safely modify the same data structure withouth having a way to queue up requests (locks/mutexes).
* With queued up requests, you can get into trouble where all your time is spent waiting for threads to get access.
* It is often more useful to:
  * Identify independent operations that can be done.
  * Apply a thread to work on each "chunk" of work.
  * Aggregate the results.

# Background

* Starting a thread is surprisingly expensive.
* You usually:
  * Need a LOT of work before threads are worth using.
  * Need "tread pools" with some bells and whistles.
    * work-stealing, etc..
* Alternating multi-/single- threaded steps hurts efficiency.

# Tree access

* `int tsk_tree_seek(tsk_tree_t *self, double position, tsk_flags_t options);`
* Does a linear time search from front/back based on if position < or > L/2.
* Work flows wanting to start from arbitrary trees become quadratic time.

# Indexing trees

In [8]:
from typing import List
class TreeIndex:
    insertion: List[int]
    removal: List[int]
    left: List[float]

* Each list is `num_trees` long.
* The integers are the tree index for an edge insertion or removal.
* The float is the left coordinate of a tree.
* The lists are populated by the "usual" loop over edges, processing removals and insertions.

# The "usual" loop (snippet)

 ```rust
     while j < num_edges || tree_left < sequence_length {
        insertion.push(j);
        removal.push(k);
        left.push(OrderedFloat(tree_left));
        while k < num_edges && edge_right[edge_removal_order[k] as usize] == tree_left {
            k += 1;
        }
        while j < num_edges && edge_left[edge_insertion_order[j] as usize] == tree_left {
            j += 1;
        }
        tree_right = sequence_length;
        if j < num_edges {
            tree_right = if tree_right < edge_left[edge_insertion_order[j] as usize] {
                tree_right
            } else {
                edge_left[edge_insertion_order[j] as usize]
            };
        }
        if k < num_edges {
            tree_right = if tree_right < edge_right[edge_removal_order[k] as usize] {
                tree_right
            } else {
                edge_right[edge_removal_order[k] as usize]
            };
        }
        tree_left = tree_right;
    }
```

# Uh oh -- things aren't working

# Parallel tree sequence recording

* Imagine we split our sim into `k` tree sequences.
* Each is `w = 1/k` of the genome.
* For a new edge (or site), is is trivial to figure out which tree sequence(s) it should be added to.
  * Divide position(s) by `w` and truncate to integer.
* New births (nodes) get recorded into all `k` tree sequences.

# Why can't this "just work"?

* Node remapping prevents this scheme from working right out of the box.
* For each of the `k` tree sequences, the same input node may/will get mapped to **different** output nodes.
* You therefore lose the node identity across your set of tree sequences.
* (This is really bad!)

# One solution

* Prevent nodes from being remapped!
* [2619](https://github.com/tskit-dev/tskit/pull/2619) implements this, but is not merged yet.

# Why is this a solution?

* Output id == input id, preserving identity left-to-right.  Yay!
* But we gets lots of extinct nodes sticking around. Boo!

For a given tree sequence, it is easy to figure out which nodes are actually still valid:

```python
used = [0]*len(treeseq.num_nodes)
for e in treeseq.edges:
    used[e.parent] = 1
    used[e.child] = 1
```

Aggregating `used` over all tree sequences tells us about "globally" extinct nodes.

# Why does this matter?

* We know the index of all nodes that are simplifed out of all `k` tree sequences.
* So we can figure out in `O(1)` time what they are (b/c we've aggregated that info)...
* ...so we can over-write existing node data w/info for new births.

(Remember -- this is all in the C API where we can touch these raw arrays.)

# Get to the point already...

We can now "easily":

* Record edges to the right subset of tree sequences.
* Likewise new sites.
* We no longer lose node identity once `2619` is merged.

Therefore:

* Simplifying the `k` tree sequences using threads "just works":
  * the `k` trees are fully independent data structures.
  * We need a single-thread job to aggregate the globally extinct node info.
  * Need a final "collect & simplify" step to get everything into one tree sequence.

# Is it really that easy?

It is a "one liner" in rust (using [rayon](https://docs.rs/rayon/1.6.0/rayon/)):

```rust
    // vector of table collections
    tables
        .par_iter_mut()
        .map(|tc| {
            // Records the new nodes to each tree seq and then sorts & simplifies
            simplify_details(flags, alive, &samples, tc, new_data)
        })
        // Nothing gets returned
        .collect::<()>();

```

# Running simulations

* Constant size WF model.
* Number of crossovers per birth is Poisson.
* Nothing interesting happening -- just record, simplify, repeat.

## Details

* Simplify every 100 generations
* Release builds (important!)
* `KEEP_INPUT_ROOTS` is on, just "for fun".
* AMD 5900x w/64GB DDR4 3600 memory.
* AMD 5950x w/128GB DDR4 3600 memory.
* Entire edge table sorted each time. (Lazy...)

# Performance: 1 thread vs 6

![](benchmark.png)

# Performance: relative change due to 6 threads

![](benchmark_relative.png)

# Performance: vary mean no. crossovers (inset), no. threads

The hi-crossover simulation for N=5e5 crashed on my machine with 128GB.
This machine has 16 physical cores, 32 threads.

![](benchmark_threads.png)

# Thoughts

* 4x speedup is pretty good!
* As always, naive use of threads **slows your work down**
  * numpy installed through conda is a prime example of this...
* Threading efficiency is about 75%. (Not an unusual number for start/stop work flows.)
* The sims with the greatest payoff took at least 30GB of RAM!

# GPU?

Completely ignoring the software architecture issues:

* The most payoff was for sims taking loads of memory.
* Gaming GPU with 24GB of VRAM cost at least $1500 US.
* Data center GPU cost about 10X more.
  * These don't burn a hole in our desk...