Skip to content

Commit

Permalink
feat(indexer): add annoy search_k parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM committed Jul 10, 2020
1 parent 714604e commit 9f75f59
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
9 changes: 5 additions & 4 deletions jina/executors/indexers/vector/annoy.py
Expand Up @@ -18,18 +18,21 @@ class AnnoyIndexer(BaseNumpyIndexer):
Annoy package dependency is only required at the query time.
"""

def __init__(self, metric: str = 'euclidean', n_trees: int = 10, *args, **kwargs):
def __init__(self, metric: str = 'euclidean', n_trees: int = 10, search_k: int = -1, *args, **kwargs):
"""
Initialize an AnnoyIndexer
:param metric: Metric can be "angular", "euclidean", "manhattan", "hamming", or "dot"
:param n_trees: builds a forest of n_trees trees. More trees gives higher precision when querying.
:param search_k: At query time annoy will inspect up to search_k nodes which defaults to
n_trees * k if not provided (set to -1)
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.metric = metric
self.n_trees = n_trees
self.search_k = search_k

def build_advanced_index(self, vecs: 'np.ndarray'):
from annoy import AnnoyIndex
Expand All @@ -41,12 +44,10 @@ def build_advanced_index(self, vecs: 'np.ndarray'):
return _index

def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
# if keys.dtype != np.float32:
# raise ValueError('vectors should be ndarray of float32')
all_idx = []
all_dist = []
for k in keys:
ret, dist = self.query_handler.get_nns_by_vector(k, top_k, include_distances=True)
ret, dist = self.query_handler.get_nns_by_vector(k, top_k, self.search_k, include_distances=True)
all_idx.append(self.int2ext_key[ret])
all_dist.append(dist)
return np.array(all_idx), np.array(all_dist)
17 changes: 16 additions & 1 deletion tests/executors/indexers/vector/test_annoy.py
Expand Up @@ -64,7 +64,6 @@ def test_annoy_indexer(self):
self.assertTrue(os.path.exists(a.index_abspath))
index_abspath = a.index_abspath
save_abspath = a.save_abspath
# a.query(np.array(np.random.random([10, 5]), dtype=np.float32), top_k=4)

with BaseIndexer.load(save_abspath) as b:
idx, dist = b.query(query, top_k=4)
Expand All @@ -79,6 +78,22 @@ def test_annoy_indexer(self):

self.add_tmpfile(index_abspath, save_abspath)

def test_annoy_indexer_with_no_search_k(self):
with AnnoyIndexer(index_filename='annoy.test.gz', search_k=0) as a:
a.add(vec_idx, vec)
a.save()
self.assertTrue(os.path.exists(a.index_abspath))
index_abspath = a.index_abspath
save_abspath = a.save_abspath

with BaseIndexer.load(save_abspath) as b:
idx, dist = b.query(query, top_k=4)
# search_k is 0, so no tree is searched for
self.assertEqual(idx.shape, dist.shape)
self.assertEqual(idx.shape, (10, 0))

self.add_tmpfile(index_abspath, save_abspath)


if __name__ == '__main__':
unittest.main()

0 comments on commit 9f75f59

Please sign in to comment.