Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
feat(indexer): delay the num_dim spec on first add
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 29, 2019
1 parent cb4e46a commit 946df39
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
19 changes: 17 additions & 2 deletions gnes/indexer/chunk/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@

class AnnoyIndexer(BCI):

def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees=10, *args, **kwargs):
def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_trees: int = 10, *args, **kwargs):
"""
Initialize an AnnoyIndexer
:param num_dim: when set to -1, then num_dim is auto decided on first .add()
:param data_path: index data file managed by the annoy indexer
:param metric:
:param n_trees:
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.num_dim = num_dim
self.data_path = data_path
Expand All @@ -34,7 +43,7 @@ def __init__(self, num_dim: int, data_path: str, metric: str = 'angular', n_tree

def post_init(self):
from annoy import AnnoyIndex
self._index = AnnoyIndex(self.num_dim, self.metric)
self._index = AnnoyIndex(self.num_dim, self.metric) if self.num_dim >= 0 else None
try:
if not os.path.exists(self.data_path):
raise FileNotFoundError('"data_path" is not exist')
Expand All @@ -54,6 +63,12 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
if vectors.dtype != np.float32:
raise ValueError("vectors should be ndarray of float32")

if self._index is None:
from annoy import AnnoyIndex
# means num_dim in unknown during init
self.num_dim = vectors.shape[1]
self._index = AnnoyIndex(self.num_dim, self.metric)

for idx, vec in enumerate(vectors):
self._index.add_item(last_idx + idx, vec)

Expand Down
13 changes: 12 additions & 1 deletion gnes/indexer/chunk/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
class FaissIndexer(BCI):

def __init__(self, num_dim: int, index_key: str, data_path: str, *args, **kwargs):
"""
Initialize an FaissIndexer
:param num_dim: when set to -1, then num_dim is auto decided on first .add()
:param data_path: index data file managed by the faiss indexer
"""
super().__init__(*args, **kwargs)
self.data_path = data_path
self.num_dim = num_dim
Expand All @@ -42,7 +47,7 @@ def post_init(self):
self._faiss_index = faiss.read_index(self.data_path)
except (RuntimeError, FileNotFoundError, IsADirectoryError):
self.logger.warning('fail to load model from %s, will init an empty one' % self.data_path)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key) if self.num_dim > 0 else None

@BCI.update_helper_indexer
def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, **kwargs):
Expand All @@ -52,6 +57,12 @@ def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[fl
if vectors.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')

if self._faiss_index is None:
import faiss
# means num_dim in unknown during init
self.num_dim = vectors.shape[1]
self._faiss_index = faiss.index_factory(self.num_dim, self.index_key)

self._faiss_index.add(vectors)

def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
Expand Down

0 comments on commit 946df39

Please sign in to comment.