### Design

The general idea here is to be able to fit into the `tsk_treeseq_sample_count_stat` framework, while being able to use pairs of windows instead of the usual 1-d windows. I think this should be easy enough to do if we add another parameter to the `tsk_flags_t`. Perhaps `TSK_WINDOW_PAIRS`?

We'd like to be able to compute the LD between two completely arbitrary windows. To do this, we'll want our own `tsk_treeseq_check_windows` function, something like `tsk_treeseq_check_window_pairs`. With this code path, we will remove many of the constraints imposed on windows. We will no longer require:

1. Windows start and end at 0 -> L
1. Window intervals are non-overlapping

To fulfill the final requirement above, we will perform reverse-seeking in the case of overlapping windows. This will have very little impact on the performance of our implementation of two-site statistics.

We will, however still require:

1. Righthand windows are sorted (they will be passed over in a tight loop) -- sorted by the start of the interval
    1. Righthand windows will be first sorted by start position, then by interval size
1. There is at least one window pair
1. Windows are paired (see below)

To use the existing framework, we will pass in a 1-d array of numbers and transform it into two vectors of pairs.

`[1, 2 ,3, 4]` will become `[[1, 2)]`, `[[3, 4)]`  <- NB windows are half open \[left, right\)

What follows is a toy implementation of what I'm talking about, if we think this is a good direction, I can implement this in the C api.

In [1]:
import numpy as np

In [2]:
def print_windows(left, right):
    print('left: ', [(left[i], left[i + 1]) for i in range(0, len(left), 2)])
    print('right: ', [(right[i], right[i + 1]) for i in range(0, len(right), 2)])

def check_windows(windows, num_windows, print_=True):
    assert num_windows % 2 == 0, 'there must be an even number of windows'
    row_len = 2
    midpoint = num_windows // row_len
    left_ivls = windows[:midpoint]
    right_ivls = windows[midpoint:]
    if print_:
        print_windows(left_ivls, right_ivls)

    for i in range(0, midpoint, 2):
        size_left = left_ivls[i + 1] - left_ivls[i]
        size_right = right_ivls[i + 1] - right_ivls[i]
        if size_left <= 0:
            raise ValueError(
                f'left interval invalid, idx={i} : {left_ivls[i]} {left_ivls[i + 1]}'
            )
        elif size_right <= 0:
            raise ValueError(
                f'right interval invalid, idx={i} : {right_ivls[i]} {right_ivls[i + 1]}'
            )

    for i in range(0, midpoint - 2, 2):  # select index of start of each interval
        if right_ivls[i] > right_ivls[i + 2]:
            raise ValueError(
                'right intervals must be sorted by position, '
                f'idx={i}: {right_ivls[i]} {right_ivls[i + 2]}'
            )

        elif right_ivls[i] == right_ivls[i + 2]:
            size_curr = right_ivls[i + 1] - right_ivls[i]
            size_next = right_ivls[i + 2 + 1] - right_ivls[i + 2]

            if size_next - size_curr < 0:
                raise ValueError(
                    'right intervals must be sorted by size, '
                    f'idx={i}: {size_curr} {size_next}'
                )

In [3]:
windows = [1, 2, 1, 2, 1, 2, 2, 3, 9, 15, 1, 2, 2, 4, 5, 6, 2, 3, 9, 15]
check_windows(windows, len(windows))

left:  [(1, 2), (1, 2), (1, 2), (2, 3), (9, 15)]
right:  [(1, 2), (2, 4), (5, 6), (2, 3), (9, 15)]


ValueError: right intervals must be sorted by position, idx=4: 5 2

In [4]:
windows = [1, 2, 1, 2, 1, 2, 2, 3, 9, 15, 2, 1, 2, 4, 5, 6, 2, 3, 9, 15]
check_windows(windows, len(windows))

left:  [(1, 2), (1, 2), (1, 2), (2, 3), (9, 15)]
right:  [(2, 1), (2, 4), (5, 6), (2, 3), (9, 15)]


ValueError: right interval invalid, idx=0 : 2 1

In [5]:
windows = [1, 2, 1, 2, 2, 3, 1, 2, 9, 10, 1, 2, 2, 4, 2, 3, 5, 6, 9, 10]
check_windows(windows, len(windows))

