Skip to content
Permalink
Browse files

fix(service): fix training logic in encoderservice

  • Loading branch information...
hanxiao committed Aug 29, 2019
1 parent 5828d20 commit c6183960ac224f813f4b7859e6b935f61a8dcf57
Showing with 13 additions and 10 deletions.
  1. +13 −10 gnes/service/encoder.py
@@ -28,12 +28,14 @@ def post_init(self):
self._model = self.load_model(BaseEncoder)
self.train_data = []

def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document']):
def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.Document'],
do_encoding: bool = True):
if not isinstance(docs, list):
docs = [docs]

contents = []
chunks = []
embeds = None

for d in docs:
for c in d.chunks:
@@ -46,14 +48,15 @@ def embed_chunks_in_docs(self, docs: Union[List['gnes_pb2.Document'], 'gnes_pb2.
raise ServiceError(
'chunk content is in type: %s, dont kow how to handle that' % c.WhichOneof('content'))

embeds = self._model.encode(contents)
if len(chunks) != embeds.shape[0]:
raise ServiceError(
'mismatched %d chunks and a %s shape embedding, '
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
return embeds
if do_encoding:
embeds = self._model.encode(contents)
if len(chunks) != embeds.shape[0]:
raise ServiceError(
'mismatched %d chunks and a %s shape embedding, '
'the first dimension must be the same' % (len(chunks), embeds.shape))
for idx, c in enumerate(chunks):
c.embedding.CopyFrom(array2blob(embeds[idx]))
return contents, embeds

@handler.register(gnes_pb2.Request.IndexRequest)
def _handler_index(self, msg: 'gnes_pb2.Message'):
@@ -62,7 +65,7 @@ def _handler_index(self, msg: 'gnes_pb2.Message'):
@handler.register(gnes_pb2.Request.TrainRequest)
def _handler_train(self, msg: 'gnes_pb2.Message'):
if msg.request.train.docs:
_, contents = self.embed_chunks_in_docs(msg.request.train.docs)
contents, _ = self.embed_chunks_in_docs(msg.request.train.docs, do_encoding=False)
self.train_data.extend(contents)
msg.response.train.status = gnes_pb2.Response.PENDING
# raise BlockMessage

0 comments on commit c618396

Please sign in to comment.
You can’t perform that action at this time.