Skip to content

Commit

Permalink
Force ancestors in same epoch to same dep level
Browse files Browse the repository at this point in the history
Nuclear option to fix logic bug (tskit-dev#486 (comment))
  • Loading branch information
hyanwong authored and benjeffery committed Apr 12, 2023
1 parent affd8e5 commit babf038
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import collections
import copy
import heapq
import itertools
import json
import logging
import queue
Expand Down Expand Up @@ -1614,15 +1615,29 @@ def __init__(self, sample_data, ancestor_data, time_units=None, **kwargs):
anc_end = ancestor_data.ancestors_end[:]
anc_time = ancestor_data.ancestors_time[:]
dep_level = np.zeros(self.num_ancestors, dtype=int)
for anc_id, (lft, rgt, t) in enumerate(zip(anc_start, anc_end, anc_time)):
dependencies = np.where(
np.logical_and.reduce((anc_start < rgt, anc_end > lft, anc_time > t))
)[0]
for dep_id in dependencies:
if dep_level[dep_id] >= dep_level[anc_id]:
dep_level[anc_id] = dep_level[dep_id] + 1
for i, level in enumerate(np.unique(dep_level)):
assert i == level # we should have all levels from 0..n
anc_iter = enumerate(zip(anc_start, anc_end, anc_time))
for epoch_time, epoch_grp in itertools.groupby(anc_iter, key=lambda x: x[1][2]):
curr_epoch_start = None
for anc_id, (lft, rgt, t) in epoch_grp:
if curr_epoch_start is None:
curr_epoch_start = anc_id
assert epoch_time == t
# NB the line below is currently quite slow, and should be optimised
prev_ancestors = slice(0, curr_epoch_start)
dependencies = np.where(
np.logical_and(
anc_start[prev_ancestors] < rgt, anc_end[prev_ancestors] > lft
)
)[0]
if len(dependencies) > 0:
dep_level[anc_id] = np.max(dep_level[dependencies]) + 1
# One issue is that overlapping ancestors within the same epoch can be at
# different dep levels => force all ancestors in this epoch to the max level
# Perhaps better to do this in a single pass rather than within this loop?
dep_level[curr_epoch_start : (anc_id + 1)] = np.max(
dep_level[curr_epoch_start : (anc_id + 1)]
)
for level in np.unique(dep_level):
if level > 0: # Only run matching for ancestors that have dependencies
self.ancestors_dependency_level[level] = np.where(dep_level == level)[0]

Expand Down

0 comments on commit babf038

Please sign in to comment.