In [1]:
import polars as pl

# Problem Statement
There were a couple of issues with the previous branch algorithm. 1) it was inefficient, performing many summary function calls for each branch addition/removal. 2) it had some corner cases where we could end up adjusting a node with no samples, causing our running total of the summary function to turn into a NaN value, which effectively "poisons" the rest of the LD matrix (the running sum of a NaN is simply a NaN).

The old algorithm updated each stat/node to account for samples that already existed on the node by subtracting thir contribution from the stat. We also had to add back in contributions from samples that remained after removing a sample from a node. In short, we were adding/subtracting a stat from the running total twice for every time we propagated the results up the tree to the parents. This means that we were performing two summary function calls and set two union/subtraction calls every time we needed to crawl up the tree to propagate changes to parent nodes.

We devised a new algorithm for two-locus branch stats that avoids incremental adjustments and natively avoids computing NaN values. The new algorithm defers all stat updates until the end of branch addition/removal. We do this by storing the nodes that were affected by sample addition or removal, removing the total contribution from all nodes with sample removals. At the end of the routine, we add back in the stat contribution from nodes that are still present after updates.

The purpose of this analysis is to gain an understanding of the computational improvements we realize with this new method. The python version of this algorithm isn't optimized for speed, it's written to reflect the C code and to serve as documentation for the inner workings of the C code. So, we won't be measuring runtime, instead we'll be measuring:

* Node sample set unions/subtractions
* Calls to the summary function

The most computationally expensive parts of this algorithm are calls to the summary function and the manipulation of sample set bit arrays. By measuring the relative number of expensive operations that each algorithm performs, we can estimate the time improvements.

# Simulations
Let's simulate some trees with reasonable complexity to benchmark with. 

In [2]:
import msprime

In [3]:
ts = msprime.sim_ancestry(
    samples=25,
    sequence_length=1e2,
    recombination_rate=1e-2,
    ploidy=1,
    population_size=25,
    random_seed=23,
)

The python code still can't handle large trees in a reasonable amount of time, so the simulation must be relatively small

In [4]:
ts

Tree Sequence,Unnamed: 1
Trees,82
Sequence Length,100.0
Time Units,generations
Sample Nodes,25
Total Size,25.8 KiB
Metadata,No Metadata

Table,Rows,Size,Has Metadata
Edges,504,15.8 KiB,
Individuals,25,724 Bytes,
Migrations,0,8 Bytes,
Mutations,0,16 Bytes,
Nodes,150,4.1 KiB,
Populations,1,224 Bytes,✅
Provenances,1,1005 Bytes,
Sites,0,16 Bytes,


A bit of code for benchmarking

In [5]:
import sys
from pathlib import Path

TS_PATH = str(Path("~/repo/tskit/python").expanduser())
if not TS_PATH in sys.path:
    sys.path.insert(0, TS_PATH)

from tests.test_ld_matrix import ld_matrix

In [6]:
import io
import sys


def capture_stdout(func, *args, **kwargs):
    """
    We go through this dance so that we can read print statements
    directly from the code
    """
    buf = io.BytesIO()
    try:
        orig = sys.stdout
        sys.stdout = buf
        stats = func(*args, **kwargs)
    finally:
        sys.stdout = orig
    buf.seek(0)
    return buf, stats


def read(buf, has_header=False):
    return pl.read_csv(
        buf,
        has_header=has_header,
        separator="\t",
        schema={"op": pl.String, "tree": pl.Int32, "p": pl.Int32, "c": pl.Int32},
    )

First, let's benchmarking the old code. 

In [7]:
%%time
buf, ld = capture_stdout(ld_matrix, ts, mode="branch")

CPU times: user 2min 31s, sys: 59.2 ms, total: 2min 31s
Wall time: 2min 31s


In [81]:
# old = read(buf)

In [13]:
old = read("old-algo-bench.tsv.xz", has_header=True)

In [82]:
old['op'].value_counts(sort=True)

op,count
str,u32
"""summary_func""",7588608
"""add_samples_child""",79680
"""add_samples""",61918
"""add_stat""",61918
"""subtract_stat""",57934
"""subtract_samples""",57934


A breakdown of the operations
* `add_samples_child` storing child samples below the added node to subtract as we climb the tree to propagate changes upward.
* `add_samples` performing a union operation on `child_samples` and the samples contained under a node
* `add_stat` calling the summary function and adding to the running total for samples under a node
* `subtract_stat` calling the summary function and removing from the running total for samples under a node
* `subtract_samples` performing a set subtraction, removing `child_samples` from samples contained under a node

Next, let's benchmark the new code. 

In [7]:
%%time
buf, ld = capture_stdout(ld_matrix, ts, mode="branch")

CPU times: user 1min 40s, sys: 0 ns, total: 1min 40s
Wall time: 1min 40s


In [8]:
new = read_buf(buf)

In [9]:
new.write_csv("new-algo-bench.tsv.xz", separator="\t")

In [17]:
new['op'].value_counts(sort=True)

op,count
str,u32
"""summary_func""",5203392
"""subtract_stat""",66566
"""add_stat""",66566
"""add_samples""",61918
"""subtract_samples""",57934


The operation definitions are the same for the new algorithm, but we don't have to track child samples anymore

In [18]:
old['op'].value_counts(sort=True)

op,count
str,u32
"""summary_func""",7588608
"""add_samples_child""",79680
"""add_samples""",61918
"""add_stat""",61918
"""subtract_stat""",57934
"""subtract_samples""",57934


In [11]:
(7588608 - 5203392) / 7588608

0.31431535269709543

Overall, we see around a 30% reduction in the number of summary function calls, and we natively rid ourselves of the NaN issue.

In [79]:
import numpy as np

def count_tree_updates(ts):
    """store counts of 'parent ops' and 'root ops' """
    ops = {t: [0, 0] for t in range(ts.num_trees)}
    parents = -np.ones(ts.num_nodes, dtype=np.int32)
    for t, (_, i, o) in enumerate(ts.edge_diffs()):
        for e in o:
            p = e.parent
            c = e.child
            in_parent = False
            while p != -1:
                if in_parent:
                    ops[t][0] += 1
                else:
                    ops[t][1] += 1
                p = parents[p]
                in_parent = True
            parents[c] = -1
        for e in i:
            p = e.parent
            c = e.child
            parents[c] = p
            in_parent = False
            while p != -1:
                if in_parent:
                    ops[t][0] += 1
                else:
                    ops[t][1] += 1
                p = parents[p]
                in_parent = True
    return ops

In [65]:
counts = count_tree_updates(ts)
counts = np.array([counts[t] for t in range(ts.num_trees)])

This is what I estimate the theretical lower bound of summary function calls to be, we're getting closer with this new algorithm

In [104]:
theoret_summary_func = counts.sum() ** 2 + counts[0].sum() ** 2
theoret_summary_func

4050448

In [106]:
new_summary_func = len(new.select(pl.col('op') == 'summary_func'))
new_summary_func

5456376

In [107]:
new_summary_func - theoret_summary_func

1405928