Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
revert(service): revert encoder service
Browse files Browse the repository at this point in the history
  • Loading branch information
hanhxiao committed Sep 4, 2019
1 parent 35fa3ba commit 1bbc435
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
47 changes: 23 additions & 24 deletions gnes/service/encoder.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


from typing import List, Union from typing import List, Union


from .base import BaseService as BS, MessageHandler, ServiceError from .base import BaseService as BS, MessageHandler
from ..proto import gnes_pb2, array2blob, blob2array from ..proto import gnes_pb2, array2blob, blob2array




Expand All @@ -35,11 +35,13 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
docs = [docs] docs = [docs]


contents = [] contents = []
ids = [] chunks = []
embeds = None


for d in docs: for d in docs:
ids.append(len(d.chunks)) if not d.chunks:
self.logger.warning('document (doc_id=%s) contains no chunks!' % d.doc_id)
continue

for c in d.chunks: for c in d.chunks:
if d.doc_type == gnes_pb2.Document.TEXT: if d.doc_type == gnes_pb2.Document.TEXT:
contents.append(c.text) contents.append(c.text)
Expand All @@ -48,34 +50,32 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
else: else:
self.logger.warning( self.logger.warning(
'chunk content is in type: %s, dont kow how to handle that, ignored' % c.WhichOneof('content')) 'chunk content is in type: %s, dont kow how to handle that, ignored' % c.WhichOneof('content'))
chunks.append(c)


if do_encoding: if do_encoding and contents:
embeds = self._model.encode(contents) try:
if sum(ids) != embeds.shape[0]: embeds = self._model.encode(contents)
raise ServiceError( if len(chunks) != embeds.shape[0]:
'mismatched %d chunks and a %s shape embedding, ' self.logger.error(
'the first dimension must be the same' % (sum(ids), embeds.shape)) 'mismatched %d chunks and a %s shape embedding, '
idx = 0 'the first dimension must be the same' % (len(chunks), embeds.shape))
for d in docs: for idx, c in enumerate(chunks):
for c in d.chunks:
c.embedding.CopyFrom(array2blob(embeds[idx])) c.embedding.CopyFrom(array2blob(embeds[idx]))
idx += 1 except Exception as ex:
self.logger.error(ex, exc_info=True)
self.logger.warning('encoder service throws an exception, '
'the sequel pipeline may not work properly')


return contents, embeds return contents


@handler.register(gnes_pb2.Request.IndexRequest) @handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'): def _handler_index(self, msg: 'gnes_pb2.Message'):
_, embeds = self.embed_chunks_in_docs(msg.request.index.docs) self.embed_chunks_in_docs(msg.request.index.docs)
idx = 0
for d in msg.request.index.docs:
for c in d.chunks:
c.embedding.CopyFrom(array2blob(embeds[idx]))
idx += 1


@handler.register(gnes_pb2.Request.TrainRequest) @handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'): def _handler_train(self, msg: 'gnes_pb2.Message'):
if msg.request.train.docs: if msg.request.train.docs:
contents, _ = self.embed_chunks_in_docs(msg.request.train.docs, do_encoding=False) contents = self.embed_chunks_in_docs(msg.request.train.docs, do_encoding=False)
self.train_data.extend(contents) self.train_data.extend(contents)
msg.response.train.status = gnes_pb2.Response.PENDING msg.response.train.status = gnes_pb2.Response.PENDING
# raise BlockMessage # raise BlockMessage
Expand All @@ -88,5 +88,4 @@ def _handler_train(self, msg: 'gnes_pb2.Message'):


@handler.register(gnes_pb2.Request.QueryRequest) @handler.register(gnes_pb2.Request.QueryRequest)
def _handler_search(self, msg: 'gnes_pb2.Message'): def _handler_search(self, msg: 'gnes_pb2.Message'):
_, embeds = self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False) self.embed_chunks_in_docs(msg.request.search.query, is_input_list=False)
msg.request.search.query.chunk_embeddings.CopyFrom(array2blob(embeds))
4 changes: 2 additions & 2 deletions gnes/service/indexer.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def _handler_chunk_index(self, msg: 'gnes_pb2.Message'):
offsets += [c.offset for c in d.chunks] offsets += [c.offset for c in d.chunks]
weights += [c.weight for c in d.chunks] weights += [c.weight for c in d.chunks]


# self.logger.info('%d %d %d %d' % (len(vecs), len(doc_ids), len(offsets), len(weights)))
# self.logger.info(np.stack(vecs).shape)
if vecs: if vecs:
self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights) self._model.add(list(zip(doc_ids, offsets)), np.stack(vecs), weights)
else:
self.logger.warning('chunks contain no embedded vectors, %the indexer will do nothing')


def _handler_doc_index(self, msg: 'gnes_pb2.Message'): def _handler_doc_index(self, msg: 'gnes_pb2.Message'):
self._model.add([d.doc_id for d in msg.request.index.docs], self._model.add([d.doc_id for d in msg.request.index.docs],
Expand Down

0 comments on commit 1bbc435

Please sign in to comment.