Skip to content

Commit

Permalink
Debug and fix SBT search (#484)
Browse files Browse the repository at this point in the history
* add a script to check tree search
* change max_n_below -> min_n_below
* test for sbt search bug
* Bump version to 2.0.0a7
* Bump SBT db version
  • Loading branch information
ctb authored and luizirber committed Jun 5, 2018
1 parent c5791cc commit ad9999e
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 33 deletions.
2 changes: 1 addition & 1 deletion sourmash/VERSION
@@ -1 +1 @@
2.0.0a6
2.0.0a7
73 changes: 59 additions & 14 deletions sourmash/sbt.py
Expand Up @@ -305,7 +305,7 @@ def save(self, path, storage=None, sparseness=0.0):
str
full path to the new SBT description
"""
version = 3
version = 4

if path.endswith('.sbt.json'):
path = path[:-9]
Expand Down Expand Up @@ -395,6 +395,7 @@ def load(cls, location, leaf_loader=None, storage=None):
1: cls._load_v1,
2: cls._load_v2,
3: cls._load_v3,
4: cls._load_v4,
}

# @CTB hack: check to make sure khmer Nodegraph supports the
Expand Down Expand Up @@ -524,27 +525,71 @@ def _load_v3(cls, info, leaf_loader, dirname, storage):
# TODO: this might not be true with combine...
tree.next_node = max_node

tree._fill_max_n_below()
tree._fill_min_n_below()

return tree

def _fill_max_n_below(self):
@classmethod
def _load_v4(cls, info, leaf_loader, dirname, storage):
nodes = {int(k): v for (k, v) in info['nodes'].items()}

if not nodes:
raise ValueError("Empty tree!")

sbt_nodes = defaultdict(lambda: None)

klass = STORAGES[info['storage']['backend']]
if info['storage']['backend'] == "FSStorage":
storage = FSStorage(dirname, info['storage']['args']['path'])
elif storage is None:
storage = klass(**info['storage']['args'])

factory = GraphFactory(*info['factory']['args'])

max_node = 0
for k, node in nodes.items():
if node is None:
continue

if 'internal' in node['name']:
node['factory'] = factory
sbt_node = Node.load(node, storage)
else:
sbt_node = leaf_loader(node, storage)

sbt_nodes[k] = sbt_node
max_node = max(max_node, k)

tree = cls(factory, d=info['d'], storage=storage)
tree.nodes = sbt_nodes
tree.missing_nodes = {i for i in range(max_node)
if i not in sbt_nodes}
# TODO: this might not be true with combine...
tree.next_node = max_node

return tree

def _fill_min_n_below(self):
"""\
Propagate the smallest hash size below each node up the tree from
the leaves.
"""
for i, n in self.nodes.items():
if isinstance(n, Leaf):
parent = self.parent(i)
if parent.pos not in self.missing_nodes:
max_n_below = parent.node.metadata.get('max_n_below', 0)
max_n_below = max(len(n.data.minhash.get_mins()),
max_n_below)
parent.node.metadata['max_n_below'] = max_n_below
min_n_below = parent.node.metadata.get('min_n_below', 1)
min_n_below = min(len(n.data.minhash.get_mins()),
min_n_below)
parent.node.metadata['min_n_below'] = min_n_below

current = parent
parent = self.parent(parent.pos)
while parent and parent.pos not in self.missing_nodes:
max_n_below = parent.node.metadata.get('max_n_below', 0)
max_n_below = max(current.node.metadata['max_n_below'],
max_n_below)
parent.node.metadata['max_n_below'] = max_n_below
min_n_below = parent.node.metadata.get('min_n_below', 1)
min_n_below = min(current.node.metadata['min_n_below'],
min_n_below)
parent.node.metadata['min_n_below'] = min_n_below
current = parent
parent = self.parent(parent.pos)

Expand Down Expand Up @@ -699,9 +744,9 @@ def load(info, storage=None):

def update(self, parent):
parent.data.update(self.data)
max_n_below = max(parent.metadata.get('max_n_below', 0),
self.metadata.get('max_n_below'))
parent.metadata['max_n_below'] = max_n_below
min_n_below = min(parent.metadata.get('min_n_below', 1),
self.metadata.get('min_n_below'))
parent.metadata['min_n_below'] = min_n_below


class Leaf(object):
Expand Down
59 changes: 41 additions & 18 deletions sourmash/sbtmh.py
Expand Up @@ -54,10 +54,11 @@ def save(self, path):
def update(self, parent):
for v in self.data.minhash.get_mins():
parent.data.count(v)
max_n_below = parent.metadata.get('max_n_below', 0)
max_n_below = max(len(self.data.minhash.get_mins()),
max_n_below)
parent.metadata['max_n_below'] = max_n_below
min_n_below = parent.metadata.get('min_n_below', 1)
min_n_below = min(len(self.data.minhash.get_mins()),
min_n_below)

parent.metadata['min_n_below'] = min_n_below

@property
def data(self):
Expand All @@ -72,7 +73,39 @@ def data(self, new_data):
self._data = new_data


### Search functionality.

def _max_jaccard_underneath_internal_node(node, hashes):
"""\
calculate the maximum possibility similarity score below
this node, based on the number of matches in 'hashes' at this node,
divided by the smallest minhash size below this node.
This should yield be an upper bound on the Jaccard similarity
for any signature below this point.
"""
if len(hashes) == 0:
return 0.0

# count the maximum number of hash matches beneath this node
matches = sum(1 for value in hashes if node.data.get(value))

# get the size of the smallest collection of hashes below this point
min_n_below = node.metadata.get('min_n_below', -1)

if min_n_below == -1:
raise Exception('cannot do similarity search on this SBT; need to rebuild.')

# max of numerator divided by min of denominator => max Jaccard
max_score = float(matches) / min_n_below

return max_score


def search_minhashes(node, sig, threshold, results=None, downsample=True):
"""\
Default tree search function, searching for best Jaccard similarity.
"""
mins = sig.minhash.get_mins()
score = 0

Expand All @@ -88,13 +121,8 @@ def search_minhashes(node, sig, threshold, results=None, downsample=True):
else:
raise

else: # Node or Leaf, Nodegraph by minhash comparison
if len(mins):
matches = sum(1 for value in mins if node.data.get(value))
max_mins = node.metadata.get('max_n_below', -1)
if max_mins == -1:
raise Exception('cannot do similarity search on this SBT; need to rebuild.')
score = float(matches) / max_mins
else: # Node minhash comparison
score = _max_jaccard_underneath_internal_node(node, mins)

if results is not None:
results[node.name] = score
Expand Down Expand Up @@ -126,18 +154,13 @@ def search(self, node, sig, threshold, results=None):
else:
raise
else: # internal object, not leaf.
if len(mins):
matches = sum(1 for value in mins if node.data.get(value))
max_mins = node.metadata.get('max_n_below', -1)
if max_mins == -1:
raise Exception('cannot do similarity search on this SBT; need to rebuild.')
score = float(matches) / max_mins
score = _max_jaccard_underneath_internal_node(node, mins)

if results is not None:
results[node.name] = score

if score >= threshold:
# have we done better than this? if yes, truncate.
# have we done better than this elsewhere? if yes, truncate.
if score > self.best_match:
# update best if it's a leaf node...
if isinstance(node, SigLeaf):
Expand Down
1 change: 1 addition & 0 deletions tests/test-data/sbt-search-bug/bacteroides.sig

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/test-data/sbt-search-bug/nano.sig

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions tests/test_sourmash.py
Expand Up @@ -1238,6 +1238,26 @@ def test_do_sourmash_sbt_search_output():
assert 'short2.fa' in output


# check against a bug in sbt search triggered by incorrect max Jaccard
# calculation.
def test_do_sourmash_sbt_search_check_bug():
with utils.TempDirectory() as location:
testdata1 = utils.get_test_data('sbt-search-bug/nano.sig')
testdata2 = utils.get_test_data('sbt-search-bug/bacteroides.sig')

status, out, err = utils.runscript('sourmash',
['index', 'zzz', '-k', '31',
testdata1, testdata2],
in_directory=location)

assert os.path.exists(os.path.join(location, 'zzz.sbt.json'))

status, out, err = utils.runscript('sourmash',
['search', testdata1, 'zzz'],
in_directory=location)
assert '1 matches:' in out


def test_do_sourmash_sbt_move_and_search_output():
with utils.TempDirectory() as location:
testdata1 = utils.get_test_data('short.fa')
Expand Down
32 changes: 32 additions & 0 deletions utils/check-tree.py
@@ -0,0 +1,32 @@
#! /usr/bin/env python
"""
Check SBT search by taking every leaf node in a tree and checking to make
sure we can find it.
"""
import argparse
import sourmash
from sourmash import sourmash_args
from sourmash.sbtmh import search_minhashes

THRESHOLD=0.08


def main():
p = argparse.ArgumentParser()
p.add_argument('sbt')
args = p.parse_args()

db = sourmash.load_sbt_index(args.sbt)
threshold = THRESHOLD

for leaf in db.leaves():
query = leaf.data
matches = db.find(search_minhashes, query, threshold)
matches = list([ x.data for x in matches ])
if query not in matches:
print(query)
assert 0


if __name__ == '__main__':
main()

0 comments on commit ad9999e

Please sign in to comment.