Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(indexer): fix vec np.concat
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 2, 2019
1 parent 2ba135d commit 2d6c70f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gnes/indexer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def query(self, keys: np.ndarray, top_k: int, *args, **kwargs) -> List[List[Tupl
def query_and_score(self, q_chunks: List['gnes_pb2.Chunk'], top_k: int, *args, **kwargs) -> List[
'gnes_pb2.Response.QueryResponse.ScoredResult']:
vecs = [blob2array(c.embedding) for c in q_chunks]
queried_results = self.query(np.concatenate(vecs, 0), top_k=top_k)
queried_results = self.query(np.stack(vecs), top_k=top_k)
results = []
for q_chunk, topk_chunks in zip(q_chunks, queried_results):
for _doc_id, _offset, _weight, _relevance in topk_chunks:
Expand Down
8 changes: 3 additions & 5 deletions gnes/service/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,15 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue

for c in d.chunks:
self.logger.info(c.embedding)
vecs += [blob2array(c.embedding) for c in d.chunks]
doc_ids += [d.doc_id] * len(d.chunks)
offsets += [c.offset for c in d.chunks]
weights += [c.weight for c in d.chunks]

self.logger.info('%d %d %d %d' % (len(vecs), len(doc_ids), len(offsets), len(weights)))
self.logger.info(np.concatenate(vecs, 0).shape)
# self.logger.info('%d %d %d %d' % (len(vecs), len(doc_ids), len(offsets), len(weights)))
# self.logger.info(np.stack(vecs).shape)
if vecs:
self._model.add(list(zip(doc_ids, offsets)), np.concatenate(vecs, 0), weights)
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)

def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs],
Expand Down

0 comments on commit 2d6c70f

Please sign in to comment.