left:  [(1, 2), (1, 2), (2, 3), (1, 2), (9, 10)]
right:  [(1, 2), (2, 4), (2, 3), (5, 6), (9, 10)]


ValueError: right intervals must be sorted by size, idx=2: 2 1

In [6]:
windows = [1, 2, 1, 2, 2, 3, 1, 2, 9, 15, 1, 2, 2, 4, 2, 3, 5, 6, 9, 15]
check_windows(windows, len(windows))

left:  [(1, 2), (1, 2), (2, 3), (1, 2), (9, 15)]
right:  [(1, 2), (2, 4), (2, 3), (5, 6), (9, 15)]


ValueError: right intervals must be sorted by size, idx=2: 2 1

In [7]:
windows = [1, 2, 2, 3, 1, 2, 1, 2, 9, 15, 1, 2, 2, 3, 2, 4, 5, 6, 9, 15]
check_windows(windows, len(windows))

left:  [(1, 2), (2, 3), (1, 2), (1, 2), (9, 15)]
right:  [(1, 2), (2, 3), (2, 4), (5, 6), (9, 15)]


### Application

Now, let's demonstrate how this window format will be used when making comparisons between pairs of sites.

First off, we would compute the first and last window and would restrict our computations within this extent.

Then, we'd perform all pairwise comparisons within the windows we've chosen.

In [8]:
# we'll represent sites purely by their position
# for simplicity's sake, there is one site at every integer position
sites = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [9]:
windows = [2, 4, 3, 6, 3, 8, 0, 4, 4, 8, 9, 12]
check_windows(windows, len(windows))

left:  [(2, 4), (3, 6), (3, 8)]
right:  [(0, 4), (4, 8), (9, 12)]


In [10]:
def compare_sites(windows, num_windows, sites, num_sites):
    row_len = 2
    midpoint = num_windows // row_len
    left_ivls = windows[:midpoint]
    right_ivls = windows[midpoint:]

    # two arrays allocated, storing the site offsets for the left and right intervals
    left_offsets = get_offsets(left_ivls, midpoint, sites)
    right_offsets = get_offsets(right_ivls, midpoint, sites)
    # right_offsets = np.empty(midpoint, np.uint64)
    print('left_offsets', left_offsets)
    print('right_offsets', right_offsets)

    for w in range(0, midpoint, row_len):
        # inner = 0
        for site_1_idx in range(left_ivls[w], left_ivls[w + 1]):
            site_1 = sites[site_1_idx]
            # for site_2_idx in range(right_ivls[w] + inner, right_ivls[w + 1]):
            for site_2_idx in range(right_ivls[w], right_ivls[w + 1]):
                site_2 = sites[site_2_idx]
                print(w, site_1, site_2)
            # inner += 1

def get_offsets(windows, num_windows, sites):
    offsets = np.zeros(num_windows, np.uint64)
    win = 0
    s = 0
    while True:
        start = windows[win]
        stop = windows[win + 1]
        # seek to start
        while sites[s] < start:
            s += 1
        offsets[win] = s
        # seek within range
        while sites[s + 1] < stop:  # TODO: bounds checking?
            s += 1
        offsets[win + 1] = s
        win += 2
        if win == num_windows:
            break
        if stop > windows[win]:
            # seek backward to next window start
            while sites[s] > windows[win]:
                s -= 1
    return offsets

In [11]:
compare_sites(windows, len(windows), sites, len(sites))

left_offsets [2 3 3 5 3 7]
right_offsets [ 0  3  4  7  9 11]
0 2 0
0 2 1
0 2 2
0 2 3
0 3 0
0 3 1
0 3 2
0 3 3
2 3 4
2 3 5
2 3 6
2 3 7
2 4 4
2 4 5
2 4 6
2 4 7
2 5 4
2 5 5
2 5 6
2 5 7
4 3 9
4 3 10
4 3 11
4 4 9
4 4 10
4 4 11
4 5 9
4 5 10
4 5 11
4 6 9
4 6 10
4 6 11
4 7 9
4 7 10
4 7 11
