Skip to content
Permalink
Browse files

feat(indexer): delay the num_dim spec on first add

  • Loading branch information...
hanxiao committed Sep 29, 2019
1 parent cb4e46a commit 946df39bd4c9abe8525bc22011bc06d2f7103cfc
Showing with 29 additions and 3 deletions.
  1. +17 −2 gnes/indexer/chunk/annoy.py
  2. +12 −1 gnes/indexer/chunk/faiss.py
@@ -24,7 +24,16 @@

class AnnoyIndexer(BCI):

def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees=10, *args, **kwargs):
def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees: int = 10, *args, **kwargs):
"""
Initialize an AnnoyIndexer
:param num_dim: when set to -1, then num_dim is auto decided on first .add()
:param data_path: index data file managed by the annoy indexer
:param metric:
:param n_trees:
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.num_dim = num_dim
self.data_path = data_path
@@ -34,7 +43,7 @@ def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_tree

def post_init(self):
from annoy import AnnoyIndex
self._index = AnnoyIndex(self.num_dim, self.metric)
self._index = AnnoyIndex(self.num_dim, self.metric) if self.num_dim >= 0 else None
try:
if not os.path.exists(self.data_path):
raise FileNotFoundError('"data_path" is not exist')
@@ -54,6 +63,12 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
if vectors.dtype != np.float32:
raise ValueError("vectors should be ndarray of float32")

if self._index is None:
from annoy import AnnoyIndex
# means num_dim in unknown during init
self.num_dim = vectors.shape[1]
self._index = AnnoyIndex(self.num_dim, self.metric)

for idx, vec in enumerate(vectors):
self._index.add_item(last_idx + idx, vec)

@@ -26,6 +26,11 @@
class FaissIndexer(BCI):

def __init__(self, num_dim: int, index_key: str, data_path: str, *args, **kwargs):
"""
Initialize an FaissIndexer
:param num_dim: when set to -1, then num_dim is auto decided on first .add()
:param data_path: index data file managed by the faiss indexer
"""
super().__init__(*args, **kwargs)
self.data_path = data_path
self.num_dim = num_dim
@@ -42,7 +47,7 @@ def post_init(self):
self._faiss_index = faiss.read_index(self.data_path)
except (RuntimeError, FileNotFoundError, IsADirectoryError):
self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key) if self.num_dim > 0 else None

@BCI.update_helper_indexer
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
@@ -52,6 +57,12 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
if vectors.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')

if self._faiss_index is None:
import faiss
# means num_dim in unknown during init
self.num_dim = vectors.shape[1]
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)

self._faiss_index.add(vectors)

def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:

0 comments on commit 946df39

Please sign in to comment.
You can’t perform that action at this time.