Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(indexer): fix numpy indexer
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 5, 2019
1 parent 91762ff commit cd53a24
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
31 changes: 18 additions & 13 deletions gnes/indexer/chunk/numpy.py
Expand Up @@ -23,41 +23,46 @@




class NumpyIndexer(BaseChunkIndexer): class NumpyIndexer(BaseChunkIndexer):
"""An exhaustive search indexer using numpy
The distance is computed as L1 distance normalized by the number of dimension
"""


def __init__(self, num_bytes: int = None, *args, **kwargs): def __init__(self, is_binary: bool = False, *args, **kwargs):
super().__init__() super().__init__()
self.num_bytes = num_bytes self._num_dim = None
self._vectors = None # type: np.ndarray self._vectors = None # type: np.ndarray
self._is_binary = is_binary
self._key_info_indexer = ListKeyIndexer() self._key_info_indexer = ListKeyIndexer()


def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args, def add(self, keys: List[Tuple[int, Any]], vectors: np.ndarray, weights: List[float], *args,
**kwargs): **kwargs):
if len(vectors) % len(keys) != 0: if len(vectors) % len(keys) != 0:
raise ValueError('vectors bytes should be divided by doc_ids') raise ValueError('vectors bytes should be divided by doc_ids')


if not self.num_bytes: if not self._num_dim:
self.num_bytes = vectors.shape[1] self._num_dim = vectors.shape[1]
elif self.num_bytes != vectors.shape[1]: elif self._num_dim != vectors.shape[1]:
raise ValueError( raise ValueError(
"vectors' shape [%d, %d] does not match with indexer's dim: %d" % "vectors' shape [%d, %d] does not match with indexer's dim: %d" %
(vectors.shape[0], vectors.shape[1], self.num_bytes)) (vectors.shape[0], vectors.shape[1], self._num_dim))


if self._vectors is not None: if self._vectors is not None:
self._vectors = np.concatenate([self._vectors, vectors], axis=0) self._vectors = np.concatenate([self._vectors, vectors], axis=0)
else: else:
self._vectors = vectors self._vectors = vectors
self._key_info_indexer.add(keys, weights) self._key_info_indexer.add(keys, weights)


def query(self, keys: np.ndarray, top_k: int, *args, **kwargs def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tuple]]:
) -> List[List[Tuple]]: dist = np.abs(np.expand_dims(keys, axis=1) - np.expand_dims(self._vectors, axis=0))
keys = np.expand_dims(keys, axis=1)
dist = keys - np.expand_dims(self._vectors, axis=0) if self._is_binary:
score = np.sum(np.minimum(np.abs(dist), 1), -1) / self.num_bytes dist = np.minimum(dist, 1)

score = np.sum(dist, -1) / self._num_dim


ret = [] ret = []
for ids in score: for ids in score:
rk = sorted(enumerate(ids), key=lambda x: -x[1]) rk = sorted(enumerate(ids), key=lambda x: x[1])[:top_k]
chunk_info = self._key_info_indexer.query([j[0] for j in rk]) chunk_info = self._key_info_indexer.query([j[0] for j in rk])

ret.append([(*r, s) for r, s in zip(chunk_info, [j[1] for j in rk])]) ret.append([(*r, s) for r, s in zip(chunk_info, [j[1] for j in rk])])
return ret return ret
2 changes: 1 addition & 1 deletion gnes/service/base.py
Expand Up @@ -267,7 +267,7 @@ def post_handler(self, msg: 'gnes_pb2.Message', *args, **kwargs):
def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs): def _hook_warn_body_type_change(self, msg: 'gnes_pb2.Message', old_body_type: str, *args, **kwargs):
new_type = msg.WhichOneof('body') new_type = msg.WhichOneof('body')
if new_type != old_body_type: if new_type != old_body_type:
self.logger.warning('message body is changed from %s to %s' % (new_type, old_body_type)) self.logger.warning('message body is changed from %s to %s' % (old_body_type, new_type))


def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs): def _hook_sort_response(self, msg: 'gnes_pb2.Message', *args, **kwargs):
if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results: if 'sorted_response' in self.args and self.args.sorted_response and msg.response.search.topk_results:
Expand Down

0 comments on commit cd53a24

Please sign in to comment.